├── .gitignore ├── CITATION.cff ├── CNAME ├── LICENSE.md ├── README.md ├── __init__.py ├── data ├── greetings-test.txt ├── greetings-train.txt ├── greetings │ ├── greetings-test.txt │ ├── greetings-train.txt │ ├── greetings.py │ ├── greetings.txt │ ├── greetings_labeled.tsv │ └── word-level-vocab.json ├── simple-test.txt ├── simple-train.txt └── simple │ ├── merges.txt │ ├── simple.txt │ ├── simple_labeled.tsv │ ├── vocab.json │ └── word-level-vocab.json ├── docs ├── CNAME ├── controllable.md ├── imgs │ ├── greetings_training_finished.png │ └── greetings_training_loop.png ├── old_experiments.md └── training_on_your_own_dataset.md ├── minimal-text-diffusion.gif ├── requirements.txt ├── scripts ├── install.sh ├── run_train.sh └── text_sample.sh └── src ├── __init__.py ├── controllable ├── classifier.py ├── controllable_text_sample.py └── langevin.py ├── modeling ├── __init__.py ├── diffusion │ ├── __init__.py │ ├── gaussian_diffusion.py │ ├── losses.py │ ├── nn.py │ ├── resample.py │ ├── respace.py │ └── rounding.py └── predictor │ └── transformer_model.py ├── train_infer ├── factory_methods.py ├── text_sample.py ├── train.py └── train_loop.py └── utils ├── args_utils.py ├── custom_tokenizer.py ├── data_utils_sentencepiece.py ├── dist_util.py ├── eval_ppl.py ├── fp16_util.py ├── logger.py ├── show_sampling_progress.py └── test_util.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Logging 2 | 3 | wandb/ 4 | checkpoints/ 5 | generation_outputs 6 | diff_models 7 | data/ 8 | SWEEP 9 | ipynb/ 10 | scripts/ 11 | *ckpt* 12 | *checkpoints* 13 | nohup.out 14 | 15 | # Byte-compiled / optimized / DLL files 16 | __pycache__/ 17 | *.py[cod] 18 | *$py.class 19 | 20 | # C extensions 21 | *.so 22 | 23 | # Distribution / packaging 24 | .Python 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | wheels/ 37 | pip-wheel-metadata/ 38 | share/python-wheels/ 39 | *.egg-info/ 40 | .installed.cfg 41 | *.egg 42 | MANIFEST 43 | 44 | # PyInstaller 45 | # Usually these files are written by a python script from a template 46 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 47 | *.manifest 48 | *.spec 49 | 50 | # Installer logs 51 | pip-log.txt 52 | pip-delete-this-directory.txt 53 | 54 | # Unit test / coverage reports 55 | htmlcov/ 56 | .tox/ 57 | .nox/ 58 | .coverage 59 | .coverage.* 60 | .cache 61 | nosetests.xml 62 | coverage.xml 63 | *.cover 64 | *.py,cover 65 | .hypothesis/ 66 | .pytest_cache/ 67 | 68 | # Translations 69 | *.mo 70 | *.pot 71 | 72 | # Django stuff: 73 | *.log 74 | local_settings.py 75 | db.sqlite3 76 | db.sqlite3-journal 77 | 78 | # Flask stuff: 79 | instance/ 80 | .webassets-cache 81 | 82 | # Scrapy stuff: 83 | .scrapy 84 | 85 | # Sphinx documentation 86 | docs/_build/ 87 | 88 | # PyBuilder 89 | target/ 90 | 91 | # Jupyter Notebook 92 | .ipynb_checkpoints 93 | 94 | # IPython 95 | profile_default/ 96 | ipython_config.py 97 | 98 | # pyenv 99 | .python-version 100 | 101 | # pipenv 102 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 103 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 104 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 105 | # install all needed dependencies. 106 | #Pipfile.lock 107 | 108 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 109 | __pypackages__/ 110 | 111 | # Celery stuff 112 | celerybeat-schedule 113 | celerybeat.pid 114 | 115 | # SageMath parsed files 116 | *.sage.py 117 | 118 | # Environments 119 | .env 120 | .venv 121 | env/ 122 | venv/ 123 | ENV/ 124 | env.bak/ 125 | venv.bak/ 126 | 127 | # Spyder project settings 128 | .spyderproject 129 | .spyproject 130 | 131 | # Rope project settings 132 | .ropeproject 133 | 134 | # mkdocs documentation 135 | /site 136 | 137 | # mypy 138 | .mypy_cache/ 139 | .dmypy.json 140 | dmypy.json 141 | 142 | # Pyre type checker 143 | .pyre/ 144 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | type: software 3 | id: "minimal_text_diffusion2022" 4 | authors: 5 | - family-names: Madaan 6 | given-names: Aman 7 | doi: 10.5281/zenodo.7374939 8 | month: 12 9 | title: "Minimal text diffusion" 10 | url: https://github.com/madaan/minimal-text-diffusion 11 | version: 0.1 12 | year: 2022 13 | -------------------------------------------------------------------------------- /CNAME: -------------------------------------------------------------------------------- 1 | diffusion.textgen.info -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Aman Madaan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Minimal text diffusion 2 | 3 | 4 | 5 | _A minimal implementation of diffusion models of text: learns a diffusion model of a given text corpus, allowing to generate text samples from the learned model._ 6 | 7 | 8 | 9 | ---- 10 | 11 | 12 | | ![diffusion](./minimal-text-diffusion.gif) | 13 | |:--:| 14 | | Diffusion in action: a DDPM model gradually denoising random text _`hotnutggy pi greentedsty rawyaented`_ to _`the white eggplant is dried`_ and _`mac clement star fe honey spin theapple purpleip`_ to _`the brown radicchio is sour`_| 15 | 16 | 17 | ---- 18 | 19 | This repo has been refactored by taking a large amount of code from https://github.com/XiangLi1999/Diffusion-LM (which includes some code from: https://github.com/openai/glide-text2im), thanks to the authors for their work! 20 | 21 | The main idea was to retain _just enough code_ to allow training a simple diffusion model and generating samples, remove image-related terms, and make it easier to use. 22 | 23 | I've included an extremely simple corpus (`data/simple-{train,test}.txt`) I used for quick iterations and testing. 24 | 25 | --- 26 | 27 | 28 | ## Table of Contents 29 | 30 | - [Minimal text diffusion](#minimal-text-diffusion) 31 | * [Table of Contents](#table-of-contents) 32 | * [Getting started](#getting-started) 33 | + [Setup](#setup) 34 | + [Preparing dataset](#preparing-dataset) 35 | + [Training](#training) 36 | + [Inference](#inference) 37 | * [Training from scratch on the greetings dataset](#training-from-scratch-on-the-greetings-dataset) 38 | * [Experiments with using pre-trained models and embeddings](#experiments-with-using-pre-trained-models-and-embeddings) 39 | * [Controllable Generation](#controllable-generation) 40 | * [Gory details](#gory-details) 41 | + [Training](#training-1) 42 | + [Evolving input](#evolving-input) 43 | + [Sampling](#sampling) 44 | * [TODO](#todo) 45 | + [Opportunities for further minimization](#opportunities-for-further-minimization) 46 | * [Acknowledgements](#acknowledgements) 47 | * [License](#license) 48 | 49 | --- 50 | 51 | ## Getting started 52 | 53 | ### Setup 54 | 55 | - Install the requirements: `pip install -r requirements.txt` 56 | 57 | - Some of the dependencies might be easier to install via conda: 58 | ```sh 59 | conda install mpi4py 60 | conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch 61 | ``` 62 | 63 | ### Preparing dataset 64 | 65 | - We will use `data/simple.txt` as a running example. To begin, we need to create a tokenizer over the dataset. I found that word-level tokenization works best, but the implementation in `src/utils/custom_tokenizer` includes options to create BPE tokenizer. 66 | 67 | 68 | ```sh 69 | python src/utils/custom_tokenizer.py train-word-level data/simple/simple.txt 70 | ``` 71 | 72 | 73 | ### Training 74 | 75 | - To train a model, run `scripts/train.sh`. By default, this will train a model on the `simple` corpus. However, you can change this to any text file using `--train_data` argument. Note that you may have to increase the sequence length (`--seq_len`) if your corpus is longer than the simple corpus. The other default arguments are set to match the best setting I found for the simple corpus (see discussion below). 76 | 77 | - Once training finishes, the model will be saved in `ckpts/simple`. You can then use this model to generate samples. 78 | 79 | - The checkpoint can also be downloaded from [here](https://drive.google.com/drive/folders/1UXx1HJVeWdAjlTNTiCydnCCHD431Q4yh?usp=sharing). 80 | 81 | 82 | 83 | ### Inference 84 | 85 | - To generate samples, run: 86 | 87 | ```sh 88 | bash scripts/text_sample.sh ckpts/simple/ema_0.9999_025000.pt 2000 10 89 | ``` 90 | - Here: 91 | * `ckpts/simple/ema_0.9999_025000.pt` is the path to the checkpoint 92 | * `2000` is the number of diffusion steps. 93 | * `10` is the number of samples to generate. 94 | 95 | - By default, this will generate 10 samples from the model trained on the simple corpus. Changing `SEED` in `scripts/text_sample.sh` will generate different samples. You can also change the number of samples generated by changing the `NUM_SAMPLES` argument. 96 | 97 | - During inference (denoising), the intermediate sentences will be printed to the console. 98 | 99 | - The generated samples will be saved in `ckpt/simple/`. 100 | 101 | - Complete set of outputs are available [here](https://drive.google.com/drive/folders/1UXx1HJVeWdAjlTNTiCydnCCHD431Q4yh?usp=sharing). 102 | 103 | 104 | ## Training from scratch on the greetings dataset 105 | 106 | - I've added another trainign from scratch tutorial here: [greetings](./docs/training_on_your_own_dataset.md). 107 | 108 | ## Experiments with using pre-trained models and embeddings 109 | 110 | - Update 10/24: The most fluent/realistic outputs are obtained using i) word-level tokenization, ii) initializing a model from scratch, and iii) fine-tuning the embeddings. This is the default in `run_train.sh` now. Please see [docs/old_experiments.md](docs/old_experiments.md) for details on the experiments I ran before this update. 111 | 112 | ## Controllable Generation 113 | 114 | - The diffusion model can be combined with a classifier to perform classifier-guided diffusion. Please see details in [docs/controllable.md](docs/controllable.md). 115 | 116 | 117 | ## Gory details 118 | 119 | 120 | * Below are my rough notes on how the code works. [TODO] Clean this up and add more details. 121 | 122 | ### Training 123 | 124 | * Input text is embedded. This is the mean of `x_start_mean`. Some noise is added to `x_start_mean` to get `x_start`. 125 | 126 | * Using random `t`, a noisy version of the input is created from q(x_t | x_0). This is simply x_t = x_0 * sqrt(1 - \beta_t) + \epsilon_t * sqrt(\beta_t). The function used for this is `q_sample`. Any operation that involves going ahead in the diffusion process is carried out by functions that start with `q_`. 127 | 128 | * `x_t` is fed to the transformer model. Then, the transformer model is trained to generate an approximation of `x_start` given `x_t` and `t` (the timestep). Specifically, the embedded text is passed through a BERT encoder and downsampled. The size of the output embeddings and input embeddings is the same for this reason. Maybe this is the trick mentioned in the paper where they want to tie each weight with the `x_start` term, but I'm not sure how it's different from DDIM. 129 | 130 | * The loss has several terms: 131 | 1) Difference between the actual `x_start` and the output of the transformer model. This is the MSE loss. 132 | 2) Mean of the `xT` should be close to zero. This is the `tT_loss` term. It is obtained by calling `q_mean_variance` for the t=T. `q_mean_variance` is like `q_sample,` but it returns the mean and variance of the distribution `x_t | x0` instead of a sample. 133 | 134 | 3) Decoder NLL loss. This is the `decoder_nll` term. It is obtained by calling `token_discrete_loss`. `token_discrete_loss` calls `get_logits`, which in turns uses the embeddings to convert to logits. The logits are then used to calculate the NLL loss. Essentially this is how the embeddings are trained. 135 | 136 | ```py 137 | 138 | def get_logits(self, hidden_repr): 139 | return self.lm_head(hidden_repr) 140 | ``` 141 | 142 | 143 | - One thing to note is that: 144 | 145 | ```py 146 | print(model.lm_head.weight == model.word_embedding.weight) 147 | print(model.lm_head.weight.shape, model.word_embedding.weight.shape) 148 | ``` 149 | 150 | They are identical! Intuitively, the model is trained to predict the embedded input. Thus, having a linear layer with the weights from `word_embedding` is like doing a nearest neighbor search. While initializing, the weights are assigned to `lm_head` from `word_embedding` under `torch.no_grad()`, so that the gradients are not computed for `lm_head`. 151 | 152 | 153 | ### Evolving input 154 | 155 | - Note that the embeddings are *trained*. Although initial embeddings are passed in training losses, they are not used. Instead, the `get_embeds` method is used to get the embeddings. This is because the embeddings are trained to predict the input text. Thus, the embeddings are not the same as the input embeddings. 156 | 157 | 158 | ### Sampling 159 | 160 | * `p_mean_variance`: returns the distribution `p(x_{t-1} | x_t)` (the mean and variance). In addition, returns a prediction for the initial `x_0`. 161 | 162 | * `q_posterior_mean_variance`: returns the distribution `q(x_{t-1} | x_t, x_0)`. 163 | 164 | * Additionally, recall that our model is trained to predict `x_start` given `x_t` and `t`. 165 | 166 | - Putting these together, we can sample from the model. The sampling is done in the following way: 167 | 168 | 1. Starting with noise `xT`, a noisy `x_start` is first generated using the model. 169 | 170 | 2. The `xT` and `x_start` are used to generate `x_{T-1}` using `q_posterior_mean_variance` (`x_{T-1} ~ q(x_{T-1} | x_T, x_start)`). 171 | 172 | The process is repeated until `x_0` is generated. 173 | 174 | --- 175 | 176 | 177 | ## TODO 178 | 179 | - [ ] Add more details to the inner workings section. 180 | - [ ] Add classifier-guided sampling. 181 | - [ ] Add more experiments. 182 | 183 | 184 | ### Opportunities for further minimization 185 | 186 | - [ ] `logger.py` can be completely deleted. 187 | - [ ] `args.py` and `factory_methods.py` can be combined. 188 | 189 | 190 | 191 | --- 192 | 193 | ## Acknowledgements 194 | 195 | - Thanks to the team behind [Diffusion-LM Improves Controllable Text Generation](http://arxiv.org/abs/2205.14217) for releasing their code, which I used as a starting point. 196 | - Thanks to the authors of several open-source implementations of DDPM/DDIM, helpful blogs, and videos. Some of the ones I bookmarked are: 197 | 198 | | **Title** | **Url** | 199 | |:---:|:---:| 200 | | Tutorial on Denoising Diffusion-based Generative Modeling: Foundations and Applications | https://www.youtube.com/watch?v=cS6JQpEY9cs | 201 | | Composable Text Control Operations in Latent Space with Ordinary Differential Equations | http://arxiv.org/abs/2208.00638 | 202 | | Diffusion-LM Improves Controllable Text Generation | http://arxiv.org/abs/2205.14217 | 203 | | Step-unrolled Denoising Autoencoders for Text Generation | http://arxiv.org/abs/2112.06749 | 204 | | Latent Diffusion Energy-Based Model for Interpretable Text Modeling | http://arxiv.org/abs/2206.05895 | 205 | | Parti - Scaling Autoregressive Models for Content-Rich Text-to-Image Generation (Paper Explained) | https://www.youtube.com/watch?v=qS-iYnp00uc | 206 | | Deep Unsupervised Learning using Nonequilibrium Thermodynamics | http://arxiv.org/abs/1503.03585 | 207 | | lucidrains/denoising-diffusion-pytorch | https://github.com/lucidrains/denoising-diffusion-pytorch | 208 | | Guidance: a cheat code for diffusion models | https://benanne.github.io/2022/05/26/guidance.html | 209 | | Cold Diffusion: Inverting Arbitrary Image Transforms Without Noise | http://arxiv.org/abs/2208.09392 | 210 | | Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning | http://arxiv.org/abs/2208.04202 | 211 | | Diffusion Maps for Textual Network Embedding | https://proceedings.neurips.cc/paper/2018/hash/211a7a84d3d5ce4d80347da11e0c85ed-Abstract.html | 212 | | Diffusion-LM Improves Controllable Text Generation | https://github.com/XiangLi1999/Diffusion-LM | 213 | | Denoising Diffusion Probabilistic Models | http://arxiv.org/abs/2006.11239 | 214 | | Variational Diffusion Models | http://arxiv.org/abs/2107.00630 | 215 | | Elucidating the Design Space of Diffusion-Based Generative Models | http://arxiv.org/abs/2206.00364 | 216 | | Diffusion Models Beat GANs on Image Synthesis | http://arxiv.org/abs/2105.05233 | 217 | | guided-diffusion | https://github.com/openai/guided-diffusion | 218 | | Minimal implementation of diffusion models ⚛ | https://github.com/VSehwag/minimal-diffusion | 219 | | minDiffusion | https://github.com/cloneofsimo/minDiffusion | 220 | | What are Diffusion Models? | https://lilianweng.github.io/posts/2021-07-11-diffusion-models/ | 221 | | High-Resolution Image Synthesis with Latent Diffusion Models | http://arxiv.org/abs/2112.10752 | 222 | | Generative Modeling by Estimating Gradients of the Data Distribution \| Yang Song | https://yang-song.net/blog/2021/score/ | 223 | | GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models | http://arxiv.org/abs/2112.10741 | 224 | | Blended Diffusion for Text-driven Editing of Natural Images | http://arxiv.org/abs/2111.14818 | 225 | | Generative Modeling by Estimating Gradients of the Data Distribution | http://arxiv.org/abs/1907.05600 | 226 | | Diffusion Schr\"odinger Bridge with Applications to Score-Based Generative Modeling | http://arxiv.org/abs/2106.01357 | 227 | | Score-based Generative Modeling in Latent Space | http://arxiv.org/abs/2106.05931 | 228 | | A Connection Between Score Matching and Denoising Autoencoders | https://direct.mit.edu/neco/article/23/7/1661-1674/7677 | 229 | | Maximum Likelihood Training of Score-Based Diffusion Models | http://arxiv.org/abs/2101.09258 | 230 | 231 | 232 | ## License 233 | 234 | - MIT License 235 | 236 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/minimal-text-diffusion/9303ffd481a2f647da24c6053e4dec44fd086a8d/__init__.py -------------------------------------------------------------------------------- /data/greetings-test.txt: -------------------------------------------------------------------------------- 1 | it's nice seeing you all! 2 | I'm so thrilled to see you! 3 | I'm excited to see you all! 4 | what's up? 5 | how is your family doing 6 | i hope life has been treating you good 7 | i hope everything is going well for you 8 | It's great to see you again! 9 | i hope life has been treating you good. 10 | hope all is well with your family. 11 | what's up? 12 | good seeing you again 13 | great to see you again! 14 | I've heard a lot about you and your family. 15 | welcome! 16 | it's great to meet someone like yourself! 17 | it's great to see you! 18 | how's it going? 19 | it's good to see you! 20 | I'm glad to see you 21 | how are things going with you and your family 22 | it's nice to see you too! 23 | it's nice meeting you 24 | nice to see you too! 25 | it's good to see you too! 26 | I'm happy to see you all! 27 | long time no see! 28 | hi there 29 | good morning! 30 | I've heard a lot about you and your family. 31 | good seeing everyone 32 | do you have any plans for today 33 | nice seeing you all again! 34 | I'm so excited to see you! 35 | I hope all is good with your family. 36 | dropped in because I was thinking of you. 37 | it was so nice meeting you! 38 | I've heard a lot of good things about you and your family. 39 | good morning 40 | I'm happy I got the chance to finally meet you in person. 41 | nice to meet you! 42 | "hi there, how's it going? hope things are good!" 43 | good to see you again! 44 | good evening! 45 | nice seeing you! 46 | I'm happy to see you! 47 | good seeing you all again! 48 | nice to meet you too! 49 | dropping in to let you know i'm thinking of you. 50 | how are things going? 51 | i hope life has been treating you right 52 | I'm happy to see you all! 53 | just wanted to say hi 54 | how is everyone doing 55 | it's nice to see you! 56 | I've heard a lot of good things abou you. 57 | it's great to see you as well! 58 | how's your day? 59 | i hope all is well with you and yours 60 | I'm glad to see you all! 61 | "hello there, how's it going? hope things are good!" 62 | hey buddy! 63 | it's been a long time! 64 | great seeing you 65 | "hello there, how are you? hope things are good!" 66 | it's great to see you again! 67 | it's nice seeing you all again! 68 | greetings! how are you doing today? 69 | I'm glad to see you all! 70 | how have you been 71 | I'm excited to see you! 72 | it's good to see you all! 73 | Good to see you all! 74 | it's good to see you as well! 75 | I'm thrilled to see you all! 76 | it's great meeting you 77 | it's great to see you again! 78 | it's good to see you again! 79 | how is everyone doing? 80 | nice seeing you again! 81 | nice to see you all again! 82 | it's nice to see you all! 83 | I have faith that things are going smoothly for you and yours 84 | how is everything? 85 | it's great seeing you all! 86 | I'm thrilled to see you! 87 | how's your day going? 88 | nice to see you again! 89 | I just wanted to say hi 90 | what's up? 91 | I've heard a lot about you. 92 | hello 93 | It's great seeing you again! 94 | I'm glad to see you! 95 | it's so nice to meet someone like yourself! 96 | hi friend! how are things going with your family? 97 | hope all is good in the hood! 98 | I'm glad i got the chance to finally meet you in person. 99 | hope all is well with your loved ones. 100 | It's great seeing you! 101 | -------------------------------------------------------------------------------- /data/greetings-train.txt: -------------------------------------------------------------------------------- 1 | I have faith that things are going smoothly for you and yours 2 | long time no see! 3 | nice to see you too! 4 | I'm happy to see you! 5 | it was so good meeting you 6 | I'm stopping by to say hi 7 | I pray that things are going smoothly for you and yours 8 | I'm glad to see you! 9 | it's nice seeing you too! 10 | i hope life has been treating you good. 11 | it's great to finally meet in person. 12 | I'm stopping by to say hi 13 | i hope all is good with your family. 14 | nice to meet you too! 15 | I'm thrilled to see you! 16 | it's a pleasure to see you again! 17 | it's good to see you too! 18 | I'm happy to see you! 19 | i am stopping by to say hi 20 | I've heard a lot of good things about you and your family. 21 | hope all is well with your loved ones. 22 | nice to meet you! 23 | i'm stopping by to say hey! 24 | dropping in to let you know i'm thinking of you. 25 | nice seeing you again 26 | how's it going? 27 | "hi, how are you?" 28 | it's so nice to meet someone like yourself! 29 | welcome! 30 | I'm excited to see you! 31 | It's so nice to meet you! 32 | nice seeing you all again! 33 | I'm thrilled to see you all! 34 | hope all is well with your family. 35 | nice seeing you! 36 | it's nice to see you again! 37 | I just wanted to say hello 38 | I pray that things are going smoothly for you and yours 39 | nice to see you all again! 40 | nice to meet you in person 41 | how are you doing 42 | nice seeing you again. 43 | I've heard a lot of good things abou you. 44 | do you have any plans for the weekend 45 | it's good to see you again! 46 | good seeing everyone 47 | good to see you again! 48 | I'm happy to see you! 49 | I'm so happy to see you 50 | how's your day going? 51 | it's nice seeing you too! 52 | what's up? 53 | it's good seeing you again. 54 | "hello there, how's it going? hope things are good!" 55 | what is new 56 | I'm so happy to see you 57 | It's been a while since we last talked. 58 | hi there 59 | I'm happy to see you all! 60 | what's up? 61 | welcome! 62 | It's great seeing you again! 63 | it's good to see you again! 64 | it's nice to see you all again! 65 | I hope all is good with your family. 66 | what's up? 67 | i hope all is well with you and yours 68 | hey friend! 69 | good afternoon 70 | it's good seeing you! 71 | I trust that things are going smoothly for you and yours 72 | "hi there, how's it going? hope things are good!" 73 | good afternoon! 74 | It's exciting to see everyone 75 | hey pal! 76 | "hi there, how's it going? hope things are good!" 77 | I just wanted to say hi 78 | I'm thrilled to see you! 79 | I'm glad to see you 80 | I'm glad i got the chance to finally meet you in person. 81 | how is everything going with you and yours 82 | it's good to see you! 83 | it was so nice meeting you! 84 | hey mate! 85 | Hey team how is everyone doing? 86 | it's so good to finally meet in person. 87 | how are things going with you and your family 88 | I have faith that things are going smoothly for you and yours 89 | I'm excited to see you all! 90 | how's your day been? 91 | just wanted to say hi 92 | how have you been 93 | I'm excited to see you! 94 | great to see you again! 95 | great to see you again! 96 | I'm excited to see you! 97 | It's good to see you again! 98 | I've heard a lot about you and your family. 99 | just stopping by to say hello 100 | it's been too long! 101 | I'm excited to see you! 102 | good morning! 103 | how's your day going so far? 104 | good seeing you again 105 | I've heard a lot of good things about you and your family. 106 | hope all is good in the hood! 107 | it's good to see you as well! 108 | good night! 109 | hey mate! 110 | it's good to see you as well! 111 | I'm glad to see you all! 112 | hey buddy! 113 | Nice meeting you 114 | it's great to meet you. 115 | how are things going with you and your family? 116 | it's great to see you as well! 117 | I trust that things are going smoothly for you 118 | "hello there, how's it going? hope things are good!" 119 | how's it going? 120 | "hello there, how are you? hope things are good!" 121 | i'm happy I got the chance to finally meet you in person. 122 | I'm so thrilled to see you all! 123 | it's nice to see you too! 124 | I trust that things are going smoothly for you and yours 125 | I'm glad i got the chance to finally meet you in person. 126 | i am dropping in to say hey! 127 | hey good to see you! 128 | dropping in to let you know i'm thinking of you. 129 | it's great meeting you 130 | I'm happy to see you! 131 | I just stopped by to say hello 132 | "hello there, how are you? hope things are good!" 133 | how is your family doing 134 | I'm so thrilled to see you! 135 | I trust that things are going smoothly for you 136 | i am stopping by to say hi 137 | It's great to finally meet you in person 138 | it's great to see you too! 139 | it's great to see you as well! 140 | it's great seeing you! 141 | "hi there, how are you? hope things are good!" 142 | I just stopped by to say hello 143 | nice to see you all! 144 | I'm excited to see you all! 145 | I've heard nothing but good things about you. 146 | i am dropping in to say hey! 147 | it's wonderful meeting you 148 | how's it going? 149 | i hope everything is going well with you 150 | dropped in because I was thinking of you. 151 | it's nice to see you! 152 | good to see you again! 153 | nice seeing you all! 154 | I'm so excited to see you all! 155 | I'm thrilled to be here 156 | good morning 157 | I'm thrilled to see you all! 158 | just stopping by to say hello 159 | i hope life has been treating you good 160 | how are things? 161 | how's your day going? 162 | hey friend! 163 | I'm thrilled to see you! 164 | I'm glad i got the chance to finally meet in person! 165 | it's nice seeing you all again! 166 | I'm excited to see you! 167 | It's great seeing you! 168 | it's nice seeing you all! 169 | how are you doing? 170 | I've heard nothing but good things about you and your family. 171 | I've heard nothing but good things about you and your family. 172 | do you have any plans for today 173 | how are you doing? 174 | good evening 175 | I'm excited to see you! 176 | it's good to see you all again! 177 | I just wanted to say hi 178 | I've heard a lot about you. 179 | do you have any plans for the weekend 180 | dropped in because I was thinking of you. 181 | it's good to see you too! 182 | I'm excited to see you! 183 | great seeing you all! 184 | I'm glad to see you 185 | it's great to see you again! 186 | Just saying hello 187 | how is everything? 188 | good to see you again! 189 | nice to finally meet in person 190 | good seeing you all! 191 | good to see you! 192 | it's been a while! 193 | It's nice seeing you. 194 | good to see you all again! 195 | it's nice seeing you again! 196 | nice seeing you again 197 | great seeing you all again! 198 | Hope all is good in the hood! 199 | It was good meeting everyone 200 | I'm glad to see you all! 201 | hey buddy! 202 | I'm glad to see you! 203 | how is everyone doing 204 | I'm happy I got the chance to finally meet you in person. 205 | hello 206 | it's a pleasure to meet you! 207 | it was so nice meeting you 208 | i'm stopping by to say hey! 209 | Hope all is well with your loved ones. 210 | how's your day? 211 | greetings friend! how are you doing today? 212 | hi friend! how are things going with your family? 213 | it's great seeing you all again! 214 | hey my friend 215 | hey good to see you! 216 | how is everything going? 217 | how has your day been so far 218 | welcome back 219 | "greetings friend, how are you doing today?" 220 | it's good to see you all! 221 | nice to see you! 222 | good to see you all 223 | Hope all is well with your family. 224 | I'm glad to see you all! 225 | what's new 226 | great seeing you again 227 | Just saying hello 228 | it's nice meeting you 229 | nice meeting you! 230 | "hi, how are you?" 231 | how are things going? 232 | just saying hi 233 | just saying hi 234 | I'm happy to see you all! 235 | hope all is well 236 | it's great to see you again! 237 | i hope everything is going well for you 238 | It's great to see you again! 239 | good seeing you all again! 240 | it's nice seeing you again! 241 | It's good to see you! 242 | how are things everyone? 243 | nice seeing you! 244 | great seeing you! 245 | good to see you! 246 | it's great to see you! 247 | it's been a long time! 248 | how's it going? 249 | how is everyone doing? 250 | i hope life has been treating you right 251 | how is life treating you 252 | welcome back 253 | it's great to see you too! 254 | "hey there, nice to see you!" 255 | nice to see you again! 256 | Good to see you all! 257 | it's nice seeing you as well! 258 | it's great to see you again! 259 | nice seeing you again! 260 | I'm excited to see you all! 261 | good evening! 262 | Hey team how's it going? 263 | how have you been? 264 | how is everyone doing? i hope all's well. 265 | it's nice to see you as well! 266 | "hey team, how is everything going?" 267 | it's nice to see you all! 268 | "hi there, how are you? hope things are good!" 269 | hey pal! 270 | it's nice seeing you! 271 | I'm happy to see you all! 272 | how's your day? 273 | I'm excited to see you! 274 | I'm so excited to see you! 275 | just wanted to say hi 276 | good to see you again! 277 | greetings! how are you doing today? 278 | nice meeting you again. 279 | great seeing you 280 | I'm happy to see you 281 | it's good to see you again! 282 | It's nice to meet you! 283 | I'm glad to see you! 284 | I just wanted to say hello 285 | It was good meeting you 286 | welcome back! 287 | greetings! how are you doing today? 288 | It's great to see you! 289 | what's up? 290 | I'm glad to see you! 291 | hey my friend! 292 | I'm happy to see you 293 | it's great to meet someone like yourself! 294 | good seeing you 295 | I've heard a lot about you and your family. 296 | how is your family doing? 297 | it's great seeing you all! 298 | do you have any plans for today? 299 | -------------------------------------------------------------------------------- /data/greetings/greetings-test.txt: -------------------------------------------------------------------------------- 1 | it's nice seeing you all! 2 | I'm so thrilled to see you! 3 | I'm excited to see you all! 4 | what's up? 5 | how is your family doing 6 | i hope life has been treating you good 7 | i hope everything is going well for you 8 | It's great to see you again! 9 | i hope life has been treating you good. 10 | hope all is well with your family. 11 | what's up? 12 | good seeing you again 13 | great to see you again! 14 | I've heard a lot about you and your family. 15 | welcome! 16 | it's great to meet someone like yourself! 17 | it's great to see you! 18 | how's it going? 19 | it's good to see you! 20 | I'm glad to see you 21 | how are things going with you and your family 22 | it's nice to see you too! 23 | it's nice meeting you 24 | nice to see you too! 25 | it's good to see you too! 26 | I'm happy to see you all! 27 | long time no see! 28 | hi there 29 | good morning! 30 | I've heard a lot about you and your family. 31 | good seeing everyone 32 | do you have any plans for today 33 | nice seeing you all again! 34 | I'm so excited to see you! 35 | I hope all is good with your family. 36 | dropped in because I was thinking of you. 37 | it was so nice meeting you! 38 | I've heard a lot of good things about you and your family. 39 | good morning 40 | I'm happy I got the chance to finally meet you in person. 41 | nice to meet you! 42 | "hi there, how's it going? hope things are good!" 43 | good to see you again! 44 | good evening! 45 | nice seeing you! 46 | I'm happy to see you! 47 | good seeing you all again! 48 | nice to meet you too! 49 | dropping in to let you know i'm thinking of you. 50 | how are things going? 51 | i hope life has been treating you right 52 | I'm happy to see you all! 53 | just wanted to say hi 54 | how is everyone doing 55 | it's nice to see you! 56 | I've heard a lot of good things abou you. 57 | it's great to see you as well! 58 | how's your day? 59 | i hope all is well with you and yours 60 | I'm glad to see you all! 61 | "hello there, how's it going? hope things are good!" 62 | hey buddy! 63 | it's been a long time! 64 | great seeing you 65 | "hello there, how are you? hope things are good!" 66 | it's great to see you again! 67 | it's nice seeing you all again! 68 | greetings! how are you doing today? 69 | I'm glad to see you all! 70 | how have you been 71 | I'm excited to see you! 72 | it's good to see you all! 73 | Good to see you all! 74 | it's good to see you as well! 75 | I'm thrilled to see you all! 76 | it's great meeting you 77 | it's great to see you again! 78 | it's good to see you again! 79 | how is everyone doing? 80 | nice seeing you again! 81 | nice to see you all again! 82 | it's nice to see you all! 83 | I have faith that things are going smoothly for you and yours 84 | how is everything? 85 | it's great seeing you all! 86 | I'm thrilled to see you! 87 | how's your day going? 88 | nice to see you again! 89 | I just wanted to say hi 90 | what's up? 91 | I've heard a lot about you. 92 | hello 93 | It's great seeing you again! 94 | I'm glad to see you! 95 | it's so nice to meet someone like yourself! 96 | hi friend! how are things going with your family? 97 | hope all is good in the hood! 98 | I'm glad i got the chance to finally meet you in person. 99 | hope all is well with your loved ones. 100 | It's great seeing you! 101 | -------------------------------------------------------------------------------- /data/greetings/greetings-train.txt: -------------------------------------------------------------------------------- 1 | I have faith that things are going smoothly for you and yours 2 | long time no see! 3 | nice to see you too! 4 | I'm happy to see you! 5 | it was so good meeting you 6 | I'm stopping by to say hi 7 | I pray that things are going smoothly for you and yours 8 | I'm glad to see you! 9 | it's nice seeing you too! 10 | i hope life has been treating you good. 11 | it's great to finally meet in person. 12 | I'm stopping by to say hi 13 | i hope all is good with your family. 14 | nice to meet you too! 15 | I'm thrilled to see you! 16 | it's a pleasure to see you again! 17 | it's good to see you too! 18 | I'm happy to see you! 19 | i am stopping by to say hi 20 | I've heard a lot of good things about you and your family. 21 | hope all is well with your loved ones. 22 | nice to meet you! 23 | i'm stopping by to say hey! 24 | dropping in to let you know i'm thinking of you. 25 | nice seeing you again 26 | how's it going? 27 | "hi, how are you?" 28 | it's so nice to meet someone like yourself! 29 | welcome! 30 | I'm excited to see you! 31 | It's so nice to meet you! 32 | nice seeing you all again! 33 | I'm thrilled to see you all! 34 | hope all is well with your family. 35 | nice seeing you! 36 | it's nice to see you again! 37 | I just wanted to say hello 38 | I pray that things are going smoothly for you and yours 39 | nice to see you all again! 40 | nice to meet you in person 41 | how are you doing 42 | nice seeing you again. 43 | I've heard a lot of good things abou you. 44 | do you have any plans for the weekend 45 | it's good to see you again! 46 | good seeing everyone 47 | good to see you again! 48 | I'm happy to see you! 49 | I'm so happy to see you 50 | how's your day going? 51 | it's nice seeing you too! 52 | what's up? 53 | it's good seeing you again. 54 | "hello there, how's it going? hope things are good!" 55 | what is new 56 | I'm so happy to see you 57 | It's been a while since we last talked. 58 | hi there 59 | I'm happy to see you all! 60 | what's up? 61 | welcome! 62 | It's great seeing you again! 63 | it's good to see you again! 64 | it's nice to see you all again! 65 | I hope all is good with your family. 66 | what's up? 67 | i hope all is well with you and yours 68 | hey friend! 69 | good afternoon 70 | it's good seeing you! 71 | I trust that things are going smoothly for you and yours 72 | "hi there, how's it going? hope things are good!" 73 | good afternoon! 74 | It's exciting to see everyone 75 | hey pal! 76 | "hi there, how's it going? hope things are good!" 77 | I just wanted to say hi 78 | I'm thrilled to see you! 79 | I'm glad to see you 80 | I'm glad i got the chance to finally meet you in person. 81 | how is everything going with you and yours 82 | it's good to see you! 83 | it was so nice meeting you! 84 | hey mate! 85 | Hey team how is everyone doing? 86 | it's so good to finally meet in person. 87 | how are things going with you and your family 88 | I have faith that things are going smoothly for you and yours 89 | I'm excited to see you all! 90 | how's your day been? 91 | just wanted to say hi 92 | how have you been 93 | I'm excited to see you! 94 | great to see you again! 95 | great to see you again! 96 | I'm excited to see you! 97 | It's good to see you again! 98 | I've heard a lot about you and your family. 99 | just stopping by to say hello 100 | it's been too long! 101 | I'm excited to see you! 102 | good morning! 103 | how's your day going so far? 104 | good seeing you again 105 | I've heard a lot of good things about you and your family. 106 | hope all is good in the hood! 107 | it's good to see you as well! 108 | good night! 109 | hey mate! 110 | it's good to see you as well! 111 | I'm glad to see you all! 112 | hey buddy! 113 | Nice meeting you 114 | it's great to meet you. 115 | how are things going with you and your family? 116 | it's great to see you as well! 117 | I trust that things are going smoothly for you 118 | "hello there, how's it going? hope things are good!" 119 | how's it going? 120 | "hello there, how are you? hope things are good!" 121 | i'm happy I got the chance to finally meet you in person. 122 | I'm so thrilled to see you all! 123 | it's nice to see you too! 124 | I trust that things are going smoothly for you and yours 125 | I'm glad i got the chance to finally meet you in person. 126 | i am dropping in to say hey! 127 | hey good to see you! 128 | dropping in to let you know i'm thinking of you. 129 | it's great meeting you 130 | I'm happy to see you! 131 | I just stopped by to say hello 132 | "hello there, how are you? hope things are good!" 133 | how is your family doing 134 | I'm so thrilled to see you! 135 | I trust that things are going smoothly for you 136 | i am stopping by to say hi 137 | It's great to finally meet you in person 138 | it's great to see you too! 139 | it's great to see you as well! 140 | it's great seeing you! 141 | "hi there, how are you? hope things are good!" 142 | I just stopped by to say hello 143 | nice to see you all! 144 | I'm excited to see you all! 145 | I've heard nothing but good things about you. 146 | i am dropping in to say hey! 147 | it's wonderful meeting you 148 | how's it going? 149 | i hope everything is going well with you 150 | dropped in because I was thinking of you. 151 | it's nice to see you! 152 | good to see you again! 153 | nice seeing you all! 154 | I'm so excited to see you all! 155 | I'm thrilled to be here 156 | good morning 157 | I'm thrilled to see you all! 158 | just stopping by to say hello 159 | i hope life has been treating you good 160 | how are things? 161 | how's your day going? 162 | hey friend! 163 | I'm thrilled to see you! 164 | I'm glad i got the chance to finally meet in person! 165 | it's nice seeing you all again! 166 | I'm excited to see you! 167 | It's great seeing you! 168 | it's nice seeing you all! 169 | how are you doing? 170 | I've heard nothing but good things about you and your family. 171 | I've heard nothing but good things about you and your family. 172 | do you have any plans for today 173 | how are you doing? 174 | good evening 175 | I'm excited to see you! 176 | it's good to see you all again! 177 | I just wanted to say hi 178 | I've heard a lot about you. 179 | do you have any plans for the weekend 180 | dropped in because I was thinking of you. 181 | it's good to see you too! 182 | I'm excited to see you! 183 | great seeing you all! 184 | I'm glad to see you 185 | it's great to see you again! 186 | Just saying hello 187 | how is everything? 188 | good to see you again! 189 | nice to finally meet in person 190 | good seeing you all! 191 | good to see you! 192 | it's been a while! 193 | It's nice seeing you. 194 | good to see you all again! 195 | it's nice seeing you again! 196 | nice seeing you again 197 | great seeing you all again! 198 | Hope all is good in the hood! 199 | It was good meeting everyone 200 | I'm glad to see you all! 201 | hey buddy! 202 | I'm glad to see you! 203 | how is everyone doing 204 | I'm happy I got the chance to finally meet you in person. 205 | hello 206 | it's a pleasure to meet you! 207 | it was so nice meeting you 208 | i'm stopping by to say hey! 209 | Hope all is well with your loved ones. 210 | how's your day? 211 | greetings friend! how are you doing today? 212 | hi friend! how are things going with your family? 213 | it's great seeing you all again! 214 | hey my friend 215 | hey good to see you! 216 | how is everything going? 217 | how has your day been so far 218 | welcome back 219 | "greetings friend, how are you doing today?" 220 | it's good to see you all! 221 | nice to see you! 222 | good to see you all 223 | Hope all is well with your family. 224 | I'm glad to see you all! 225 | what's new 226 | great seeing you again 227 | Just saying hello 228 | it's nice meeting you 229 | nice meeting you! 230 | "hi, how are you?" 231 | how are things going? 232 | just saying hi 233 | just saying hi 234 | I'm happy to see you all! 235 | hope all is well 236 | it's great to see you again! 237 | i hope everything is going well for you 238 | It's great to see you again! 239 | good seeing you all again! 240 | it's nice seeing you again! 241 | It's good to see you! 242 | how are things everyone? 243 | nice seeing you! 244 | great seeing you! 245 | good to see you! 246 | it's great to see you! 247 | it's been a long time! 248 | how's it going? 249 | how is everyone doing? 250 | i hope life has been treating you right 251 | how is life treating you 252 | welcome back 253 | it's great to see you too! 254 | "hey there, nice to see you!" 255 | nice to see you again! 256 | Good to see you all! 257 | it's nice seeing you as well! 258 | it's great to see you again! 259 | nice seeing you again! 260 | I'm excited to see you all! 261 | good evening! 262 | Hey team how's it going? 263 | how have you been? 264 | how is everyone doing? i hope all's well. 265 | it's nice to see you as well! 266 | "hey team, how is everything going?" 267 | it's nice to see you all! 268 | "hi there, how are you? hope things are good!" 269 | hey pal! 270 | it's nice seeing you! 271 | I'm happy to see you all! 272 | how's your day? 273 | I'm excited to see you! 274 | I'm so excited to see you! 275 | just wanted to say hi 276 | good to see you again! 277 | greetings! how are you doing today? 278 | nice meeting you again. 279 | great seeing you 280 | I'm happy to see you 281 | it's good to see you again! 282 | It's nice to meet you! 283 | I'm glad to see you! 284 | I just wanted to say hello 285 | It was good meeting you 286 | welcome back! 287 | greetings! how are you doing today? 288 | It's great to see you! 289 | what's up? 290 | I'm glad to see you! 291 | hey my friend! 292 | I'm happy to see you 293 | it's great to meet someone like yourself! 294 | good seeing you 295 | I've heard a lot about you and your family. 296 | how is your family doing? 297 | it's great seeing you all! 298 | do you have any plans for today? 299 | -------------------------------------------------------------------------------- /data/greetings/greetings.py: -------------------------------------------------------------------------------- 1 | # list of 1000 greetings in Python 2 | greetings = [ 3 | "hi, how are you?", 4 | "hey good to see you!", 5 | "how's it going?", 6 | "what's up?", 7 | "how's your day?", 8 | "how's your day going?", 9 | "how's your day been?", 10 | "how's your day going so far?", 11 | "how is everyone doing?", 12 | "how are you doing?", 13 | "good to see you!", 14 | "good to see you again!", 15 | "nice to meet you!", 16 | "nice to see you!", 17 | "it's nice to see you again!", 18 | "it's a pleasure to meet you!", 19 | "it's a pleasure to see you again!", 20 | "nice to see you again!", 21 | "nice to see you too!", 22 | "nice to meet you too!", 23 | "long time no see!", 24 | "it's been a while!", 25 | "it's been a long time!", 26 | "it's been too long!", 27 | "welcome back!", 28 | "great seeing you!", 29 | "good to see you again!", 30 | "it's good to see you again!", 31 | "it's good to see you too!", 32 | "it's good to see you as well!", 33 | "it's great to see you again!", 34 | "it's great to see you too!", 35 | "it's great to see you as well!", 36 | "it's nice seeing you again!", 37 | "it's nice seeing you too!", 38 | "good to see you all", 39 | "good to see you all again!", 40 | "it's good to see you all!", 41 | "it's good to see you all again!", 42 | "it's nice to see you all!", 43 | "it's nice to see you all again!", 44 | "it's nice seeing you all!", 45 | "it's nice seeing you all again!", 46 | "nice to see you all!", 47 | "nice to see you all again!", 48 | "nice seeing you all!", 49 | "nice seeing you all again!", 50 | "good seeing you all!", 51 | "good seeing you all again!", 52 | "great seeing you all!", 53 | "great seeing you all again!", 54 | "it's great seeing you all!", 55 | "it's great seeing you all again!", 56 | "good to see you!", 57 | "it's good to see you!", 58 | "it's good to see you too!", 59 | "it's good to see you as well!", 60 | "it's nice to see you!", 61 | "it's nice to see you too!", 62 | "it's nice to see you as well!", 63 | "it's great to see you!", 64 | "it's great to see you too!", 65 | "it's great to see you as well!", 66 | "it's nice seeing you!", 67 | "it's nice seeing you too!", 68 | "it's nice seeing you as well!", 69 | "it's great seeing you!", 70 | "I'm thrilled to be here", 71 | "It's exciting to see everyone", 72 | "I'm so excited to see you all!", 73 | "I'm so excited to see you!", 74 | "I'm thrilled to see you all!", 75 | "I'm thrilled to see you!", 76 | "I'm so thrilled to see you all!", 77 | "I'm so thrilled to see you!", 78 | "I'm excited to see you all!", 79 | "I'm excited to see you!", 80 | "I'm excited to see you!", 81 | "I'm thrilled to see you all!", 82 | "I'm thrilled to see you!", 83 | "I'm thrilled to see you!", 84 | "I'm excited to see you all!", 85 | "I'm excited to see you!", 86 | "I'm excited to see you!", 87 | "I'm excited to see you all!", 88 | "I'm excited to see you!", 89 | "I'm excited to see you!", 90 | "I'm happy to see you!", 91 | "I'm happy to see you all!", 92 | "I'm happy to see you all!", 93 | "I'm happy to see you all!", 94 | "I'm happy to see you!", 95 | "I'm happy to see you!", 96 | "I'm happy to see you!", 97 | "I'm glad to see you!", 98 | "I'm glad to see you all!", 99 | "I'm glad to see you all!", 100 | "I'm glad to see you all!", 101 | "I'm glad to see you!", 102 | "I'm glad to see you!", 103 | "I'm glad to see you!", 104 | "nice seeing you again!", 105 | "it's nice seeing you again!", 106 | "how are things going?", 107 | "how are things?", 108 | "what's up?", 109 | "hi there", 110 | "hello", 111 | "good morning", 112 | "good afternoon", 113 | "good evening", 114 | "greetings friend! how are you doing today?", 115 | "greetings! how are you doing today?", 116 | "how are you doing?", 117 | "how's it going?", 118 | "how is everything going?", 119 | "how is everything?", 120 | "how is life treating you", 121 | "how has your day been so far", 122 | "hi there, how are you? hope things are good!", 123 | "hello there, how are you? hope things are good!", 124 | "hi there, how's it going? hope things are good!", 125 | "hello there, how's it going? hope things are good!", 126 | "just saying hi", 127 | "just wanted to say hi", 128 | "I just wanted to say hi", 129 | "Just saying hello", 130 | "I just wanted to say hello", 131 | "just stopping by to say hello", 132 | "I just stopped by to say hello", 133 | "i am stopping by to say hi", 134 | "I'm stopping by to say hi", 135 | "i'm stopping by to say hey!", 136 | "i am dropping in to say hey!", 137 | "dropping in to let you know i'm thinking of you.", 138 | "dropped in because I was thinking of you.", 139 | "hey friend!", 140 | "hey buddy!", 141 | "hey pal!", 142 | "hey mate!", 143 | "hey my friend!", 144 | "what is new", 145 | "how have you been", 146 | "how are things going with you and your family", 147 | "how is your family doing", 148 | "do you have any plans for today", 149 | "do you have any plans for the weekend", 150 | "what's new", 151 | "what's up?", 152 | "how have you been?", 153 | "how are things going with you and your family?", 154 | "how is your family doing?", 155 | "do you have any plans for today?", 156 | "do you have any plans for the weekend", 157 | "hope all is well", 158 | "i hope all is well with you and yours", 159 | "i hope life has been treating you good", 160 | "i hope life has been treating you right", 161 | "i hope everything is going well for you", 162 | "i hope everything is going well with you", 163 | "i hope life has been treating you good.", 164 | "i hope all is good with your family.", 165 | "hope all is well with your family.", 166 | "hope all is well with your loved ones.", 167 | "hope all is good in the hood!", 168 | "I trust that things are going smoothly for you", 169 | "I trust that things are going smoothly for you and yours", 170 | "I have faith that things are going smoothly for you and yours", 171 | "I pray that things are going smoothly for you and yours", 172 | "I hope all is good with your family.", 173 | "Hope all is well with your family.", 174 | "Hope all is well with your loved ones.", 175 | "Hope all is good in the hood!", 176 | "I trust that things are going smoothly for you", 177 | "I trust that things are going smoothly for you and yours", 178 | "I have faith that things are going smoothly for you and yours", 179 | "I pray that things are going smoothly for you and yours", 180 | "welcome!", 181 | "welcome back", 182 | "it's good to see you again!", 183 | "good to see you again!", 184 | "it's great to see you again!", 185 | "great to see you again!", 186 | "nice seeing you!", 187 | "nice seeing you again", 188 | "I'm glad to see you", 189 | "I'm happy to see you", 190 | "I'm so happy to see you", 191 | "it's so good to finally meet in person.", 192 | "I've heard a lot about you.", 193 | "I've heard a lot of good things abou you.", 194 | "I've heard nothing but good things about you.", 195 | "I've heard a lot about you and your family.", 196 | "I've heard a lot of good things about you and your family.", 197 | "I've heard nothing but good things about you and your family.", 198 | "nice to meet you in person", 199 | "nice to finally meet in person", 200 | "it's nice meeting you", 201 | "it's great meeting you", 202 | "it's wonderful meeting you", 203 | "it's so nice to meet someone like yourself!", 204 | "it's great to meet someone like yourself!", 205 | "it was so nice meeting you", 206 | "it was so good meeting you", 207 | "I'm glad i got the chance to finally meet you in person.", 208 | "i'm happy I got the chance to finally meet you in person.", 209 | "I've heard a lot about you and your family.", 210 | "I've heard a lot of good things about you and your family.", 211 | "I've heard nothing but good things about you and your family.", 212 | "hi, how are you?", 213 | "hey good to see you!", 214 | "how's it going?", 215 | "what's up?", 216 | "how's your day?", 217 | "how's your day going?", 218 | "how is everyone doing", 219 | "good morning!", 220 | "good afternoon!", 221 | "good evening!", 222 | "good night!", 223 | "greetings friend, how are you doing today?", 224 | "greetings! how are you doing today?", 225 | "how are you doing", 226 | "how's it going?", 227 | "how is everything going with you and yours", 228 | "hi there, how are you? hope things are good!", 229 | "hello there, how are you? hope things are good!", 230 | "hi there, how's it going? hope things are good!", 231 | "hello there, how's it going? hope things are good!", 232 | "just saying hi", 233 | "just wanted to say hi", 234 | "I just wanted to say hi", 235 | "Just saying hello", 236 | "I just wanted to say hello", 237 | "just stopping by to say hello", 238 | "I just stopped by to say hello", 239 | "i am stopping by to say hi", 240 | "I'm stopping by to say hi", 241 | "i'm stopping by to say hey!", 242 | "i am dropping in to say hey!", 243 | "dropping in to let you know i'm thinking of you.", 244 | "dropped in because I was thinking of you.", 245 | "hey friend!", 246 | "hey buddy!", 247 | "hey pal!", 248 | "hey mate!", 249 | "hey my friend", 250 | "welcome!", 251 | "welcome back", 252 | "it's good to see you again!", 253 | "good to see you again!", 254 | "it's great to see you again!", 255 | "great to see you again!", 256 | "nice seeing you!", 257 | "nice seeing you again", 258 | "I'm glad to see you", 259 | "I'm happy to see you", 260 | "I'm so happy to see you", 261 | "it's good seeing you!", 262 | "It's great seeing you!", 263 | "It's great to see you!", 264 | "It's good to see you!", 265 | "good seeing you", 266 | "great seeing you", 267 | "nice meeting you again.", 268 | "nice seeing you again.", 269 | "it's good seeing you again.", 270 | "It's great seeing you again!", 271 | "It's great to see you again!", 272 | "It's good to see you again!", 273 | "good seeing you again", 274 | "great seeing you again", 275 | "nice meeting you!", 276 | "It's nice to meet you!", 277 | "It's so nice to meet you!", 278 | "it's great to finally meet in person.", 279 | "It's great to finally meet you in person", 280 | "it was so nice meeting you!", 281 | "I'm glad i got the chance to finally meet you in person.", 282 | "I'm happy I got the chance to finally meet you in person.", 283 | "it's great to meet you.", 284 | "It was good meeting you", 285 | "I'm glad i got the chance to finally meet in person!", 286 | "I'm excited to see you!", 287 | "It's nice seeing you.", 288 | "Nice meeting you", 289 | "Hey team how is everyone doing?", 290 | "Hey team how's it going?", 291 | "Good to see you all!", 292 | "It was good meeting everyone", 293 | "It's been a while since we last talked.", 294 | "I'm excited to see you!", 295 | "good seeing everyone", 296 | "hey team, how is everything going?", 297 | "how are things everyone?", 298 | "hey there, nice to see you!", 299 | "hi friend! how are things going with your family?", 300 | "how is everyone doing? i hope all's well.", 301 | ] 302 | 303 | import pandas as pd 304 | data = pd.DataFrame({"greeting": greetings}) 305 | data = data.sample(frac=1).reset_index(drop=True) 306 | data.to_csv("greetings.txt", index=False) -------------------------------------------------------------------------------- /data/greetings/greetings.txt: -------------------------------------------------------------------------------- 1 | I have faith that things are going smoothly for you and yours 2 | long time no see! 3 | nice to see you too! 4 | I'm happy to see you! 5 | it was so good meeting you 6 | I'm stopping by to say hi 7 | I pray that things are going smoothly for you and yours 8 | I'm glad to see you! 9 | it's nice seeing you too! 10 | i hope life has been treating you good. 11 | it's great to finally meet in person. 12 | I'm stopping by to say hi 13 | i hope all is good with your family. 14 | nice to meet you too! 15 | I'm thrilled to see you! 16 | it's a pleasure to see you again! 17 | it's good to see you too! 18 | I'm happy to see you! 19 | i am stopping by to say hi 20 | I've heard a lot of good things about you and your family. 21 | hope all is well with your loved ones. 22 | nice to meet you! 23 | i'm stopping by to say hey! 24 | dropping in to let you know i'm thinking of you. 25 | nice seeing you again 26 | how's it going? 27 | "hi, how are you?" 28 | it's so nice to meet someone like yourself! 29 | welcome! 30 | I'm excited to see you! 31 | It's so nice to meet you! 32 | nice seeing you all again! 33 | I'm thrilled to see you all! 34 | hope all is well with your family. 35 | nice seeing you! 36 | it's nice to see you again! 37 | I just wanted to say hello 38 | I pray that things are going smoothly for you and yours 39 | nice to see you all again! 40 | nice to meet you in person 41 | how are you doing 42 | nice seeing you again. 43 | I've heard a lot of good things abou you. 44 | do you have any plans for the weekend 45 | it's good to see you again! 46 | good seeing everyone 47 | good to see you again! 48 | I'm happy to see you! 49 | I'm so happy to see you 50 | how's your day going? 51 | it's nice seeing you too! 52 | what's up? 53 | it's good seeing you again. 54 | "hello there, how's it going? hope things are good!" 55 | what is new 56 | I'm so happy to see you 57 | It's been a while since we last talked. 58 | hi there 59 | I'm happy to see you all! 60 | what's up? 61 | welcome! 62 | It's great seeing you again! 63 | it's good to see you again! 64 | it's nice to see you all again! 65 | I hope all is good with your family. 66 | what's up? 67 | i hope all is well with you and yours 68 | hey friend! 69 | good afternoon 70 | it's good seeing you! 71 | I trust that things are going smoothly for you and yours 72 | "hi there, how's it going? hope things are good!" 73 | good afternoon! 74 | It's exciting to see everyone 75 | hey pal! 76 | "hi there, how's it going? hope things are good!" 77 | I just wanted to say hi 78 | I'm thrilled to see you! 79 | I'm glad to see you 80 | I'm glad i got the chance to finally meet you in person. 81 | how is everything going with you and yours 82 | it's good to see you! 83 | it was so nice meeting you! 84 | hey mate! 85 | Hey team how is everyone doing? 86 | it's so good to finally meet in person. 87 | how are things going with you and your family 88 | I have faith that things are going smoothly for you and yours 89 | I'm excited to see you all! 90 | how's your day been? 91 | just wanted to say hi 92 | how have you been 93 | I'm excited to see you! 94 | great to see you again! 95 | great to see you again! 96 | I'm excited to see you! 97 | It's good to see you again! 98 | I've heard a lot about you and your family. 99 | just stopping by to say hello 100 | it's been too long! 101 | I'm excited to see you! 102 | good morning! 103 | how's your day going so far? 104 | good seeing you again 105 | I've heard a lot of good things about you and your family. 106 | hope all is good in the hood! 107 | it's good to see you as well! 108 | good night! 109 | hey mate! 110 | it's good to see you as well! 111 | I'm glad to see you all! 112 | hey buddy! 113 | Nice meeting you 114 | it's great to meet you. 115 | how are things going with you and your family? 116 | it's great to see you as well! 117 | I trust that things are going smoothly for you 118 | "hello there, how's it going? hope things are good!" 119 | how's it going? 120 | "hello there, how are you? hope things are good!" 121 | i'm happy I got the chance to finally meet you in person. 122 | I'm so thrilled to see you all! 123 | it's nice to see you too! 124 | I trust that things are going smoothly for you and yours 125 | I'm glad i got the chance to finally meet you in person. 126 | i am dropping in to say hey! 127 | hey good to see you! 128 | dropping in to let you know i'm thinking of you. 129 | it's great meeting you 130 | I'm happy to see you! 131 | I just stopped by to say hello 132 | "hello there, how are you? hope things are good!" 133 | how is your family doing 134 | I'm so thrilled to see you! 135 | I trust that things are going smoothly for you 136 | i am stopping by to say hi 137 | It's great to finally meet you in person 138 | it's great to see you too! 139 | it's great to see you as well! 140 | it's great seeing you! 141 | "hi there, how are you? hope things are good!" 142 | I just stopped by to say hello 143 | nice to see you all! 144 | I'm excited to see you all! 145 | I've heard nothing but good things about you. 146 | i am dropping in to say hey! 147 | it's wonderful meeting you 148 | how's it going? 149 | i hope everything is going well with you 150 | dropped in because I was thinking of you. 151 | it's nice to see you! 152 | good to see you again! 153 | nice seeing you all! 154 | I'm so excited to see you all! 155 | I'm thrilled to be here 156 | good morning 157 | I'm thrilled to see you all! 158 | just stopping by to say hello 159 | i hope life has been treating you good 160 | how are things? 161 | how's your day going? 162 | hey friend! 163 | I'm thrilled to see you! 164 | I'm glad i got the chance to finally meet in person! 165 | it's nice seeing you all again! 166 | I'm excited to see you! 167 | It's great seeing you! 168 | it's nice seeing you all! 169 | how are you doing? 170 | I've heard nothing but good things about you and your family. 171 | I've heard nothing but good things about you and your family. 172 | do you have any plans for today 173 | how are you doing? 174 | good evening 175 | I'm excited to see you! 176 | it's good to see you all again! 177 | I just wanted to say hi 178 | I've heard a lot about you. 179 | do you have any plans for the weekend 180 | dropped in because I was thinking of you. 181 | it's good to see you too! 182 | I'm excited to see you! 183 | great seeing you all! 184 | I'm glad to see you 185 | it's great to see you again! 186 | Just saying hello 187 | how is everything? 188 | good to see you again! 189 | nice to finally meet in person 190 | good seeing you all! 191 | good to see you! 192 | it's been a while! 193 | It's nice seeing you. 194 | good to see you all again! 195 | it's nice seeing you again! 196 | nice seeing you again 197 | great seeing you all again! 198 | Hope all is good in the hood! 199 | It was good meeting everyone 200 | I'm glad to see you all! 201 | hey buddy! 202 | I'm glad to see you! 203 | how is everyone doing 204 | I'm happy I got the chance to finally meet you in person. 205 | hello 206 | it's a pleasure to meet you! 207 | it was so nice meeting you 208 | i'm stopping by to say hey! 209 | Hope all is well with your loved ones. 210 | how's your day? 211 | greetings friend! how are you doing today? 212 | hi friend! how are things going with your family? 213 | it's great seeing you all again! 214 | hey my friend 215 | hey good to see you! 216 | how is everything going? 217 | how has your day been so far 218 | welcome back 219 | "greetings friend, how are you doing today?" 220 | it's good to see you all! 221 | nice to see you! 222 | good to see you all 223 | Hope all is well with your family. 224 | I'm glad to see you all! 225 | what's new 226 | great seeing you again 227 | Just saying hello 228 | it's nice meeting you 229 | nice meeting you! 230 | "hi, how are you?" 231 | how are things going? 232 | just saying hi 233 | just saying hi 234 | I'm happy to see you all! 235 | hope all is well 236 | it's great to see you again! 237 | i hope everything is going well for you 238 | It's great to see you again! 239 | good seeing you all again! 240 | it's nice seeing you again! 241 | It's good to see you! 242 | how are things everyone? 243 | nice seeing you! 244 | great seeing you! 245 | good to see you! 246 | it's great to see you! 247 | it's been a long time! 248 | how's it going? 249 | how is everyone doing? 250 | i hope life has been treating you right 251 | how is life treating you 252 | welcome back 253 | it's great to see you too! 254 | "hey there, nice to see you!" 255 | nice to see you again! 256 | Good to see you all! 257 | it's nice seeing you as well! 258 | it's great to see you again! 259 | nice seeing you again! 260 | I'm excited to see you all! 261 | good evening! 262 | Hey team how's it going? 263 | how have you been? 264 | how is everyone doing? i hope all's well. 265 | it's nice to see you as well! 266 | "hey team, how is everything going?" 267 | it's nice to see you all! 268 | "hi there, how are you? hope things are good!" 269 | hey pal! 270 | it's nice seeing you! 271 | I'm happy to see you all! 272 | how's your day? 273 | I'm excited to see you! 274 | I'm so excited to see you! 275 | just wanted to say hi 276 | good to see you again! 277 | greetings! how are you doing today? 278 | nice meeting you again. 279 | great seeing you 280 | I'm happy to see you 281 | it's good to see you again! 282 | It's nice to meet you! 283 | I'm glad to see you! 284 | I just wanted to say hello 285 | It was good meeting you 286 | welcome back! 287 | greetings! how are you doing today? 288 | It's great to see you! 289 | what's up? 290 | I'm glad to see you! 291 | hey my friend! 292 | I'm happy to see you 293 | it's great to meet someone like yourself! 294 | good seeing you 295 | I've heard a lot about you and your family. 296 | how is your family doing? 297 | it's great seeing you all! 298 | do you have any plans for today? 299 | -------------------------------------------------------------------------------- /data/greetings/greetings_labeled.tsv: -------------------------------------------------------------------------------- 1 | I have faith that things are going smoothly for you and yours 0 2 | long time no see! 0 3 | nice to see you too! 0 4 | I'm happy to see you! 0 5 | it was so good meeting you 1 6 | I'm stopping by to say hi 0 7 | I pray that things are going smoothly for you and yours 0 8 | I'm glad to see you! 0 9 | it's nice seeing you too! 0 10 | i hope life has been treating you good. 1 11 | it's great to finally meet in person. 0 12 | I'm stopping by to say hi 0 13 | i hope all is good with your family. 1 14 | nice to meet you too! 0 15 | I'm thrilled to see you! 0 16 | it's a pleasure to see you again! 0 17 | it's good to see you too! 1 18 | I'm happy to see you! 0 19 | i am stopping by to say hi 0 20 | I've heard a lot of good things about you and your family. 1 21 | hope all is well with your loved ones. 0 22 | nice to meet you! 0 23 | i'm stopping by to say hey! 0 24 | dropping in to let you know i'm thinking of you. 0 25 | nice seeing you again 0 26 | how's it going? 0 27 | hi, how are you? 0 28 | it's so nice to meet someone like yourself! 0 29 | welcome! 0 30 | I'm excited to see you! 0 31 | It's so nice to meet you! 0 32 | nice seeing you all again! 0 33 | I'm thrilled to see you all! 0 34 | hope all is well with your family. 0 35 | nice seeing you! 0 36 | it's nice to see you again! 0 37 | I just wanted to say hello 0 38 | I pray that things are going smoothly for you and yours 0 39 | nice to see you all again! 0 40 | nice to meet you in person 0 41 | how are you doing 0 42 | nice seeing you again. 0 43 | I've heard a lot of good things abou you. 1 44 | do you have any plans for the weekend 0 45 | it's good to see you again! 1 46 | good seeing everyone 1 47 | good to see you again! 1 48 | I'm happy to see you! 0 49 | I'm so happy to see you 0 50 | how's your day going? 0 51 | it's nice seeing you too! 0 52 | what's up? 0 53 | it's good seeing you again. 1 54 | hello there, how's it going? hope things are good! 1 55 | what is new 0 56 | I'm so happy to see you 0 57 | It's been a while since we last talked. 0 58 | hi there 0 59 | I'm happy to see you all! 0 60 | what's up? 0 61 | welcome! 0 62 | It's great seeing you again! 0 63 | it's good to see you again! 1 64 | it's nice to see you all again! 0 65 | I hope all is good with your family. 1 66 | what's up? 0 67 | i hope all is well with you and yours 0 68 | hey friend! 0 69 | good afternoon 1 70 | it's good seeing you! 1 71 | I trust that things are going smoothly for you and yours 0 72 | hi there, how's it going? hope things are good! 1 73 | good afternoon! 1 74 | It's exciting to see everyone 0 75 | hey pal! 0 76 | hi there, how's it going? hope things are good! 1 77 | I just wanted to say hi 0 78 | I'm thrilled to see you! 0 79 | I'm glad to see you 0 80 | I'm glad i got the chance to finally meet you in person. 0 81 | how is everything going with you and yours 0 82 | it's good to see you! 1 83 | it was so nice meeting you! 0 84 | hey mate! 0 85 | Hey team how is everyone doing? 0 86 | it's so good to finally meet in person. 1 87 | how are things going with you and your family 0 88 | I have faith that things are going smoothly for you and yours 0 89 | I'm excited to see you all! 0 90 | how's your day been? 0 91 | just wanted to say hi 0 92 | how have you been 0 93 | I'm excited to see you! 0 94 | great to see you again! 0 95 | great to see you again! 0 96 | I'm excited to see you! 0 97 | It's good to see you again! 1 98 | I've heard a lot about you and your family. 0 99 | just stopping by to say hello 0 100 | it's been too long! 0 101 | I'm excited to see you! 0 102 | good morning! 1 103 | how's your day going so far? 0 104 | good seeing you again 1 105 | I've heard a lot of good things about you and your family. 1 106 | hope all is good in the hood! 1 107 | it's good to see you as well! 1 108 | good night! 1 109 | hey mate! 0 110 | it's good to see you as well! 1 111 | I'm glad to see you all! 0 112 | hey buddy! 0 113 | Nice meeting you 0 114 | it's great to meet you. 0 115 | how are things going with you and your family? 0 116 | it's great to see you as well! 0 117 | I trust that things are going smoothly for you 0 118 | hello there, how's it going? hope things are good! 1 119 | how's it going? 0 120 | hello there, how are you? hope things are good! 1 121 | i'm happy I got the chance to finally meet you in person. 0 122 | I'm so thrilled to see you all! 0 123 | it's nice to see you too! 0 124 | I trust that things are going smoothly for you and yours 0 125 | I'm glad i got the chance to finally meet you in person. 0 126 | i am dropping in to say hey! 0 127 | hey good to see you! 1 128 | dropping in to let you know i'm thinking of you. 0 129 | it's great meeting you 0 130 | I'm happy to see you! 0 131 | I just stopped by to say hello 0 132 | hello there, how are you? hope things are good! 1 133 | how is your family doing 0 134 | I'm so thrilled to see you! 0 135 | I trust that things are going smoothly for you 0 136 | i am stopping by to say hi 0 137 | It's great to finally meet you in person 0 138 | it's great to see you too! 0 139 | it's great to see you as well! 0 140 | it's great seeing you! 0 141 | hi there, how are you? hope things are good! 1 142 | I just stopped by to say hello 0 143 | nice to see you all! 0 144 | I'm excited to see you all! 0 145 | I've heard nothing but good things about you. 1 146 | i am dropping in to say hey! 0 147 | it's wonderful meeting you 0 148 | how's it going? 0 149 | i hope everything is going well with you 0 150 | dropped in because I was thinking of you. 0 151 | it's nice to see you! 0 152 | good to see you again! 1 153 | nice seeing you all! 0 154 | I'm so excited to see you all! 0 155 | I'm thrilled to be here 0 156 | good morning 1 157 | I'm thrilled to see you all! 0 158 | just stopping by to say hello 0 159 | i hope life has been treating you good 1 160 | how are things? 0 161 | how's your day going? 0 162 | hey friend! 0 163 | I'm thrilled to see you! 0 164 | I'm glad i got the chance to finally meet in person! 0 165 | it's nice seeing you all again! 0 166 | I'm excited to see you! 0 167 | It's great seeing you! 0 168 | it's nice seeing you all! 0 169 | how are you doing? 0 170 | I've heard nothing but good things about you and your family. 1 171 | I've heard nothing but good things about you and your family. 1 172 | do you have any plans for today 0 173 | how are you doing? 0 174 | good evening 1 175 | I'm excited to see you! 0 176 | it's good to see you all again! 1 177 | I just wanted to say hi 0 178 | I've heard a lot about you. 0 179 | do you have any plans for the weekend 0 180 | dropped in because I was thinking of you. 0 181 | it's good to see you too! 1 182 | I'm excited to see you! 0 183 | great seeing you all! 0 184 | I'm glad to see you 0 185 | it's great to see you again! 0 186 | Just saying hello 0 187 | how is everything? 0 188 | good to see you again! 1 189 | nice to finally meet in person 0 190 | good seeing you all! 1 191 | good to see you! 1 192 | it's been a while! 0 193 | It's nice seeing you. 0 194 | good to see you all again! 1 195 | it's nice seeing you again! 0 196 | nice seeing you again 0 197 | great seeing you all again! 0 198 | Hope all is good in the hood! 1 199 | It was good meeting everyone 1 200 | I'm glad to see you all! 0 201 | hey buddy! 0 202 | I'm glad to see you! 0 203 | how is everyone doing 0 204 | I'm happy I got the chance to finally meet you in person. 0 205 | hello 0 206 | it's a pleasure to meet you! 0 207 | it was so nice meeting you 0 208 | i'm stopping by to say hey! 0 209 | Hope all is well with your loved ones. 0 210 | how's your day? 0 211 | greetings friend! how are you doing today? 0 212 | hi friend! how are things going with your family? 0 213 | it's great seeing you all again! 0 214 | hey my friend 0 215 | hey good to see you! 1 216 | how is everything going? 0 217 | how has your day been so far 0 218 | welcome back 0 219 | greetings friend, how are you doing today? 0 220 | it's good to see you all! 1 221 | nice to see you! 0 222 | good to see you all 1 223 | Hope all is well with your family. 0 224 | I'm glad to see you all! 0 225 | what's new 0 226 | great seeing you again 0 227 | Just saying hello 0 228 | it's nice meeting you 0 229 | nice meeting you! 0 230 | hi, how are you? 0 231 | how are things going? 0 232 | just saying hi 0 233 | just saying hi 0 234 | I'm happy to see you all! 0 235 | hope all is well 0 236 | it's great to see you again! 0 237 | i hope everything is going well for you 0 238 | It's great to see you again! 0 239 | good seeing you all again! 1 240 | it's nice seeing you again! 0 241 | It's good to see you! 1 242 | how are things everyone? 0 243 | nice seeing you! 0 244 | great seeing you! 0 245 | good to see you! 1 246 | it's great to see you! 0 247 | it's been a long time! 0 248 | how's it going? 0 249 | how is everyone doing? 0 250 | i hope life has been treating you right 0 251 | how is life treating you 0 252 | welcome back 0 253 | it's great to see you too! 0 254 | hey there, nice to see you! 0 255 | nice to see you again! 0 256 | Good to see you all! 0 257 | it's nice seeing you as well! 0 258 | it's great to see you again! 0 259 | nice seeing you again! 0 260 | I'm excited to see you all! 0 261 | good evening! 1 262 | Hey team how's it going? 0 263 | how have you been? 0 264 | how is everyone doing? i hope all's well. 0 265 | it's nice to see you as well! 0 266 | hey team, how is everything going? 0 267 | it's nice to see you all! 0 268 | hi there, how are you? hope things are good! 1 269 | hey pal! 0 270 | it's nice seeing you! 0 271 | I'm happy to see you all! 0 272 | how's your day? 0 273 | I'm excited to see you! 0 274 | I'm so excited to see you! 0 275 | just wanted to say hi 0 276 | good to see you again! 1 277 | greetings! how are you doing today? 0 278 | nice meeting you again. 0 279 | great seeing you 0 280 | I'm happy to see you 0 281 | it's good to see you again! 1 282 | It's nice to meet you! 0 283 | I'm glad to see you! 0 284 | I just wanted to say hello 0 285 | It was good meeting you 1 286 | welcome back! 0 287 | greetings! how are you doing today? 0 288 | It's great to see you! 0 289 | what's up? 0 290 | I'm glad to see you! 0 291 | hey my friend! 0 292 | I'm happy to see you 0 293 | it's great to meet someone like yourself! 0 294 | good seeing you 1 295 | I've heard a lot about you and your family. 0 296 | how is your family doing? 0 297 | it's great seeing you all! 0 298 | do you have any plans for today? 0 299 | -------------------------------------------------------------------------------- /data/greetings/word-level-vocab.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "1.0", 3 | "truncation": { 4 | "direction": "Right", 5 | "max_length": 512, 6 | "strategy": "LongestFirst", 7 | "stride": 0 8 | }, 9 | "padding": null, 10 | "added_tokens": [ 11 | { 12 | "id": 0, 13 | "content": "[UNK]", 14 | "single_word": false, 15 | "lstrip": false, 16 | "rstrip": false, 17 | "normalized": false, 18 | "special": true 19 | }, 20 | { 21 | "id": 1, 22 | "content": "[CLS]", 23 | "single_word": false, 24 | "lstrip": false, 25 | "rstrip": false, 26 | "normalized": false, 27 | "special": true 28 | }, 29 | { 30 | "id": 2, 31 | "content": "[SEP]", 32 | "single_word": false, 33 | "lstrip": false, 34 | "rstrip": false, 35 | "normalized": false, 36 | "special": true 37 | }, 38 | { 39 | "id": 3, 40 | "content": "[PAD]", 41 | "single_word": false, 42 | "lstrip": false, 43 | "rstrip": false, 44 | "normalized": false, 45 | "special": true 46 | }, 47 | { 48 | "id": 4, 49 | "content": "[MASK]", 50 | "single_word": false, 51 | "lstrip": false, 52 | "rstrip": false, 53 | "normalized": false, 54 | "special": true 55 | } 56 | ], 57 | "normalizer": { 58 | "type": "Sequence", 59 | "normalizers": [ 60 | { 61 | "type": "NFD" 62 | }, 63 | { 64 | "type": "Lowercase" 65 | }, 66 | { 67 | "type": "StripAccents" 68 | } 69 | ] 70 | }, 71 | "pre_tokenizer": { 72 | "type": "Sequence", 73 | "pretokenizers": [ 74 | { 75 | "type": "Digits", 76 | "individual_digits": true 77 | }, 78 | { 79 | "type": "Whitespace" 80 | } 81 | ] 82 | }, 83 | "post_processor": { 84 | "type": "TemplateProcessing", 85 | "single": [ 86 | { 87 | "SpecialToken": { 88 | "id": "[CLS]", 89 | "type_id": 0 90 | } 91 | }, 92 | { 93 | "Sequence": { 94 | "id": "A", 95 | "type_id": 0 96 | } 97 | }, 98 | { 99 | "SpecialToken": { 100 | "id": "[SEP]", 101 | "type_id": 0 102 | } 103 | } 104 | ], 105 | "pair": [ 106 | { 107 | "Sequence": { 108 | "id": "A", 109 | "type_id": 0 110 | } 111 | }, 112 | { 113 | "Sequence": { 114 | "id": "B", 115 | "type_id": 1 116 | } 117 | } 118 | ], 119 | "special_tokens": { 120 | "[CLS]": { 121 | "id": "[CLS]", 122 | "ids": [ 123 | 1 124 | ], 125 | "tokens": [ 126 | "[CLS]" 127 | ] 128 | }, 129 | "[SEP]": { 130 | "id": "[SEP]", 131 | "ids": [ 132 | 2 133 | ], 134 | "tokens": [ 135 | "[SEP]" 136 | ] 137 | } 138 | } 139 | }, 140 | "decoder": null, 141 | "model": { 142 | "type": "WordLevel", 143 | "vocab": { 144 | "[UNK]": 0, 145 | "[CLS]": 1, 146 | "[SEP]": 2, 147 | "[PAD]": 3, 148 | "[MASK]": 4, 149 | "you": 5, 150 | "!": 6, 151 | "'": 7, 152 | "to": 8, 153 | "i": 9, 154 | "see": 10, 155 | "s": 11, 156 | "it": 12, 157 | "good": 13, 158 | "m": 14, 159 | "how": 15, 160 | "all": 16, 161 | "nice": 17, 162 | "?": 18, 163 | "again": 19, 164 | "are": 20, 165 | "seeing": 21, 166 | ".": 22, 167 | "going": 23, 168 | "things": 24, 169 | "great": 25, 170 | "hope": 26, 171 | "is": 27, 172 | "your": 28, 173 | "hey": 29, 174 | "hi": 30, 175 | "in": 31, 176 | "meet": 32, 177 | "say": 33, 178 | "and": 34, 179 | "family": 35, 180 | "well": 36, 181 | "just": 37, 182 | "so": 38, 183 | "\"": 39, 184 | ",": 40, 185 | "doing": 41, 186 | "excited": 42, 187 | "for": 43, 188 | "happy": 44, 189 | "hello": 45, 190 | "glad": 46, 191 | "with": 47, 192 | "a": 48, 193 | "been": 49, 194 | "meeting": 50, 195 | "by": 51, 196 | "person": 52, 197 | "there": 53, 198 | "too": 54, 199 | "!\"": 55, 200 | "finally": 56, 201 | "heard": 57, 202 | "the": 58, 203 | "ve": 59, 204 | "about": 60, 205 | "everyone": 61, 206 | "have": 62, 207 | "smoothly": 63, 208 | "stopping": 64, 209 | "that": 65, 210 | "thrilled": 66, 211 | "yours": 67, 212 | "day": 68, 213 | "friend": 69, 214 | "of": 70, 215 | "was": 71, 216 | "as": 72, 217 | "everything": 73, 218 | "lot": 74, 219 | "today": 75, 220 | "wanted": 76, 221 | "what": 77, 222 | "chance": 78, 223 | "got": 79, 224 | "welcome": 80, 225 | "?\"": 81, 226 | "am": 82, 227 | "any": 83, 228 | "do": 84, 229 | "dropping": 85, 230 | "greetings": 86, 231 | "has": 87, 232 | "life": 88, 233 | "plans": 89, 234 | "saying": 90, 235 | "thinking": 91, 236 | "treating": 92, 237 | "trust": 93, 238 | "up": 94, 239 | "back": 95, 240 | "but": 96, 241 | "long": 97, 242 | "nothing": 98, 243 | "team": 99, 244 | "afternoon": 100, 245 | "because": 101, 246 | "buddy": 102, 247 | "dropped": 103, 248 | "evening": 104, 249 | "faith": 105, 250 | "far": 106, 251 | "hood": 107, 252 | "know": 108, 253 | "let": 109, 254 | "like": 110, 255 | "loved": 111, 256 | "mate": 112, 257 | "morning": 113, 258 | "my": 114, 259 | "new": 115, 260 | "ones": 116, 261 | "pal": 117, 262 | "pleasure": 118, 263 | "pray": 119, 264 | "someone": 120, 265 | "stopped": 121, 266 | "time": 122, 267 | "weekend": 123, 268 | "while": 124, 269 | "yourself": 125, 270 | "abou": 126, 271 | "be": 127, 272 | "exciting": 128, 273 | "here": 129, 274 | "last": 130, 275 | "night": 131, 276 | "no": 132, 277 | "right": 133, 278 | "since": 134, 279 | "talked": 135, 280 | "we": 136, 281 | "wonderful": 137 282 | }, 283 | "unk_token": "[UNK]" 284 | } 285 | } -------------------------------------------------------------------------------- /data/simple/merges.txt: -------------------------------------------------------------------------------- 1 | #version: 0.2 - Trained by `huggingface/tokenizers` 2 | h e 3 | i s 4 | T he 5 | Ġ is 6 | e d 7 | Ġ b 8 | Ġ p 9 | Ġ s 10 | a n 11 | Ġ c 12 | o w 13 | Ġb l 14 | i n 15 | t e 16 | Ġ r 17 | Ġb r 18 | a r 19 | a c 20 | u r 21 | l e 22 | Ġ g 23 | e l 24 | i c 25 | e e 26 | Ġg r 27 | g e 28 | Ġ o 29 | l ow 30 | Ġ y 31 | h i 32 | ow n 33 | Ġbr own 34 | a s 35 | i ed 36 | e r 37 | Ġ w 38 | ac k 39 | p le 40 | r an 41 | f r 42 | Ġbl ack 43 | Ġo ran 44 | Ġoran ge 45 | Ġp in 46 | u e 47 | Ġbl ue 48 | Ġ m 49 | Ġr ed 50 | Ġp ur 51 | Ġpur ple 52 | t ed 53 | Ġpin k 54 | el low 55 | Ġy ellow 56 | hi te 57 | Ġw hite 58 | ee n 59 | Ġgr een 60 | o n 61 | Ġ d 62 | Ġ t 63 | u t 64 | o t 65 | l ed 66 | a m 67 | Ġp e 68 | u s 69 | fr ied 70 | r y 71 | Ġc o 72 | e n 73 | Ġ a 74 | k ed 75 | ar d 76 | Ġs o 77 | Ġ f 78 | u m 79 | he d 80 | Ġs t 81 | o as 82 | oas ted 83 | a d 84 | u c 85 | b er 86 | Ġ h 87 | Ġs te 88 | o ked 89 | Ġs p 90 | ur n 91 | Ġp o 92 | a u 93 | an d 94 | i p 95 | i l 96 | Ġs m 97 | l i 98 | b e 99 | ic k 100 | ar in 101 | Ġc r 102 | a t 103 | at o 104 | i t 105 | Ġ le 106 | g g 107 | a l 108 | ee t 109 | m en 110 | a v 111 | a p 112 | r ied 113 | Ġd ried 114 | f low 115 | flow er 116 | m on 117 | Ġc and 118 | Ġcand ied 119 | Ġsm oked 120 | am ed 121 | Ġste amed 122 | Ġm us 123 | an t 124 | Ġr ad 125 | Ġt urn 126 | Ġturn ip 127 | r o 128 | ac hed 129 | Ġpo ached 130 | a b 131 | Ġc h 132 | o m 133 | ee p 134 | Ġd eep 135 | Ġ ar 136 | hi o 137 | Ġs au 138 | te ed 139 | Ġsau teed 140 | i r 141 | Ġco l 142 | Ġst ir 143 | n ut 144 | Ġgr il 145 | ber ry 146 | Ġgril led 147 | i led 148 | o iled 149 | Ġ fr 150 | Ġb oiled 151 | Ġbrown ed 152 | in ed 153 | Ġbr ined 154 | Ġp ick 155 | Ġpick led 156 | t y 157 | u n 158 | i on 159 | r ow 160 | ic row 161 | Ġm icrow 162 | av ed 163 | Ġmicrow aved 164 | a g 165 | Ġt oasted 166 | a is 167 | c u 168 | Ġb ar 169 | Ġbr ais 170 | be cu 171 | Ġbar becu 172 | Ġbrais ed 173 | Ġbarbecu ed 174 | a ted 175 | Ġm arin 176 | Ġmarin ated 177 | c h 178 | er men 179 | Ġf ermen 180 | Ġfermen ted 181 | g ed 182 | ic y 183 | Ġa ged 184 | c hed 185 | an ched 186 | Ġbl anched 187 | Ġc ur 188 | Ġcur ed 189 | Ġr oasted 190 | Ġ k 191 | ac h 192 | te r 193 | Ġp ot 194 | Ġpot ato 195 | is h 196 | u it 197 | fr uit 198 | as p 199 | c on 200 | Ġco con 201 | Ġcocon ut 202 | e gg 203 | l ant 204 | p lant 205 | Ġ egg 206 | Ġegg plant 207 | w eet 208 | Ġs weet 209 | Ġc uc 210 | um ber 211 | Ġcuc umber 212 | a w 213 | o c 214 | v e 215 | Ġ j 216 | ge r 217 | e an 218 | Ġb ean 219 | Ġgr ap 220 | Ġgrap e 221 | q u 222 | Ġs qu 223 | as h 224 | Ġsqu ash 225 | a le 226 | Ġk ale 227 | a ge 228 | b age 229 | Ġc ab 230 | Ġcab bage 231 | a ked 232 | r ot 233 | Ġb aked 234 | Ġc ar 235 | Ġcar rot 236 | w ed 237 | Ġste wed 238 | Ġb it 239 | Ġbit ter 240 | l ard 241 | Ġcol lard 242 | m el 243 | in e 244 | Ġc au 245 | li flower 246 | Ġcau liflower 247 | o z 248 | Ġfr oz 249 | Ġfroz en 250 | p p 251 | in ach 252 | Ġpe pp 253 | Ġsp inach 254 | Ġrad ish 255 | Ġpepp er 256 | f f 257 | l u 258 | Ġf lu 259 | ff y 260 | Ġflu ffy 261 | Ġpe a 262 | Ġco oked 263 | t ard 264 | Ġmus tard 265 | Ġs un 266 | Ġsun flower 267 | s el 268 | Ġbr us 269 | Ġso gg 270 | Ġsp ro 271 | sel s 272 | Ġbrus sels 273 | Ġsogg y 274 | Ġspro ut 275 | g u 276 | l a 277 | u gu 278 | Ġsp icy 279 | Ġar ugu 280 | Ġarugu la 281 | o ry 282 | Ġb urn 283 | ic ory 284 | Ġch icory 285 | Ġburn t 286 | e le 287 | h ro 288 | Ġc ele 289 | Ġt om 290 | Ġmus hro 291 | Ġch ard 292 | Ġcele ry 293 | Ġtom ato 294 | Ġmushro om 295 | l mon 296 | t t 297 | y be 298 | Ġa lmon 299 | Ġso ybe 300 | uc e 301 | Ġle tt 302 | Ġalmon d 303 | Ġsoybe an 304 | Ġlett uce 305 | a te 306 | n ion 307 | Ġb eet 308 | Ġo nion 309 | n t 310 | Ġle nt 311 | Ġlent il 312 | k in 313 | p kin 314 | z uc 315 | Ġ asp 316 | Ġ zuc 317 | Ġp um 318 | in i 319 | ar ag 320 | Ġw al 321 | Ġso ur 322 | ch ini 323 | Ġasp arag 324 | Ġzuc chini 325 | Ġpum pkin 326 | Ġwal nut 327 | Ġasparag us 328 | c hio 329 | t ac 330 | is tac 331 | Ġp istac 332 | Ġr aw 333 | ic chio 334 | Ġh ot 335 | Ġrad icchio 336 | Ġpistac hio 337 | i o 338 | he w 339 | Ġc as 340 | el ic 341 | Ġy am 342 | Ġd elic 343 | io us 344 | Ġcas hew 345 | Ġdelic ious 346 | i a 347 | ac ad 348 | as ty 349 | Ġm acad 350 | Ġt asty 351 | am ia 352 | Ġmacad amia 353 | c o 354 | f t 355 | Ġbr oc 356 | Ġso ft 357 | co li 358 | Ġbroc coli 359 | an ut 360 | Ġpe anut 361 | e am 362 | Ġcr eam 363 | Ġcream y 364 | a z 365 | el nut 366 | Ġa p 367 | Ġh az 368 | Ġhaz elnut 369 | h o 370 | k e 371 | t ic 372 | te n 373 | Ġr ot 374 | Ġar tic 375 | ho ke 376 | Ġrot ten 377 | Ġartic hoke 378 | m y 379 | Ġs al 380 | Ġy um 381 | Ġsal ty 382 | Ġyum my 383 | c an 384 | Ġ fried 385 | Ġpe can 386 | Ġr ip 387 | Ġst ick 388 | Ġrip e 389 | Ġstick y 390 | m e 391 | p y 392 | is py 393 | Ġcr ispy 394 | el ion 395 | Ġd and 396 | Ġh ard 397 | Ġdand elion 398 | Ġr ut 399 | ab ag 400 | Ġrut abag 401 | Ġrutabag a 402 | d i 403 | Ġ en 404 | di ve 405 | Ġen dive 406 | Ġcol d 407 | Ġcr un 408 | ch y 409 | Ġcrun chy 410 | e s 411 | Ġfr es 412 | Ġfres h 413 | o ot 414 | Ġsm oot 415 | Ġsmoot h 416 | l ic 417 | ar lic 418 | Ġg arlic 419 | e as 420 | Ġgr eas 421 | Ġgreas y 422 | u icy 423 | Ġj uicy 424 | in ger 425 | Ġg inger 426 | Ġpe ach 427 | t ine 428 | Ġc le 429 | men tine 430 | Ġcle mentine 431 | a ter 432 | Ġw ater 433 | mel on 434 | Ġwater melon 435 | e ap 436 | Ġpin eap 437 | Ġgrape fruit 438 | Ġpineap ple 439 | i g 440 | o u 441 | p e 442 | r ry 443 | he rry 444 | Ġc ant 445 | Ġc herry 446 | Ġf ig 447 | al ou 448 | Ġcant alou 449 | Ġcantalou pe 450 | Ġle mon 451 | i m 452 | s im 453 | Ġp er 454 | sim mon 455 | Ġper simmon 456 | i w 457 | Ġk iw 458 | Ġkiw i 459 | c t 460 | d e 461 | e y 462 | e ct 463 | n ect 464 | Ġ nect 465 | Ġblack berry 466 | on ey 467 | Ġh oney 468 | arin e 469 | de w 470 | Ġnect arine 471 | Ġhoney dew 472 | c he 473 | l y 474 | Ġ ly 475 | che e 476 | Ġly chee 477 | k r 478 | Ġo kr 479 | Ġokr a 480 | g o 481 | l um 482 | Ġp lum 483 | an go 484 | Ġm ango 485 | Ġb an 486 | an a 487 | Ġap ple 488 | Ġban ana 489 | g ran 490 | Ġr asp 491 | Ġpo me 492 | gran ate 493 | Ġrasp berry 494 | Ġpome granate 495 | Ġf ish 496 | Ġpo mel 497 | Ġpomel o 498 | el on 499 | Ġm elon 500 | Ġd ate 501 | ack fruit 502 | Ġpe ar 503 | Ġj ackfruit 504 | u av 505 | an ger 506 | Ġg uav 507 | Ġt anger 508 | Ġguav a 509 | Ġtanger ine 510 | i an 511 | v oc 512 | ur ian 513 | Ġd urian 514 | Ġa voc 515 | ad o 516 | Ġavoc ado 517 | a y 518 | Ġp ap 519 | ar fruit 520 | Ġst arfruit 521 | ay a 522 | Ġpap aya 523 | r aw 524 | Ġst raw 525 | Ġstraw berry 526 | Ġo li 527 | Ġoli ve 528 | r ic 529 | Ġblue berry 530 | Ġap ric 531 | Ġapric ot 532 | Ġ li 533 | Ġt am 534 | arin d 535 | Ġli me 536 | Ġtam arind 537 | -------------------------------------------------------------------------------- /data/simple/vocab.json: -------------------------------------------------------------------------------- 1 | {"":0,"":1,"":2,"":3,"":4,"!":5,"\"":6,"#":7,"$":8,"%":9,"&":10,"'":11,"(":12,")":13,"*":14,"+":15,",":16,"-":17,".":18,"/":19,"0":20,"1":21,"2":22,"3":23,"4":24,"5":25,"6":26,"7":27,"8":28,"9":29,":":30,";":31,"<":32,"=":33,">":34,"?":35,"@":36,"A":37,"B":38,"C":39,"D":40,"E":41,"F":42,"G":43,"H":44,"I":45,"J":46,"K":47,"L":48,"M":49,"N":50,"O":51,"P":52,"Q":53,"R":54,"S":55,"T":56,"U":57,"V":58,"W":59,"X":60,"Y":61,"Z":62,"[":63,"\\":64,"]":65,"^":66,"_":67,"`":68,"a":69,"b":70,"c":71,"d":72,"e":73,"f":74,"g":75,"h":76,"i":77,"j":78,"k":79,"l":80,"m":81,"n":82,"o":83,"p":84,"q":85,"r":86,"s":87,"t":88,"u":89,"v":90,"w":91,"x":92,"y":93,"z":94,"{":95,"|":96,"}":97,"~":98,"¡":99,"¢":100,"£":101,"¤":102,"¥":103,"¦":104,"§":105,"¨":106,"©":107,"ª":108,"«":109,"¬":110,"®":111,"¯":112,"°":113,"±":114,"²":115,"³":116,"´":117,"µ":118,"¶":119,"·":120,"¸":121,"¹":122,"º":123,"»":124,"¼":125,"½":126,"¾":127,"¿":128,"À":129,"Á":130,"Â":131,"Ã":132,"Ä":133,"Å":134,"Æ":135,"Ç":136,"È":137,"É":138,"Ê":139,"Ë":140,"Ì":141,"Í":142,"Î":143,"Ï":144,"Ð":145,"Ñ":146,"Ò":147,"Ó":148,"Ô":149,"Õ":150,"Ö":151,"×":152,"Ø":153,"Ù":154,"Ú":155,"Û":156,"Ü":157,"Ý":158,"Þ":159,"ß":160,"à":161,"á":162,"â":163,"ã":164,"ä":165,"å":166,"æ":167,"ç":168,"è":169,"é":170,"ê":171,"ë":172,"ì":173,"í":174,"î":175,"ï":176,"ð":177,"ñ":178,"ò":179,"ó":180,"ô":181,"õ":182,"ö":183,"÷":184,"ø":185,"ù":186,"ú":187,"û":188,"ü":189,"ý":190,"þ":191,"ÿ":192,"Ā":193,"ā":194,"Ă":195,"ă":196,"Ą":197,"ą":198,"Ć":199,"ć":200,"Ĉ":201,"ĉ":202,"Ċ":203,"ċ":204,"Č":205,"č":206,"Ď":207,"ď":208,"Đ":209,"đ":210,"Ē":211,"ē":212,"Ĕ":213,"ĕ":214,"Ė":215,"ė":216,"Ę":217,"ę":218,"Ě":219,"ě":220,"Ĝ":221,"ĝ":222,"Ğ":223,"ğ":224,"Ġ":225,"ġ":226,"Ģ":227,"ģ":228,"Ĥ":229,"ĥ":230,"Ħ":231,"ħ":232,"Ĩ":233,"ĩ":234,"Ī":235,"ī":236,"Ĭ":237,"ĭ":238,"Į":239,"į":240,"İ":241,"ı":242,"IJ":243,"ij":244,"Ĵ":245,"ĵ":246,"Ķ":247,"ķ":248,"ĸ":249,"Ĺ":250,"ĺ":251,"Ļ":252,"ļ":253,"Ľ":254,"ľ":255,"Ŀ":256,"ŀ":257,"Ł":258,"ł":259,"Ń":260,"he":261,"is":262,"The":263,"Ġis":264,"ed":265,"Ġb":266,"Ġp":267,"Ġs":268,"an":269,"Ġc":270,"ow":271,"Ġbl":272,"in":273,"te":274,"Ġr":275,"Ġbr":276,"ar":277,"ac":278,"ur":279,"le":280,"Ġg":281,"el":282,"ic":283,"ee":284,"Ġgr":285,"ge":286,"Ġo":287,"low":288,"Ġy":289,"hi":290,"own":291,"Ġbrown":292,"as":293,"ied":294,"er":295,"Ġw":296,"ack":297,"ple":298,"ran":299,"fr":300,"Ġblack":301,"Ġoran":302,"Ġorange":303,"Ġpin":304,"ue":305,"Ġblue":306,"Ġm":307,"Ġred":308,"Ġpur":309,"Ġpurple":310,"ted":311,"Ġpink":312,"ellow":313,"Ġyellow":314,"hite":315,"Ġwhite":316,"een":317,"Ġgreen":318,"on":319,"Ġd":320,"Ġt":321,"ut":322,"ot":323,"led":324,"am":325,"Ġpe":326,"us":327,"fried":328,"ry":329,"Ġco":330,"en":331,"Ġa":332,"ked":333,"ard":334,"Ġso":335,"Ġf":336,"um":337,"hed":338,"Ġst":339,"oas":340,"oasted":341,"ad":342,"uc":343,"ber":344,"Ġh":345,"Ġste":346,"oked":347,"Ġsp":348,"urn":349,"Ġpo":350,"au":351,"and":352,"ip":353,"il":354,"Ġsm":355,"li":356,"be":357,"ick":358,"arin":359,"Ġcr":360,"at":361,"ato":362,"it":363,"Ġle":364,"gg":365,"al":366,"eet":367,"men":368,"av":369,"ap":370,"ried":371,"Ġdried":372,"flow":373,"flower":374,"mon":375,"Ġcand":376,"Ġcandied":377,"Ġsmoked":378,"amed":379,"Ġsteamed":380,"Ġmus":381,"ant":382,"Ġrad":383,"Ġturn":384,"Ġturnip":385,"ro":386,"ached":387,"Ġpoached":388,"ab":389,"Ġch":390,"om":391,"eep":392,"Ġdeep":393,"Ġar":394,"hio":395,"Ġsau":396,"teed":397,"Ġsauteed":398,"ir":399,"Ġcol":400,"Ġstir":401,"nut":402,"Ġgril":403,"berry":404,"Ġgrilled":405,"iled":406,"oiled":407,"Ġfr":408,"Ġboiled":409,"Ġbrowned":410,"ined":411,"Ġbrined":412,"Ġpick":413,"Ġpickled":414,"ty":415,"un":416,"ion":417,"row":418,"icrow":419,"Ġmicrow":420,"aved":421,"Ġmicrowaved":422,"ag":423,"Ġtoasted":424,"ais":425,"cu":426,"Ġbar":427,"Ġbrais":428,"becu":429,"Ġbarbecu":430,"Ġbraised":431,"Ġbarbecued":432,"ated":433,"Ġmarin":434,"Ġmarinated":435,"ch":436,"ermen":437,"Ġfermen":438,"Ġfermented":439,"ged":440,"icy":441,"Ġaged":442,"ched":443,"anched":444,"Ġblanched":445,"Ġcur":446,"Ġcured":447,"Ġroasted":448,"Ġk":449,"ach":450,"ter":451,"Ġpot":452,"Ġpotato":453,"ish":454,"uit":455,"fruit":456,"asp":457,"con":458,"Ġcocon":459,"Ġcoconut":460,"egg":461,"lant":462,"plant":463,"Ġegg":464,"Ġeggplant":465,"weet":466,"Ġsweet":467,"Ġcuc":468,"umber":469,"Ġcucumber":470,"aw":471,"oc":472,"ve":473,"Ġj":474,"ger":475,"ean":476,"Ġbean":477,"Ġgrap":478,"Ġgrape":479,"qu":480,"Ġsqu":481,"ash":482,"Ġsquash":483,"ale":484,"Ġkale":485,"age":486,"bage":487,"Ġcab":488,"Ġcabbage":489,"aked":490,"rot":491,"Ġbaked":492,"Ġcar":493,"Ġcarrot":494,"wed":495,"Ġstewed":496,"Ġbit":497,"Ġbitter":498,"lard":499,"Ġcollard":500,"mel":501,"ine":502,"Ġcau":503,"liflower":504,"Ġcauliflower":505,"oz":506,"Ġfroz":507,"Ġfrozen":508,"pp":509,"inach":510,"Ġpepp":511,"Ġspinach":512,"Ġradish":513,"Ġpepper":514,"ff":515,"lu":516,"Ġflu":517,"ffy":518,"Ġfluffy":519,"Ġpea":520,"Ġcooked":521,"tard":522,"Ġmustard":523,"Ġsun":524,"Ġsunflower":525,"sel":526,"Ġbrus":527,"Ġsogg":528,"Ġspro":529,"sels":530,"Ġbrussels":531,"Ġsoggy":532,"Ġsprout":533,"gu":534,"la":535,"ugu":536,"Ġspicy":537,"Ġarugu":538,"Ġarugula":539,"ory":540,"Ġburn":541,"icory":542,"Ġchicory":543,"Ġburnt":544,"ele":545,"hro":546,"Ġcele":547,"Ġtom":548,"Ġmushro":549,"Ġchard":550,"Ġcelery":551,"Ġtomato":552,"Ġmushroom":553,"lmon":554,"tt":555,"ybe":556,"Ġalmon":557,"Ġsoybe":558,"uce":559,"Ġlett":560,"Ġalmond":561,"Ġsoybean":562,"Ġlettuce":563,"ate":564,"nion":565,"Ġbeet":566,"Ġonion":567,"nt":568,"Ġlent":569,"Ġlentil":570,"kin":571,"pkin":572,"zuc":573,"Ġasp":574,"Ġzuc":575,"Ġpum":576,"ini":577,"arag":578,"Ġwal":579,"Ġsour":580,"chini":581,"Ġasparag":582,"Ġzucchini":583,"Ġpumpkin":584,"Ġwalnut":585,"Ġasparagus":586,"chio":587,"tac":588,"istac":589,"Ġpistac":590,"Ġraw":591,"icchio":592,"Ġhot":593,"Ġradicchio":594,"Ġpistachio":595,"io":596,"hew":597,"Ġcas":598,"elic":599,"Ġyam":600,"Ġdelic":601,"ious":602,"Ġcashew":603,"Ġdelicious":604,"ia":605,"acad":606,"asty":607,"Ġmacad":608,"Ġtasty":609,"amia":610,"Ġmacadamia":611,"co":612,"ft":613,"Ġbroc":614,"Ġsoft":615,"coli":616,"Ġbroccoli":617,"anut":618,"Ġpeanut":619,"eam":620,"Ġcream":621,"Ġcreamy":622,"az":623,"elnut":624,"Ġap":625,"Ġhaz":626,"Ġhazelnut":627,"ho":628,"ke":629,"tic":630,"ten":631,"Ġrot":632,"Ġartic":633,"hoke":634,"Ġrotten":635,"Ġartichoke":636,"my":637,"Ġsal":638,"Ġyum":639,"Ġsalty":640,"Ġyummy":641,"can":642,"Ġfried":643,"Ġpecan":644,"Ġrip":645,"Ġstick":646,"Ġripe":647,"Ġsticky":648,"me":649,"py":650,"ispy":651,"Ġcrispy":652,"elion":653,"Ġdand":654,"Ġhard":655,"Ġdandelion":656,"Ġrut":657,"abag":658,"Ġrutabag":659,"Ġrutabaga":660,"di":661,"Ġen":662,"dive":663,"Ġendive":664,"Ġcold":665,"Ġcrun":666,"chy":667,"Ġcrunchy":668,"es":669,"Ġfres":670,"Ġfresh":671,"oot":672,"Ġsmoot":673,"Ġsmooth":674,"lic":675,"arlic":676,"Ġgarlic":677,"eas":678,"Ġgreas":679,"Ġgreasy":680,"uicy":681,"Ġjuicy":682,"inger":683,"Ġginger":684,"Ġpeach":685,"tine":686,"Ġcle":687,"mentine":688,"Ġclementine":689,"ater":690,"Ġwater":691,"melon":692,"Ġwatermelon":693,"eap":694,"Ġpineap":695,"Ġgrapefruit":696,"Ġpineapple":697,"ig":698,"ou":699,"pe":700,"rry":701,"herry":702,"Ġcant":703,"Ġcherry":704,"Ġfig":705,"alou":706,"Ġcantalou":707,"Ġcantaloupe":708,"Ġlemon":709,"im":710,"sim":711,"Ġper":712,"simmon":713,"Ġpersimmon":714,"iw":715,"Ġkiw":716,"Ġkiwi":717,"ct":718,"de":719,"ey":720,"ect":721,"nect":722,"Ġnect":723,"Ġblackberry":724,"oney":725,"Ġhoney":726,"arine":727,"dew":728,"Ġnectarine":729,"Ġhoneydew":730,"che":731,"ly":732,"Ġly":733,"chee":734,"Ġlychee":735,"kr":736,"Ġokr":737,"Ġokra":738,"go":739,"lum":740,"Ġplum":741,"ango":742,"Ġmango":743,"Ġban":744,"ana":745,"Ġapple":746,"Ġbanana":747,"gran":748,"Ġrasp":749,"Ġpome":750,"granate":751,"Ġraspberry":752,"Ġpomegranate":753,"Ġfish":754,"Ġpomel":755,"Ġpomelo":756,"elon":757,"Ġmelon":758,"Ġdate":759,"ackfruit":760,"Ġpear":761,"Ġjackfruit":762,"uav":763,"anger":764,"Ġguav":765,"Ġtanger":766,"Ġguava":767,"Ġtangerine":768,"ian":769,"voc":770,"urian":771,"Ġdurian":772,"Ġavoc":773,"ado":774,"Ġavocado":775,"ay":776,"Ġpap":777,"arfruit":778,"Ġstarfruit":779,"aya":780,"Ġpapaya":781,"raw":782,"Ġstraw":783,"Ġstrawberry":784,"Ġoli":785,"Ġolive":786,"ric":787,"Ġblueberry":788,"Ġapric":789,"Ġapricot":790,"Ġli":791,"Ġtam":792,"arind":793,"Ġlime":794,"Ġtamarind":795} -------------------------------------------------------------------------------- /data/simple/word-level-vocab.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "1.0", 3 | "truncation": { 4 | "direction": "Right", 5 | "max_length": 512, 6 | "strategy": "LongestFirst", 7 | "stride": 0 8 | }, 9 | "padding": null, 10 | "added_tokens": [ 11 | { 12 | "id": 0, 13 | "content": "[UNK]", 14 | "single_word": false, 15 | "lstrip": false, 16 | "rstrip": false, 17 | "normalized": false, 18 | "special": true 19 | }, 20 | { 21 | "id": 1, 22 | "content": "[CLS]", 23 | "single_word": false, 24 | "lstrip": false, 25 | "rstrip": false, 26 | "normalized": false, 27 | "special": true 28 | }, 29 | { 30 | "id": 2, 31 | "content": "[SEP]", 32 | "single_word": false, 33 | "lstrip": false, 34 | "rstrip": false, 35 | "normalized": false, 36 | "special": true 37 | }, 38 | { 39 | "id": 3, 40 | "content": "[PAD]", 41 | "single_word": false, 42 | "lstrip": false, 43 | "rstrip": false, 44 | "normalized": false, 45 | "special": true 46 | }, 47 | { 48 | "id": 4, 49 | "content": "[MASK]", 50 | "single_word": false, 51 | "lstrip": false, 52 | "rstrip": false, 53 | "normalized": false, 54 | "special": true 55 | } 56 | ], 57 | "normalizer": { 58 | "type": "Sequence", 59 | "normalizers": [ 60 | { 61 | "type": "NFD" 62 | }, 63 | { 64 | "type": "Lowercase" 65 | }, 66 | { 67 | "type": "StripAccents" 68 | } 69 | ] 70 | }, 71 | "pre_tokenizer": { 72 | "type": "Sequence", 73 | "pretokenizers": [ 74 | { 75 | "type": "Digits", 76 | "individual_digits": true 77 | }, 78 | { 79 | "type": "Whitespace" 80 | } 81 | ] 82 | }, 83 | "post_processor": { 84 | "type": "TemplateProcessing", 85 | "single": [ 86 | { 87 | "SpecialToken": { 88 | "id": "[CLS]", 89 | "type_id": 0 90 | } 91 | }, 92 | { 93 | "Sequence": { 94 | "id": "A", 95 | "type_id": 0 96 | } 97 | }, 98 | { 99 | "SpecialToken": { 100 | "id": "[SEP]", 101 | "type_id": 0 102 | } 103 | } 104 | ], 105 | "pair": [ 106 | { 107 | "Sequence": { 108 | "id": "A", 109 | "type_id": 0 110 | } 111 | }, 112 | { 113 | "Sequence": { 114 | "id": "B", 115 | "type_id": 1 116 | } 117 | } 118 | ], 119 | "special_tokens": { 120 | "[CLS]": { 121 | "id": "[CLS]", 122 | "ids": [ 123 | 1 124 | ], 125 | "tokens": [ 126 | "[CLS]" 127 | ] 128 | }, 129 | "[SEP]": { 130 | "id": "[SEP]", 131 | "ids": [ 132 | 2 133 | ], 134 | "tokens": [ 135 | "[SEP]" 136 | ] 137 | } 138 | } 139 | }, 140 | "decoder": null, 141 | "model": { 142 | "type": "WordLevel", 143 | "vocab": { 144 | "[UNK]": 0, 145 | "[CLS]": 1, 146 | "[SEP]": 2, 147 | "[PAD]": 3, 148 | "[MASK]": 4, 149 | ".": 5, 150 | "is": 6, 151 | "the": 7, 152 | "orange": 8, 153 | "black": 9, 154 | "red": 10, 155 | "purple": 11, 156 | "pink": 12, 157 | "blue": 13, 158 | "yellow": 14, 159 | "white": 15, 160 | "brown": 16, 161 | "green": 17, 162 | "fried": 18, 163 | "-": 19, 164 | "dried": 20, 165 | "candied": 21, 166 | "smoked": 22, 167 | "steamed": 23, 168 | "turnip": 24, 169 | "poached": 25, 170 | "deep": 26, 171 | "sauteed": 27, 172 | "stir": 28, 173 | "grilled": 29, 174 | "boiled": 30, 175 | "browned": 31, 176 | "brined": 32, 177 | "pickled": 33, 178 | "microwaved": 34, 179 | "toasted": 35, 180 | "barbecued": 36, 181 | "braised": 37, 182 | "marinated": 38, 183 | "fermented": 39, 184 | "aged": 40, 185 | "blanched": 41, 186 | "cured": 42, 187 | "roasted": 43, 188 | "potato": 44, 189 | "coconut": 45, 190 | "eggplant": 46, 191 | "sweet": 47, 192 | "cucumber": 48, 193 | "bean": 49, 194 | "squash": 50, 195 | "kale": 51, 196 | "cabbage": 52, 197 | "baked": 53, 198 | "carrot": 54, 199 | "stewed": 55, 200 | "bitter": 56, 201 | "collard": 57, 202 | "cauliflower": 58, 203 | "frozen": 59, 204 | "pepper": 60, 205 | "radish": 61, 206 | "spinach": 62, 207 | "fluffy": 63, 208 | "pea": 64, 209 | "cooked": 65, 210 | "mustard": 66, 211 | "sunflower": 67, 212 | "brussels": 68, 213 | "soggy": 69, 214 | "sprout": 70, 215 | "arugula": 71, 216 | "spicy": 72, 217 | "burnt": 73, 218 | "chicory": 74, 219 | "celery": 75, 220 | "chard": 76, 221 | "mushroom": 77, 222 | "tomato": 78, 223 | "almond": 79, 224 | "lettuce": 80, 225 | "soybean": 81, 226 | "beet": 82, 227 | "onion": 83, 228 | "lentil": 84, 229 | "asparagus": 85, 230 | "pumpkin": 86, 231 | "sour": 87, 232 | "walnut": 88, 233 | "zucchini": 89, 234 | "hot": 90, 235 | "pistachio": 91, 236 | "radicchio": 92, 237 | "raw": 93, 238 | "cashew": 94, 239 | "delicious": 95, 240 | "yam": 96, 241 | "macadamia": 97, 242 | "tasty": 98, 243 | "broccoli": 99, 244 | "soft": 100, 245 | "peanut": 101, 246 | "creamy": 102, 247 | "hazelnut": 103, 248 | "artichoke": 104, 249 | "rotten": 105, 250 | "salty": 106, 251 | "yummy": 107, 252 | "pecan": 108, 253 | "ripe": 109, 254 | "sticky": 110, 255 | "crispy": 111, 256 | "dandelion": 112, 257 | "hard": 113, 258 | "rutabaga": 114, 259 | "endive": 115, 260 | "cold": 116, 261 | "crunchy": 117, 262 | "fresh": 118, 263 | "smooth": 119, 264 | "garlic": 120, 265 | "greasy": 121, 266 | "juicy": 122, 267 | "ginger": 123, 268 | "peach": 124, 269 | "clementine": 125, 270 | "grape": 126, 271 | "watermelon": 127, 272 | "grapefruit": 128, 273 | "pineapple": 129, 274 | "cantaloupe": 130, 275 | "cherry": 131, 276 | "fig": 132, 277 | "lemon": 133, 278 | "persimmon": 134, 279 | "kiwi": 135, 280 | "blackberry": 136, 281 | "honeydew": 137, 282 | "nectarine": 138, 283 | "lychee": 139, 284 | "okra": 140, 285 | "mango": 141, 286 | "plum": 142, 287 | "apple": 143, 288 | "banana": 144, 289 | "pomegranate": 145, 290 | "raspberry": 146, 291 | "fish": 147, 292 | "pomelo": 148, 293 | "date": 149, 294 | "melon": 150, 295 | "jackfruit": 151, 296 | "pear": 152, 297 | "guava": 153, 298 | "tangerine": 154, 299 | "avocado": 155, 300 | "durian": 156, 301 | "papaya": 157, 302 | "starfruit": 158, 303 | "strawberry": 159, 304 | "olive": 160, 305 | "apricot": 161, 306 | "blueberry": 162, 307 | "lime": 163, 308 | "tamarind": 164 309 | }, 310 | "unk_token": "[UNK]" 311 | } 312 | } -------------------------------------------------------------------------------- /docs/CNAME: -------------------------------------------------------------------------------- 1 | diffusion.textgen.info -------------------------------------------------------------------------------- /docs/controllable.md: -------------------------------------------------------------------------------- 1 | ## Classifier-guided Controllable Generation 2 | 3 | * The unconditional generation is used to sample some sentences from a distribution of interest. However, a more interesting task is to generate sentences that satisfy some constraints. For example, we may want to generate sentences containing a certain color. 4 | 5 | - For this walkthrough, let's say we want to generate sentences that contain {"red", "blue", "green", "white"}. We create such a labeled dataset in `data/simple/simple_labeled.tsv`. 6 | 7 | ```sh 8 | $ shuf data/simple/simple_labeled.tsv|head -2 9 | The purple pumpkin is juicy. 0 10 | The green pear is sweet. 1 11 | ``` 12 | - The labeled file is required to train a classifier. In general, if you are working with a dataset name `dset`, a labeled file should be present at `data/{dset}/dset_labeled.tsv` with two columns (sentence and label). 13 | 14 | Let's start! 15 | 16 | ### Step 0: Train a diffusion model. 17 | 18 | * This is the backbone diffusion model whose generations we want to guide. We will use a diffusion model trained on the `simple` dataset introduced in the README. 19 | * Please download the new checkpoint (word-level vocab) from [here](https://drive.google.com/drive/folders/1zPiopN0MqhkYNlUza6zOChPqWyoDofLh?usp=sharing) and put it in the `ckpts/simplev2` folder. 20 | 21 | ### Step 1: Train a classifier 22 | 23 | 24 | * Train a classifier 25 | 26 | ```sh 27 | python -u src/controllable/classifier.py --model_name_or_path ckpts/simplev2/ema_0.9999_005001.pt 28 | ``` 29 | 30 | * This trains a classifier on the latent/noisy samples ($$x_t$$). 31 | 32 | - It is sufficient to only specify the checkpoint! The name of the dataset and other hyperparameters are loaded from the diffusion model's config file (`ckpts/simplev2/ema_0.9999_005001.pt`). However, the classifier does require the labeled file to be present at `data/{dset}/dset_labeled.tsv`. 33 | 34 | - The classifier is saved at `ckpts/simplev2/classifier.pt`. 35 | 36 | 37 | ### Step 2: Run controllable generation 38 | 39 | ```sh 40 | bash scripts/ctrl_text_sample.sh ckpts/simplev2/ema_0.9999_005001.pt 300 50 41 | ``` 42 | 43 | - Note that we use only 300 diffusion steps vs. 2000 for training. This works because the decoding is actually DDIM style: we approximate `x0` at each step, which is used for denoising. 44 | 45 | - The outputs are generated at: `ckpts/simplev2/ema_0.9999_005001.pt.samples_50.steps-300.clamp-no_clamp.txt.ctrl`. 46 | 47 | - Let's also generate 500 samples from the unguided model for comparison: 48 | 49 | ```sh 50 | CUDA_VISIBLE_DEVICES=8 && bash scripts/text_sample.sh ckpts/simplev2/ema_0.9999_005001.pt 300 500 51 | ``` 52 | 53 | * Let's compare the outputs of the two models: 54 | 55 | ```sh 56 | # top 5 colors in the unguided output: 57 | 58 | (diffusion) amadaan@sa:~/home2/minimal-text-diffusion$ cut -f3 -d" " ckpts/simplev2/ema_0.9999_005001.pt.samples_500.steps-300.clamp-no_clamp.txt | sort | uniq -c | sed 's/^\s*//g' | sort -n|tail -5 59 | 30 purple 60 | 53 yellow 61 | 69 green 62 | 111 pink 63 | 166 white 64 | ``` 65 | 66 | ```sh 67 | # top 5 colors in the guided output: 68 | (diffusion) amadaan@sa:~/home2/minimal-text-diffusion$ cut -f3 -d" " ckpts/simplev2/ema_0.9999_005001.pt.samples_500.steps-300.clamp-no_clamp.txt.ctrl.sample1 | sort | uniq -c | sed 's/^\s*//g' | sort -n|tail -5 69 | 15 pink 70 | 16 black 71 | 25 purple 72 | 124 yellow 73 | 269 green 74 | ``` 75 | 76 | * 50% of the sentences in the guided output contain the color word "green" vs. 69/500 = 14% in the unguided output. It looks like it's working! (recall that green was one of the 4 colors we specified in the classifier for label 1). 77 | 78 | 79 | 80 | ## Implementation Details 81 | 82 | - The files relevant to controllable generation are in `src/controllable/`. 83 | 84 | - Listing 85 | src/controllable/ 86 | ├── classifier.py 87 | ├── controllable_text_sample.py 88 | └── langevin.py 89 | 90 | 91 | * Here: 92 | - `classifier.py` trains a classifier on the latents of the diffusion model. 93 | 94 | - `controllable_text_sample.py` runs controllable generation. 95 | 96 | - `langevin.py` refines the embeddings with classifier guidance (using Langevin dynamics). 97 | 98 | 99 | - At a high level, the procedure is as follows: 100 | a) `p_sample_loop_langevin_progressive` in `src/modeling/diffusion/gaussian_diffusion.py` first creates an approximate `x_{t-1}` and then calls `langevin_binary_classifier` in `src/controllable/langevin.py` 101 | 102 | b) `langevin_binary_classifier` then refines the embeddings with classifier guidance. This is the Langevin dynamics step. $x_{t-1} = x_{t-1} + \epsilon \nabla_x \log p(y = 1 \mid x_{t-1})$ where $\log p(y \mid x_{t-1})$ is the probability of $y = 1$ given the noisy input $x_{t-1}$. The controllable generation is currently only done for labels = 1, but this can be changed by flipping the labels in `langevin_binary_classifier`. (TODO: add support for dynamic labels). 103 | 104 | -------------------------------------------------------------------------------- /docs/imgs/greetings_training_finished.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/minimal-text-diffusion/9303ffd481a2f647da24c6053e4dec44fd086a8d/docs/imgs/greetings_training_finished.png -------------------------------------------------------------------------------- /docs/imgs/greetings_training_loop.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/minimal-text-diffusion/9303ffd481a2f647da24c6053e4dec44fd086a8d/docs/imgs/greetings_training_loop.png -------------------------------------------------------------------------------- /docs/old_experiments.md: -------------------------------------------------------------------------------- 1 | ## Experiments and Results 2 | 3 | * I've tried experiments with the following hyperparameters: 4 | 5 | 1. Embeddings/vocab: pre-trained `bert-base-uncased` vs. initialized randomly. 6 | 7 | 2. Model backbone: pre-trained `bert-base-uncased` vs. initialized from scratch. 8 | 9 | 3. Embeddings fine-tuning: fine-tuned vs. frozen. 10 | 11 | Out of the 8 possible combinations, the best results were obtained with the following hyperparameters: 12 | 13 | | File | Sample Sentences | Perplexity | % Unique Lines | % Unique Tokens | 14 | |----------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------|----------------|-----------------| 15 | | MODEL_PT-False_EMBEDS_PT-False-FREEZE_EMBEDS-False | The yellow lentil is stir-fried., The green lemon is stir-fried., The white kiwi is deep-fried., The orange turnip is stir-fried., The blue blackberry is deep-fried. | 212.70 | 80.0 | 3.76 | 16 | | MODEL_PT-False_EMBEDS_PT-False-FREEZE_EMBEDS-True | The green spinach is stir-fried., The pink pomelo is stir-fried., The brown onion is stir-fried., The yellow artichoke is stir-fried., The blue pomegranate is deep-fried. | 218.77 | 74.2 | 3.76 | 17 | | MODEL_PT-True_EMBEDS_PT-True-FREEZE_EMBEDS-False | the yellow poccoli isggy., the red pe is sauteed., the green spinach is candied., the green danmelli isy., the brown kale is candied. | 1424.21 | 78.0 | 6.1 | 18 | 19 | 20 | --- 21 | 22 | - The best setting in terms of diversity is using pre-trained bert, bert embeddings, and fine-tuning the embeddings. However, this setting has the lowest perplexity because it generates weird sentences (which could be a good thing depending on the application!). 23 | 24 | - Some random samples from this setting: 25 | 26 | - Update 10/24: The following samples were likely from the `MODEL_PT-False_EMBEDS_PT-False-FREEZE_EMBEDS-False` setting. 27 | 28 | ``` 29 | the purple chard is braised. 30 | the pink eggplant is soft. 31 | the blue perychmmon is fluffy. 32 | the pink eggplant is fluffy. 33 | the orange macadamia is spicy. 34 | the blue almond is poached. 35 | the black avychnd is steamed. 36 | the brown radicchio is delicious. 37 | the blue yam is microwaved. 38 | the black pistachio is dried. 39 | ``` 40 | 41 | - The model was trained on a single RTX 2080 Ti GPU for 25k steps. The training time was ~1 hours. 42 | 43 | --- -------------------------------------------------------------------------------- /docs/training_on_your_own_dataset.md: -------------------------------------------------------------------------------- 1 | ## Steps to train a model on the simple greetings dataset. 2 | 3 | This dataset is small to allow faster training, and the test data is simply a fraction of the [training data](https://github.com/madaan/minimal-text-diffusion/blob/main/data/greetings-train.txt). The data set was generated using few-shot prompting. 4 | 5 | 6 | ### Step 1: Tokenization 7 | 8 | - For diffusion models trained on images, there is little pre-processing required. Each image is already a tensor of `height x width x channels.` However, when dealing with text-corpus, we need to do some pre-processing. 9 | 10 | Specifically, we need to (i) convert the text to a sequence of tokens (integers or IDs) (tokenization) and then (ii) map each token to a continuous vector (embeddings). 11 | 12 | [Tokenization](https://huggingface.co/course/chapter2/4?fw=pt) is an important design choice for training a language generation model. I found word-level tokenization to be the most effective. Still, the implementation in `src/utils/custom_tokenizer` also includes BPE if you want to experiment (_intuitively, BPE trivially increases the dimensionality, so that might be hurting the performance_). 13 | 14 | Since we are creating vocabulary from scratch, the embeddings for each token are randomly initialized. The embeddings are learned during training. 15 | 16 | 17 | * To train a word-level tokenizer 18 | 19 | ```sh 20 | python src/utils/custom_tokenizer.py train-word-level data/greetings/greetings.txt 21 | ``` 22 | 23 | - This creates a vocab file in `data/greetings/word-level-vocab.json.` 24 | 25 | - The defaults have been changed to use word-level tokenization, but please see `def create_tokenizer` in `src/utils/custom_tokenizer.py` for more details. The training code looks for the tokenizer file in `data/greetings//word-level-vocab.json` (more generally, `data//word-level-vocab.json`). 26 | 27 | ### Step 2: Training 28 | 29 | - After creating the tokenizer, you can start training: 30 | 31 | ```sh 32 | bash scripts/run_train.sh greetings 1 False False False 5000 33 | ``` 34 | 35 | Here, the options mean: 36 | - `greetings`: dataset name 37 | - `1`: GPU ID 38 | - `False`: whether to use a pretrained model. We are training from scratch, so we set this to `False.` A finer point is that the goal of the model is not to learn good sentence representations. Instead, the goal of the model is to take noisy embeddings of text (`xt`) and predict the clean version (`x0`). So, using a pre-trained model isn't necessary. 39 | - `False`: whether to use pretrained embeddings. We are using our word-level vocab, so this is set to `False.` 40 | - `False`: whether to freeze the embeddings. We are training from scratch, so we set this to `False` as we want the embeddings to be learned. 41 | - `5000`: number of training steps. This is a hyperparameter that you can tune. I found that 5000 steps are sufficient for `greetings` and training finishes in ~15 minutes. 42 | 43 | Some boolean options may appear redundant, but they allow interesting ablations (e.g., using pre-trained embeddings but not a pre-trained model or freezing pre-trained embeddings). 44 | 45 | 46 | * After starting the job, you should see a wandb URL. This is where you can monitor the training progress. The training process is also displayed on the terminal. 47 | 48 | 49 | - ![training_loop](imgs/greetings_training_loop.png) 50 | 51 | 52 | - ![training_finished](imgs/greetings_training_finished.png) 53 | 54 | 55 | 56 | * The checkpoint is saved in `ckpts/greetings/.` 57 | 58 | 59 | 60 | ### Step 3: Evaluation 61 | 62 | - After training, you can evaluate the model using the following command: 63 | 64 | ```sh 65 | CUDA_VISIBLE_DEVICES=9 && bash scripts/text_sample.sh ckpts/greetings/ema_0.9999_005000.pt 2000 50 66 | ``` 67 | 68 | * The sampling finishes with the following note: 69 | ```written the decoded output to ckpts/greetings/ema_0.9999_005000.pt.samples_50.steps-2000.clamp-no_clamp.txt``` 70 | 71 | 72 | Let's see some random samples from the model (the command cleans the special tokens): 73 | 74 | ```sh 75 | shuf ckpts/greetings/ema_0.9999_005000.pt.samples_50.steps-2000.clamp-no_clamp.txt|head -n 10|cut -f 2 -d '['|cut -f2 -d ']'|sed 's/^\s*//g' 76 | 77 | i's that to see 78 | i's nice are see you right 79 | i'm let everyone you 80 | i's stopped a you you! 81 | i's glad hi see you right 82 | i's glad to you 83 | i's that to you you! 84 | i's glad to you you right 85 | i've chance to you you! 86 | i's greetings a you 87 | ``` 88 | 89 | * Not bad for a model trained for 5000 steps on a small dataset! 90 | 91 | 92 | 93 | 94 | ## Controllable generation 95 | 96 | - Let's perform controllable generation. Say we only want sentences that contain the word `good`. We assign `1` to such sentences and create a labeled file here: `data/greetings/greetings_labeled.tsv` 97 | 98 | 99 | 100 | ### Step 1: Train a classifier on the latents 101 | 102 | ```sh 103 | python -u src/controllable/classifier.py --model_name_or_path ckpts/greetings/ema_0.9999_005000.pt --classifier_num_epochs 50 104 | ``` 105 | 106 | - We are using the diffusion model trained above to train the classifier. Note that we don't really use the denoising process during training the classifier. We are only using the diffusion model to get the latents (i.e., run the forward or generative process). 107 | 108 | ### Step 2: Run generation! 109 | 110 | 111 | ```sh 112 | bash scripts/ctrl_text_sample.sh ckpts/greetings/ema_0.9999_005000.pt 300 50 113 | ``` 114 | 115 | 116 | ## TODO 117 | 118 | - [ ] Expose arguments of the langevin function. 119 | -------------------------------------------------------------------------------- /minimal-text-diffusion.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/minimal-text-diffusion/9303ffd481a2f647da24c6053e4dec44fd086a8d/minimal-text-diffusion.gif -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Automatically generated by https://github.com/damnever/pigar. 2 | 3 | blobfile == 1.3.3 4 | boto3 == 1.24.85 5 | botocore == 1.27.85 6 | datasets==1.8.0 7 | ftfy == 6.1.1 8 | huggingface_hub==0.4.0 9 | mpi4py == 3.0.3 10 | numpy == 1.23.1 11 | pandas == 1.5.0 12 | regex == 2022.9.11 13 | requests == 2.28.1 14 | sacremoses == 0.0.53 15 | sentencepiece == 0.1.97 16 | six == 1.16.0 17 | spacy == 3.2.4 18 | tokenizers == 0.12.1 19 | torch == 1.10.0 20 | tqdm == 4.49.0 21 | transformers == 4.21.3 22 | wandb == 0.13.3 -------------------------------------------------------------------------------- /scripts/install.sh: -------------------------------------------------------------------------------- 1 | conda install mpi4py 2 | conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch 3 | pip install -e improved-diffusion/ 4 | pip install -e transformers/ 5 | pip install spacy==3.2.4 6 | pip install datasets==1.8.0 7 | pip install huggingface_hub==0.4.0 8 | pip install wandb 9 | -------------------------------------------------------------------------------- /scripts/run_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -u 3 | 4 | DSET=${1:-simple} 5 | 6 | GPU=${2:-0} 7 | INIT_PRETRAINED_MODEL=${3:-"True"} 8 | USE_PRETRAINED_EMBEDDINGS=${4:-"True"} 9 | FREEZE_EMBEDDINGS=${5:-"False"} 10 | 11 | LR_ANNEAL_STEPS=${6:-25001} 12 | LR=${7:-0.0001} 13 | DIFFUSION_STEPS=${8:-2000} 14 | NOISE_SCHEDULE=${9:-sqrt} 15 | BATCH_SIZE=${10:-64} 16 | SEQ_LEN=${11:-50} 17 | 18 | CHECKPOINT_PATH=${12:-"ckpts/${DSET}"} 19 | TRAIN_TXT_PATH=${13:-data/${DSET}-train.txt} 20 | VAL_TXT_PATH=${14:-data/${DSET}-test.txt} 21 | IN_CHANNELS=${15:-128} 22 | WEIGHT_DECAY=${16:-0.0} 23 | SEED=${17:-10708} 24 | DROPOUT=${18:-0.1} 25 | NUM_HEADS=${19:-4} 26 | CONFIG_NAME=${20:-"bert-base-uncased"} 27 | 28 | 29 | NOTES=${18:-"Pre-trained models, pre-trained embeddings, embeddings not frozen"} 30 | 31 | mkdir -p ${CHECKPOINT_PATH} 32 | 33 | # PLEASE NOTE THE CHECKPOINT PATH! 34 | # NOTE: You can use the following checkpoint path if you're sweeping over hyperparams 35 | # ${DSET}_${CHECKPOINT_PATH}/MODEL_PT-${INIT_PRETRAINED_MODEL}_EMBEDS_PT-${USE_PRETRAINED_EMBEDDINGS}-FREEZE_EMBEDS-${FREEZE_EMBEDDINGS}" 36 | 37 | 38 | 39 | 40 | ARGS=(--checkpoint_path ${CHECKPOINT_PATH} 41 | --save_interval 50000 --lr ${LR} 42 | --batch_size ${BATCH_SIZE} 43 | --diffusion_steps ${DIFFUSION_STEPS} 44 | --noise_schedule ${NOISE_SCHEDULE} 45 | --sequence_len ${SEQ_LEN} --seed ${SEED} 46 | --dropout ${DROPOUT} --in_channel ${IN_CHANNELS} 47 | --out_channel ${IN_CHANNELS} 48 | --weight_decay ${WEIGHT_DECAY} 49 | --predict_xstart True 50 | --train_txt_path ${TRAIN_TXT_PATH} 51 | --dataset ${DSET} 52 | --val_txt_path ${VAL_TXT_PATH} 53 | --num_heads ${NUM_HEADS} 54 | --config_name ${CONFIG_NAME} 55 | --init_pretrained ${INIT_PRETRAINED_MODEL} 56 | --freeze_embeddings ${FREEZE_EMBEDDINGS} 57 | --use_pretrained_embeddings ${USE_PRETRAINED_EMBEDDINGS} 58 | --notes \""${NOTES}"\") 59 | 60 | 61 | if [ ${LR_ANNEAL_STEPS} -eq 0 ]; then 62 | LR_ANNEAL_STEPS=100 63 | DEBUG=true 64 | else 65 | DEBUG=false 66 | fi 67 | 68 | ARGS+=(--lr_anneal_steps $LR_ANNEAL_STEPS) 69 | 70 | 71 | 72 | if [ $DEBUG = true ]; then 73 | ARGS+=(--debug) 74 | fi 75 | 76 | 77 | 78 | 79 | 80 | export CUDA_VISIBLE_DEVICES=$GPU && python -u src/train_infer/train.py "${ARGS[@]}" 81 | 82 | 83 | -------------------------------------------------------------------------------- /scripts/text_sample.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | MODEL_NAME=$1 5 | # dir of MODEL_NAME 6 | 7 | DIFFUSION_STEPS=${2:-20} 8 | 9 | NUM_SAMPLES=${3:-10} 10 | 11 | OUT_DIR=${4} 12 | 13 | if [ -z "$OUT_DIR" ]; then 14 | OUT_DIR=${MODEL_NAME} 15 | fi 16 | 17 | BATCH_SIZE=${5:-50} 18 | TOP_P=${6:-0.9} 19 | CLAMP=${7:-no_clamp} 20 | SEQ_LEN=${8:-10} 21 | SEED=${9:-10708} 22 | 23 | python -u src/train_infer/text_sample.py --model_name_or_path ${MODEL_NAME} \ 24 | --batch_size ${BATCH_SIZE} --num_samples ${NUM_SAMPLES} --top_p ${TOP_P} \ 25 | --seed ${SEED} \ 26 | --out_dir ${OUT_DIR} --diffusion_steps ${DIFFUSION_STEPS} --clamp ${CLAMP} --sequence_len ${SEQ_LEN} -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/minimal-text-diffusion/9303ffd481a2f647da24c6053e4dec44fd086a8d/src/__init__.py -------------------------------------------------------------------------------- /src/controllable/controllable_text_sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate a large batch of image samples from a model and save them as a large 3 | numpy array. This can be used to produce samples for FID evaluation. 4 | """ 5 | import os, json 6 | import sys 7 | from typing import List 8 | import numpy as np 9 | import torch as th 10 | import torch.distributed as dist 11 | from transformers import set_seed 12 | from functools import partial 13 | from src.utils import dist_util, logger 14 | 15 | 16 | from src.utils.args_utils import * 17 | from train_infer.factory_methods import create_model_and_diffusion 18 | from src.utils.args_utils import create_argparser, args_to_dict, model_and_diffusion_defaults 19 | from src.utils.custom_tokenizer import create_tokenizer 20 | from src.controllable.langevin import langevin_binary_classifier 21 | from src.controllable.classifier import DiffusionBertForSequenceClassification 22 | 23 | 24 | def main(): 25 | 26 | args = create_argparser().parse_args() 27 | 28 | set_seed(args.seed) 29 | dist_util.setup_dist() 30 | logger.configure() 31 | 32 | # load configurations. 33 | args.checkpoint_path = os.path.split(args.model_name_or_path)[0] 34 | 35 | config_path = os.path.join(args.checkpoint_path, "training_args.json") 36 | training_args = read_training_args(config_path) 37 | training_args["batch_size"] = args.batch_size 38 | # overwrite this because we want to allow generation for any diffusion step. 39 | training_args["diffusion_steps"] = args.diffusion_steps 40 | training_args["model_name_or_path"] = args.model_name_or_path 41 | training_args["clamp"] = args.clamp 42 | training_args["out_dir"] = args.out_dir 43 | training_args["num_samples"] = args.num_samples 44 | 45 | args.__dict__.update(training_args) 46 | args.sigma_small = True 47 | 48 | logger.info(f"Init pretrained = {args.init_pretrained}") 49 | logger.info(f"Freeze embeddings = {args.freeze_embeddings}") 50 | logger.info(f"Use pretrained embeddings = {args.use_pretrained_embeddings}") 51 | 52 | model, diffusion = create_model_and_diffusion( 53 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 54 | ) 55 | model.load_state_dict(dist_util.load_state_dict(args.model_name_or_path, map_location="cpu")) 56 | model.eval() 57 | 58 | tokenizer = create_tokenizer( 59 | return_pretokenized=args.use_pretrained_embeddings, path=f"data/{args.dataset}/" 60 | ) 61 | 62 | model.config.update({"embedding_dim": args.in_channel}) 63 | model.config.update({"train_diffusion_steps": args.diffusion_steps}) 64 | model.config.update({"vocab_size": tokenizer.vocab_size}) 65 | 66 | classifier = DiffusionBertForSequenceClassification.load_from_checkpoint( 67 | checkpoint_path=args.checkpoint_path + "/classifier.pt", 68 | config=model.config, 69 | diffusion_model=diffusion, 70 | ).to("cuda") 71 | 72 | # freeze the classifier 73 | for param in classifier.parameters(): 74 | param.requires_grad = False 75 | 76 | langevin_classifier_wrapper = partial(langevin_binary_classifier, classifier=classifier) 77 | 78 | pytorch_total_params = sum(p.numel() for p in model.parameters()) 79 | logger.log(f"the parameter count is {pytorch_total_params}") 80 | 81 | diffusion.rescale_timesteps = True 82 | 83 | model.to(dist_util.dev()) 84 | model.eval() # DEBUG 85 | 86 | logger.log(f"Generating {args.num_samples} samples") 87 | logger.log(f"Clamping is set to {args.clamp}") 88 | all_samples = [] 89 | while len(all_samples) * args.batch_size < args.num_samples: 90 | model_kwargs = {} 91 | sample_shape = (args.batch_size, args.sequence_len, model.word_embedding.weight.shape[1]) 92 | sample = diffusion.p_sample_loop( 93 | model, 94 | sample_shape, 95 | clip_denoised=args.clip_denoised, 96 | denoised_fn=None, 97 | model_kwargs=model_kwargs, 98 | top_p=args.top_p, 99 | progress=True, 100 | tokenizer=tokenizer, 101 | log_verbose=True, 102 | langevin_fn=langevin_classifier_wrapper, 103 | ) 104 | 105 | gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())] 106 | dist.all_gather(gathered_samples, sample) # gather not supported with NCCL 107 | all_samples.extend([sample.cpu().numpy() for sample in gathered_samples]) 108 | 109 | logger.log(f"created {len(all_samples)} samples") 110 | 111 | arr = np.concatenate(all_samples, axis=0) 112 | arr = arr[: args.num_samples * args.mbr_sample] 113 | 114 | x_t = th.tensor(arr).cuda() 115 | 116 | logits = model.get_logits(x_t) # bsz, seqlen, vocab 117 | cands = th.topk(logits, k=1, dim=-1) 118 | 119 | decoded_sentences = [] 120 | 121 | for seq in cands.indices: 122 | decoded_sentence = tokenizer.decode(seq.squeeze(1).tolist()) 123 | decoded_sentences.append(decoded_sentence) 124 | 125 | dist.barrier() 126 | logger.log("sampling complete") 127 | 128 | write_outputs(args=args, sentences=decoded_sentences) 129 | 130 | 131 | def load_embeddings(checkpoint_path, tokenizer, emb_dim): 132 | embeddings = th.nn.Embedding(tokenizer.vocab_size, emb_dim) 133 | embeddings.load_state_dict(th.load(f"{checkpoint_path}/random_emb.torch")) 134 | return embeddings 135 | 136 | 137 | def read_training_args(config_path): 138 | with open(config_path, "r") as f: 139 | return json.load(f) 140 | 141 | 142 | def write_outputs(args: dict, sentences: List[str]) -> None: 143 | 144 | model_dir = os.path.split(args.model_name_or_path)[0] 145 | model_base_name = os.path.split(args.model_name_or_path)[1] 146 | 147 | num_samples = len(sentences) 148 | output_file_basepath = ( 149 | os.path.join( 150 | model_dir, 151 | f"{model_base_name}.samples_{num_samples}.steps-{args.diffusion_steps}.clamp-{args.clamp}", 152 | ) 153 | + ".txt.ctrl" 154 | ) 155 | 156 | with open(output_file_basepath, "w") as text_fout: 157 | for generated_sentence in sentences: 158 | text_fout.write(generated_sentence + "\n") 159 | 160 | print(f"written the decoded output to {output_file_basepath}") 161 | 162 | 163 | if __name__ == "__main__": 164 | main() 165 | -------------------------------------------------------------------------------- /src/controllable/langevin.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilizes a trained classifier model to guide the diffusion process. 3 | 4 | - Given: 5 | 1. input embeddings 6 | 2. A classifier model 7 | 3. Labels 8 | 9 | The classifier model is used to refine the input embeddings such that the logits of the classifier model are maximized for the labels. 10 | """ 11 | import torch 12 | 13 | 14 | def langevin_binary_classifier(classifier, label_ids, x_t, t, num_langevin_steps: int = 3, step_size: float=1e-3): # current best. 15 | 16 | x_t_as_params = torch.nn.Parameter(x_t) 17 | 18 | with torch.enable_grad(): 19 | for i in range(num_langevin_steps): 20 | optimizer = torch.optim.Adagrad([x_t_as_params], lr=step_size) 21 | 22 | optimizer.zero_grad() 23 | model_out = classifier.label_logp(inputs_with_added_noise=x_t_as_params, 24 | labels=label_ids, 25 | t=t) 26 | loss = -model_out.loss # logp 27 | loss.backward() 28 | # print(f"{i}> grad norm: {x_t_as_params.grad.data.norm(2)} | loss: {loss}") 29 | 30 | optimizer.step() 31 | 32 | 33 | x_t_as_params = torch.nn.Parameter(x_t_as_params.data.detach()) 34 | 35 | return x_t_as_params.data.detach() 36 | -------------------------------------------------------------------------------- /src/modeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/minimal-text-diffusion/9303ffd481a2f647da24c6053e4dec44fd086a8d/src/modeling/__init__.py -------------------------------------------------------------------------------- /src/modeling/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/minimal-text-diffusion/9303ffd481a2f647da24c6053e4dec44fd086a8d/src/modeling/diffusion/__init__.py -------------------------------------------------------------------------------- /src/modeling/diffusion/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for various likelihood-based losses. These are ported from the original 3 | Ho et al. diffusion models codebase: 4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 5 | """ 6 | 7 | import numpy as np 8 | 9 | import torch as th 10 | 11 | 12 | def normal_kl(mean1, logvar1, mean2, logvar2): 13 | """ 14 | Compute the KL divergence between two gaussians. 15 | 16 | Shapes are automatically broadcasted, so batches can be compared to 17 | scalars, among other use cases. 18 | """ 19 | tensor = None 20 | for obj in (mean1, logvar1, mean2, logvar2): 21 | if isinstance(obj, th.Tensor): 22 | tensor = obj 23 | break 24 | assert tensor is not None, "at least one argument must be a Tensor" 25 | 26 | # Force variances to be Tensors. Broadcasting helps convert scalars to 27 | # Tensors, but it does not work for th.exp(). 28 | logvar1, logvar2 = [ 29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 30 | for x in (logvar1, logvar2) 31 | ] 32 | 33 | # print(logvar2.shape) 34 | # temp1 = 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2)) 35 | # print(f'const = {temp1.mean()}, coef={(th.exp(-logvar2) * 0.5).mean()}, mse={((mean1 - mean2) ** 2).mean().item()}') 36 | 37 | return 0.5 * ( 38 | -1.0 39 | + logvar2 40 | - logvar1 41 | + th.exp(logvar1 - logvar2) 42 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 43 | ) 44 | 45 | 46 | def approx_standard_normal_cdf(x): 47 | """ 48 | A fast approximation of the cumulative distribution function of the 49 | standard normal. 50 | """ 51 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 52 | 53 | 54 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 55 | """ 56 | Compute the log-likelihood of a Gaussian distribution discretizing to a 57 | given image. 58 | 59 | :param x: the target images. It is assumed that this was uint8 values, 60 | rescaled to the range [-1, 1]. 61 | :param means: the Gaussian mean Tensor. 62 | :param log_scales: the Gaussian log stddev Tensor. 63 | :return: a tensor like x of log probabilities (in nats). 64 | """ 65 | assert x.shape == means.shape == log_scales.shape 66 | centered_x = x - means 67 | inv_stdv = th.exp(-log_scales) 68 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 69 | cdf_plus = approx_standard_normal_cdf(plus_in) 70 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 71 | cdf_min = approx_standard_normal_cdf(min_in) 72 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 73 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 74 | cdf_delta = cdf_plus - cdf_min 75 | log_probs = th.where( 76 | x < -0.999, 77 | log_cdf_plus, 78 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 79 | ) 80 | assert log_probs.shape == x.shape 81 | return log_probs 82 | 83 | def gaussian_density(x, *, means, log_scales): 84 | from torch.distributions import Normal 85 | normal_dist = Normal(means, log_scales.exp()) 86 | logp = normal_dist.log_prob(x) 87 | return logp 88 | -------------------------------------------------------------------------------- /src/modeling/diffusion/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | 11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 12 | class SiLU(nn.Module): 13 | def forward(self, x): 14 | return x * th.sigmoid(x) 15 | 16 | 17 | class GroupNorm32(nn.GroupNorm): 18 | def forward(self, x): 19 | return super().forward(x.float()).type(x.dtype) 20 | 21 | 22 | def conv_nd(dims, *args, **kwargs): 23 | """ 24 | Create a 1D, 2D, or 3D convolution module. 25 | """ 26 | if dims == 1: 27 | return nn.Conv1d(*args, **kwargs) 28 | elif dims == 2: 29 | return nn.Conv2d(*args, **kwargs) 30 | elif dims == 3: 31 | return nn.Conv3d(*args, **kwargs) 32 | raise ValueError(f"unsupported dimensions: {dims}") 33 | 34 | 35 | def linear(*args, **kwargs): 36 | """ 37 | Create a linear module. 38 | """ 39 | return nn.Linear(*args, **kwargs) 40 | 41 | 42 | def avg_pool_nd(dims, *args, **kwargs): 43 | """ 44 | Create a 1D, 2D, or 3D average pooling module. 45 | """ 46 | if dims == 1: 47 | return nn.AvgPool1d(*args, **kwargs) 48 | elif dims == 2: 49 | return nn.AvgPool2d(*args, **kwargs) 50 | elif dims == 3: 51 | return nn.AvgPool3d(*args, **kwargs) 52 | raise ValueError(f"unsupported dimensions: {dims}") 53 | 54 | 55 | def update_ema(target_params, source_params, rate=0.99): 56 | """ 57 | Update target parameters to be closer to those of source parameters using 58 | an exponential moving average. 59 | 60 | :param target_params: the target parameter sequence. 61 | :param source_params: the source parameter sequence. 62 | :param rate: the EMA rate (closer to 1 means slower). 63 | """ 64 | for targ, src in zip(target_params, source_params): 65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 66 | 67 | 68 | def zero_module(module): 69 | """ 70 | Zero out the parameters of a module and return it. 71 | """ 72 | for p in module.parameters(): 73 | p.detach().zero_() 74 | return module 75 | 76 | 77 | def scale_module(module, scale): 78 | """ 79 | Scale the parameters of a module and return it. 80 | """ 81 | for p in module.parameters(): 82 | p.detach().mul_(scale) 83 | return module 84 | 85 | 86 | def mean_flat(tensor): 87 | """ 88 | Take the mean over all non-batch dimensions. 89 | """ 90 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 91 | 92 | 93 | def normalization(channels): 94 | """ 95 | Make a standard normalization layer. 96 | 97 | :param channels: number of input channels. 98 | :return: an nn.Module for normalization. 99 | """ 100 | return GroupNorm32(32, channels) 101 | 102 | 103 | def timestep_embedding(timesteps, dim, max_period=10000): 104 | """ 105 | Create sinusoidal timestep embeddings. 106 | 107 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 108 | These may be fractional. 109 | :param dim: the dimension of the output. 110 | :param max_period: controls the minimum frequency of the embeddings. 111 | :return: an [N x dim] Tensor of positional embeddings. 112 | """ 113 | half = dim // 2 114 | freqs = th.exp( 115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 116 | ).to(device=timesteps.device) 117 | args = timesteps[:, None].float() * freqs[None] 118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 119 | if dim % 2: 120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 121 | return embedding 122 | 123 | 124 | def checkpoint(func, inputs, params, flag): 125 | """ 126 | Evaluate a function without caching intermediate activations, allowing for 127 | reduced memory at the expense of extra compute in the backward pass. 128 | 129 | :param func: the function to evaluate. 130 | :param inputs: the argument sequence to pass to `func`. 131 | :param params: a sequence of parameters `func` depends on but does not 132 | explicitly take as arguments. 133 | :param flag: if False, disable gradient checkpointing. 134 | """ 135 | if flag: 136 | args = tuple(inputs) + tuple(params) 137 | return CheckpointFunction.apply(func, len(inputs), *args) 138 | else: 139 | return func(*inputs) 140 | 141 | 142 | class CheckpointFunction(th.autograd.Function): 143 | @staticmethod 144 | def forward(ctx, run_function, length, *args): 145 | ctx.run_function = run_function 146 | ctx.input_tensors = list(args[:length]) 147 | ctx.input_params = list(args[length:]) 148 | with th.no_grad(): 149 | output_tensors = ctx.run_function(*ctx.input_tensors) 150 | return output_tensors 151 | 152 | @staticmethod 153 | def backward(ctx, *output_grads): 154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 155 | with th.enable_grad(): 156 | # Fixes a bug where the first op in run_function modifies the 157 | # Tensor storage in place, which is not allowed for detach()'d 158 | # Tensors. 159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 160 | output_tensors = ctx.run_function(*shallow_copies) 161 | input_grads = th.autograd.grad( 162 | output_tensors, 163 | ctx.input_tensors + ctx.input_params, 164 | output_grads, 165 | allow_unused=True, 166 | ) 167 | del ctx.input_tensors 168 | del ctx.input_params 169 | del output_tensors 170 | return (None, None) + input_grads 171 | -------------------------------------------------------------------------------- /src/modeling/diffusion/resample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | import torch.distributed as dist 6 | 7 | 8 | def create_named_schedule_sampler(name, diffusion): 9 | """ 10 | Create a ScheduleSampler from a library of pre-defined samplers. 11 | 12 | :param name: the name of the sampler. 13 | :param diffusion: the diffusion object to sample for. 14 | """ 15 | if name == "uniform": 16 | return UniformSampler(diffusion) 17 | elif name == "loss-second-moment": 18 | return LossSecondMomentResampler(diffusion) 19 | else: 20 | raise NotImplementedError(f"unknown schedule sampler: {name}") 21 | 22 | 23 | class ScheduleSampler(ABC): 24 | """ 25 | A distribution over timesteps in the diffusion process, intended to reduce 26 | variance of the objective. 27 | 28 | By default, samplers perform unbiased importance sampling, in which the 29 | objective's mean is unchanged. 30 | However, subclasses may override sample() to change how the resampled 31 | terms are reweighted, allowing for actual changes in the objective. 32 | """ 33 | 34 | @abstractmethod 35 | def weights(self): 36 | """ 37 | Get a numpy array of weights, one per diffusion step. 38 | 39 | The weights needn't be normalized, but must be positive. 40 | """ 41 | 42 | def sample(self, batch_size, device): 43 | """ 44 | Importance-sample timesteps for a batch. 45 | 46 | :param batch_size: the number of timesteps. 47 | :param device: the torch device to save to. 48 | :return: a tuple (timesteps, weights): 49 | - timesteps: a tensor of timestep indices. 50 | - weights: a tensor of weights to scale the resulting losses. 51 | """ 52 | w = self.weights() 53 | p = w / np.sum(w) 54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 55 | indices = th.from_numpy(indices_np).long().to(device) 56 | weights_np = 1 / (len(p) * p[indices_np]) 57 | weights = th.from_numpy(weights_np).float().to(device) 58 | return indices, weights 59 | 60 | 61 | class UniformSampler(ScheduleSampler): 62 | def __init__(self, diffusion): 63 | self.diffusion = diffusion 64 | self._weights = np.ones([diffusion.num_timesteps]) 65 | 66 | def weights(self): 67 | return self._weights 68 | 69 | 70 | class LossAwareSampler(ScheduleSampler): 71 | def update_with_local_losses(self, local_ts, local_losses): 72 | """ 73 | Update the reweighting using losses from a model. 74 | 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | 80 | :param local_ts: an integer Tensor of timesteps. 81 | :param local_losses: a 1D Tensor of losses. 82 | """ 83 | batch_sizes = [ 84 | th.tensor([0], dtype=th.int32, device=local_ts.device) 85 | for _ in range(dist.get_world_size()) 86 | ] 87 | dist.all_gather( 88 | batch_sizes, 89 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 90 | ) 91 | 92 | # Pad all_gather batches to be the maximum batch size. 93 | batch_sizes = [x.item() for x in batch_sizes] 94 | max_bs = max(batch_sizes) 95 | 96 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 97 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 98 | dist.all_gather(timestep_batches, local_ts) 99 | dist.all_gather(loss_batches, local_losses) 100 | timesteps = [ 101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 102 | ] 103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 104 | self.update_with_all_losses(timesteps, losses) 105 | 106 | @abstractmethod 107 | def update_with_all_losses(self, ts, losses): 108 | """ 109 | Update the reweighting using losses from a model. 110 | 111 | Sub-classes should override this method to update the reweighting 112 | using losses from the model. 113 | 114 | This method directly updates the reweighting without synchronizing 115 | between workers. It is called by update_with_local_losses from all 116 | ranks with identical arguments. Thus, it should have deterministic 117 | behavior to maintain state across workers. 118 | 119 | :param ts: a list of int timesteps. 120 | :param losses: a list of float losses, one per timestep. 121 | """ 122 | 123 | 124 | class LossSecondMomentResampler(LossAwareSampler): 125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 126 | self.diffusion = diffusion 127 | self.history_per_term = history_per_term 128 | self.uniform_prob = uniform_prob 129 | self._loss_history = np.zeros( 130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 131 | ) 132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 133 | 134 | def weights(self): 135 | if not self._warmed_up(): 136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 138 | weights /= np.sum(weights) 139 | weights *= 1 - self.uniform_prob 140 | weights += self.uniform_prob / len(weights) 141 | return weights 142 | 143 | def update_with_all_losses(self, ts, losses): 144 | for t, loss in zip(ts, losses): 145 | if self._loss_counts[t] == self.history_per_term: 146 | # Shift out the oldest loss term. 147 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 148 | self._loss_history[t, -1] = loss 149 | else: 150 | self._loss_history[t, self._loss_counts[t]] = loss 151 | self._loss_counts[t] += 1 152 | 153 | def _warmed_up(self): 154 | return (self._loss_counts == self.history_per_term).all() 155 | -------------------------------------------------------------------------------- /src/modeling/diffusion/respace.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | 4 | from src.modeling.diffusion.gaussian_diffusion import GaussianDiffusion 5 | 6 | 7 | def space_timesteps(num_timesteps, section_counts): 8 | """ 9 | Create a list of timesteps to use from an original diffusion process, 10 | given the number of timesteps we want to take from equally-sized portions 11 | of the original process. 12 | 13 | For example, if there's 300 timesteps and the section counts are [10,15,20] 14 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 15 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 16 | 17 | If the stride is a string starting with "ddim", then the fixed striding 18 | from the DDIM paper is used, and only one section is allowed. 19 | 20 | :param num_timesteps: the number of diffusion steps in the original 21 | process to divide up. 22 | :param section_counts: either a list of numbers, or a string containing 23 | comma-separated numbers, indicating the step count 24 | per section. As a special case, use "ddimN" where N 25 | is a number of steps to use the striding from the 26 | DDIM paper. 27 | :return: a set of diffusion steps from the original process to use. 28 | """ 29 | if isinstance(section_counts, str): 30 | if section_counts.startswith("ddim"): 31 | desired_count = int(section_counts[len("ddim") :]) 32 | for i in range(1, num_timesteps): 33 | if len(range(0, num_timesteps, i)) == desired_count: 34 | return set(range(0, num_timesteps, i)) 35 | raise ValueError( 36 | f"cannot create exactly {num_timesteps} steps with an integer stride" 37 | ) 38 | section_counts = [int(x) for x in section_counts.split(",")] 39 | size_per = num_timesteps // len(section_counts) 40 | extra = num_timesteps % len(section_counts) 41 | start_idx = 0 42 | all_steps = [] 43 | for i, section_count in enumerate(section_counts): 44 | size = size_per + (1 if i < extra else 0) 45 | if size < section_count: 46 | raise ValueError( 47 | f"cannot divide section of {size} steps into {section_count}" 48 | ) 49 | if section_count <= 1: 50 | frac_stride = 1 51 | else: 52 | frac_stride = (size - 1) / (section_count - 1) 53 | cur_idx = 0.0 54 | taken_steps = [] 55 | for _ in range(section_count): 56 | taken_steps.append(start_idx + round(cur_idx)) 57 | cur_idx += frac_stride 58 | all_steps += taken_steps 59 | start_idx += size 60 | return set(all_steps) 61 | 62 | 63 | class SpacedDiffusion(GaussianDiffusion): 64 | """ 65 | A diffusion process which can skip steps in a base diffusion process. 66 | 67 | :param use_timesteps: a collection (sequence or set) of timesteps from the 68 | original diffusion process to retain. 69 | :param kwargs: the kwargs to create the base diffusion process. 70 | """ 71 | 72 | def __init__(self, use_timesteps, **kwargs): 73 | self.use_timesteps = set(use_timesteps) 74 | self.timestep_map = [] 75 | self.original_num_steps = len(kwargs["betas"]) 76 | 77 | # print(kwargs.keys()) 78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 79 | last_alpha_cumprod = 1.0 80 | new_betas = [] 81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 82 | if i in self.use_timesteps: 83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 84 | last_alpha_cumprod = alpha_cumprod 85 | self.timestep_map.append(i) 86 | kwargs["betas"] = np.array(new_betas) 87 | super().__init__(**kwargs) 88 | 89 | def p_mean_variance( 90 | self, model, *args, **kwargs 91 | ): # pylint: disable=signature-differs 92 | # print('called p_mean_var') 93 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 94 | 95 | def training_losses( 96 | self, model, *args, **kwargs 97 | ): # pylint: disable=signature-differs 98 | # print('called training_losses') 99 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 100 | 101 | def _wrap_model(self, model): 102 | if isinstance(model, _WrappedModel): 103 | return model 104 | return _WrappedModel( 105 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps 106 | ) 107 | 108 | def _scale_timesteps(self, t): 109 | # Scaling is done by the wrapped model. 110 | return t 111 | 112 | 113 | class _WrappedModel: 114 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): 115 | self.model = model 116 | self.timestep_map = timestep_map 117 | self.rescale_timesteps = rescale_timesteps 118 | self.original_num_steps = original_num_steps 119 | 120 | def __call__(self, x, ts, **kwargs): 121 | # print(ts) 122 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 123 | new_ts = map_tensor[ts] 124 | # print(new_ts) 125 | if self.rescale_timesteps: 126 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 127 | # temp = self.model(x, new_ts, **kwargs) 128 | # print(temp.shape) 129 | # return temp 130 | # print(new_ts) 131 | return self.model(x, new_ts, **kwargs) 132 | -------------------------------------------------------------------------------- /src/modeling/diffusion/rounding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # bert results 3 | 4 | # print( os.path.join(sys.path[0], '../../transformers/examples/pytorch/language-modeling')) 5 | # sys.path.insert(0, 'diffusion_lm/transformers/examples/pytorch/language-modeling') 6 | # sys.path.insert(0, os.path.join(sys.path[0], '../../transformers/examples/pytorch/language-modeling')) 7 | # from custom_trainer import GPT2LMHeadModelCompress, BERTModelCompress, AutoEncoderWithNoise 8 | import json 9 | 10 | def load_embeddings_and_tokenizer(modality=None, mode=None, model_name_or_path=None, emb_dim=None, checkpoint_path=None, extra_args=None): 11 | 12 | path_save_tokenizer = '{}/vocab.json'.format(checkpoint_path) 13 | print(f'loading from {path_save_tokenizer}') 14 | with open(path_save_tokenizer, 'r') as f: 15 | vocab = json.load(f) 16 | print(len(vocab)) 17 | tokenizer = {v: k for k, v in vocab.items()} 18 | model = torch.nn.Embedding(tokenizer.vocab_size, emb_dim) 19 | path_save = '{}/random_emb.torch'.format(checkpoint_path) 20 | model.load_state_dict(torch.load(path_save)) 21 | 22 | return model, tokenizer 23 | 24 | 25 | def load_tokenizer(modality, mode, model_name_or_path): 26 | import json 27 | path_save_tokenizer = '{}/vocab.json'.format(model_name_or_path) 28 | with open(path_save_tokenizer, 'r') as f: 29 | vocab = json.load(f) 30 | tokenizer = {v: k for k, v in vocab.items()} 31 | 32 | return tokenizer 33 | 34 | def rounding_func(mode, text_emb_lst, model, tokenizer, emb_scale_factor=1.0): 35 | decoded_out_lst = [] 36 | if mode in ['random', 'random_up_proj', 'glove']: 37 | down_proj_emb = model.weight # input_embs 38 | down_proj_emb2 = None 39 | 40 | 41 | def get_knn(down_proj_emb, text_emb, dist='cos'): 42 | 43 | if dist == 'cos': 44 | adjacency = down_proj_emb @ text_emb.transpose(1, 0).to(down_proj_emb.device) 45 | elif dist == 'l2': 46 | adjacency = down_proj_emb.unsqueeze(1).expand(-1, text_emb.size(0), -1) - text_emb.unsqueeze(0).expand( 47 | down_proj_emb.size(0), -1, -1) 48 | adjacency = -torch.norm(adjacency, dim=-1) 49 | topk_out = torch.topk(adjacency, k=6, dim=0) 50 | return topk_out.values, topk_out.indices 51 | 52 | dist = 'l2' 53 | # print(npzfile['arr_0'].shape) 54 | for text_emb in text_emb_lst: 55 | import torch 56 | text_emb = torch.tensor(text_emb) 57 | # print(text_emb.shape) 58 | if len(text_emb.shape) > 2: 59 | text_emb = text_emb.view(-1, text_emb.size(-1)) 60 | else: 61 | text_emb = text_emb 62 | val, indices = get_knn((down_proj_emb2 if dist == 'cos' else down_proj_emb), 63 | text_emb.to(down_proj_emb.device), dist=dist) 64 | # generated_lst.append(tuple(indices[0].tolist())) 65 | 66 | # print(indices[0].tolist()) 67 | # for i in range(64): 68 | # print([tokenizer[x.item()] for x in indices[:,i]]) 69 | decoded_out = " ".join([tokenizer[i] for i in indices[0].tolist()]) 70 | decoded_out_lst.append(decoded_out) 71 | 72 | return decoded_out_lst 73 | 74 | -------------------------------------------------------------------------------- /src/modeling/predictor/transformer_model.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | 3 | # from transformers import BertEncoder 4 | from transformers.models.bert.modeling_bert import BertEncoder 5 | import torch 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | from src.modeling.diffusion.nn import ( 10 | SiLU, 11 | linear, 12 | timestep_embedding, 13 | ) 14 | 15 | 16 | class TransformerNetModel(nn.Module): 17 | """ 18 | A transformer model to be used in Diffusion Model Training. 19 | 20 | :param in_channels: channels in the input Tensor. 21 | :param model_channels: base channel count for the model. 22 | :param out_channels: channels in the output Tensor. 23 | :param dropout: the dropout probability. 24 | :param channel_mult: channel multiplier for each level of the UNet. 25 | :param dims: determines if the signal is 1D, 2D, or 3D. 26 | :param num_classes: if specified (as an int), then this model will be 27 | class-conditional with `num_classes` classes. TODO for the next version 28 | :param use_checkpoint: use gradient checkpointing to reduce memory usage. 29 | :param num_heads: the number of attention heads in each attention layer. 30 | """ 31 | 32 | def __init__( 33 | self, 34 | in_channels, 35 | model_channels, 36 | out_channels, 37 | init_pretrained, 38 | freeze_embeddings, 39 | use_pretrained_embeddings, 40 | dropout=0, 41 | use_checkpoint=False, 42 | num_heads=1, 43 | config=None, 44 | config_name="bert-base-uncased", 45 | vocab_size=None, 46 | logits_mode=1, 47 | ): 48 | super().__init__() 49 | 50 | if config is None: 51 | config = AutoConfig.from_pretrained(config_name) 52 | config.hidden_dropout_prob = dropout 53 | # config.hidden_size = 512 54 | 55 | self.in_channels = in_channels 56 | self.model_channels = model_channels 57 | self.out_channels = out_channels 58 | self.dropout = dropout 59 | self.use_checkpoint = use_checkpoint 60 | self.num_heads = num_heads 61 | self.logits_mode = logits_mode 62 | self.vocab_size = vocab_size 63 | self.init_pretrained = init_pretrained 64 | self.freeze_embeddings = freeze_embeddings 65 | self.use_pretrained_embeddings = use_pretrained_embeddings 66 | self.config = config 67 | self.config_name = config_name 68 | 69 | time_embed_dim = model_channels * 4 70 | self.time_embed = nn.Sequential( 71 | linear(model_channels, time_embed_dim), 72 | SiLU(), 73 | linear(time_embed_dim, config.hidden_size), 74 | ) 75 | 76 | self.build_xstart_predictor() 77 | self.build_input_output_projections() 78 | self.build_embeddings() 79 | 80 | self.register_buffer( 81 | "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) 82 | ) 83 | 84 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 85 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 86 | 87 | def build_xstart_predictor(self): 88 | if self.init_pretrained: 89 | from transformers.models.bert.modeling_bert import BertModel 90 | 91 | temp_bert = BertModel.from_pretrained(self.config_name, config=self.config) 92 | del temp_bert.pooler 93 | self.input_transformers = temp_bert.encoder 94 | else: 95 | self.input_transformers = BertEncoder(self.config) 96 | 97 | def build_input_output_projections(self): 98 | if self.use_pretrained_embeddings: 99 | self.input_up_proj = nn.Identity() 100 | self.output_down_proj = nn.Identity() 101 | else: # need to adapt the model to the embedding size 102 | self.input_up_proj = nn.Sequential( 103 | nn.Linear(self.in_channels, self.config.hidden_size), 104 | nn.Tanh(), 105 | nn.Linear(self.config.hidden_size, self.config.hidden_size), 106 | ) 107 | 108 | self.output_down_proj = nn.Sequential( 109 | nn.Linear(self.config.hidden_size, self.config.hidden_size), 110 | nn.Tanh(), 111 | nn.Linear(self.config.hidden_size, self.out_channels), 112 | ) 113 | 114 | def build_embeddings(self): 115 | if self.use_pretrained_embeddings: 116 | from transformers.models.bert.modeling_bert import BertModel 117 | 118 | temp_bert = BertModel.from_pretrained(self.config_name, config=self.config) 119 | self.word_embedding = temp_bert.embeddings.word_embeddings 120 | self.position_embeddings = temp_bert.embeddings.position_embeddings 121 | else: 122 | self.word_embedding = nn.Embedding(self.vocab_size, self.in_channels) 123 | self.position_embeddings = nn.Embedding( 124 | self.config.max_position_embeddings, self.config.hidden_size 125 | ) 126 | 127 | self.lm_head = nn.Linear(self.in_channels, self.word_embedding.weight.shape[0]) 128 | 129 | if self.freeze_embeddings: 130 | self.word_embedding.weight.requires_grad = False 131 | self.position_embeddings.weight.requires_grad = False 132 | 133 | with th.no_grad(): 134 | self.lm_head.weight = self.word_embedding.weight 135 | 136 | def get_embeds(self, input_ids): 137 | return self.word_embedding(input_ids) 138 | 139 | def get_logits(self, hidden_repr): 140 | return self.lm_head(hidden_repr) 141 | 142 | def forward(self, x, timesteps, y=None, src_ids=None, src_mask=None, attention_mask=None): 143 | """ 144 | Apply the model to an input batch. 145 | 146 | :param x: an [N x C x ...] Tensor of inputs. 147 | :param timesteps: a 1-D batch of timesteps. 148 | :param y: an [N] Tensor of labels, if class-conditional. 149 | :return: an [N x C x ...] Tensor of outputs. 150 | """ 151 | 152 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) 153 | 154 | emb_x = self.input_up_proj(x) 155 | seq_length = x.size(1) 156 | position_ids = self.position_ids[:, :seq_length] 157 | emb_inputs = ( 158 | self.position_embeddings(position_ids) 159 | + emb_x 160 | + emb.unsqueeze(1).expand(-1, seq_length, -1) 161 | ) 162 | emb_inputs = self.dropout(self.LayerNorm(emb_inputs)) 163 | 164 | # https://github.com/huggingface/transformers/blob/e95d433d77727a9babadf008dd621a2326d37303/src/transformers/modeling_utils.py#L700 165 | if attention_mask is not None: 166 | attention_mask = attention_mask[:, None, None, :] 167 | 168 | input_trans_hidden_states = self.input_transformers( 169 | emb_inputs, attention_mask=attention_mask 170 | ).last_hidden_state 171 | 172 | h = self.output_down_proj(input_trans_hidden_states) 173 | h = h.type(x.dtype) 174 | return h 175 | -------------------------------------------------------------------------------- /src/train_infer/factory_methods.py: -------------------------------------------------------------------------------- 1 | import src.modeling.diffusion.gaussian_diffusion as gd 2 | from src.modeling.diffusion.respace import SpacedDiffusion, space_timesteps 3 | from src.modeling.predictor.transformer_model import TransformerNetModel 4 | 5 | 6 | def create_model_and_diffusion( 7 | class_cond, 8 | learn_sigma, 9 | sigma_small, 10 | num_channels, 11 | num_heads, 12 | dropout, 13 | diffusion_steps, 14 | noise_schedule, 15 | timestep_respacing, 16 | use_kl, 17 | predict_xstart, 18 | rescale_timesteps, 19 | rescale_learned_sigmas, 20 | use_checkpoint, 21 | model_arch, 22 | in_channel, 23 | out_channel, 24 | training_mode, 25 | vocab_size, 26 | config_name, 27 | logits_mode, 28 | init_pretrained, 29 | freeze_embeddings, 30 | use_pretrained_embeddings, 31 | **kwargs, 32 | ): 33 | model = create_model( 34 | num_channels, 35 | learn_sigma=learn_sigma, 36 | class_cond=class_cond, 37 | use_checkpoint=use_checkpoint, 38 | num_heads=num_heads, 39 | dropout=dropout, 40 | in_channel=in_channel, 41 | out_channel=out_channel, 42 | training_mode=training_mode, 43 | vocab_size=vocab_size, 44 | config_name=config_name, 45 | logits_mode=logits_mode, 46 | init_pretrained=init_pretrained, 47 | freeze_embeddings=freeze_embeddings, 48 | use_pretrained_embeddings=use_pretrained_embeddings, 49 | ) 50 | diffusion = create_gaussian_diffusion( 51 | steps=diffusion_steps, 52 | learn_sigma=learn_sigma, 53 | sigma_small=sigma_small, 54 | noise_schedule=noise_schedule, 55 | use_kl=use_kl, 56 | predict_xstart=predict_xstart, 57 | rescale_timesteps=rescale_timesteps, 58 | rescale_learned_sigmas=rescale_learned_sigmas, 59 | timestep_respacing=timestep_respacing, 60 | model_arch=model_arch, 61 | training_mode=training_mode, 62 | ) 63 | return model, diffusion 64 | 65 | 66 | def create_model( 67 | num_channels, 68 | learn_sigma, 69 | use_checkpoint, 70 | class_cond, # TODO for the next version 71 | num_heads, 72 | dropout, 73 | init_pretrained, 74 | freeze_embeddings, 75 | use_pretrained_embeddings, 76 | in_channel, 77 | out_channel, 78 | training_mode, 79 | vocab_size, 80 | config_name, 81 | logits_mode, 82 | ): 83 | 84 | return TransformerNetModel( 85 | in_channels=in_channel, 86 | model_channels=num_channels, 87 | out_channels=(out_channel if not learn_sigma else out_channel * 2), 88 | dropout=dropout, 89 | use_checkpoint=use_checkpoint, 90 | num_heads=num_heads, 91 | config_name=config_name, 92 | vocab_size=vocab_size, 93 | logits_mode=logits_mode, 94 | init_pretrained=init_pretrained, 95 | use_pretrained_embeddings=use_pretrained_embeddings, 96 | freeze_embeddings=freeze_embeddings, 97 | 98 | ) 99 | 100 | 101 | def create_gaussian_diffusion( 102 | *, 103 | steps=1000, 104 | learn_sigma=False, 105 | sigma_small=False, 106 | noise_schedule="linear", 107 | use_kl=False, 108 | predict_xstart=False, 109 | rescale_timesteps=False, 110 | rescale_learned_sigmas=False, 111 | timestep_respacing="", 112 | model_arch="transformer", 113 | training_mode="diffusion-lm", 114 | ): 115 | betas = gd.get_named_beta_schedule(noise_schedule, steps) 116 | 117 | if use_kl: 118 | loss_type = gd.LossType.E2E_KL 119 | else: 120 | loss_type = gd.LossType.E2E_MSE 121 | 122 | if not timestep_respacing: 123 | timestep_respacing = [steps] 124 | 125 | # Whether variance is learned or fixed 126 | model_var_type = None 127 | if not learn_sigma: 128 | if sigma_small: 129 | model_var_type = gd.ModelVarType.FIXED_SMALL 130 | else: 131 | model_var_type = gd.ModelVarType.FIXED_LARGE 132 | else: 133 | model_var_type = gd.ModelVarType.LEARNED_RANGE 134 | 135 | # what is the interpretation of the output generated by the model? Is it generating the noise or the mean directly? 136 | 137 | model_mean_type = None 138 | if not predict_xstart: 139 | model_mean_type = gd.ModelMeanType.EPSILON # predicts noise 140 | else: # predicts starting x (x0 estimate, possibly used by DDIM?) 141 | model_mean_type = gd.ModelMeanType.START_X 142 | 143 | return SpacedDiffusion( 144 | use_timesteps=space_timesteps(steps, timestep_respacing), 145 | betas=betas, 146 | model_var_type=model_var_type, 147 | model_mean_type=model_mean_type, 148 | loss_type=loss_type, 149 | rescale_timesteps=rescale_timesteps, 150 | model_arch=model_arch, 151 | training_mode=training_mode, 152 | ) 153 | -------------------------------------------------------------------------------- /src/train_infer/text_sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate a large batch of image samples from a model and save them as a large 3 | numpy array. This can be used to produce samples for FID evaluation. 4 | """ 5 | import os, json 6 | from typing import List 7 | import numpy as np 8 | import torch as th 9 | import torch.distributed as dist 10 | from transformers import set_seed 11 | from src.utils import dist_util, logger 12 | 13 | from src.utils.args_utils import * 14 | from train_infer.factory_methods import create_model_and_diffusion 15 | from src.utils.args_utils import create_argparser, args_to_dict, model_and_diffusion_defaults 16 | from src.utils.custom_tokenizer import create_tokenizer 17 | 18 | 19 | 20 | 21 | 22 | def main(): 23 | 24 | args = create_argparser().parse_args() 25 | 26 | set_seed(args.seed) 27 | dist_util.setup_dist() 28 | logger.configure() 29 | 30 | # load configurations. 31 | args.checkpoint_path = os.path.split(args.model_name_or_path)[0] 32 | 33 | config_path = os.path.join(args.checkpoint_path, "training_args.json") 34 | training_args = read_training_args(config_path) 35 | training_args["batch_size"] = args.batch_size 36 | training_args["diffusion_steps"] = args.diffusion_steps 37 | training_args['model_name_or_path'] = args.model_name_or_path 38 | training_args["clamp"] = args.clamp 39 | training_args['out_dir'] = args.out_dir 40 | training_args['num_samples'] = args.num_samples 41 | 42 | args.__dict__.update(training_args) 43 | args.sigma_small = True 44 | 45 | 46 | logger.info(f"Init pretrained = {args.init_pretrained}") 47 | logger.info(f"Freeze embeddings = {args.freeze_embeddings}") 48 | logger.info(f"Use pretrained embeddings = {args.use_pretrained_embeddings}") 49 | 50 | model, diffusion = create_model_and_diffusion( 51 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 52 | ) 53 | model.load_state_dict(dist_util.load_state_dict(args.model_name_or_path, map_location="cpu")) 54 | model.eval() 55 | 56 | tokenizer = create_tokenizer(return_pretokenized=args.use_pretrained_embeddings, path=f"data/{args.dataset}/") 57 | 58 | pytorch_total_params = sum(p.numel() for p in model.parameters()) 59 | logger.log(f"the parameter count is {pytorch_total_params}") 60 | 61 | diffusion.rescale_timesteps = True 62 | 63 | model.to(dist_util.dev()) 64 | model.eval() # DEBUG 65 | 66 | 67 | logger.log("sampling...") 68 | logger.log(f"Clamping is set to {args.clamp}") 69 | all_samples = [] 70 | while len(all_samples) * args.batch_size < args.num_samples: 71 | model_kwargs = {} 72 | sample_shape = (args.batch_size, args.sequence_len, model.word_embedding.weight.shape[1]) 73 | sample = diffusion.p_sample_loop( 74 | model, 75 | sample_shape, 76 | clip_denoised=args.clip_denoised, 77 | denoised_fn=None, 78 | model_kwargs=model_kwargs, 79 | top_p=args.top_p, 80 | progress=True, 81 | tokenizer=tokenizer, 82 | log_verbose=True 83 | ) 84 | 85 | 86 | gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())] 87 | dist.all_gather(gathered_samples, sample) # gather not supported with NCCL 88 | all_samples.extend([sample.cpu().numpy() for sample in gathered_samples]) 89 | 90 | logger.log(f"created {len(all_samples) * args.batch_size} samples") 91 | 92 | arr = np.concatenate(all_samples, axis=0) 93 | arr = arr[: args.num_samples * args.mbr_sample] 94 | 95 | x_t = th.tensor(arr).cuda() 96 | 97 | logits = model.get_logits(x_t) # bsz, seqlen, vocab 98 | cands = th.topk(logits, k=1, dim=-1) 99 | 100 | decoded_sentences = [] 101 | 102 | for seq in cands.indices: 103 | decoded_sentence = tokenizer.decode(seq.squeeze(1).tolist()) 104 | decoded_sentences.append(decoded_sentence) 105 | 106 | dist.barrier() 107 | logger.log("sampling complete") 108 | 109 | write_outputs(args=args, sentences=decoded_sentences) 110 | 111 | 112 | def load_embeddings(checkpoint_path, tokenizer, emb_dim): 113 | embeddings = th.nn.Embedding(tokenizer.vocab_size, emb_dim) 114 | embeddings.load_state_dict(th.load(f'{checkpoint_path}/random_emb.torch')) 115 | return embeddings 116 | 117 | 118 | def read_training_args(config_path): 119 | with open(config_path, "r") as f: 120 | return json.load(f) 121 | 122 | 123 | def write_outputs(args: dict, sentences: List[str]) -> None: 124 | 125 | model_dir = os.path.split(args.model_name_or_path)[0] 126 | model_base_name = os.path.split(args.model_name_or_path)[1] 127 | 128 | num_samples = len(sentences) 129 | output_file_basepath = os.path.join( 130 | model_dir, 131 | f"{model_base_name}.samples_{num_samples}.steps-{args.diffusion_steps}.clamp-{args.clamp}", 132 | ) + ".txt" 133 | with open(output_file_basepath, "w") as text_fout: 134 | for generated_sentence in sentences: 135 | text_fout.write(generated_sentence + "\n") 136 | 137 | print(f"written the decoded output to {output_file_basepath}") 138 | 139 | 140 | if __name__ == "__main__": 141 | main() 142 | -------------------------------------------------------------------------------- /src/train_infer/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a diffusion model on images. 3 | """ 4 | 5 | import json, os 6 | import pathlib 7 | import pprint 8 | import sys 9 | import wandb 10 | from transformers import set_seed 11 | import os 12 | 13 | from src.utils import dist_util, logger 14 | from src.modeling.diffusion.resample import create_named_schedule_sampler 15 | from train_infer.factory_methods import create_model_and_diffusion 16 | from train_loop import TrainLoop 17 | from src.utils import data_utils_sentencepiece 18 | from src.utils.args_utils import create_argparser, args_to_dict, model_and_diffusion_defaults 19 | from src.utils.custom_tokenizer import create_tokenizer 20 | 21 | 22 | def main(): 23 | args = create_argparser().parse_args() 24 | set_seed(args.seed) 25 | dist_util.setup_dist() # DEBUG ** 26 | logger.configure() 27 | 28 | 29 | logger.log("creating data loader") 30 | 31 | pathlib.Path(args.checkpoint_path).mkdir(parents=True, exist_ok=True) 32 | 33 | tokenizer = create_tokenizer(return_pretokenized=args.use_pretrained_embeddings, path=f"data/{args.dataset}/") 34 | 35 | train_dataloader = data_utils_sentencepiece.get_dataloader( 36 | tokenizer=tokenizer, 37 | data_path=args.train_txt_path, 38 | batch_size=args.batch_size, 39 | max_seq_len=args.sequence_len 40 | ) 41 | 42 | val_dataloader = data_utils_sentencepiece.get_dataloader( 43 | tokenizer=tokenizer, 44 | data_path=args.val_txt_path, 45 | batch_size=args.batch_size, 46 | max_seq_len=args.sequence_len 47 | ) 48 | 49 | 50 | args.vocab_size = tokenizer.vocab_size 51 | 52 | logger.log("creating model and diffusion...") 53 | 54 | 55 | 56 | model, diffusion = create_model_and_diffusion( 57 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 58 | ) 59 | model.to(dist_util.dev()) # DEBUG ** 60 | # model.cuda() # DEBUG ** 61 | 62 | print(model) 63 | 64 | pytorch_total_params = sum(p.numel() for p in model.parameters()) 65 | 66 | logger.log(f"the parameter count is {pytorch_total_params}") 67 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) 68 | 69 | logger.log(f"saving the hyperparameters to {args.checkpoint_path}/training_args.json") 70 | with open(f"{args.checkpoint_path}/training_args.json", "w") as f: 71 | json.dump(args.__dict__, f, indent=2) 72 | 73 | if args.debug: 74 | wandb.init(mode="disabled") 75 | else: 76 | wandb.init( 77 | project=os.getenv("WANDB_PROJECT", "minimial-text-diffusion"), 78 | name=args.checkpoint_path + make_wandb_name_from_args(args), 79 | notes=args.notes, 80 | ) 81 | wandb.config.update(args.__dict__, allow_val_change=True) 82 | 83 | logger.log("training...") 84 | TrainLoop( 85 | model=model, 86 | diffusion=diffusion, 87 | data=train_dataloader, 88 | batch_size=args.batch_size, 89 | microbatch=args.microbatch, 90 | lr=args.lr, 91 | ema_rate=args.ema_rate, 92 | log_interval=args.log_interval, 93 | save_interval=args.save_interval, 94 | resume_checkpoint=args.resume_checkpoint, 95 | use_fp16=args.use_fp16, 96 | fp16_scale_growth=args.fp16_scale_growth, 97 | schedule_sampler=schedule_sampler, 98 | weight_decay=args.weight_decay, 99 | lr_anneal_steps=args.lr_anneal_steps, 100 | checkpoint_path=args.checkpoint_path, 101 | gradient_clipping=args.gradient_clipping, 102 | eval_data=val_dataloader, 103 | eval_interval=args.eval_interval, 104 | ).run_loop() 105 | 106 | 107 | def make_wandb_name_from_args(args): 108 | keys_to_add = ["batch_size", "lr", "num_heads", "lr_anneal_steps", "config_name", "seed", "in_channel"] 109 | name = "" 110 | for key in keys_to_add: 111 | name += f"{key}={getattr(args, key)}_" 112 | return name 113 | 114 | if __name__ == "__main__": 115 | main() 116 | -------------------------------------------------------------------------------- /src/utils/args_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for command line arguments. 3 | """ 4 | 5 | import argparse 6 | 7 | 8 | 9 | def create_argparser(): 10 | defaults = dict( 11 | data_dir="", 12 | schedule_sampler="uniform", 13 | lr=1e-4, 14 | weight_decay=0.0, 15 | lr_anneal_steps=30000, 16 | batch_size=1, 17 | microbatch=-1, # -1 disables microbatches 18 | ema_rate="0.9999", # comma-separated list of EMA values 19 | log_interval=50, 20 | save_interval=25000, 21 | resume_checkpoint="", 22 | use_fp16=False, 23 | fp16_scale_growth=1e-3, 24 | seed=101, 25 | gradient_clipping=-1.0, 26 | eval_interval=2000, 27 | checkpoint_path="diff_models", 28 | train_txt_path="data/quotes_train.txt", 29 | val_txt_path="data/quotes_valid.txt", 30 | dataset="", 31 | notes="", 32 | ) 33 | text_defaults = dict( 34 | modality="text", 35 | emb_scale_factor=1.0, 36 | in_channel=16, 37 | out_channel=16, 38 | noise_level=0.0, 39 | cache_mode="no", 40 | use_bert_tokenizer="no", 41 | padding_mode="block", 42 | preprocessing_num_workers=1, 43 | tok_thresh=150 44 | ) 45 | 46 | guided_generation_defaults = dict( 47 | classifier_num_epochs=15 48 | ) 49 | 50 | defaults.update(model_and_diffusion_defaults()) 51 | defaults.update(text_defaults) 52 | defaults.update(guided_generation_defaults) 53 | defaults.update(decoding_defaults()) 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument("--debug", action="store_true") 56 | 57 | add_dict_to_argparser(parser, defaults) 58 | return parser 59 | 60 | 61 | def model_and_diffusion_defaults(): 62 | """ 63 | Defaults for text-diffusion model training. 64 | """ 65 | return dict( 66 | sequence_len=64, 67 | num_channels=16, 68 | num_heads=4, 69 | dropout=0.0, 70 | learn_sigma=False, 71 | sigma_small=False, 72 | class_cond=False, 73 | diffusion_steps=10000, 74 | noise_schedule="linear", 75 | timestep_respacing="", 76 | use_kl=False, 77 | predict_xstart=False, 78 | rescale_timesteps=True, 79 | rescale_learned_sigmas=True, 80 | use_checkpoint=False, 81 | model_arch="transformer", 82 | in_channel=16, 83 | out_channel=16, 84 | vocab_size=66, 85 | config_name="bert-base-uncased", 86 | logits_mode=1, 87 | training_mode="diffusion-lm", 88 | init_pretrained=False, 89 | freeze_embeddings=False, 90 | use_pretrained_embeddings=True, 91 | ) 92 | 93 | 94 | def decoding_defaults(): 95 | return dict( 96 | num_samples=50, 97 | top_p=0.9, 98 | out_dir="", 99 | model_name_or_path="", 100 | checkpoint_path="", 101 | use_ddim=False, 102 | clip_denoised=False, 103 | batch_size=64, 104 | mbr_sample=1, 105 | verbose="yes", 106 | clamp="clamp", 107 | preprocessing_num_workers=1, 108 | emb_scale_factor=1.0, 109 | classifier_path="", 110 | ) 111 | 112 | 113 | def add_dict_to_argparser(parser, default_dict): 114 | for k, v in default_dict.items(): 115 | v_type = type(v) 116 | if v is None: 117 | v_type = str 118 | elif isinstance(v, bool): 119 | v_type = str2bool 120 | parser.add_argument(f"--{k}", default=v, type=v_type) 121 | 122 | 123 | def args_to_dict(args, keys): 124 | return {k: getattr(args, k) for k in keys} 125 | 126 | 127 | def str2bool(v): 128 | """ 129 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 130 | """ 131 | if isinstance(v, bool): 132 | return v 133 | if v.lower() in ("yes", "true", "t", "y", "1"): 134 | return True 135 | elif v.lower() in ("no", "false", "f", "n", "0"): 136 | return False 137 | else: 138 | raise argparse.ArgumentTypeError("boolean value expected") 139 | -------------------------------------------------------------------------------- /src/utils/custom_tokenizer.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import pathlib 4 | import torch 5 | from transformers import AutoTokenizer 6 | 7 | from tokenizers.processors import BertProcessing 8 | from tokenizers import ByteLevelBPETokenizer, decoders 9 | 10 | logging.basicConfig(level=logging.INFO) 11 | 12 | def create_tokenizer(return_pretokenized, path, tokenizer_type: str = "word-level"): 13 | if return_pretokenized: 14 | tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") 15 | return tokenizer 16 | 17 | if tokenizer_type == "byte-level": 18 | return read_byte_level(path) 19 | elif tokenizer_type == "word-level": 20 | return read_word_level(path) 21 | else: 22 | raise ValueError(f"Invalid tokenizer type: {tokenizer_type}") 23 | 24 | def train_bytelevel( 25 | path, 26 | vocab_size=10000, 27 | min_frequency=1, 28 | special_tokens=["", "", "", "", ""], 29 | ): 30 | 31 | tokenizer = ByteLevelBPETokenizer() 32 | 33 | # Customize training 34 | tokenizer.train( 35 | files=[path], 36 | vocab_size=vocab_size, 37 | min_frequency=min_frequency, 38 | special_tokens=special_tokens, 39 | ) 40 | 41 | tokenizer.save_model(str(pathlib.Path(path).parent)) 42 | 43 | 44 | 45 | def read_byte_level(path: str): 46 | tokenizer = ByteLevelBPETokenizer( 47 | f"{path}/vocab.json", 48 | f"{path}/merges.txt", 49 | ) 50 | 51 | tokenizer._tokenizer.post_processor = BertProcessing( 52 | ("", tokenizer.token_to_id("")), 53 | ("", tokenizer.token_to_id("")), 54 | ) 55 | 56 | tokenizer.enable_truncation(max_length=512) 57 | 58 | print( 59 | tokenizer.encode( 60 | "Bores can be divided into two classes; those who have their own particular subject, and those who do not need a subject." 61 | ).tokens 62 | ) 63 | 64 | with open(f"{path}/vocab.json", "r") as fin: 65 | vocab = json.load(fin) 66 | 67 | # add length method to tokenizer object 68 | tokenizer.vocab_size = len(vocab) 69 | 70 | # add length property to tokenizer object 71 | tokenizer.__len__ = property(lambda self: self.vocab_size) 72 | 73 | tokenizer.decoder = decoders.ByteLevel() 74 | print(tokenizer.vocab_size) 75 | 76 | print( 77 | tokenizer.encode( 78 | "Bores can be divided into two classes; those who have their own particular subject, and those who do not need a subject." 79 | ).ids 80 | ) 81 | 82 | print( 83 | tokenizer.decode( 84 | tokenizer.encode( 85 | "Bores can be divided into two classes; those who have their own particular subject, and those who do not need a subject." 86 | ).ids, 87 | skip_special_tokens=True, 88 | ) 89 | ) 90 | 91 | ids = tokenizer.encode( 92 | "Bores can be divided into two classes; those who have their own particular subject, and those who do not need a subject." 93 | ).ids 94 | tensor = torch.tensor(ids) 95 | print(tokenizer.decode(tensor.tolist(), skip_special_tokens=True)) 96 | print(f"Vocab size: {tokenizer.vocab_size}") 97 | 98 | return tokenizer 99 | 100 | 101 | def read_word_level(path: str): 102 | 103 | from transformers import PreTrainedTokenizerFast 104 | 105 | logging.info(f"Loading tokenizer from {path}/word-level-vocab.json") 106 | tokenizer = PreTrainedTokenizerFast( 107 | tokenizer_file=f"{str(pathlib.Path(path))}/word-level-vocab.json", 108 | bos_token="[CLS]", 109 | eos_token="[SEP]", 110 | unk_token="[UNK]", 111 | sep_token="[SEP]", 112 | pad_token="[PAD]", 113 | cls_token="[CLS]", 114 | mask_token="[MASK]", 115 | padding_side="right", 116 | ) 117 | 118 | # add length property to tokenizer object 119 | tokenizer.__len__ = property(lambda self: self.vocab_size) 120 | 121 | return tokenizer 122 | 123 | 124 | def train_word_level_tokenizer( 125 | path: str, 126 | vocab_size: int = 10000, 127 | special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"], 128 | ): 129 | 130 | from tokenizers import Tokenizer, normalizers, pre_tokenizers 131 | from tokenizers.models import WordLevel 132 | from tokenizers.normalizers import NFD, Lowercase, StripAccents 133 | from tokenizers.pre_tokenizers import Digits, Whitespace 134 | from tokenizers.processors import TemplateProcessing 135 | from tokenizers.trainers import WordLevelTrainer 136 | 137 | tokenizer = Tokenizer(WordLevel(unk_token="[UNK]")) 138 | tokenizer.normalizer = normalizers.Sequence([NFD(), Lowercase(), StripAccents()]) 139 | tokenizer.pre_tokenizer = pre_tokenizers.Sequence( 140 | [Digits(individual_digits=True), Whitespace()] 141 | ) 142 | tokenizer.post_processor = TemplateProcessing( 143 | single="[CLS] $A [SEP]", special_tokens=[("[CLS]", 1), ("[SEP]", 2)] 144 | ) 145 | 146 | trainer = WordLevelTrainer(vocab_size=vocab_size, special_tokens=special_tokens) 147 | tokenizer.train(files=[path], trainer=trainer) 148 | 149 | tokenizer.__len__ = property(lambda self: self.vocab_size) 150 | 151 | tokenizer.enable_truncation(max_length=512) 152 | 153 | print(tokenizer.encode("the red.").ids) 154 | 155 | print(tokenizer.encode("the red.")) 156 | 157 | tokenizer.save(f"{str(pathlib.Path(path).parent)}/word-level-vocab.json") 158 | 159 | 160 | if __name__ == "__main__": 161 | import sys 162 | 163 | if sys.argv[1] == "train-word-level": 164 | train_word_level_tokenizer(path=sys.argv[2]) 165 | elif sys.argv[1] == "train-byte-level": 166 | train_bytelevel(path=sys.argv[2]) 167 | elif sys.argv[1] == "create": 168 | create_tokenizer(path=sys.argv[2]) 169 | -------------------------------------------------------------------------------- /src/utils/data_utils_sentencepiece.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | import pandas as pd 4 | from torch.utils.data import DataLoader, Dataset 5 | import torch 6 | from functools import partial 7 | 8 | logging.basicConfig(level=logging.INFO) 9 | 10 | # BAD: this should not be global 11 | # tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") 12 | 13 | 14 | 15 | 16 | def get_dataloader(tokenizer, data_path, batch_size, max_seq_len): 17 | dataset = TextDataset(tokenizer=tokenizer, data_path=data_path) 18 | 19 | dataloader = DataLoader( 20 | dataset, 21 | batch_size=batch_size, # 20, 22 | drop_last=True, 23 | shuffle=True, 24 | num_workers=1, 25 | collate_fn=partial(TextDataset.collate_pad, cutoff=max_seq_len), 26 | ) 27 | 28 | while True: 29 | for batch in dataloader: 30 | yield batch 31 | 32 | 33 | class TextDataset(Dataset): 34 | def __init__( 35 | self, 36 | tokenizer, 37 | data_path: str, 38 | has_labels: bool = False 39 | ) -> None: 40 | super().__init__() 41 | self.data_path = data_path 42 | self.tokenizer = tokenizer 43 | self.read_data() 44 | if has_labels: 45 | self.read_labels() 46 | 47 | def read_data(self): 48 | logging.info("Reading data from {}".format(self.data_path)) 49 | data = pd.read_csv(self.data_path, sep="\t", header=None) # read text file 50 | logging.info(f"Tokenizing {len(data)} sentences") 51 | 52 | self.text = data[0].apply(lambda x: x.strip()).tolist() 53 | # encoded_input = self.tokenizer(self.questions, self.paragraphs) 54 | 55 | # check if tokenizer has a method 'encode_batch' 56 | if hasattr(self.tokenizer, 'encode_batch'): 57 | 58 | encoded_input = self.tokenizer.encode_batch(self.text) 59 | self.input_ids = [x.ids for x in encoded_input] 60 | 61 | else: 62 | encoded_input = self.tokenizer(self.text) 63 | self.input_ids = encoded_input["input_ids"] 64 | 65 | def read_labels(self): 66 | self.labels = pd.read_csv(self.data_path, sep="\t", header=None)[1].tolist() 67 | # check if labels are already numerical 68 | self.labels = [str(x) for x in self.labels] 69 | if isinstance(self.labels[0], int): 70 | return 71 | # if not, convert to numerical 72 | all_labels = sorted(list(set(self.labels))) 73 | self.label_to_idx = {label: i for i, label in enumerate(all_labels)} 74 | self.idx_to_label = {i: label for i, label in self.label_to_idx.items()} 75 | self.labels = [self.label_to_idx[label] for label in self.labels] 76 | 77 | 78 | 79 | def __len__(self) -> int: 80 | return len(self.text) 81 | 82 | def __getitem__(self, i): 83 | out_dict = { 84 | "input_ids": self.input_ids[i], 85 | # "attention_mask": [1] * len(self.input_ids[i]), 86 | } 87 | if hasattr(self, "labels"): 88 | out_dict["label"] = self.labels[i] 89 | return out_dict 90 | 91 | @staticmethod 92 | def collate_pad(batch, cutoff: int): 93 | max_token_len = 0 94 | num_elems = len(batch) 95 | # batch[0] -> __getitem__[0] --> returns a tuple (embeddings, out_dict) 96 | 97 | for i in range(num_elems): 98 | max_token_len = max(max_token_len, len(batch[i]["input_ids"])) 99 | 100 | max_token_len = min(cutoff, max_token_len) 101 | 102 | tokens = torch.zeros(num_elems, max_token_len).long() 103 | tokens_mask = torch.zeros(num_elems, max_token_len).long() 104 | 105 | has_labels = False 106 | if "label" in batch[0]: 107 | labels = torch.zeros(num_elems).long() 108 | has_labels = True 109 | 110 | for i in range(num_elems): 111 | toks = batch[i]["input_ids"] 112 | length = len(toks) 113 | tokens[i, :length] = torch.LongTensor(toks) 114 | tokens_mask[i, :length] = 1 115 | if has_labels: 116 | labels[i] = batch[i]["label"] 117 | 118 | # TODO: the first return None is just for backward compatibility -- can be removed 119 | if has_labels: 120 | return None, {"input_ids": tokens, "attention_mask": tokens_mask, "labels": labels} 121 | else: 122 | return None, {"input_ids": tokens, "attention_mask": tokens_mask} 123 | -------------------------------------------------------------------------------- /src/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for distributed training. 3 | """ 4 | 5 | import io 6 | import os 7 | import socket 8 | 9 | import blobfile as bf 10 | from mpi4py import MPI 11 | import torch as th 12 | import torch.distributed as dist 13 | 14 | # Change this to reflect your cluster layout. 15 | # The GPU for a given rank is (rank % GPUS_PER_NODE). 16 | GPUS_PER_NODE = 10 #8 17 | 18 | SETUP_RETRY_COUNT = 3 19 | 20 | 21 | def setup_dist(): 22 | """ 23 | Setup a distributed process group. 24 | """ 25 | if dist.is_initialized(): 26 | return 27 | 28 | comm = MPI.COMM_WORLD 29 | backend = "gloo" if not th.cuda.is_available() else "nccl" 30 | 31 | if backend == "gloo": 32 | hostname = "localhost" 33 | else: 34 | hostname = socket.gethostbyname(socket.getfqdn()) 35 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0) 36 | os.environ["RANK"] = str(comm.rank) 37 | os.environ["WORLD_SIZE"] = str(comm.size) 38 | 39 | 40 | port = comm.bcast(_find_free_port(), root=0) 41 | os.environ["MASTER_PORT"] = str(port) 42 | dist.init_process_group(backend=backend, init_method="env://") 43 | 44 | 45 | def dev(): 46 | """ 47 | Get the device to use for torch.distributed. 48 | """ 49 | if th.cuda.is_available(): 50 | return th.device(f"cuda:{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}") 51 | return th.device("cpu") 52 | 53 | 54 | def load_state_dict(path, **kwargs): 55 | """ 56 | Load a PyTorch file without redundant fetches across MPI ranks. 57 | """ 58 | if MPI.COMM_WORLD.Get_rank() == 0: 59 | with bf.BlobFile(path, "rb") as f: 60 | data = f.read() 61 | else: 62 | data = None 63 | data = MPI.COMM_WORLD.bcast(data) 64 | return th.load(io.BytesIO(data), **kwargs) 65 | 66 | 67 | def sync_params(params): 68 | """ 69 | Synchronize a sequence of Tensors across ranks from rank 0. 70 | """ 71 | for p in params: 72 | with th.no_grad(): 73 | dist.broadcast(p, 0) 74 | 75 | 76 | def _find_free_port(): 77 | try: 78 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 79 | s.bind(("", 0)) 80 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 81 | return s.getsockname()[1] 82 | finally: 83 | s.close() 84 | -------------------------------------------------------------------------------- /src/utils/eval_ppl.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluates perplexity of a model on a dataset. 3 | Directly taken from https://huggingface.co/spaces/evaluate-measurement/perplexity/blob/main/perplexity.py 4 | """ 5 | 6 | from itertools import chain 7 | from typing import List 8 | import torch 9 | 10 | import numpy as np 11 | import torch 12 | from torch.nn import CrossEntropyLoss 13 | from transformers import AutoModelForCausalLM, AutoTokenizer 14 | 15 | device = "cuda" if torch.cuda.is_available() else "cpu" 16 | 17 | model_id = "distilgpt2" 18 | model = AutoModelForCausalLM.from_pretrained(model_id) 19 | model = model.to(device) 20 | 21 | tokenizer = AutoTokenizer.from_pretrained(model_id) 22 | 23 | 24 | def compute_perplexity(data, batch_size: int = 16, add_start_token: bool = True): 25 | 26 | 27 | 28 | # if batch_size > 1 (which generally leads to padding being required), and 29 | # if there is not an already assigned pad_token, assign an existing 30 | # special token to also be the padding token 31 | if tokenizer.pad_token is None and batch_size > 1: 32 | existing_special_tokens = list(tokenizer.special_tokens_map_extended.values()) 33 | # check that the model already has at least one special token defined 34 | assert ( 35 | len(existing_special_tokens) > 0 36 | ), "If batch_size > 1, model must have at least one special token to use for padding. Please use a different model or set batch_size=1." 37 | # assign one of the special tokens to also be the pad token 38 | tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]}) 39 | 40 | if add_start_token: 41 | # leave room for token to be added: 42 | assert ( 43 | tokenizer.bos_token is not None 44 | ), "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False" 45 | max_tokenized_len = model.config.max_length - 1 46 | else: 47 | max_tokenized_len = model.config.max_length 48 | 49 | encodings = tokenizer( 50 | data, 51 | add_special_tokens=False, 52 | padding=True, 53 | truncation=True, 54 | max_length=max_tokenized_len, 55 | return_tensors="pt", 56 | return_attention_mask=True, 57 | ).to(device) 58 | 59 | encoded_texts = encodings["input_ids"] 60 | attn_masks = encodings["attention_mask"] 61 | 62 | # check that each input is long enough: 63 | if add_start_token: 64 | assert torch.all(torch.ge(attn_masks.sum(1), 1)), "Each input text must be at least one token long." 65 | else: 66 | assert torch.all( 67 | torch.ge(attn_masks.sum(1), 2) 68 | ), "When add_start_token=False, each input text must be at least two tokens long. Run with add_start_token=True if inputting strings of only one token, and remove all empty input strings." 69 | 70 | ppls = [] 71 | loss_fct = CrossEntropyLoss(reduction="none") 72 | 73 | for start_index in tqdm(range(0, len(encoded_texts), batch_size)): 74 | end_index = min(start_index + batch_size, len(encoded_texts)) 75 | encoded_batch = encoded_texts[start_index:end_index] 76 | attn_mask = attn_masks[start_index:end_index] 77 | 78 | if add_start_token: 79 | bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * encoded_batch.size(dim=0)).to(device) 80 | encoded_batch = torch.cat([bos_tokens_tensor, encoded_batch], dim=1) 81 | attn_mask = torch.cat( 82 | [torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(device), attn_mask], dim=1 83 | ) 84 | 85 | labels = encoded_batch 86 | 87 | with torch.no_grad(): 88 | out_logits = model(encoded_batch, attention_mask=attn_mask).logits 89 | 90 | shift_logits = out_logits[..., :-1, :].contiguous() 91 | shift_labels = labels[..., 1:].contiguous() 92 | shift_attention_mask_batch = attn_mask[..., 1:].contiguous() 93 | 94 | perplexity_batch = torch.exp( 95 | (loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).sum(1) 96 | / shift_attention_mask_batch.sum(1) 97 | ) 98 | 99 | ppls += perplexity_batch.tolist() 100 | 101 | return {"perplexities": ppls, "mean_perplexity": np.mean(ppls)} 102 | 103 | 104 | def calculate_perplexity_for_file(path: str): 105 | # read lines 106 | special_tokens = ["[CLS]", "[SEP]", "[PAD]", "[MASK]", "", ""] 107 | with open(path, "r") as f: 108 | lines = f.readlines() 109 | 110 | lines = [remove_all(line.strip(), special_tokens) for line in lines if len(line.strip()) > 0] 111 | try: 112 | num_unique_lines = len(set(lines)) 113 | perc_unique_lines = round(num_unique_lines * 100 / len(lines), 2) 114 | all_tokens = list(chain(*[line.split() for line in lines])) 115 | perc_unique_tokens = round(len(set(all_tokens)) * 100 / len(all_tokens), 2) 116 | 117 | return {"data": lines, "ppl": compute_perplexity(lines)['mean_perplexity'], "perc_unique_lines": perc_unique_lines, "perc_unique_tokens": perc_unique_tokens} 118 | except Exception as e: 119 | return {"data": [], "ppl": 1e6, "perc_unique_lines": 0, "perc_unique_tokens": 0} 120 | 121 | 122 | def remove_all(line: str, special_toks: List[str]) -> str: 123 | for tok in special_toks: 124 | line = line.replace(tok, "").strip() 125 | return line 126 | 127 | 128 | if __name__ == '__main__': 129 | import sys 130 | import glob 131 | from tqdm import tqdm 132 | from pprint import pprint 133 | import json 134 | files = glob.glob(sys.argv[1]) 135 | res = dict() 136 | for file in tqdm(files): 137 | res[file] = calculate_perplexity_for_file(file) 138 | 139 | 140 | # sort by perplexity 141 | res = {k: v for k, v in sorted(res.items(), key=lambda item: item[1]['ppl'])} 142 | 143 | for file in res: 144 | # show a few lines 145 | print(f"File: {file}") 146 | pprint(res[file]['data'][:5]) 147 | 148 | # show the perplexity 149 | print(f"Perplexity: {res[file]['ppl']}") 150 | print(f"Percentage of unique lines: {res[file]['perc_unique_lines']}") 151 | print(f"Percentage of unique tokens: {res[file]['perc_unique_tokens']}") 152 | print("-" * 100) 153 | 154 | # Create a nice MARKDOWN report with: i) sample sentences, ii) perplexity, iii) percentage of unique lines, iv) percentage of unique tokens 155 | 156 | import random 157 | print("| File | Sample Sentences | Perplexity | % Unique Lines | % Unique Tokens |") 158 | for file in res: 159 | sentences = set(res[file]['data']) 160 | # pick 5 random sentences 161 | sentences = random.sample(sentences, 5) if len(sentences) > 5 else sentences 162 | 163 | filename = "#".join(file.split("/")[:-1]) 164 | # print row 165 | print('-' * 80) 166 | if res[file]['perc_unique_tokens'] > 0: 167 | print(f"| {filename} | {', '.join(sentences)} | {res[file]['ppl']} | {res[file]['perc_unique_lines']} | {res[file]['perc_unique_tokens']} |") 168 | 169 | 170 | with open("perplexity.json", "w") as f: 171 | json.dump(res, f) 172 | 173 | -------------------------------------------------------------------------------- /src/utils/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import torch.nn as nn 6 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 7 | 8 | 9 | def convert_module_to_f16(l): 10 | """ 11 | Convert primitive modules to float16. 12 | """ 13 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 14 | l.weight.data = l.weight.data.half() 15 | l.bias.data = l.bias.data.half() 16 | 17 | 18 | def convert_module_to_f32(l): 19 | """ 20 | Convert primitive modules to float32, undoing convert_module_to_f16(). 21 | """ 22 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 23 | l.weight.data = l.weight.data.float() 24 | l.bias.data = l.bias.data.float() 25 | 26 | 27 | def make_master_params(model_params): 28 | """ 29 | Copy model parameters into a (differently-shaped) list of full-precision 30 | parameters. 31 | """ 32 | master_params = _flatten_dense_tensors( 33 | [param.detach().float() for param in model_params] 34 | ) 35 | master_params = nn.Parameter(master_params) 36 | master_params.requires_grad = True 37 | return [master_params] 38 | 39 | 40 | def model_grads_to_master_grads(model_params, master_params): 41 | """ 42 | Copy the gradients from the model parameters into the master parameters 43 | from make_master_params(). 44 | """ 45 | master_params[0].grad = _flatten_dense_tensors( 46 | [param.grad.data.detach().float() for param in model_params] 47 | ) 48 | 49 | 50 | def master_params_to_model_params(model_params, master_params): 51 | """ 52 | Copy the master parameter data back into the model parameters. 53 | """ 54 | # Without copying to a list, if a generator is passed, this will 55 | # silently not copy any parameters. 56 | model_params = list(model_params) 57 | 58 | for param, master_param in zip( 59 | model_params, unflatten_master_params(model_params, master_params) 60 | ): 61 | param.detach().copy_(master_param) 62 | 63 | 64 | def unflatten_master_params(model_params, master_params): 65 | """ 66 | Unflatten the master parameters to look like model_params. 67 | """ 68 | return _unflatten_dense_tensors(master_params[0].detach(), model_params) 69 | 70 | 71 | def zero_grad(model_params): 72 | for param in model_params: 73 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 74 | if param.grad is not None: 75 | param.grad.detach_() 76 | param.grad.zero_() 77 | -------------------------------------------------------------------------------- /src/utils/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies: 3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py 4 | """ 5 | 6 | import os 7 | import sys 8 | import shutil 9 | import os.path as osp 10 | import json 11 | import time 12 | import datetime 13 | import tempfile 14 | import warnings 15 | from collections import defaultdict 16 | from contextlib import contextmanager 17 | import wandb 18 | 19 | DEBUG = 10 20 | INFO = 20 21 | WARN = 30 22 | ERROR = 40 23 | 24 | DISABLED = 50 25 | 26 | 27 | class KVWriter(object): 28 | def writekvs(self, kvs): 29 | raise NotImplementedError 30 | 31 | 32 | class SeqWriter(object): 33 | def writeseq(self, seq): 34 | raise NotImplementedError 35 | 36 | 37 | class HumanOutputFormat(KVWriter, SeqWriter): 38 | def __init__(self, filename_or_file): 39 | if isinstance(filename_or_file, str): 40 | self.file = open(filename_or_file, "wt") 41 | self.own_file = True 42 | else: 43 | assert hasattr(filename_or_file, "read"), ( 44 | "expected file or str, got %s" % filename_or_file 45 | ) 46 | self.file = filename_or_file 47 | self.own_file = False 48 | 49 | def writekvs(self, kvs): 50 | # Create strings for printing 51 | key2str = {} 52 | for (key, val) in sorted(kvs.items()): 53 | if hasattr(val, "__float__"): 54 | valstr = "%-8.3g" % val 55 | else: 56 | valstr = str(val) 57 | key2str[self._truncate(key)] = self._truncate(valstr) 58 | 59 | # Find max widths 60 | if len(key2str) == 0: 61 | print("WARNING: tried to write empty key-value dict") 62 | return 63 | else: 64 | keywidth = max(map(len, key2str.keys())) 65 | valwidth = max(map(len, key2str.values())) 66 | 67 | # Write out the data 68 | dashes = "-" * (keywidth + valwidth + 7) 69 | lines = [dashes] 70 | for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()): 71 | lines.append( 72 | "| %s%s | %s%s |" 73 | % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val))) 74 | ) 75 | lines.append(dashes) 76 | self.file.write("\n".join(lines) + "\n") 77 | 78 | # Flush the output to the file 79 | self.file.flush() 80 | 81 | def _truncate(self, s): 82 | maxlen = 30 83 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s 84 | 85 | def writeseq(self, seq): 86 | seq = list(seq) 87 | for (i, elem) in enumerate(seq): 88 | self.file.write(elem) 89 | if i < len(seq) - 1: # add space unless this is the last one 90 | self.file.write(" ") 91 | self.file.write("\n") 92 | self.file.flush() 93 | 94 | def close(self): 95 | if self.own_file: 96 | self.file.close() 97 | 98 | 99 | class JSONOutputFormat(KVWriter): 100 | def __init__(self, filename): 101 | self.file = open(filename, "wt") 102 | 103 | def writekvs(self, kvs): 104 | for k, v in sorted(kvs.items()): 105 | if hasattr(v, "dtype"): 106 | kvs[k] = float(v) 107 | self.file.write(json.dumps(kvs) + "\n") 108 | self.file.flush() 109 | 110 | def close(self): 111 | self.file.close() 112 | 113 | 114 | class CSVOutputFormat(KVWriter): 115 | def __init__(self, filename): 116 | self.file = open(filename, "w+t") 117 | self.keys = [] 118 | self.sep = "," 119 | 120 | def writekvs(self, kvs): 121 | # Add our current row to the history 122 | extra_keys = list(kvs.keys() - self.keys) 123 | extra_keys.sort() 124 | if extra_keys: 125 | self.keys.extend(extra_keys) 126 | self.file.seek(0) 127 | lines = self.file.readlines() 128 | self.file.seek(0) 129 | for (i, k) in enumerate(self.keys): 130 | if i > 0: 131 | self.file.write(",") 132 | self.file.write(k) 133 | self.file.write("\n") 134 | for line in lines[1:]: 135 | self.file.write(line[:-1]) 136 | self.file.write(self.sep * len(extra_keys)) 137 | self.file.write("\n") 138 | for (i, k) in enumerate(self.keys): 139 | if i > 0: 140 | self.file.write(",") 141 | v = kvs.get(k) 142 | if v is not None: 143 | self.file.write(str(v)) 144 | self.file.write("\n") 145 | self.file.flush() 146 | 147 | def close(self): 148 | self.file.close() 149 | 150 | 151 | class TensorBoardOutputFormat(KVWriter): 152 | """ 153 | Dumps key/value pairs into TensorBoard's numeric format. 154 | """ 155 | 156 | def __init__(self, dir): 157 | os.makedirs(dir, exist_ok=True) 158 | self.dir = dir 159 | self.step = 1 160 | prefix = "events" 161 | path = osp.join(osp.abspath(dir), prefix) 162 | import tensorflow as tf 163 | from tensorflow.python import pywrap_tensorflow 164 | from tensorflow.core.util import event_pb2 165 | from tensorflow.python.util import compat 166 | 167 | self.tf = tf 168 | self.event_pb2 = event_pb2 169 | self.pywrap_tensorflow = pywrap_tensorflow 170 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) 171 | 172 | def writekvs(self, kvs): 173 | def summary_val(k, v): 174 | kwargs = {"tag": k, "simple_value": float(v)} 175 | return self.tf.Summary.Value(**kwargs) 176 | 177 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) 178 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary) 179 | event.step = ( 180 | self.step 181 | ) # is there any reason why you'd want to specify the step? 182 | self.writer.WriteEvent(event) 183 | self.writer.Flush() 184 | self.step += 1 185 | 186 | def close(self): 187 | if self.writer: 188 | self.writer.Close() 189 | self.writer = None 190 | 191 | 192 | def make_output_format(format, ev_dir, log_suffix=""): 193 | os.makedirs(ev_dir, exist_ok=True) 194 | if format == "stdout": 195 | return HumanOutputFormat(sys.stdout) 196 | elif format == "log": 197 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix)) 198 | elif format == "json": 199 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix)) 200 | elif format == "csv": 201 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix)) 202 | elif format == "tensorboard": 203 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix)) 204 | else: 205 | raise ValueError("Unknown format specified: %s" % (format,)) 206 | 207 | 208 | # ================================================================ 209 | # API 210 | # ================================================================ 211 | 212 | 213 | def logkv(key, val): 214 | """ 215 | Log a value of some diagnostic 216 | Call this once for each diagnostic quantity, each iteration 217 | If called many times, last value will be used. 218 | """ 219 | get_current().logkv(key, val) 220 | 221 | 222 | def logkv_mean(key, val): 223 | """ 224 | The same as logkv(), but if called many times, values averaged. 225 | """ 226 | get_current().logkv_mean(key, val) 227 | 228 | 229 | def logkvs(d): 230 | """ 231 | Log a dictionary of key-value pairs 232 | """ 233 | for (k, v) in d.items(): 234 | logkv(k, v) 235 | 236 | 237 | def dumpkvs(): 238 | """ 239 | Write all of the diagnostics from the current iteration 240 | """ 241 | return get_current().dumpkvs() 242 | 243 | 244 | def getkvs(): 245 | return get_current().name2val 246 | 247 | 248 | def log(*args, level=INFO): 249 | """ 250 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 251 | """ 252 | get_current().log(*args, level=level) 253 | 254 | 255 | def debug(*args): 256 | log(*args, level=DEBUG) 257 | 258 | 259 | def info(*args): 260 | log(*args, level=INFO) 261 | 262 | 263 | def warn(*args): 264 | log(*args, level=WARN) 265 | 266 | 267 | def error(*args): 268 | log(*args, level=ERROR) 269 | 270 | 271 | def set_level(level): 272 | """ 273 | Set logging threshold on current logger. 274 | """ 275 | get_current().set_level(level) 276 | 277 | 278 | def set_comm(comm): 279 | get_current().set_comm(comm) 280 | 281 | 282 | def get_dir(): 283 | """ 284 | Get directory that log files are being written to. 285 | will be None if there is no output directory (i.e., if you didn't call start) 286 | """ 287 | return get_current().get_dir() 288 | 289 | 290 | record_tabular = logkv 291 | dump_tabular = dumpkvs 292 | 293 | 294 | @contextmanager 295 | def profile_kv(scopename): 296 | logkey = "wait_" + scopename 297 | tstart = time.time() 298 | try: 299 | yield 300 | finally: 301 | get_current().name2val[logkey] += time.time() - tstart 302 | 303 | 304 | def profile(n): 305 | """ 306 | Usage: 307 | @profile("my_func") 308 | def my_func(): code 309 | """ 310 | 311 | def decorator_with_name(func): 312 | def func_wrapper(*args, **kwargs): 313 | with profile_kv(n): 314 | return func(*args, **kwargs) 315 | 316 | return func_wrapper 317 | 318 | return decorator_with_name 319 | 320 | 321 | # ================================================================ 322 | # Backend 323 | # ================================================================ 324 | 325 | 326 | def get_current(): 327 | if Logger.CURRENT is None: 328 | _configure_default_logger() 329 | 330 | return Logger.CURRENT 331 | 332 | 333 | class Logger(object): 334 | DEFAULT = None # A logger with no output files. (See right below class definition) 335 | # So that you can still log to the terminal without setting up any output files 336 | CURRENT = None # Current logger being used by the free functions above 337 | 338 | def __init__(self, dir, output_formats, comm=None): 339 | self.name2val = defaultdict(float) # values this iteration 340 | self.name2cnt = defaultdict(int) 341 | self.level = INFO 342 | self.dir = dir 343 | self.output_formats = output_formats 344 | self.comm = comm 345 | 346 | # Logging API, forwarded 347 | # ---------------------------------------- 348 | def logkv(self, key, val): 349 | self.name2val[key] = val 350 | 351 | def logkv_mean(self, key, val): 352 | oldval, cnt = self.name2val[key], self.name2cnt[key] 353 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) 354 | self.name2cnt[key] = cnt + 1 355 | 356 | def dumpkvs(self, prefix=None): 357 | if self.comm is None: 358 | d = self.name2val 359 | else: 360 | d = mpi_weighted_mean( 361 | self.comm, 362 | { 363 | name: (val, self.name2cnt.get(name, 1)) 364 | for (name, val) in self.name2val.items() 365 | }, 366 | ) 367 | if self.comm.rank != 0: 368 | d["dummy"] = 1 # so we don't get a warning about empty dict 369 | # LISA 370 | wandb.log({**d}) 371 | out = d.copy() # Return the dict for unit testing purposes 372 | for fmt in self.output_formats: 373 | if isinstance(fmt, KVWriter): 374 | fmt.writekvs(d) 375 | self.name2val.clear() 376 | self.name2cnt.clear() 377 | return out 378 | 379 | def log(self, *args, level=INFO): 380 | if self.level <= level: 381 | self._do_log(args) 382 | 383 | # Configuration 384 | # ---------------------------------------- 385 | def set_level(self, level): 386 | self.level = level 387 | 388 | def set_comm(self, comm): 389 | self.comm = comm 390 | 391 | def get_dir(self): 392 | return self.dir 393 | 394 | def close(self): 395 | for fmt in self.output_formats: 396 | fmt.close() 397 | 398 | # Misc 399 | # ---------------------------------------- 400 | def _do_log(self, args): 401 | for fmt in self.output_formats: 402 | if isinstance(fmt, SeqWriter): 403 | fmt.writeseq(map(str, args)) 404 | 405 | 406 | def get_rank_without_mpi_import(): 407 | # check environment variables here instead of importing mpi4py 408 | # to avoid calling MPI_Init() when this module is imported 409 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]: 410 | if varname in os.environ: 411 | return int(os.environ[varname]) 412 | return 0 413 | 414 | 415 | def mpi_weighted_mean(comm, local_name2valcount): 416 | """ 417 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110 418 | Perform a weighted average over dicts that are each on a different node 419 | Input: local_name2valcount: dict mapping key -> (value, count) 420 | Returns: key -> mean 421 | """ 422 | all_name2valcount = comm.gather(local_name2valcount) 423 | if comm.rank == 0: 424 | name2sum = defaultdict(float) 425 | name2count = defaultdict(float) 426 | for n2vc in all_name2valcount: 427 | for (name, (val, count)) in n2vc.items(): 428 | try: 429 | val = float(val) 430 | except ValueError: 431 | if comm.rank == 0: 432 | warnings.warn( 433 | "WARNING: tried to compute mean on non-float {}={}".format( 434 | name, val 435 | ) 436 | ) 437 | else: 438 | name2sum[name] += val * count 439 | name2count[name] += count 440 | return {name: name2sum[name] / name2count[name] for name in name2sum} 441 | else: 442 | return {} 443 | 444 | 445 | def configure(dir=None, format_strs=None, comm=None, log_suffix=""): 446 | """ 447 | If comm is provided, average all numerical stats across that comm 448 | """ 449 | if dir is None: 450 | dir = os.getenv("OPENAI_LOGDIR") 451 | if dir is None: 452 | dir = osp.join( 453 | tempfile.gettempdir(), 454 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"), 455 | ) 456 | assert isinstance(dir, str) 457 | dir = os.path.expanduser(dir) 458 | os.makedirs(os.path.expanduser(dir), exist_ok=True) 459 | 460 | rank = get_rank_without_mpi_import() 461 | if rank > 0: 462 | log_suffix = log_suffix + "-rank%03i" % rank 463 | 464 | if format_strs is None: 465 | if rank == 0: 466 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",") 467 | else: 468 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",") 469 | format_strs = filter(None, format_strs) 470 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] 471 | 472 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) 473 | if output_formats: 474 | log("Logging to %s" % dir) 475 | 476 | 477 | def _configure_default_logger(): 478 | configure() 479 | Logger.DEFAULT = Logger.CURRENT 480 | 481 | 482 | def reset(): 483 | if Logger.CURRENT is not Logger.DEFAULT: 484 | Logger.CURRENT.close() 485 | Logger.CURRENT = Logger.DEFAULT 486 | log("Reset logger") 487 | 488 | 489 | @contextmanager 490 | def scoped_configure(dir=None, format_strs=None, comm=None): 491 | prevlogger = Logger.CURRENT 492 | configure(dir=dir, format_strs=format_strs, comm=comm) 493 | try: 494 | yield 495 | finally: 496 | Logger.CURRENT.close() 497 | Logger.CURRENT = prevlogger 498 | 499 | -------------------------------------------------------------------------------- /src/utils/show_sampling_progress.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import List 3 | list_of_colors_from_red_to_blue = [f"\033[38;2;{r};0;{b}m" for r, b in zip(range(255, 0, -10), range(0, 255, 10))] 4 | 5 | def pprint_sentences(sentences: List[str], banner: str = "", sep: str = ""): 6 | """ 7 | Given a list of sentences, prints them with a gradient of colors from red to blue 8 | """ 9 | print() 10 | print(f"\033[1m{'=' * 20} {banner} {'=' * 20}\033[0m") 11 | for i, sentence in enumerate(sentences): 12 | sentence_color = list_of_colors_from_red_to_blue[i] 13 | if i == len(sentences) - 1: 14 | print(f"\033[38;5;{sentence_color}{sentence}\033[0m") 15 | else: 16 | print(f"\033[38;5;{sentence_color}{sentence}\033[0m", end=sep) 17 | print() 18 | 19 | 20 | if __name__ == '__main__': 21 | sentences = [ 22 | "This is a sentence", 23 | "This is another sentence", 24 | "This is a third sentence", 25 | "This is a fourth sentence", 26 | "This is a fifth sentence", 27 | "This is a sixth sentence", 28 | "This is a seventh sentence", 29 | "This is an eighth sentence", 30 | "This is a ninth sentence", 31 | "This is a tenth sentence", 32 | "This is an eleventh sentence", 33 | "This is a twelfth sentence", 34 | "This is a thirteenth sentence", 35 | "This is a fourteenth sentence", 36 | "This is a fifteenth sentence", 37 | "This is a sixteenth sentence", 38 | "This is a seventeenth sentence", 39 | "This is an eighteenth sentence", 40 | "This is a nineteenth sentence", 41 | "This is a twentieth sentence", 42 | ] 43 | for i in range(1, len(sentences) + 1): 44 | pprint_sentences(sentences[:i], sep= " -> ") 45 | print("---") 46 | 47 | -------------------------------------------------------------------------------- /src/utils/test_util.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import numpy as np 3 | 4 | def compute_logp(args, model, x, input_ids): 5 | word_emb = model.weight 6 | sigma = 0.1 7 | if args.model_arch == '1d-unet': 8 | x = x.permute(0, 2, 1) 9 | 10 | bsz, seqlen, dim = x.shape 11 | 12 | x_flat = x.reshape(-1, x.size(-1)).unsqueeze(0) # 1, bsz*sample*seqlen, dim 13 | word_emb_flat = word_emb.unsqueeze(1) # vocab, 1, dim 14 | diff = (x_flat - word_emb_flat) ** 2 # vocab, seqlen, dim 15 | 16 | logp_expanded = -diff.sum(dim=-1) / (2 * sigma ** 2) # vocab, seqlen 17 | logp_expanded = logp_expanded.permute((1, 0)) 18 | # print(th.topk(logp_expanded.view(bsz, seqlen, -1), k=5, dim=-1)[0]) 19 | # print(input_ids[0]) 20 | ce = th.nn.CrossEntropyLoss(reduction='none') 21 | loss = ce(logp_expanded, input_ids.view(-1)).view(bsz, seqlen) 22 | # print(loss[0]) 23 | 24 | # print(loss.shape) 25 | return loss 26 | 27 | def get_weights(model, args): 28 | if hasattr(model, 'transformer'): 29 | input_embs = model.transformer.wte # input_embs 30 | down_proj = model.down_proj 31 | down_proj_emb = down_proj(input_embs.weight) 32 | print(down_proj_emb.shape) 33 | # model = th.nn.Embedding(down_proj_emb.shape[1], down_proj_emb.shape[0]) 34 | model = th.nn.Embedding(down_proj_emb.size(0), down_proj_emb.size(1)) 35 | print(args.emb_scale_factor) 36 | model.weight.data = down_proj_emb * args.emb_scale_factor 37 | 38 | elif hasattr(model, 'weight'): 39 | pass 40 | else: 41 | assert NotImplementedError 42 | 43 | model.weight.requires_grad = False 44 | return model 45 | 46 | def denoised_fn_round(args, model, text_emb, t): 47 | 48 | down_proj_emb = model.weight # input_embs 49 | # print(t) 50 | old_shape = text_emb.shape 51 | old_device = text_emb.device 52 | 53 | def get_efficient_knn(down_proj_emb, text_emb, dist='l2'): 54 | if dist == 'l2': 55 | emb_norm = (down_proj_emb**2).sum(-1).view(-1, 1) #vocab 56 | text_emb_t = th.transpose(text_emb.view(-1, text_emb.size(-1)), 0, 1) #d, bsz*seqlen 57 | arr_norm = (text_emb ** 2).sum(-1).view(-1, 1) #bsz*seqlen, 1 58 | # print(emb_norm.shape, arr_norm.shape) 59 | dist = emb_norm + arr_norm.transpose(0, 1) - 2.0 * th.mm(down_proj_emb, text_emb_t) #(vocab, d) x (d, bsz*seqlen) 60 | dist = th.clamp(dist, 0.0, np.inf) 61 | # print(dist.shape) 62 | topk_out = th.topk(-dist, k=1, dim=0) 63 | # adjacency = down_proj_emb.unsqueeze(1).expand(-1, text_emb.size(0), -1) - text_emb.unsqueeze(0).expand( 64 | # down_proj_emb.size(0), -1, -1) 65 | # adjacency = -th.norm(adjacency, dim=-1) 66 | # topk_out = th.topk(adjacency, k=1, dim=0) 67 | # print(topk_out1.indices == topk_out.indices) 68 | # assert th.all(topk_out1.indices == topk_out.indices) 69 | return topk_out.values, topk_out.indices 70 | 71 | def get_knn(down_proj_emb, text_emb, dist='l2'): 72 | if dist == 'l2': 73 | adjacency = down_proj_emb.unsqueeze(1).expand(-1, text_emb.size(0), -1) - text_emb.unsqueeze(0).expand( 74 | down_proj_emb.size(0), -1, -1) 75 | adjacency = -th.norm(adjacency, dim=-1) 76 | topk_out = th.topk(adjacency, k=1, dim=0) 77 | return topk_out.values, topk_out.indices 78 | 79 | dist = 'l2' 80 | if len(text_emb.shape) > 2: 81 | text_emb = text_emb.reshape(-1, text_emb.size(-1)) 82 | else: 83 | text_emb = text_emb 84 | # val, indices = get_knn(down_proj_emb, 85 | # text_emb.to(down_proj_emb.device), dist=dist) 86 | val, indices = get_efficient_knn(down_proj_emb, 87 | text_emb.to(down_proj_emb.device), dist=dist) 88 | rounded_tokens = indices[0] 89 | # print(rounded_tokens.shape) 90 | new_embeds = model(rounded_tokens).view(old_shape).to(old_device) 91 | if args.model_arch == '1d-unet': 92 | new_embeds = new_embeds.permute(0, 2, 1) 93 | return new_embeds 94 | 95 | def load_results(json_path, load_dict): 96 | import json 97 | with open(json_path, 'w') as f: 98 | json.dump(load_dict, f, indent=2) 99 | --------------------------------------------------------------------------------