├── .gitignore
├── CITATION.cff
├── CNAME
├── LICENSE.md
├── README.md
├── __init__.py
├── data
├── greetings-test.txt
├── greetings-train.txt
├── greetings
│ ├── greetings-test.txt
│ ├── greetings-train.txt
│ ├── greetings.py
│ ├── greetings.txt
│ ├── greetings_labeled.tsv
│ └── word-level-vocab.json
├── simple-test.txt
├── simple-train.txt
└── simple
│ ├── merges.txt
│ ├── simple.txt
│ ├── simple_labeled.tsv
│ ├── vocab.json
│ └── word-level-vocab.json
├── docs
├── CNAME
├── controllable.md
├── imgs
│ ├── greetings_training_finished.png
│ └── greetings_training_loop.png
├── old_experiments.md
└── training_on_your_own_dataset.md
├── minimal-text-diffusion.gif
├── requirements.txt
├── scripts
├── install.sh
├── run_train.sh
└── text_sample.sh
└── src
├── __init__.py
├── controllable
├── classifier.py
├── controllable_text_sample.py
└── langevin.py
├── modeling
├── __init__.py
├── diffusion
│ ├── __init__.py
│ ├── gaussian_diffusion.py
│ ├── losses.py
│ ├── nn.py
│ ├── resample.py
│ ├── respace.py
│ └── rounding.py
└── predictor
│ └── transformer_model.py
├── train_infer
├── factory_methods.py
├── text_sample.py
├── train.py
└── train_loop.py
└── utils
├── args_utils.py
├── custom_tokenizer.py
├── data_utils_sentencepiece.py
├── dist_util.py
├── eval_ppl.py
├── fp16_util.py
├── logger.py
├── show_sampling_progress.py
└── test_util.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Logging
2 |
3 | wandb/
4 | checkpoints/
5 | generation_outputs
6 | diff_models
7 | data/
8 | SWEEP
9 | ipynb/
10 | scripts/
11 | *ckpt*
12 | *checkpoints*
13 | nohup.out
14 |
15 | # Byte-compiled / optimized / DLL files
16 | __pycache__/
17 | *.py[cod]
18 | *$py.class
19 |
20 | # C extensions
21 | *.so
22 |
23 | # Distribution / packaging
24 | .Python
25 | build/
26 | develop-eggs/
27 | dist/
28 | downloads/
29 | eggs/
30 | .eggs/
31 | lib/
32 | lib64/
33 | parts/
34 | sdist/
35 | var/
36 | wheels/
37 | pip-wheel-metadata/
38 | share/python-wheels/
39 | *.egg-info/
40 | .installed.cfg
41 | *.egg
42 | MANIFEST
43 |
44 | # PyInstaller
45 | # Usually these files are written by a python script from a template
46 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
47 | *.manifest
48 | *.spec
49 |
50 | # Installer logs
51 | pip-log.txt
52 | pip-delete-this-directory.txt
53 |
54 | # Unit test / coverage reports
55 | htmlcov/
56 | .tox/
57 | .nox/
58 | .coverage
59 | .coverage.*
60 | .cache
61 | nosetests.xml
62 | coverage.xml
63 | *.cover
64 | *.py,cover
65 | .hypothesis/
66 | .pytest_cache/
67 |
68 | # Translations
69 | *.mo
70 | *.pot
71 |
72 | # Django stuff:
73 | *.log
74 | local_settings.py
75 | db.sqlite3
76 | db.sqlite3-journal
77 |
78 | # Flask stuff:
79 | instance/
80 | .webassets-cache
81 |
82 | # Scrapy stuff:
83 | .scrapy
84 |
85 | # Sphinx documentation
86 | docs/_build/
87 |
88 | # PyBuilder
89 | target/
90 |
91 | # Jupyter Notebook
92 | .ipynb_checkpoints
93 |
94 | # IPython
95 | profile_default/
96 | ipython_config.py
97 |
98 | # pyenv
99 | .python-version
100 |
101 | # pipenv
102 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
103 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
104 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
105 | # install all needed dependencies.
106 | #Pipfile.lock
107 |
108 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
109 | __pypackages__/
110 |
111 | # Celery stuff
112 | celerybeat-schedule
113 | celerybeat.pid
114 |
115 | # SageMath parsed files
116 | *.sage.py
117 |
118 | # Environments
119 | .env
120 | .venv
121 | env/
122 | venv/
123 | ENV/
124 | env.bak/
125 | venv.bak/
126 |
127 | # Spyder project settings
128 | .spyderproject
129 | .spyproject
130 |
131 | # Rope project settings
132 | .ropeproject
133 |
134 | # mkdocs documentation
135 | /site
136 |
137 | # mypy
138 | .mypy_cache/
139 | .dmypy.json
140 | dmypy.json
141 |
142 | # Pyre type checker
143 | .pyre/
144 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | type: software
3 | id: "minimal_text_diffusion2022"
4 | authors:
5 | - family-names: Madaan
6 | given-names: Aman
7 | doi: 10.5281/zenodo.7374939
8 | month: 12
9 | title: "Minimal text diffusion"
10 | url: https://github.com/madaan/minimal-text-diffusion
11 | version: 0.1
12 | year: 2022
13 |
--------------------------------------------------------------------------------
/CNAME:
--------------------------------------------------------------------------------
1 | diffusion.textgen.info
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Aman Madaan
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Minimal text diffusion
2 |
3 |
4 |
5 | _A minimal implementation of diffusion models of text: learns a diffusion model of a given text corpus, allowing to generate text samples from the learned model._
6 |
7 |
8 |
9 | ----
10 |
11 |
12 | |  |
13 | |:--:|
14 | | Diffusion in action: a DDPM model gradually denoising random text _`hotnutggy pi greentedsty rawyaented`_ to _`the white eggplant is dried`_ and _`mac clement star fe honey spin theapple purpleip`_ to _`the brown radicchio is sour`_|
15 |
16 |
17 | ----
18 |
19 | This repo has been refactored by taking a large amount of code from https://github.com/XiangLi1999/Diffusion-LM (which includes some code from: https://github.com/openai/glide-text2im), thanks to the authors for their work!
20 |
21 | The main idea was to retain _just enough code_ to allow training a simple diffusion model and generating samples, remove image-related terms, and make it easier to use.
22 |
23 | I've included an extremely simple corpus (`data/simple-{train,test}.txt`) I used for quick iterations and testing.
24 |
25 | ---
26 |
27 |
28 | ## Table of Contents
29 |
30 | - [Minimal text diffusion](#minimal-text-diffusion)
31 | * [Table of Contents](#table-of-contents)
32 | * [Getting started](#getting-started)
33 | + [Setup](#setup)
34 | + [Preparing dataset](#preparing-dataset)
35 | + [Training](#training)
36 | + [Inference](#inference)
37 | * [Training from scratch on the greetings dataset](#training-from-scratch-on-the-greetings-dataset)
38 | * [Experiments with using pre-trained models and embeddings](#experiments-with-using-pre-trained-models-and-embeddings)
39 | * [Controllable Generation](#controllable-generation)
40 | * [Gory details](#gory-details)
41 | + [Training](#training-1)
42 | + [Evolving input](#evolving-input)
43 | + [Sampling](#sampling)
44 | * [TODO](#todo)
45 | + [Opportunities for further minimization](#opportunities-for-further-minimization)
46 | * [Acknowledgements](#acknowledgements)
47 | * [License](#license)
48 |
49 | ---
50 |
51 | ## Getting started
52 |
53 | ### Setup
54 |
55 | - Install the requirements: `pip install -r requirements.txt`
56 |
57 | - Some of the dependencies might be easier to install via conda:
58 | ```sh
59 | conda install mpi4py
60 | conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
61 | ```
62 |
63 | ### Preparing dataset
64 |
65 | - We will use `data/simple.txt` as a running example. To begin, we need to create a tokenizer over the dataset. I found that word-level tokenization works best, but the implementation in `src/utils/custom_tokenizer` includes options to create BPE tokenizer.
66 |
67 |
68 | ```sh
69 | python src/utils/custom_tokenizer.py train-word-level data/simple/simple.txt
70 | ```
71 |
72 |
73 | ### Training
74 |
75 | - To train a model, run `scripts/train.sh`. By default, this will train a model on the `simple` corpus. However, you can change this to any text file using `--train_data` argument. Note that you may have to increase the sequence length (`--seq_len`) if your corpus is longer than the simple corpus. The other default arguments are set to match the best setting I found for the simple corpus (see discussion below).
76 |
77 | - Once training finishes, the model will be saved in `ckpts/simple`. You can then use this model to generate samples.
78 |
79 | - The checkpoint can also be downloaded from [here](https://drive.google.com/drive/folders/1UXx1HJVeWdAjlTNTiCydnCCHD431Q4yh?usp=sharing).
80 |
81 |
82 |
83 | ### Inference
84 |
85 | - To generate samples, run:
86 |
87 | ```sh
88 | bash scripts/text_sample.sh ckpts/simple/ema_0.9999_025000.pt 2000 10
89 | ```
90 | - Here:
91 | * `ckpts/simple/ema_0.9999_025000.pt` is the path to the checkpoint
92 | * `2000` is the number of diffusion steps.
93 | * `10` is the number of samples to generate.
94 |
95 | - By default, this will generate 10 samples from the model trained on the simple corpus. Changing `SEED` in `scripts/text_sample.sh` will generate different samples. You can also change the number of samples generated by changing the `NUM_SAMPLES` argument.
96 |
97 | - During inference (denoising), the intermediate sentences will be printed to the console.
98 |
99 | - The generated samples will be saved in `ckpt/simple/`.
100 |
101 | - Complete set of outputs are available [here](https://drive.google.com/drive/folders/1UXx1HJVeWdAjlTNTiCydnCCHD431Q4yh?usp=sharing).
102 |
103 |
104 | ## Training from scratch on the greetings dataset
105 |
106 | - I've added another trainign from scratch tutorial here: [greetings](./docs/training_on_your_own_dataset.md).
107 |
108 | ## Experiments with using pre-trained models and embeddings
109 |
110 | - Update 10/24: The most fluent/realistic outputs are obtained using i) word-level tokenization, ii) initializing a model from scratch, and iii) fine-tuning the embeddings. This is the default in `run_train.sh` now. Please see [docs/old_experiments.md](docs/old_experiments.md) for details on the experiments I ran before this update.
111 |
112 | ## Controllable Generation
113 |
114 | - The diffusion model can be combined with a classifier to perform classifier-guided diffusion. Please see details in [docs/controllable.md](docs/controllable.md).
115 |
116 |
117 | ## Gory details
118 |
119 |
120 | * Below are my rough notes on how the code works. [TODO] Clean this up and add more details.
121 |
122 | ### Training
123 |
124 | * Input text is embedded. This is the mean of `x_start_mean`. Some noise is added to `x_start_mean` to get `x_start`.
125 |
126 | * Using random `t`, a noisy version of the input is created from q(x_t | x_0). This is simply x_t = x_0 * sqrt(1 - \beta_t) + \epsilon_t * sqrt(\beta_t). The function used for this is `q_sample`. Any operation that involves going ahead in the diffusion process is carried out by functions that start with `q_`.
127 |
128 | * `x_t` is fed to the transformer model. Then, the transformer model is trained to generate an approximation of `x_start` given `x_t` and `t` (the timestep). Specifically, the embedded text is passed through a BERT encoder and downsampled. The size of the output embeddings and input embeddings is the same for this reason. Maybe this is the trick mentioned in the paper where they want to tie each weight with the `x_start` term, but I'm not sure how it's different from DDIM.
129 |
130 | * The loss has several terms:
131 | 1) Difference between the actual `x_start` and the output of the transformer model. This is the MSE loss.
132 | 2) Mean of the `xT` should be close to zero. This is the `tT_loss` term. It is obtained by calling `q_mean_variance` for the t=T. `q_mean_variance` is like `q_sample,` but it returns the mean and variance of the distribution `x_t | x0` instead of a sample.
133 |
134 | 3) Decoder NLL loss. This is the `decoder_nll` term. It is obtained by calling `token_discrete_loss`. `token_discrete_loss` calls `get_logits`, which in turns uses the embeddings to convert to logits. The logits are then used to calculate the NLL loss. Essentially this is how the embeddings are trained.
135 |
136 | ```py
137 |
138 | def get_logits(self, hidden_repr):
139 | return self.lm_head(hidden_repr)
140 | ```
141 |
142 |
143 | - One thing to note is that:
144 |
145 | ```py
146 | print(model.lm_head.weight == model.word_embedding.weight)
147 | print(model.lm_head.weight.shape, model.word_embedding.weight.shape)
148 | ```
149 |
150 | They are identical! Intuitively, the model is trained to predict the embedded input. Thus, having a linear layer with the weights from `word_embedding` is like doing a nearest neighbor search. While initializing, the weights are assigned to `lm_head` from `word_embedding` under `torch.no_grad()`, so that the gradients are not computed for `lm_head`.
151 |
152 |
153 | ### Evolving input
154 |
155 | - Note that the embeddings are *trained*. Although initial embeddings are passed in training losses, they are not used. Instead, the `get_embeds` method is used to get the embeddings. This is because the embeddings are trained to predict the input text. Thus, the embeddings are not the same as the input embeddings.
156 |
157 |
158 | ### Sampling
159 |
160 | * `p_mean_variance`: returns the distribution `p(x_{t-1} | x_t)` (the mean and variance). In addition, returns a prediction for the initial `x_0`.
161 |
162 | * `q_posterior_mean_variance`: returns the distribution `q(x_{t-1} | x_t, x_0)`.
163 |
164 | * Additionally, recall that our model is trained to predict `x_start` given `x_t` and `t`.
165 |
166 | - Putting these together, we can sample from the model. The sampling is done in the following way:
167 |
168 | 1. Starting with noise `xT`, a noisy `x_start` is first generated using the model.
169 |
170 | 2. The `xT` and `x_start` are used to generate `x_{T-1}` using `q_posterior_mean_variance` (`x_{T-1} ~ q(x_{T-1} | x_T, x_start)`).
171 |
172 | The process is repeated until `x_0` is generated.
173 |
174 | ---
175 |
176 |
177 | ## TODO
178 |
179 | - [ ] Add more details to the inner workings section.
180 | - [ ] Add classifier-guided sampling.
181 | - [ ] Add more experiments.
182 |
183 |
184 | ### Opportunities for further minimization
185 |
186 | - [ ] `logger.py` can be completely deleted.
187 | - [ ] `args.py` and `factory_methods.py` can be combined.
188 |
189 |
190 |
191 | ---
192 |
193 | ## Acknowledgements
194 |
195 | - Thanks to the team behind [Diffusion-LM Improves Controllable Text Generation](http://arxiv.org/abs/2205.14217) for releasing their code, which I used as a starting point.
196 | - Thanks to the authors of several open-source implementations of DDPM/DDIM, helpful blogs, and videos. Some of the ones I bookmarked are:
197 |
198 | | **Title** | **Url** |
199 | |:---:|:---:|
200 | | Tutorial on Denoising Diffusion-based Generative Modeling: Foundations and Applications | https://www.youtube.com/watch?v=cS6JQpEY9cs |
201 | | Composable Text Control Operations in Latent Space with Ordinary Differential Equations | http://arxiv.org/abs/2208.00638 |
202 | | Diffusion-LM Improves Controllable Text Generation | http://arxiv.org/abs/2205.14217 |
203 | | Step-unrolled Denoising Autoencoders for Text Generation | http://arxiv.org/abs/2112.06749 |
204 | | Latent Diffusion Energy-Based Model for Interpretable Text Modeling | http://arxiv.org/abs/2206.05895 |
205 | | Parti - Scaling Autoregressive Models for Content-Rich Text-to-Image Generation (Paper Explained) | https://www.youtube.com/watch?v=qS-iYnp00uc |
206 | | Deep Unsupervised Learning using Nonequilibrium Thermodynamics | http://arxiv.org/abs/1503.03585 |
207 | | lucidrains/denoising-diffusion-pytorch | https://github.com/lucidrains/denoising-diffusion-pytorch |
208 | | Guidance: a cheat code for diffusion models | https://benanne.github.io/2022/05/26/guidance.html |
209 | | Cold Diffusion: Inverting Arbitrary Image Transforms Without Noise | http://arxiv.org/abs/2208.09392 |
210 | | Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning | http://arxiv.org/abs/2208.04202 |
211 | | Diffusion Maps for Textual Network Embedding | https://proceedings.neurips.cc/paper/2018/hash/211a7a84d3d5ce4d80347da11e0c85ed-Abstract.html |
212 | | Diffusion-LM Improves Controllable Text Generation | https://github.com/XiangLi1999/Diffusion-LM |
213 | | Denoising Diffusion Probabilistic Models | http://arxiv.org/abs/2006.11239 |
214 | | Variational Diffusion Models | http://arxiv.org/abs/2107.00630 |
215 | | Elucidating the Design Space of Diffusion-Based Generative Models | http://arxiv.org/abs/2206.00364 |
216 | | Diffusion Models Beat GANs on Image Synthesis | http://arxiv.org/abs/2105.05233 |
217 | | guided-diffusion | https://github.com/openai/guided-diffusion |
218 | | Minimal implementation of diffusion models ⚛ | https://github.com/VSehwag/minimal-diffusion |
219 | | minDiffusion | https://github.com/cloneofsimo/minDiffusion |
220 | | What are Diffusion Models? | https://lilianweng.github.io/posts/2021-07-11-diffusion-models/ |
221 | | High-Resolution Image Synthesis with Latent Diffusion Models | http://arxiv.org/abs/2112.10752 |
222 | | Generative Modeling by Estimating Gradients of the Data Distribution \| Yang Song | https://yang-song.net/blog/2021/score/ |
223 | | GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models | http://arxiv.org/abs/2112.10741 |
224 | | Blended Diffusion for Text-driven Editing of Natural Images | http://arxiv.org/abs/2111.14818 |
225 | | Generative Modeling by Estimating Gradients of the Data Distribution | http://arxiv.org/abs/1907.05600 |
226 | | Diffusion Schr\"odinger Bridge with Applications to Score-Based Generative Modeling | http://arxiv.org/abs/2106.01357 |
227 | | Score-based Generative Modeling in Latent Space | http://arxiv.org/abs/2106.05931 |
228 | | A Connection Between Score Matching and Denoising Autoencoders | https://direct.mit.edu/neco/article/23/7/1661-1674/7677 |
229 | | Maximum Likelihood Training of Score-Based Diffusion Models | http://arxiv.org/abs/2101.09258 |
230 |
231 |
232 | ## License
233 |
234 | - MIT License
235 |
236 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/minimal-text-diffusion/9303ffd481a2f647da24c6053e4dec44fd086a8d/__init__.py
--------------------------------------------------------------------------------
/data/greetings-test.txt:
--------------------------------------------------------------------------------
1 | it's nice seeing you all!
2 | I'm so thrilled to see you!
3 | I'm excited to see you all!
4 | what's up?
5 | how is your family doing
6 | i hope life has been treating you good
7 | i hope everything is going well for you
8 | It's great to see you again!
9 | i hope life has been treating you good.
10 | hope all is well with your family.
11 | what's up?
12 | good seeing you again
13 | great to see you again!
14 | I've heard a lot about you and your family.
15 | welcome!
16 | it's great to meet someone like yourself!
17 | it's great to see you!
18 | how's it going?
19 | it's good to see you!
20 | I'm glad to see you
21 | how are things going with you and your family
22 | it's nice to see you too!
23 | it's nice meeting you
24 | nice to see you too!
25 | it's good to see you too!
26 | I'm happy to see you all!
27 | long time no see!
28 | hi there
29 | good morning!
30 | I've heard a lot about you and your family.
31 | good seeing everyone
32 | do you have any plans for today
33 | nice seeing you all again!
34 | I'm so excited to see you!
35 | I hope all is good with your family.
36 | dropped in because I was thinking of you.
37 | it was so nice meeting you!
38 | I've heard a lot of good things about you and your family.
39 | good morning
40 | I'm happy I got the chance to finally meet you in person.
41 | nice to meet you!
42 | "hi there, how's it going? hope things are good!"
43 | good to see you again!
44 | good evening!
45 | nice seeing you!
46 | I'm happy to see you!
47 | good seeing you all again!
48 | nice to meet you too!
49 | dropping in to let you know i'm thinking of you.
50 | how are things going?
51 | i hope life has been treating you right
52 | I'm happy to see you all!
53 | just wanted to say hi
54 | how is everyone doing
55 | it's nice to see you!
56 | I've heard a lot of good things abou you.
57 | it's great to see you as well!
58 | how's your day?
59 | i hope all is well with you and yours
60 | I'm glad to see you all!
61 | "hello there, how's it going? hope things are good!"
62 | hey buddy!
63 | it's been a long time!
64 | great seeing you
65 | "hello there, how are you? hope things are good!"
66 | it's great to see you again!
67 | it's nice seeing you all again!
68 | greetings! how are you doing today?
69 | I'm glad to see you all!
70 | how have you been
71 | I'm excited to see you!
72 | it's good to see you all!
73 | Good to see you all!
74 | it's good to see you as well!
75 | I'm thrilled to see you all!
76 | it's great meeting you
77 | it's great to see you again!
78 | it's good to see you again!
79 | how is everyone doing?
80 | nice seeing you again!
81 | nice to see you all again!
82 | it's nice to see you all!
83 | I have faith that things are going smoothly for you and yours
84 | how is everything?
85 | it's great seeing you all!
86 | I'm thrilled to see you!
87 | how's your day going?
88 | nice to see you again!
89 | I just wanted to say hi
90 | what's up?
91 | I've heard a lot about you.
92 | hello
93 | It's great seeing you again!
94 | I'm glad to see you!
95 | it's so nice to meet someone like yourself!
96 | hi friend! how are things going with your family?
97 | hope all is good in the hood!
98 | I'm glad i got the chance to finally meet you in person.
99 | hope all is well with your loved ones.
100 | It's great seeing you!
101 |
--------------------------------------------------------------------------------
/data/greetings-train.txt:
--------------------------------------------------------------------------------
1 | I have faith that things are going smoothly for you and yours
2 | long time no see!
3 | nice to see you too!
4 | I'm happy to see you!
5 | it was so good meeting you
6 | I'm stopping by to say hi
7 | I pray that things are going smoothly for you and yours
8 | I'm glad to see you!
9 | it's nice seeing you too!
10 | i hope life has been treating you good.
11 | it's great to finally meet in person.
12 | I'm stopping by to say hi
13 | i hope all is good with your family.
14 | nice to meet you too!
15 | I'm thrilled to see you!
16 | it's a pleasure to see you again!
17 | it's good to see you too!
18 | I'm happy to see you!
19 | i am stopping by to say hi
20 | I've heard a lot of good things about you and your family.
21 | hope all is well with your loved ones.
22 | nice to meet you!
23 | i'm stopping by to say hey!
24 | dropping in to let you know i'm thinking of you.
25 | nice seeing you again
26 | how's it going?
27 | "hi, how are you?"
28 | it's so nice to meet someone like yourself!
29 | welcome!
30 | I'm excited to see you!
31 | It's so nice to meet you!
32 | nice seeing you all again!
33 | I'm thrilled to see you all!
34 | hope all is well with your family.
35 | nice seeing you!
36 | it's nice to see you again!
37 | I just wanted to say hello
38 | I pray that things are going smoothly for you and yours
39 | nice to see you all again!
40 | nice to meet you in person
41 | how are you doing
42 | nice seeing you again.
43 | I've heard a lot of good things abou you.
44 | do you have any plans for the weekend
45 | it's good to see you again!
46 | good seeing everyone
47 | good to see you again!
48 | I'm happy to see you!
49 | I'm so happy to see you
50 | how's your day going?
51 | it's nice seeing you too!
52 | what's up?
53 | it's good seeing you again.
54 | "hello there, how's it going? hope things are good!"
55 | what is new
56 | I'm so happy to see you
57 | It's been a while since we last talked.
58 | hi there
59 | I'm happy to see you all!
60 | what's up?
61 | welcome!
62 | It's great seeing you again!
63 | it's good to see you again!
64 | it's nice to see you all again!
65 | I hope all is good with your family.
66 | what's up?
67 | i hope all is well with you and yours
68 | hey friend!
69 | good afternoon
70 | it's good seeing you!
71 | I trust that things are going smoothly for you and yours
72 | "hi there, how's it going? hope things are good!"
73 | good afternoon!
74 | It's exciting to see everyone
75 | hey pal!
76 | "hi there, how's it going? hope things are good!"
77 | I just wanted to say hi
78 | I'm thrilled to see you!
79 | I'm glad to see you
80 | I'm glad i got the chance to finally meet you in person.
81 | how is everything going with you and yours
82 | it's good to see you!
83 | it was so nice meeting you!
84 | hey mate!
85 | Hey team how is everyone doing?
86 | it's so good to finally meet in person.
87 | how are things going with you and your family
88 | I have faith that things are going smoothly for you and yours
89 | I'm excited to see you all!
90 | how's your day been?
91 | just wanted to say hi
92 | how have you been
93 | I'm excited to see you!
94 | great to see you again!
95 | great to see you again!
96 | I'm excited to see you!
97 | It's good to see you again!
98 | I've heard a lot about you and your family.
99 | just stopping by to say hello
100 | it's been too long!
101 | I'm excited to see you!
102 | good morning!
103 | how's your day going so far?
104 | good seeing you again
105 | I've heard a lot of good things about you and your family.
106 | hope all is good in the hood!
107 | it's good to see you as well!
108 | good night!
109 | hey mate!
110 | it's good to see you as well!
111 | I'm glad to see you all!
112 | hey buddy!
113 | Nice meeting you
114 | it's great to meet you.
115 | how are things going with you and your family?
116 | it's great to see you as well!
117 | I trust that things are going smoothly for you
118 | "hello there, how's it going? hope things are good!"
119 | how's it going?
120 | "hello there, how are you? hope things are good!"
121 | i'm happy I got the chance to finally meet you in person.
122 | I'm so thrilled to see you all!
123 | it's nice to see you too!
124 | I trust that things are going smoothly for you and yours
125 | I'm glad i got the chance to finally meet you in person.
126 | i am dropping in to say hey!
127 | hey good to see you!
128 | dropping in to let you know i'm thinking of you.
129 | it's great meeting you
130 | I'm happy to see you!
131 | I just stopped by to say hello
132 | "hello there, how are you? hope things are good!"
133 | how is your family doing
134 | I'm so thrilled to see you!
135 | I trust that things are going smoothly for you
136 | i am stopping by to say hi
137 | It's great to finally meet you in person
138 | it's great to see you too!
139 | it's great to see you as well!
140 | it's great seeing you!
141 | "hi there, how are you? hope things are good!"
142 | I just stopped by to say hello
143 | nice to see you all!
144 | I'm excited to see you all!
145 | I've heard nothing but good things about you.
146 | i am dropping in to say hey!
147 | it's wonderful meeting you
148 | how's it going?
149 | i hope everything is going well with you
150 | dropped in because I was thinking of you.
151 | it's nice to see you!
152 | good to see you again!
153 | nice seeing you all!
154 | I'm so excited to see you all!
155 | I'm thrilled to be here
156 | good morning
157 | I'm thrilled to see you all!
158 | just stopping by to say hello
159 | i hope life has been treating you good
160 | how are things?
161 | how's your day going?
162 | hey friend!
163 | I'm thrilled to see you!
164 | I'm glad i got the chance to finally meet in person!
165 | it's nice seeing you all again!
166 | I'm excited to see you!
167 | It's great seeing you!
168 | it's nice seeing you all!
169 | how are you doing?
170 | I've heard nothing but good things about you and your family.
171 | I've heard nothing but good things about you and your family.
172 | do you have any plans for today
173 | how are you doing?
174 | good evening
175 | I'm excited to see you!
176 | it's good to see you all again!
177 | I just wanted to say hi
178 | I've heard a lot about you.
179 | do you have any plans for the weekend
180 | dropped in because I was thinking of you.
181 | it's good to see you too!
182 | I'm excited to see you!
183 | great seeing you all!
184 | I'm glad to see you
185 | it's great to see you again!
186 | Just saying hello
187 | how is everything?
188 | good to see you again!
189 | nice to finally meet in person
190 | good seeing you all!
191 | good to see you!
192 | it's been a while!
193 | It's nice seeing you.
194 | good to see you all again!
195 | it's nice seeing you again!
196 | nice seeing you again
197 | great seeing you all again!
198 | Hope all is good in the hood!
199 | It was good meeting everyone
200 | I'm glad to see you all!
201 | hey buddy!
202 | I'm glad to see you!
203 | how is everyone doing
204 | I'm happy I got the chance to finally meet you in person.
205 | hello
206 | it's a pleasure to meet you!
207 | it was so nice meeting you
208 | i'm stopping by to say hey!
209 | Hope all is well with your loved ones.
210 | how's your day?
211 | greetings friend! how are you doing today?
212 | hi friend! how are things going with your family?
213 | it's great seeing you all again!
214 | hey my friend
215 | hey good to see you!
216 | how is everything going?
217 | how has your day been so far
218 | welcome back
219 | "greetings friend, how are you doing today?"
220 | it's good to see you all!
221 | nice to see you!
222 | good to see you all
223 | Hope all is well with your family.
224 | I'm glad to see you all!
225 | what's new
226 | great seeing you again
227 | Just saying hello
228 | it's nice meeting you
229 | nice meeting you!
230 | "hi, how are you?"
231 | how are things going?
232 | just saying hi
233 | just saying hi
234 | I'm happy to see you all!
235 | hope all is well
236 | it's great to see you again!
237 | i hope everything is going well for you
238 | It's great to see you again!
239 | good seeing you all again!
240 | it's nice seeing you again!
241 | It's good to see you!
242 | how are things everyone?
243 | nice seeing you!
244 | great seeing you!
245 | good to see you!
246 | it's great to see you!
247 | it's been a long time!
248 | how's it going?
249 | how is everyone doing?
250 | i hope life has been treating you right
251 | how is life treating you
252 | welcome back
253 | it's great to see you too!
254 | "hey there, nice to see you!"
255 | nice to see you again!
256 | Good to see you all!
257 | it's nice seeing you as well!
258 | it's great to see you again!
259 | nice seeing you again!
260 | I'm excited to see you all!
261 | good evening!
262 | Hey team how's it going?
263 | how have you been?
264 | how is everyone doing? i hope all's well.
265 | it's nice to see you as well!
266 | "hey team, how is everything going?"
267 | it's nice to see you all!
268 | "hi there, how are you? hope things are good!"
269 | hey pal!
270 | it's nice seeing you!
271 | I'm happy to see you all!
272 | how's your day?
273 | I'm excited to see you!
274 | I'm so excited to see you!
275 | just wanted to say hi
276 | good to see you again!
277 | greetings! how are you doing today?
278 | nice meeting you again.
279 | great seeing you
280 | I'm happy to see you
281 | it's good to see you again!
282 | It's nice to meet you!
283 | I'm glad to see you!
284 | I just wanted to say hello
285 | It was good meeting you
286 | welcome back!
287 | greetings! how are you doing today?
288 | It's great to see you!
289 | what's up?
290 | I'm glad to see you!
291 | hey my friend!
292 | I'm happy to see you
293 | it's great to meet someone like yourself!
294 | good seeing you
295 | I've heard a lot about you and your family.
296 | how is your family doing?
297 | it's great seeing you all!
298 | do you have any plans for today?
299 |
--------------------------------------------------------------------------------
/data/greetings/greetings-test.txt:
--------------------------------------------------------------------------------
1 | it's nice seeing you all!
2 | I'm so thrilled to see you!
3 | I'm excited to see you all!
4 | what's up?
5 | how is your family doing
6 | i hope life has been treating you good
7 | i hope everything is going well for you
8 | It's great to see you again!
9 | i hope life has been treating you good.
10 | hope all is well with your family.
11 | what's up?
12 | good seeing you again
13 | great to see you again!
14 | I've heard a lot about you and your family.
15 | welcome!
16 | it's great to meet someone like yourself!
17 | it's great to see you!
18 | how's it going?
19 | it's good to see you!
20 | I'm glad to see you
21 | how are things going with you and your family
22 | it's nice to see you too!
23 | it's nice meeting you
24 | nice to see you too!
25 | it's good to see you too!
26 | I'm happy to see you all!
27 | long time no see!
28 | hi there
29 | good morning!
30 | I've heard a lot about you and your family.
31 | good seeing everyone
32 | do you have any plans for today
33 | nice seeing you all again!
34 | I'm so excited to see you!
35 | I hope all is good with your family.
36 | dropped in because I was thinking of you.
37 | it was so nice meeting you!
38 | I've heard a lot of good things about you and your family.
39 | good morning
40 | I'm happy I got the chance to finally meet you in person.
41 | nice to meet you!
42 | "hi there, how's it going? hope things are good!"
43 | good to see you again!
44 | good evening!
45 | nice seeing you!
46 | I'm happy to see you!
47 | good seeing you all again!
48 | nice to meet you too!
49 | dropping in to let you know i'm thinking of you.
50 | how are things going?
51 | i hope life has been treating you right
52 | I'm happy to see you all!
53 | just wanted to say hi
54 | how is everyone doing
55 | it's nice to see you!
56 | I've heard a lot of good things abou you.
57 | it's great to see you as well!
58 | how's your day?
59 | i hope all is well with you and yours
60 | I'm glad to see you all!
61 | "hello there, how's it going? hope things are good!"
62 | hey buddy!
63 | it's been a long time!
64 | great seeing you
65 | "hello there, how are you? hope things are good!"
66 | it's great to see you again!
67 | it's nice seeing you all again!
68 | greetings! how are you doing today?
69 | I'm glad to see you all!
70 | how have you been
71 | I'm excited to see you!
72 | it's good to see you all!
73 | Good to see you all!
74 | it's good to see you as well!
75 | I'm thrilled to see you all!
76 | it's great meeting you
77 | it's great to see you again!
78 | it's good to see you again!
79 | how is everyone doing?
80 | nice seeing you again!
81 | nice to see you all again!
82 | it's nice to see you all!
83 | I have faith that things are going smoothly for you and yours
84 | how is everything?
85 | it's great seeing you all!
86 | I'm thrilled to see you!
87 | how's your day going?
88 | nice to see you again!
89 | I just wanted to say hi
90 | what's up?
91 | I've heard a lot about you.
92 | hello
93 | It's great seeing you again!
94 | I'm glad to see you!
95 | it's so nice to meet someone like yourself!
96 | hi friend! how are things going with your family?
97 | hope all is good in the hood!
98 | I'm glad i got the chance to finally meet you in person.
99 | hope all is well with your loved ones.
100 | It's great seeing you!
101 |
--------------------------------------------------------------------------------
/data/greetings/greetings-train.txt:
--------------------------------------------------------------------------------
1 | I have faith that things are going smoothly for you and yours
2 | long time no see!
3 | nice to see you too!
4 | I'm happy to see you!
5 | it was so good meeting you
6 | I'm stopping by to say hi
7 | I pray that things are going smoothly for you and yours
8 | I'm glad to see you!
9 | it's nice seeing you too!
10 | i hope life has been treating you good.
11 | it's great to finally meet in person.
12 | I'm stopping by to say hi
13 | i hope all is good with your family.
14 | nice to meet you too!
15 | I'm thrilled to see you!
16 | it's a pleasure to see you again!
17 | it's good to see you too!
18 | I'm happy to see you!
19 | i am stopping by to say hi
20 | I've heard a lot of good things about you and your family.
21 | hope all is well with your loved ones.
22 | nice to meet you!
23 | i'm stopping by to say hey!
24 | dropping in to let you know i'm thinking of you.
25 | nice seeing you again
26 | how's it going?
27 | "hi, how are you?"
28 | it's so nice to meet someone like yourself!
29 | welcome!
30 | I'm excited to see you!
31 | It's so nice to meet you!
32 | nice seeing you all again!
33 | I'm thrilled to see you all!
34 | hope all is well with your family.
35 | nice seeing you!
36 | it's nice to see you again!
37 | I just wanted to say hello
38 | I pray that things are going smoothly for you and yours
39 | nice to see you all again!
40 | nice to meet you in person
41 | how are you doing
42 | nice seeing you again.
43 | I've heard a lot of good things abou you.
44 | do you have any plans for the weekend
45 | it's good to see you again!
46 | good seeing everyone
47 | good to see you again!
48 | I'm happy to see you!
49 | I'm so happy to see you
50 | how's your day going?
51 | it's nice seeing you too!
52 | what's up?
53 | it's good seeing you again.
54 | "hello there, how's it going? hope things are good!"
55 | what is new
56 | I'm so happy to see you
57 | It's been a while since we last talked.
58 | hi there
59 | I'm happy to see you all!
60 | what's up?
61 | welcome!
62 | It's great seeing you again!
63 | it's good to see you again!
64 | it's nice to see you all again!
65 | I hope all is good with your family.
66 | what's up?
67 | i hope all is well with you and yours
68 | hey friend!
69 | good afternoon
70 | it's good seeing you!
71 | I trust that things are going smoothly for you and yours
72 | "hi there, how's it going? hope things are good!"
73 | good afternoon!
74 | It's exciting to see everyone
75 | hey pal!
76 | "hi there, how's it going? hope things are good!"
77 | I just wanted to say hi
78 | I'm thrilled to see you!
79 | I'm glad to see you
80 | I'm glad i got the chance to finally meet you in person.
81 | how is everything going with you and yours
82 | it's good to see you!
83 | it was so nice meeting you!
84 | hey mate!
85 | Hey team how is everyone doing?
86 | it's so good to finally meet in person.
87 | how are things going with you and your family
88 | I have faith that things are going smoothly for you and yours
89 | I'm excited to see you all!
90 | how's your day been?
91 | just wanted to say hi
92 | how have you been
93 | I'm excited to see you!
94 | great to see you again!
95 | great to see you again!
96 | I'm excited to see you!
97 | It's good to see you again!
98 | I've heard a lot about you and your family.
99 | just stopping by to say hello
100 | it's been too long!
101 | I'm excited to see you!
102 | good morning!
103 | how's your day going so far?
104 | good seeing you again
105 | I've heard a lot of good things about you and your family.
106 | hope all is good in the hood!
107 | it's good to see you as well!
108 | good night!
109 | hey mate!
110 | it's good to see you as well!
111 | I'm glad to see you all!
112 | hey buddy!
113 | Nice meeting you
114 | it's great to meet you.
115 | how are things going with you and your family?
116 | it's great to see you as well!
117 | I trust that things are going smoothly for you
118 | "hello there, how's it going? hope things are good!"
119 | how's it going?
120 | "hello there, how are you? hope things are good!"
121 | i'm happy I got the chance to finally meet you in person.
122 | I'm so thrilled to see you all!
123 | it's nice to see you too!
124 | I trust that things are going smoothly for you and yours
125 | I'm glad i got the chance to finally meet you in person.
126 | i am dropping in to say hey!
127 | hey good to see you!
128 | dropping in to let you know i'm thinking of you.
129 | it's great meeting you
130 | I'm happy to see you!
131 | I just stopped by to say hello
132 | "hello there, how are you? hope things are good!"
133 | how is your family doing
134 | I'm so thrilled to see you!
135 | I trust that things are going smoothly for you
136 | i am stopping by to say hi
137 | It's great to finally meet you in person
138 | it's great to see you too!
139 | it's great to see you as well!
140 | it's great seeing you!
141 | "hi there, how are you? hope things are good!"
142 | I just stopped by to say hello
143 | nice to see you all!
144 | I'm excited to see you all!
145 | I've heard nothing but good things about you.
146 | i am dropping in to say hey!
147 | it's wonderful meeting you
148 | how's it going?
149 | i hope everything is going well with you
150 | dropped in because I was thinking of you.
151 | it's nice to see you!
152 | good to see you again!
153 | nice seeing you all!
154 | I'm so excited to see you all!
155 | I'm thrilled to be here
156 | good morning
157 | I'm thrilled to see you all!
158 | just stopping by to say hello
159 | i hope life has been treating you good
160 | how are things?
161 | how's your day going?
162 | hey friend!
163 | I'm thrilled to see you!
164 | I'm glad i got the chance to finally meet in person!
165 | it's nice seeing you all again!
166 | I'm excited to see you!
167 | It's great seeing you!
168 | it's nice seeing you all!
169 | how are you doing?
170 | I've heard nothing but good things about you and your family.
171 | I've heard nothing but good things about you and your family.
172 | do you have any plans for today
173 | how are you doing?
174 | good evening
175 | I'm excited to see you!
176 | it's good to see you all again!
177 | I just wanted to say hi
178 | I've heard a lot about you.
179 | do you have any plans for the weekend
180 | dropped in because I was thinking of you.
181 | it's good to see you too!
182 | I'm excited to see you!
183 | great seeing you all!
184 | I'm glad to see you
185 | it's great to see you again!
186 | Just saying hello
187 | how is everything?
188 | good to see you again!
189 | nice to finally meet in person
190 | good seeing you all!
191 | good to see you!
192 | it's been a while!
193 | It's nice seeing you.
194 | good to see you all again!
195 | it's nice seeing you again!
196 | nice seeing you again
197 | great seeing you all again!
198 | Hope all is good in the hood!
199 | It was good meeting everyone
200 | I'm glad to see you all!
201 | hey buddy!
202 | I'm glad to see you!
203 | how is everyone doing
204 | I'm happy I got the chance to finally meet you in person.
205 | hello
206 | it's a pleasure to meet you!
207 | it was so nice meeting you
208 | i'm stopping by to say hey!
209 | Hope all is well with your loved ones.
210 | how's your day?
211 | greetings friend! how are you doing today?
212 | hi friend! how are things going with your family?
213 | it's great seeing you all again!
214 | hey my friend
215 | hey good to see you!
216 | how is everything going?
217 | how has your day been so far
218 | welcome back
219 | "greetings friend, how are you doing today?"
220 | it's good to see you all!
221 | nice to see you!
222 | good to see you all
223 | Hope all is well with your family.
224 | I'm glad to see you all!
225 | what's new
226 | great seeing you again
227 | Just saying hello
228 | it's nice meeting you
229 | nice meeting you!
230 | "hi, how are you?"
231 | how are things going?
232 | just saying hi
233 | just saying hi
234 | I'm happy to see you all!
235 | hope all is well
236 | it's great to see you again!
237 | i hope everything is going well for you
238 | It's great to see you again!
239 | good seeing you all again!
240 | it's nice seeing you again!
241 | It's good to see you!
242 | how are things everyone?
243 | nice seeing you!
244 | great seeing you!
245 | good to see you!
246 | it's great to see you!
247 | it's been a long time!
248 | how's it going?
249 | how is everyone doing?
250 | i hope life has been treating you right
251 | how is life treating you
252 | welcome back
253 | it's great to see you too!
254 | "hey there, nice to see you!"
255 | nice to see you again!
256 | Good to see you all!
257 | it's nice seeing you as well!
258 | it's great to see you again!
259 | nice seeing you again!
260 | I'm excited to see you all!
261 | good evening!
262 | Hey team how's it going?
263 | how have you been?
264 | how is everyone doing? i hope all's well.
265 | it's nice to see you as well!
266 | "hey team, how is everything going?"
267 | it's nice to see you all!
268 | "hi there, how are you? hope things are good!"
269 | hey pal!
270 | it's nice seeing you!
271 | I'm happy to see you all!
272 | how's your day?
273 | I'm excited to see you!
274 | I'm so excited to see you!
275 | just wanted to say hi
276 | good to see you again!
277 | greetings! how are you doing today?
278 | nice meeting you again.
279 | great seeing you
280 | I'm happy to see you
281 | it's good to see you again!
282 | It's nice to meet you!
283 | I'm glad to see you!
284 | I just wanted to say hello
285 | It was good meeting you
286 | welcome back!
287 | greetings! how are you doing today?
288 | It's great to see you!
289 | what's up?
290 | I'm glad to see you!
291 | hey my friend!
292 | I'm happy to see you
293 | it's great to meet someone like yourself!
294 | good seeing you
295 | I've heard a lot about you and your family.
296 | how is your family doing?
297 | it's great seeing you all!
298 | do you have any plans for today?
299 |
--------------------------------------------------------------------------------
/data/greetings/greetings.py:
--------------------------------------------------------------------------------
1 | # list of 1000 greetings in Python
2 | greetings = [
3 | "hi, how are you?",
4 | "hey good to see you!",
5 | "how's it going?",
6 | "what's up?",
7 | "how's your day?",
8 | "how's your day going?",
9 | "how's your day been?",
10 | "how's your day going so far?",
11 | "how is everyone doing?",
12 | "how are you doing?",
13 | "good to see you!",
14 | "good to see you again!",
15 | "nice to meet you!",
16 | "nice to see you!",
17 | "it's nice to see you again!",
18 | "it's a pleasure to meet you!",
19 | "it's a pleasure to see you again!",
20 | "nice to see you again!",
21 | "nice to see you too!",
22 | "nice to meet you too!",
23 | "long time no see!",
24 | "it's been a while!",
25 | "it's been a long time!",
26 | "it's been too long!",
27 | "welcome back!",
28 | "great seeing you!",
29 | "good to see you again!",
30 | "it's good to see you again!",
31 | "it's good to see you too!",
32 | "it's good to see you as well!",
33 | "it's great to see you again!",
34 | "it's great to see you too!",
35 | "it's great to see you as well!",
36 | "it's nice seeing you again!",
37 | "it's nice seeing you too!",
38 | "good to see you all",
39 | "good to see you all again!",
40 | "it's good to see you all!",
41 | "it's good to see you all again!",
42 | "it's nice to see you all!",
43 | "it's nice to see you all again!",
44 | "it's nice seeing you all!",
45 | "it's nice seeing you all again!",
46 | "nice to see you all!",
47 | "nice to see you all again!",
48 | "nice seeing you all!",
49 | "nice seeing you all again!",
50 | "good seeing you all!",
51 | "good seeing you all again!",
52 | "great seeing you all!",
53 | "great seeing you all again!",
54 | "it's great seeing you all!",
55 | "it's great seeing you all again!",
56 | "good to see you!",
57 | "it's good to see you!",
58 | "it's good to see you too!",
59 | "it's good to see you as well!",
60 | "it's nice to see you!",
61 | "it's nice to see you too!",
62 | "it's nice to see you as well!",
63 | "it's great to see you!",
64 | "it's great to see you too!",
65 | "it's great to see you as well!",
66 | "it's nice seeing you!",
67 | "it's nice seeing you too!",
68 | "it's nice seeing you as well!",
69 | "it's great seeing you!",
70 | "I'm thrilled to be here",
71 | "It's exciting to see everyone",
72 | "I'm so excited to see you all!",
73 | "I'm so excited to see you!",
74 | "I'm thrilled to see you all!",
75 | "I'm thrilled to see you!",
76 | "I'm so thrilled to see you all!",
77 | "I'm so thrilled to see you!",
78 | "I'm excited to see you all!",
79 | "I'm excited to see you!",
80 | "I'm excited to see you!",
81 | "I'm thrilled to see you all!",
82 | "I'm thrilled to see you!",
83 | "I'm thrilled to see you!",
84 | "I'm excited to see you all!",
85 | "I'm excited to see you!",
86 | "I'm excited to see you!",
87 | "I'm excited to see you all!",
88 | "I'm excited to see you!",
89 | "I'm excited to see you!",
90 | "I'm happy to see you!",
91 | "I'm happy to see you all!",
92 | "I'm happy to see you all!",
93 | "I'm happy to see you all!",
94 | "I'm happy to see you!",
95 | "I'm happy to see you!",
96 | "I'm happy to see you!",
97 | "I'm glad to see you!",
98 | "I'm glad to see you all!",
99 | "I'm glad to see you all!",
100 | "I'm glad to see you all!",
101 | "I'm glad to see you!",
102 | "I'm glad to see you!",
103 | "I'm glad to see you!",
104 | "nice seeing you again!",
105 | "it's nice seeing you again!",
106 | "how are things going?",
107 | "how are things?",
108 | "what's up?",
109 | "hi there",
110 | "hello",
111 | "good morning",
112 | "good afternoon",
113 | "good evening",
114 | "greetings friend! how are you doing today?",
115 | "greetings! how are you doing today?",
116 | "how are you doing?",
117 | "how's it going?",
118 | "how is everything going?",
119 | "how is everything?",
120 | "how is life treating you",
121 | "how has your day been so far",
122 | "hi there, how are you? hope things are good!",
123 | "hello there, how are you? hope things are good!",
124 | "hi there, how's it going? hope things are good!",
125 | "hello there, how's it going? hope things are good!",
126 | "just saying hi",
127 | "just wanted to say hi",
128 | "I just wanted to say hi",
129 | "Just saying hello",
130 | "I just wanted to say hello",
131 | "just stopping by to say hello",
132 | "I just stopped by to say hello",
133 | "i am stopping by to say hi",
134 | "I'm stopping by to say hi",
135 | "i'm stopping by to say hey!",
136 | "i am dropping in to say hey!",
137 | "dropping in to let you know i'm thinking of you.",
138 | "dropped in because I was thinking of you.",
139 | "hey friend!",
140 | "hey buddy!",
141 | "hey pal!",
142 | "hey mate!",
143 | "hey my friend!",
144 | "what is new",
145 | "how have you been",
146 | "how are things going with you and your family",
147 | "how is your family doing",
148 | "do you have any plans for today",
149 | "do you have any plans for the weekend",
150 | "what's new",
151 | "what's up?",
152 | "how have you been?",
153 | "how are things going with you and your family?",
154 | "how is your family doing?",
155 | "do you have any plans for today?",
156 | "do you have any plans for the weekend",
157 | "hope all is well",
158 | "i hope all is well with you and yours",
159 | "i hope life has been treating you good",
160 | "i hope life has been treating you right",
161 | "i hope everything is going well for you",
162 | "i hope everything is going well with you",
163 | "i hope life has been treating you good.",
164 | "i hope all is good with your family.",
165 | "hope all is well with your family.",
166 | "hope all is well with your loved ones.",
167 | "hope all is good in the hood!",
168 | "I trust that things are going smoothly for you",
169 | "I trust that things are going smoothly for you and yours",
170 | "I have faith that things are going smoothly for you and yours",
171 | "I pray that things are going smoothly for you and yours",
172 | "I hope all is good with your family.",
173 | "Hope all is well with your family.",
174 | "Hope all is well with your loved ones.",
175 | "Hope all is good in the hood!",
176 | "I trust that things are going smoothly for you",
177 | "I trust that things are going smoothly for you and yours",
178 | "I have faith that things are going smoothly for you and yours",
179 | "I pray that things are going smoothly for you and yours",
180 | "welcome!",
181 | "welcome back",
182 | "it's good to see you again!",
183 | "good to see you again!",
184 | "it's great to see you again!",
185 | "great to see you again!",
186 | "nice seeing you!",
187 | "nice seeing you again",
188 | "I'm glad to see you",
189 | "I'm happy to see you",
190 | "I'm so happy to see you",
191 | "it's so good to finally meet in person.",
192 | "I've heard a lot about you.",
193 | "I've heard a lot of good things abou you.",
194 | "I've heard nothing but good things about you.",
195 | "I've heard a lot about you and your family.",
196 | "I've heard a lot of good things about you and your family.",
197 | "I've heard nothing but good things about you and your family.",
198 | "nice to meet you in person",
199 | "nice to finally meet in person",
200 | "it's nice meeting you",
201 | "it's great meeting you",
202 | "it's wonderful meeting you",
203 | "it's so nice to meet someone like yourself!",
204 | "it's great to meet someone like yourself!",
205 | "it was so nice meeting you",
206 | "it was so good meeting you",
207 | "I'm glad i got the chance to finally meet you in person.",
208 | "i'm happy I got the chance to finally meet you in person.",
209 | "I've heard a lot about you and your family.",
210 | "I've heard a lot of good things about you and your family.",
211 | "I've heard nothing but good things about you and your family.",
212 | "hi, how are you?",
213 | "hey good to see you!",
214 | "how's it going?",
215 | "what's up?",
216 | "how's your day?",
217 | "how's your day going?",
218 | "how is everyone doing",
219 | "good morning!",
220 | "good afternoon!",
221 | "good evening!",
222 | "good night!",
223 | "greetings friend, how are you doing today?",
224 | "greetings! how are you doing today?",
225 | "how are you doing",
226 | "how's it going?",
227 | "how is everything going with you and yours",
228 | "hi there, how are you? hope things are good!",
229 | "hello there, how are you? hope things are good!",
230 | "hi there, how's it going? hope things are good!",
231 | "hello there, how's it going? hope things are good!",
232 | "just saying hi",
233 | "just wanted to say hi",
234 | "I just wanted to say hi",
235 | "Just saying hello",
236 | "I just wanted to say hello",
237 | "just stopping by to say hello",
238 | "I just stopped by to say hello",
239 | "i am stopping by to say hi",
240 | "I'm stopping by to say hi",
241 | "i'm stopping by to say hey!",
242 | "i am dropping in to say hey!",
243 | "dropping in to let you know i'm thinking of you.",
244 | "dropped in because I was thinking of you.",
245 | "hey friend!",
246 | "hey buddy!",
247 | "hey pal!",
248 | "hey mate!",
249 | "hey my friend",
250 | "welcome!",
251 | "welcome back",
252 | "it's good to see you again!",
253 | "good to see you again!",
254 | "it's great to see you again!",
255 | "great to see you again!",
256 | "nice seeing you!",
257 | "nice seeing you again",
258 | "I'm glad to see you",
259 | "I'm happy to see you",
260 | "I'm so happy to see you",
261 | "it's good seeing you!",
262 | "It's great seeing you!",
263 | "It's great to see you!",
264 | "It's good to see you!",
265 | "good seeing you",
266 | "great seeing you",
267 | "nice meeting you again.",
268 | "nice seeing you again.",
269 | "it's good seeing you again.",
270 | "It's great seeing you again!",
271 | "It's great to see you again!",
272 | "It's good to see you again!",
273 | "good seeing you again",
274 | "great seeing you again",
275 | "nice meeting you!",
276 | "It's nice to meet you!",
277 | "It's so nice to meet you!",
278 | "it's great to finally meet in person.",
279 | "It's great to finally meet you in person",
280 | "it was so nice meeting you!",
281 | "I'm glad i got the chance to finally meet you in person.",
282 | "I'm happy I got the chance to finally meet you in person.",
283 | "it's great to meet you.",
284 | "It was good meeting you",
285 | "I'm glad i got the chance to finally meet in person!",
286 | "I'm excited to see you!",
287 | "It's nice seeing you.",
288 | "Nice meeting you",
289 | "Hey team how is everyone doing?",
290 | "Hey team how's it going?",
291 | "Good to see you all!",
292 | "It was good meeting everyone",
293 | "It's been a while since we last talked.",
294 | "I'm excited to see you!",
295 | "good seeing everyone",
296 | "hey team, how is everything going?",
297 | "how are things everyone?",
298 | "hey there, nice to see you!",
299 | "hi friend! how are things going with your family?",
300 | "how is everyone doing? i hope all's well.",
301 | ]
302 |
303 | import pandas as pd
304 | data = pd.DataFrame({"greeting": greetings})
305 | data = data.sample(frac=1).reset_index(drop=True)
306 | data.to_csv("greetings.txt", index=False)
--------------------------------------------------------------------------------
/data/greetings/greetings.txt:
--------------------------------------------------------------------------------
1 | I have faith that things are going smoothly for you and yours
2 | long time no see!
3 | nice to see you too!
4 | I'm happy to see you!
5 | it was so good meeting you
6 | I'm stopping by to say hi
7 | I pray that things are going smoothly for you and yours
8 | I'm glad to see you!
9 | it's nice seeing you too!
10 | i hope life has been treating you good.
11 | it's great to finally meet in person.
12 | I'm stopping by to say hi
13 | i hope all is good with your family.
14 | nice to meet you too!
15 | I'm thrilled to see you!
16 | it's a pleasure to see you again!
17 | it's good to see you too!
18 | I'm happy to see you!
19 | i am stopping by to say hi
20 | I've heard a lot of good things about you and your family.
21 | hope all is well with your loved ones.
22 | nice to meet you!
23 | i'm stopping by to say hey!
24 | dropping in to let you know i'm thinking of you.
25 | nice seeing you again
26 | how's it going?
27 | "hi, how are you?"
28 | it's so nice to meet someone like yourself!
29 | welcome!
30 | I'm excited to see you!
31 | It's so nice to meet you!
32 | nice seeing you all again!
33 | I'm thrilled to see you all!
34 | hope all is well with your family.
35 | nice seeing you!
36 | it's nice to see you again!
37 | I just wanted to say hello
38 | I pray that things are going smoothly for you and yours
39 | nice to see you all again!
40 | nice to meet you in person
41 | how are you doing
42 | nice seeing you again.
43 | I've heard a lot of good things abou you.
44 | do you have any plans for the weekend
45 | it's good to see you again!
46 | good seeing everyone
47 | good to see you again!
48 | I'm happy to see you!
49 | I'm so happy to see you
50 | how's your day going?
51 | it's nice seeing you too!
52 | what's up?
53 | it's good seeing you again.
54 | "hello there, how's it going? hope things are good!"
55 | what is new
56 | I'm so happy to see you
57 | It's been a while since we last talked.
58 | hi there
59 | I'm happy to see you all!
60 | what's up?
61 | welcome!
62 | It's great seeing you again!
63 | it's good to see you again!
64 | it's nice to see you all again!
65 | I hope all is good with your family.
66 | what's up?
67 | i hope all is well with you and yours
68 | hey friend!
69 | good afternoon
70 | it's good seeing you!
71 | I trust that things are going smoothly for you and yours
72 | "hi there, how's it going? hope things are good!"
73 | good afternoon!
74 | It's exciting to see everyone
75 | hey pal!
76 | "hi there, how's it going? hope things are good!"
77 | I just wanted to say hi
78 | I'm thrilled to see you!
79 | I'm glad to see you
80 | I'm glad i got the chance to finally meet you in person.
81 | how is everything going with you and yours
82 | it's good to see you!
83 | it was so nice meeting you!
84 | hey mate!
85 | Hey team how is everyone doing?
86 | it's so good to finally meet in person.
87 | how are things going with you and your family
88 | I have faith that things are going smoothly for you and yours
89 | I'm excited to see you all!
90 | how's your day been?
91 | just wanted to say hi
92 | how have you been
93 | I'm excited to see you!
94 | great to see you again!
95 | great to see you again!
96 | I'm excited to see you!
97 | It's good to see you again!
98 | I've heard a lot about you and your family.
99 | just stopping by to say hello
100 | it's been too long!
101 | I'm excited to see you!
102 | good morning!
103 | how's your day going so far?
104 | good seeing you again
105 | I've heard a lot of good things about you and your family.
106 | hope all is good in the hood!
107 | it's good to see you as well!
108 | good night!
109 | hey mate!
110 | it's good to see you as well!
111 | I'm glad to see you all!
112 | hey buddy!
113 | Nice meeting you
114 | it's great to meet you.
115 | how are things going with you and your family?
116 | it's great to see you as well!
117 | I trust that things are going smoothly for you
118 | "hello there, how's it going? hope things are good!"
119 | how's it going?
120 | "hello there, how are you? hope things are good!"
121 | i'm happy I got the chance to finally meet you in person.
122 | I'm so thrilled to see you all!
123 | it's nice to see you too!
124 | I trust that things are going smoothly for you and yours
125 | I'm glad i got the chance to finally meet you in person.
126 | i am dropping in to say hey!
127 | hey good to see you!
128 | dropping in to let you know i'm thinking of you.
129 | it's great meeting you
130 | I'm happy to see you!
131 | I just stopped by to say hello
132 | "hello there, how are you? hope things are good!"
133 | how is your family doing
134 | I'm so thrilled to see you!
135 | I trust that things are going smoothly for you
136 | i am stopping by to say hi
137 | It's great to finally meet you in person
138 | it's great to see you too!
139 | it's great to see you as well!
140 | it's great seeing you!
141 | "hi there, how are you? hope things are good!"
142 | I just stopped by to say hello
143 | nice to see you all!
144 | I'm excited to see you all!
145 | I've heard nothing but good things about you.
146 | i am dropping in to say hey!
147 | it's wonderful meeting you
148 | how's it going?
149 | i hope everything is going well with you
150 | dropped in because I was thinking of you.
151 | it's nice to see you!
152 | good to see you again!
153 | nice seeing you all!
154 | I'm so excited to see you all!
155 | I'm thrilled to be here
156 | good morning
157 | I'm thrilled to see you all!
158 | just stopping by to say hello
159 | i hope life has been treating you good
160 | how are things?
161 | how's your day going?
162 | hey friend!
163 | I'm thrilled to see you!
164 | I'm glad i got the chance to finally meet in person!
165 | it's nice seeing you all again!
166 | I'm excited to see you!
167 | It's great seeing you!
168 | it's nice seeing you all!
169 | how are you doing?
170 | I've heard nothing but good things about you and your family.
171 | I've heard nothing but good things about you and your family.
172 | do you have any plans for today
173 | how are you doing?
174 | good evening
175 | I'm excited to see you!
176 | it's good to see you all again!
177 | I just wanted to say hi
178 | I've heard a lot about you.
179 | do you have any plans for the weekend
180 | dropped in because I was thinking of you.
181 | it's good to see you too!
182 | I'm excited to see you!
183 | great seeing you all!
184 | I'm glad to see you
185 | it's great to see you again!
186 | Just saying hello
187 | how is everything?
188 | good to see you again!
189 | nice to finally meet in person
190 | good seeing you all!
191 | good to see you!
192 | it's been a while!
193 | It's nice seeing you.
194 | good to see you all again!
195 | it's nice seeing you again!
196 | nice seeing you again
197 | great seeing you all again!
198 | Hope all is good in the hood!
199 | It was good meeting everyone
200 | I'm glad to see you all!
201 | hey buddy!
202 | I'm glad to see you!
203 | how is everyone doing
204 | I'm happy I got the chance to finally meet you in person.
205 | hello
206 | it's a pleasure to meet you!
207 | it was so nice meeting you
208 | i'm stopping by to say hey!
209 | Hope all is well with your loved ones.
210 | how's your day?
211 | greetings friend! how are you doing today?
212 | hi friend! how are things going with your family?
213 | it's great seeing you all again!
214 | hey my friend
215 | hey good to see you!
216 | how is everything going?
217 | how has your day been so far
218 | welcome back
219 | "greetings friend, how are you doing today?"
220 | it's good to see you all!
221 | nice to see you!
222 | good to see you all
223 | Hope all is well with your family.
224 | I'm glad to see you all!
225 | what's new
226 | great seeing you again
227 | Just saying hello
228 | it's nice meeting you
229 | nice meeting you!
230 | "hi, how are you?"
231 | how are things going?
232 | just saying hi
233 | just saying hi
234 | I'm happy to see you all!
235 | hope all is well
236 | it's great to see you again!
237 | i hope everything is going well for you
238 | It's great to see you again!
239 | good seeing you all again!
240 | it's nice seeing you again!
241 | It's good to see you!
242 | how are things everyone?
243 | nice seeing you!
244 | great seeing you!
245 | good to see you!
246 | it's great to see you!
247 | it's been a long time!
248 | how's it going?
249 | how is everyone doing?
250 | i hope life has been treating you right
251 | how is life treating you
252 | welcome back
253 | it's great to see you too!
254 | "hey there, nice to see you!"
255 | nice to see you again!
256 | Good to see you all!
257 | it's nice seeing you as well!
258 | it's great to see you again!
259 | nice seeing you again!
260 | I'm excited to see you all!
261 | good evening!
262 | Hey team how's it going?
263 | how have you been?
264 | how is everyone doing? i hope all's well.
265 | it's nice to see you as well!
266 | "hey team, how is everything going?"
267 | it's nice to see you all!
268 | "hi there, how are you? hope things are good!"
269 | hey pal!
270 | it's nice seeing you!
271 | I'm happy to see you all!
272 | how's your day?
273 | I'm excited to see you!
274 | I'm so excited to see you!
275 | just wanted to say hi
276 | good to see you again!
277 | greetings! how are you doing today?
278 | nice meeting you again.
279 | great seeing you
280 | I'm happy to see you
281 | it's good to see you again!
282 | It's nice to meet you!
283 | I'm glad to see you!
284 | I just wanted to say hello
285 | It was good meeting you
286 | welcome back!
287 | greetings! how are you doing today?
288 | It's great to see you!
289 | what's up?
290 | I'm glad to see you!
291 | hey my friend!
292 | I'm happy to see you
293 | it's great to meet someone like yourself!
294 | good seeing you
295 | I've heard a lot about you and your family.
296 | how is your family doing?
297 | it's great seeing you all!
298 | do you have any plans for today?
299 |
--------------------------------------------------------------------------------
/data/greetings/greetings_labeled.tsv:
--------------------------------------------------------------------------------
1 | I have faith that things are going smoothly for you and yours 0
2 | long time no see! 0
3 | nice to see you too! 0
4 | I'm happy to see you! 0
5 | it was so good meeting you 1
6 | I'm stopping by to say hi 0
7 | I pray that things are going smoothly for you and yours 0
8 | I'm glad to see you! 0
9 | it's nice seeing you too! 0
10 | i hope life has been treating you good. 1
11 | it's great to finally meet in person. 0
12 | I'm stopping by to say hi 0
13 | i hope all is good with your family. 1
14 | nice to meet you too! 0
15 | I'm thrilled to see you! 0
16 | it's a pleasure to see you again! 0
17 | it's good to see you too! 1
18 | I'm happy to see you! 0
19 | i am stopping by to say hi 0
20 | I've heard a lot of good things about you and your family. 1
21 | hope all is well with your loved ones. 0
22 | nice to meet you! 0
23 | i'm stopping by to say hey! 0
24 | dropping in to let you know i'm thinking of you. 0
25 | nice seeing you again 0
26 | how's it going? 0
27 | hi, how are you? 0
28 | it's so nice to meet someone like yourself! 0
29 | welcome! 0
30 | I'm excited to see you! 0
31 | It's so nice to meet you! 0
32 | nice seeing you all again! 0
33 | I'm thrilled to see you all! 0
34 | hope all is well with your family. 0
35 | nice seeing you! 0
36 | it's nice to see you again! 0
37 | I just wanted to say hello 0
38 | I pray that things are going smoothly for you and yours 0
39 | nice to see you all again! 0
40 | nice to meet you in person 0
41 | how are you doing 0
42 | nice seeing you again. 0
43 | I've heard a lot of good things abou you. 1
44 | do you have any plans for the weekend 0
45 | it's good to see you again! 1
46 | good seeing everyone 1
47 | good to see you again! 1
48 | I'm happy to see you! 0
49 | I'm so happy to see you 0
50 | how's your day going? 0
51 | it's nice seeing you too! 0
52 | what's up? 0
53 | it's good seeing you again. 1
54 | hello there, how's it going? hope things are good! 1
55 | what is new 0
56 | I'm so happy to see you 0
57 | It's been a while since we last talked. 0
58 | hi there 0
59 | I'm happy to see you all! 0
60 | what's up? 0
61 | welcome! 0
62 | It's great seeing you again! 0
63 | it's good to see you again! 1
64 | it's nice to see you all again! 0
65 | I hope all is good with your family. 1
66 | what's up? 0
67 | i hope all is well with you and yours 0
68 | hey friend! 0
69 | good afternoon 1
70 | it's good seeing you! 1
71 | I trust that things are going smoothly for you and yours 0
72 | hi there, how's it going? hope things are good! 1
73 | good afternoon! 1
74 | It's exciting to see everyone 0
75 | hey pal! 0
76 | hi there, how's it going? hope things are good! 1
77 | I just wanted to say hi 0
78 | I'm thrilled to see you! 0
79 | I'm glad to see you 0
80 | I'm glad i got the chance to finally meet you in person. 0
81 | how is everything going with you and yours 0
82 | it's good to see you! 1
83 | it was so nice meeting you! 0
84 | hey mate! 0
85 | Hey team how is everyone doing? 0
86 | it's so good to finally meet in person. 1
87 | how are things going with you and your family 0
88 | I have faith that things are going smoothly for you and yours 0
89 | I'm excited to see you all! 0
90 | how's your day been? 0
91 | just wanted to say hi 0
92 | how have you been 0
93 | I'm excited to see you! 0
94 | great to see you again! 0
95 | great to see you again! 0
96 | I'm excited to see you! 0
97 | It's good to see you again! 1
98 | I've heard a lot about you and your family. 0
99 | just stopping by to say hello 0
100 | it's been too long! 0
101 | I'm excited to see you! 0
102 | good morning! 1
103 | how's your day going so far? 0
104 | good seeing you again 1
105 | I've heard a lot of good things about you and your family. 1
106 | hope all is good in the hood! 1
107 | it's good to see you as well! 1
108 | good night! 1
109 | hey mate! 0
110 | it's good to see you as well! 1
111 | I'm glad to see you all! 0
112 | hey buddy! 0
113 | Nice meeting you 0
114 | it's great to meet you. 0
115 | how are things going with you and your family? 0
116 | it's great to see you as well! 0
117 | I trust that things are going smoothly for you 0
118 | hello there, how's it going? hope things are good! 1
119 | how's it going? 0
120 | hello there, how are you? hope things are good! 1
121 | i'm happy I got the chance to finally meet you in person. 0
122 | I'm so thrilled to see you all! 0
123 | it's nice to see you too! 0
124 | I trust that things are going smoothly for you and yours 0
125 | I'm glad i got the chance to finally meet you in person. 0
126 | i am dropping in to say hey! 0
127 | hey good to see you! 1
128 | dropping in to let you know i'm thinking of you. 0
129 | it's great meeting you 0
130 | I'm happy to see you! 0
131 | I just stopped by to say hello 0
132 | hello there, how are you? hope things are good! 1
133 | how is your family doing 0
134 | I'm so thrilled to see you! 0
135 | I trust that things are going smoothly for you 0
136 | i am stopping by to say hi 0
137 | It's great to finally meet you in person 0
138 | it's great to see you too! 0
139 | it's great to see you as well! 0
140 | it's great seeing you! 0
141 | hi there, how are you? hope things are good! 1
142 | I just stopped by to say hello 0
143 | nice to see you all! 0
144 | I'm excited to see you all! 0
145 | I've heard nothing but good things about you. 1
146 | i am dropping in to say hey! 0
147 | it's wonderful meeting you 0
148 | how's it going? 0
149 | i hope everything is going well with you 0
150 | dropped in because I was thinking of you. 0
151 | it's nice to see you! 0
152 | good to see you again! 1
153 | nice seeing you all! 0
154 | I'm so excited to see you all! 0
155 | I'm thrilled to be here 0
156 | good morning 1
157 | I'm thrilled to see you all! 0
158 | just stopping by to say hello 0
159 | i hope life has been treating you good 1
160 | how are things? 0
161 | how's your day going? 0
162 | hey friend! 0
163 | I'm thrilled to see you! 0
164 | I'm glad i got the chance to finally meet in person! 0
165 | it's nice seeing you all again! 0
166 | I'm excited to see you! 0
167 | It's great seeing you! 0
168 | it's nice seeing you all! 0
169 | how are you doing? 0
170 | I've heard nothing but good things about you and your family. 1
171 | I've heard nothing but good things about you and your family. 1
172 | do you have any plans for today 0
173 | how are you doing? 0
174 | good evening 1
175 | I'm excited to see you! 0
176 | it's good to see you all again! 1
177 | I just wanted to say hi 0
178 | I've heard a lot about you. 0
179 | do you have any plans for the weekend 0
180 | dropped in because I was thinking of you. 0
181 | it's good to see you too! 1
182 | I'm excited to see you! 0
183 | great seeing you all! 0
184 | I'm glad to see you 0
185 | it's great to see you again! 0
186 | Just saying hello 0
187 | how is everything? 0
188 | good to see you again! 1
189 | nice to finally meet in person 0
190 | good seeing you all! 1
191 | good to see you! 1
192 | it's been a while! 0
193 | It's nice seeing you. 0
194 | good to see you all again! 1
195 | it's nice seeing you again! 0
196 | nice seeing you again 0
197 | great seeing you all again! 0
198 | Hope all is good in the hood! 1
199 | It was good meeting everyone 1
200 | I'm glad to see you all! 0
201 | hey buddy! 0
202 | I'm glad to see you! 0
203 | how is everyone doing 0
204 | I'm happy I got the chance to finally meet you in person. 0
205 | hello 0
206 | it's a pleasure to meet you! 0
207 | it was so nice meeting you 0
208 | i'm stopping by to say hey! 0
209 | Hope all is well with your loved ones. 0
210 | how's your day? 0
211 | greetings friend! how are you doing today? 0
212 | hi friend! how are things going with your family? 0
213 | it's great seeing you all again! 0
214 | hey my friend 0
215 | hey good to see you! 1
216 | how is everything going? 0
217 | how has your day been so far 0
218 | welcome back 0
219 | greetings friend, how are you doing today? 0
220 | it's good to see you all! 1
221 | nice to see you! 0
222 | good to see you all 1
223 | Hope all is well with your family. 0
224 | I'm glad to see you all! 0
225 | what's new 0
226 | great seeing you again 0
227 | Just saying hello 0
228 | it's nice meeting you 0
229 | nice meeting you! 0
230 | hi, how are you? 0
231 | how are things going? 0
232 | just saying hi 0
233 | just saying hi 0
234 | I'm happy to see you all! 0
235 | hope all is well 0
236 | it's great to see you again! 0
237 | i hope everything is going well for you 0
238 | It's great to see you again! 0
239 | good seeing you all again! 1
240 | it's nice seeing you again! 0
241 | It's good to see you! 1
242 | how are things everyone? 0
243 | nice seeing you! 0
244 | great seeing you! 0
245 | good to see you! 1
246 | it's great to see you! 0
247 | it's been a long time! 0
248 | how's it going? 0
249 | how is everyone doing? 0
250 | i hope life has been treating you right 0
251 | how is life treating you 0
252 | welcome back 0
253 | it's great to see you too! 0
254 | hey there, nice to see you! 0
255 | nice to see you again! 0
256 | Good to see you all! 0
257 | it's nice seeing you as well! 0
258 | it's great to see you again! 0
259 | nice seeing you again! 0
260 | I'm excited to see you all! 0
261 | good evening! 1
262 | Hey team how's it going? 0
263 | how have you been? 0
264 | how is everyone doing? i hope all's well. 0
265 | it's nice to see you as well! 0
266 | hey team, how is everything going? 0
267 | it's nice to see you all! 0
268 | hi there, how are you? hope things are good! 1
269 | hey pal! 0
270 | it's nice seeing you! 0
271 | I'm happy to see you all! 0
272 | how's your day? 0
273 | I'm excited to see you! 0
274 | I'm so excited to see you! 0
275 | just wanted to say hi 0
276 | good to see you again! 1
277 | greetings! how are you doing today? 0
278 | nice meeting you again. 0
279 | great seeing you 0
280 | I'm happy to see you 0
281 | it's good to see you again! 1
282 | It's nice to meet you! 0
283 | I'm glad to see you! 0
284 | I just wanted to say hello 0
285 | It was good meeting you 1
286 | welcome back! 0
287 | greetings! how are you doing today? 0
288 | It's great to see you! 0
289 | what's up? 0
290 | I'm glad to see you! 0
291 | hey my friend! 0
292 | I'm happy to see you 0
293 | it's great to meet someone like yourself! 0
294 | good seeing you 1
295 | I've heard a lot about you and your family. 0
296 | how is your family doing? 0
297 | it's great seeing you all! 0
298 | do you have any plans for today? 0
299 |
--------------------------------------------------------------------------------
/data/greetings/word-level-vocab.json:
--------------------------------------------------------------------------------
1 | {
2 | "version": "1.0",
3 | "truncation": {
4 | "direction": "Right",
5 | "max_length": 512,
6 | "strategy": "LongestFirst",
7 | "stride": 0
8 | },
9 | "padding": null,
10 | "added_tokens": [
11 | {
12 | "id": 0,
13 | "content": "[UNK]",
14 | "single_word": false,
15 | "lstrip": false,
16 | "rstrip": false,
17 | "normalized": false,
18 | "special": true
19 | },
20 | {
21 | "id": 1,
22 | "content": "[CLS]",
23 | "single_word": false,
24 | "lstrip": false,
25 | "rstrip": false,
26 | "normalized": false,
27 | "special": true
28 | },
29 | {
30 | "id": 2,
31 | "content": "[SEP]",
32 | "single_word": false,
33 | "lstrip": false,
34 | "rstrip": false,
35 | "normalized": false,
36 | "special": true
37 | },
38 | {
39 | "id": 3,
40 | "content": "[PAD]",
41 | "single_word": false,
42 | "lstrip": false,
43 | "rstrip": false,
44 | "normalized": false,
45 | "special": true
46 | },
47 | {
48 | "id": 4,
49 | "content": "[MASK]",
50 | "single_word": false,
51 | "lstrip": false,
52 | "rstrip": false,
53 | "normalized": false,
54 | "special": true
55 | }
56 | ],
57 | "normalizer": {
58 | "type": "Sequence",
59 | "normalizers": [
60 | {
61 | "type": "NFD"
62 | },
63 | {
64 | "type": "Lowercase"
65 | },
66 | {
67 | "type": "StripAccents"
68 | }
69 | ]
70 | },
71 | "pre_tokenizer": {
72 | "type": "Sequence",
73 | "pretokenizers": [
74 | {
75 | "type": "Digits",
76 | "individual_digits": true
77 | },
78 | {
79 | "type": "Whitespace"
80 | }
81 | ]
82 | },
83 | "post_processor": {
84 | "type": "TemplateProcessing",
85 | "single": [
86 | {
87 | "SpecialToken": {
88 | "id": "[CLS]",
89 | "type_id": 0
90 | }
91 | },
92 | {
93 | "Sequence": {
94 | "id": "A",
95 | "type_id": 0
96 | }
97 | },
98 | {
99 | "SpecialToken": {
100 | "id": "[SEP]",
101 | "type_id": 0
102 | }
103 | }
104 | ],
105 | "pair": [
106 | {
107 | "Sequence": {
108 | "id": "A",
109 | "type_id": 0
110 | }
111 | },
112 | {
113 | "Sequence": {
114 | "id": "B",
115 | "type_id": 1
116 | }
117 | }
118 | ],
119 | "special_tokens": {
120 | "[CLS]": {
121 | "id": "[CLS]",
122 | "ids": [
123 | 1
124 | ],
125 | "tokens": [
126 | "[CLS]"
127 | ]
128 | },
129 | "[SEP]": {
130 | "id": "[SEP]",
131 | "ids": [
132 | 2
133 | ],
134 | "tokens": [
135 | "[SEP]"
136 | ]
137 | }
138 | }
139 | },
140 | "decoder": null,
141 | "model": {
142 | "type": "WordLevel",
143 | "vocab": {
144 | "[UNK]": 0,
145 | "[CLS]": 1,
146 | "[SEP]": 2,
147 | "[PAD]": 3,
148 | "[MASK]": 4,
149 | "you": 5,
150 | "!": 6,
151 | "'": 7,
152 | "to": 8,
153 | "i": 9,
154 | "see": 10,
155 | "s": 11,
156 | "it": 12,
157 | "good": 13,
158 | "m": 14,
159 | "how": 15,
160 | "all": 16,
161 | "nice": 17,
162 | "?": 18,
163 | "again": 19,
164 | "are": 20,
165 | "seeing": 21,
166 | ".": 22,
167 | "going": 23,
168 | "things": 24,
169 | "great": 25,
170 | "hope": 26,
171 | "is": 27,
172 | "your": 28,
173 | "hey": 29,
174 | "hi": 30,
175 | "in": 31,
176 | "meet": 32,
177 | "say": 33,
178 | "and": 34,
179 | "family": 35,
180 | "well": 36,
181 | "just": 37,
182 | "so": 38,
183 | "\"": 39,
184 | ",": 40,
185 | "doing": 41,
186 | "excited": 42,
187 | "for": 43,
188 | "happy": 44,
189 | "hello": 45,
190 | "glad": 46,
191 | "with": 47,
192 | "a": 48,
193 | "been": 49,
194 | "meeting": 50,
195 | "by": 51,
196 | "person": 52,
197 | "there": 53,
198 | "too": 54,
199 | "!\"": 55,
200 | "finally": 56,
201 | "heard": 57,
202 | "the": 58,
203 | "ve": 59,
204 | "about": 60,
205 | "everyone": 61,
206 | "have": 62,
207 | "smoothly": 63,
208 | "stopping": 64,
209 | "that": 65,
210 | "thrilled": 66,
211 | "yours": 67,
212 | "day": 68,
213 | "friend": 69,
214 | "of": 70,
215 | "was": 71,
216 | "as": 72,
217 | "everything": 73,
218 | "lot": 74,
219 | "today": 75,
220 | "wanted": 76,
221 | "what": 77,
222 | "chance": 78,
223 | "got": 79,
224 | "welcome": 80,
225 | "?\"": 81,
226 | "am": 82,
227 | "any": 83,
228 | "do": 84,
229 | "dropping": 85,
230 | "greetings": 86,
231 | "has": 87,
232 | "life": 88,
233 | "plans": 89,
234 | "saying": 90,
235 | "thinking": 91,
236 | "treating": 92,
237 | "trust": 93,
238 | "up": 94,
239 | "back": 95,
240 | "but": 96,
241 | "long": 97,
242 | "nothing": 98,
243 | "team": 99,
244 | "afternoon": 100,
245 | "because": 101,
246 | "buddy": 102,
247 | "dropped": 103,
248 | "evening": 104,
249 | "faith": 105,
250 | "far": 106,
251 | "hood": 107,
252 | "know": 108,
253 | "let": 109,
254 | "like": 110,
255 | "loved": 111,
256 | "mate": 112,
257 | "morning": 113,
258 | "my": 114,
259 | "new": 115,
260 | "ones": 116,
261 | "pal": 117,
262 | "pleasure": 118,
263 | "pray": 119,
264 | "someone": 120,
265 | "stopped": 121,
266 | "time": 122,
267 | "weekend": 123,
268 | "while": 124,
269 | "yourself": 125,
270 | "abou": 126,
271 | "be": 127,
272 | "exciting": 128,
273 | "here": 129,
274 | "last": 130,
275 | "night": 131,
276 | "no": 132,
277 | "right": 133,
278 | "since": 134,
279 | "talked": 135,
280 | "we": 136,
281 | "wonderful": 137
282 | },
283 | "unk_token": "[UNK]"
284 | }
285 | }
--------------------------------------------------------------------------------
/data/simple/merges.txt:
--------------------------------------------------------------------------------
1 | #version: 0.2 - Trained by `huggingface/tokenizers`
2 | h e
3 | i s
4 | T he
5 | Ġ is
6 | e d
7 | Ġ b
8 | Ġ p
9 | Ġ s
10 | a n
11 | Ġ c
12 | o w
13 | Ġb l
14 | i n
15 | t e
16 | Ġ r
17 | Ġb r
18 | a r
19 | a c
20 | u r
21 | l e
22 | Ġ g
23 | e l
24 | i c
25 | e e
26 | Ġg r
27 | g e
28 | Ġ o
29 | l ow
30 | Ġ y
31 | h i
32 | ow n
33 | Ġbr own
34 | a s
35 | i ed
36 | e r
37 | Ġ w
38 | ac k
39 | p le
40 | r an
41 | f r
42 | Ġbl ack
43 | Ġo ran
44 | Ġoran ge
45 | Ġp in
46 | u e
47 | Ġbl ue
48 | Ġ m
49 | Ġr ed
50 | Ġp ur
51 | Ġpur ple
52 | t ed
53 | Ġpin k
54 | el low
55 | Ġy ellow
56 | hi te
57 | Ġw hite
58 | ee n
59 | Ġgr een
60 | o n
61 | Ġ d
62 | Ġ t
63 | u t
64 | o t
65 | l ed
66 | a m
67 | Ġp e
68 | u s
69 | fr ied
70 | r y
71 | Ġc o
72 | e n
73 | Ġ a
74 | k ed
75 | ar d
76 | Ġs o
77 | Ġ f
78 | u m
79 | he d
80 | Ġs t
81 | o as
82 | oas ted
83 | a d
84 | u c
85 | b er
86 | Ġ h
87 | Ġs te
88 | o ked
89 | Ġs p
90 | ur n
91 | Ġp o
92 | a u
93 | an d
94 | i p
95 | i l
96 | Ġs m
97 | l i
98 | b e
99 | ic k
100 | ar in
101 | Ġc r
102 | a t
103 | at o
104 | i t
105 | Ġ le
106 | g g
107 | a l
108 | ee t
109 | m en
110 | a v
111 | a p
112 | r ied
113 | Ġd ried
114 | f low
115 | flow er
116 | m on
117 | Ġc and
118 | Ġcand ied
119 | Ġsm oked
120 | am ed
121 | Ġste amed
122 | Ġm us
123 | an t
124 | Ġr ad
125 | Ġt urn
126 | Ġturn ip
127 | r o
128 | ac hed
129 | Ġpo ached
130 | a b
131 | Ġc h
132 | o m
133 | ee p
134 | Ġd eep
135 | Ġ ar
136 | hi o
137 | Ġs au
138 | te ed
139 | Ġsau teed
140 | i r
141 | Ġco l
142 | Ġst ir
143 | n ut
144 | Ġgr il
145 | ber ry
146 | Ġgril led
147 | i led
148 | o iled
149 | Ġ fr
150 | Ġb oiled
151 | Ġbrown ed
152 | in ed
153 | Ġbr ined
154 | Ġp ick
155 | Ġpick led
156 | t y
157 | u n
158 | i on
159 | r ow
160 | ic row
161 | Ġm icrow
162 | av ed
163 | Ġmicrow aved
164 | a g
165 | Ġt oasted
166 | a is
167 | c u
168 | Ġb ar
169 | Ġbr ais
170 | be cu
171 | Ġbar becu
172 | Ġbrais ed
173 | Ġbarbecu ed
174 | a ted
175 | Ġm arin
176 | Ġmarin ated
177 | c h
178 | er men
179 | Ġf ermen
180 | Ġfermen ted
181 | g ed
182 | ic y
183 | Ġa ged
184 | c hed
185 | an ched
186 | Ġbl anched
187 | Ġc ur
188 | Ġcur ed
189 | Ġr oasted
190 | Ġ k
191 | ac h
192 | te r
193 | Ġp ot
194 | Ġpot ato
195 | is h
196 | u it
197 | fr uit
198 | as p
199 | c on
200 | Ġco con
201 | Ġcocon ut
202 | e gg
203 | l ant
204 | p lant
205 | Ġ egg
206 | Ġegg plant
207 | w eet
208 | Ġs weet
209 | Ġc uc
210 | um ber
211 | Ġcuc umber
212 | a w
213 | o c
214 | v e
215 | Ġ j
216 | ge r
217 | e an
218 | Ġb ean
219 | Ġgr ap
220 | Ġgrap e
221 | q u
222 | Ġs qu
223 | as h
224 | Ġsqu ash
225 | a le
226 | Ġk ale
227 | a ge
228 | b age
229 | Ġc ab
230 | Ġcab bage
231 | a ked
232 | r ot
233 | Ġb aked
234 | Ġc ar
235 | Ġcar rot
236 | w ed
237 | Ġste wed
238 | Ġb it
239 | Ġbit ter
240 | l ard
241 | Ġcol lard
242 | m el
243 | in e
244 | Ġc au
245 | li flower
246 | Ġcau liflower
247 | o z
248 | Ġfr oz
249 | Ġfroz en
250 | p p
251 | in ach
252 | Ġpe pp
253 | Ġsp inach
254 | Ġrad ish
255 | Ġpepp er
256 | f f
257 | l u
258 | Ġf lu
259 | ff y
260 | Ġflu ffy
261 | Ġpe a
262 | Ġco oked
263 | t ard
264 | Ġmus tard
265 | Ġs un
266 | Ġsun flower
267 | s el
268 | Ġbr us
269 | Ġso gg
270 | Ġsp ro
271 | sel s
272 | Ġbrus sels
273 | Ġsogg y
274 | Ġspro ut
275 | g u
276 | l a
277 | u gu
278 | Ġsp icy
279 | Ġar ugu
280 | Ġarugu la
281 | o ry
282 | Ġb urn
283 | ic ory
284 | Ġch icory
285 | Ġburn t
286 | e le
287 | h ro
288 | Ġc ele
289 | Ġt om
290 | Ġmus hro
291 | Ġch ard
292 | Ġcele ry
293 | Ġtom ato
294 | Ġmushro om
295 | l mon
296 | t t
297 | y be
298 | Ġa lmon
299 | Ġso ybe
300 | uc e
301 | Ġle tt
302 | Ġalmon d
303 | Ġsoybe an
304 | Ġlett uce
305 | a te
306 | n ion
307 | Ġb eet
308 | Ġo nion
309 | n t
310 | Ġle nt
311 | Ġlent il
312 | k in
313 | p kin
314 | z uc
315 | Ġ asp
316 | Ġ zuc
317 | Ġp um
318 | in i
319 | ar ag
320 | Ġw al
321 | Ġso ur
322 | ch ini
323 | Ġasp arag
324 | Ġzuc chini
325 | Ġpum pkin
326 | Ġwal nut
327 | Ġasparag us
328 | c hio
329 | t ac
330 | is tac
331 | Ġp istac
332 | Ġr aw
333 | ic chio
334 | Ġh ot
335 | Ġrad icchio
336 | Ġpistac hio
337 | i o
338 | he w
339 | Ġc as
340 | el ic
341 | Ġy am
342 | Ġd elic
343 | io us
344 | Ġcas hew
345 | Ġdelic ious
346 | i a
347 | ac ad
348 | as ty
349 | Ġm acad
350 | Ġt asty
351 | am ia
352 | Ġmacad amia
353 | c o
354 | f t
355 | Ġbr oc
356 | Ġso ft
357 | co li
358 | Ġbroc coli
359 | an ut
360 | Ġpe anut
361 | e am
362 | Ġcr eam
363 | Ġcream y
364 | a z
365 | el nut
366 | Ġa p
367 | Ġh az
368 | Ġhaz elnut
369 | h o
370 | k e
371 | t ic
372 | te n
373 | Ġr ot
374 | Ġar tic
375 | ho ke
376 | Ġrot ten
377 | Ġartic hoke
378 | m y
379 | Ġs al
380 | Ġy um
381 | Ġsal ty
382 | Ġyum my
383 | c an
384 | Ġ fried
385 | Ġpe can
386 | Ġr ip
387 | Ġst ick
388 | Ġrip e
389 | Ġstick y
390 | m e
391 | p y
392 | is py
393 | Ġcr ispy
394 | el ion
395 | Ġd and
396 | Ġh ard
397 | Ġdand elion
398 | Ġr ut
399 | ab ag
400 | Ġrut abag
401 | Ġrutabag a
402 | d i
403 | Ġ en
404 | di ve
405 | Ġen dive
406 | Ġcol d
407 | Ġcr un
408 | ch y
409 | Ġcrun chy
410 | e s
411 | Ġfr es
412 | Ġfres h
413 | o ot
414 | Ġsm oot
415 | Ġsmoot h
416 | l ic
417 | ar lic
418 | Ġg arlic
419 | e as
420 | Ġgr eas
421 | Ġgreas y
422 | u icy
423 | Ġj uicy
424 | in ger
425 | Ġg inger
426 | Ġpe ach
427 | t ine
428 | Ġc le
429 | men tine
430 | Ġcle mentine
431 | a ter
432 | Ġw ater
433 | mel on
434 | Ġwater melon
435 | e ap
436 | Ġpin eap
437 | Ġgrape fruit
438 | Ġpineap ple
439 | i g
440 | o u
441 | p e
442 | r ry
443 | he rry
444 | Ġc ant
445 | Ġc herry
446 | Ġf ig
447 | al ou
448 | Ġcant alou
449 | Ġcantalou pe
450 | Ġle mon
451 | i m
452 | s im
453 | Ġp er
454 | sim mon
455 | Ġper simmon
456 | i w
457 | Ġk iw
458 | Ġkiw i
459 | c t
460 | d e
461 | e y
462 | e ct
463 | n ect
464 | Ġ nect
465 | Ġblack berry
466 | on ey
467 | Ġh oney
468 | arin e
469 | de w
470 | Ġnect arine
471 | Ġhoney dew
472 | c he
473 | l y
474 | Ġ ly
475 | che e
476 | Ġly chee
477 | k r
478 | Ġo kr
479 | Ġokr a
480 | g o
481 | l um
482 | Ġp lum
483 | an go
484 | Ġm ango
485 | Ġb an
486 | an a
487 | Ġap ple
488 | Ġban ana
489 | g ran
490 | Ġr asp
491 | Ġpo me
492 | gran ate
493 | Ġrasp berry
494 | Ġpome granate
495 | Ġf ish
496 | Ġpo mel
497 | Ġpomel o
498 | el on
499 | Ġm elon
500 | Ġd ate
501 | ack fruit
502 | Ġpe ar
503 | Ġj ackfruit
504 | u av
505 | an ger
506 | Ġg uav
507 | Ġt anger
508 | Ġguav a
509 | Ġtanger ine
510 | i an
511 | v oc
512 | ur ian
513 | Ġd urian
514 | Ġa voc
515 | ad o
516 | Ġavoc ado
517 | a y
518 | Ġp ap
519 | ar fruit
520 | Ġst arfruit
521 | ay a
522 | Ġpap aya
523 | r aw
524 | Ġst raw
525 | Ġstraw berry
526 | Ġo li
527 | Ġoli ve
528 | r ic
529 | Ġblue berry
530 | Ġap ric
531 | Ġapric ot
532 | Ġ li
533 | Ġt am
534 | arin d
535 | Ġli me
536 | Ġtam arind
537 |
--------------------------------------------------------------------------------
/data/simple/vocab.json:
--------------------------------------------------------------------------------
1 | {"":0,"":1,"":2,"":3,"":4,"!":5,"\"":6,"#":7,"$":8,"%":9,"&":10,"'":11,"(":12,")":13,"*":14,"+":15,",":16,"-":17,".":18,"/":19,"0":20,"1":21,"2":22,"3":23,"4":24,"5":25,"6":26,"7":27,"8":28,"9":29,":":30,";":31,"<":32,"=":33,">":34,"?":35,"@":36,"A":37,"B":38,"C":39,"D":40,"E":41,"F":42,"G":43,"H":44,"I":45,"J":46,"K":47,"L":48,"M":49,"N":50,"O":51,"P":52,"Q":53,"R":54,"S":55,"T":56,"U":57,"V":58,"W":59,"X":60,"Y":61,"Z":62,"[":63,"\\":64,"]":65,"^":66,"_":67,"`":68,"a":69,"b":70,"c":71,"d":72,"e":73,"f":74,"g":75,"h":76,"i":77,"j":78,"k":79,"l":80,"m":81,"n":82,"o":83,"p":84,"q":85,"r":86,"s":87,"t":88,"u":89,"v":90,"w":91,"x":92,"y":93,"z":94,"{":95,"|":96,"}":97,"~":98,"¡":99,"¢":100,"£":101,"¤":102,"¥":103,"¦":104,"§":105,"¨":106,"©":107,"ª":108,"«":109,"¬":110,"®":111,"¯":112,"°":113,"±":114,"²":115,"³":116,"´":117,"µ":118,"¶":119,"·":120,"¸":121,"¹":122,"º":123,"»":124,"¼":125,"½":126,"¾":127,"¿":128,"À":129,"Á":130,"Â":131,"Ã":132,"Ä":133,"Å":134,"Æ":135,"Ç":136,"È":137,"É":138,"Ê":139,"Ë":140,"Ì":141,"Í":142,"Î":143,"Ï":144,"Ð":145,"Ñ":146,"Ò":147,"Ó":148,"Ô":149,"Õ":150,"Ö":151,"×":152,"Ø":153,"Ù":154,"Ú":155,"Û":156,"Ü":157,"Ý":158,"Þ":159,"ß":160,"à":161,"á":162,"â":163,"ã":164,"ä":165,"å":166,"æ":167,"ç":168,"è":169,"é":170,"ê":171,"ë":172,"ì":173,"í":174,"î":175,"ï":176,"ð":177,"ñ":178,"ò":179,"ó":180,"ô":181,"õ":182,"ö":183,"÷":184,"ø":185,"ù":186,"ú":187,"û":188,"ü":189,"ý":190,"þ":191,"ÿ":192,"Ā":193,"ā":194,"Ă":195,"ă":196,"Ą":197,"ą":198,"Ć":199,"ć":200,"Ĉ":201,"ĉ":202,"Ċ":203,"ċ":204,"Č":205,"č":206,"Ď":207,"ď":208,"Đ":209,"đ":210,"Ē":211,"ē":212,"Ĕ":213,"ĕ":214,"Ė":215,"ė":216,"Ę":217,"ę":218,"Ě":219,"ě":220,"Ĝ":221,"ĝ":222,"Ğ":223,"ğ":224,"Ġ":225,"ġ":226,"Ģ":227,"ģ":228,"Ĥ":229,"ĥ":230,"Ħ":231,"ħ":232,"Ĩ":233,"ĩ":234,"Ī":235,"ī":236,"Ĭ":237,"ĭ":238,"Į":239,"į":240,"İ":241,"ı":242,"IJ":243,"ij":244,"Ĵ":245,"ĵ":246,"Ķ":247,"ķ":248,"ĸ":249,"Ĺ":250,"ĺ":251,"Ļ":252,"ļ":253,"Ľ":254,"ľ":255,"Ŀ":256,"ŀ":257,"Ł":258,"ł":259,"Ń":260,"he":261,"is":262,"The":263,"Ġis":264,"ed":265,"Ġb":266,"Ġp":267,"Ġs":268,"an":269,"Ġc":270,"ow":271,"Ġbl":272,"in":273,"te":274,"Ġr":275,"Ġbr":276,"ar":277,"ac":278,"ur":279,"le":280,"Ġg":281,"el":282,"ic":283,"ee":284,"Ġgr":285,"ge":286,"Ġo":287,"low":288,"Ġy":289,"hi":290,"own":291,"Ġbrown":292,"as":293,"ied":294,"er":295,"Ġw":296,"ack":297,"ple":298,"ran":299,"fr":300,"Ġblack":301,"Ġoran":302,"Ġorange":303,"Ġpin":304,"ue":305,"Ġblue":306,"Ġm":307,"Ġred":308,"Ġpur":309,"Ġpurple":310,"ted":311,"Ġpink":312,"ellow":313,"Ġyellow":314,"hite":315,"Ġwhite":316,"een":317,"Ġgreen":318,"on":319,"Ġd":320,"Ġt":321,"ut":322,"ot":323,"led":324,"am":325,"Ġpe":326,"us":327,"fried":328,"ry":329,"Ġco":330,"en":331,"Ġa":332,"ked":333,"ard":334,"Ġso":335,"Ġf":336,"um":337,"hed":338,"Ġst":339,"oas":340,"oasted":341,"ad":342,"uc":343,"ber":344,"Ġh":345,"Ġste":346,"oked":347,"Ġsp":348,"urn":349,"Ġpo":350,"au":351,"and":352,"ip":353,"il":354,"Ġsm":355,"li":356,"be":357,"ick":358,"arin":359,"Ġcr":360,"at":361,"ato":362,"it":363,"Ġle":364,"gg":365,"al":366,"eet":367,"men":368,"av":369,"ap":370,"ried":371,"Ġdried":372,"flow":373,"flower":374,"mon":375,"Ġcand":376,"Ġcandied":377,"Ġsmoked":378,"amed":379,"Ġsteamed":380,"Ġmus":381,"ant":382,"Ġrad":383,"Ġturn":384,"Ġturnip":385,"ro":386,"ached":387,"Ġpoached":388,"ab":389,"Ġch":390,"om":391,"eep":392,"Ġdeep":393,"Ġar":394,"hio":395,"Ġsau":396,"teed":397,"Ġsauteed":398,"ir":399,"Ġcol":400,"Ġstir":401,"nut":402,"Ġgril":403,"berry":404,"Ġgrilled":405,"iled":406,"oiled":407,"Ġfr":408,"Ġboiled":409,"Ġbrowned":410,"ined":411,"Ġbrined":412,"Ġpick":413,"Ġpickled":414,"ty":415,"un":416,"ion":417,"row":418,"icrow":419,"Ġmicrow":420,"aved":421,"Ġmicrowaved":422,"ag":423,"Ġtoasted":424,"ais":425,"cu":426,"Ġbar":427,"Ġbrais":428,"becu":429,"Ġbarbecu":430,"Ġbraised":431,"Ġbarbecued":432,"ated":433,"Ġmarin":434,"Ġmarinated":435,"ch":436,"ermen":437,"Ġfermen":438,"Ġfermented":439,"ged":440,"icy":441,"Ġaged":442,"ched":443,"anched":444,"Ġblanched":445,"Ġcur":446,"Ġcured":447,"Ġroasted":448,"Ġk":449,"ach":450,"ter":451,"Ġpot":452,"Ġpotato":453,"ish":454,"uit":455,"fruit":456,"asp":457,"con":458,"Ġcocon":459,"Ġcoconut":460,"egg":461,"lant":462,"plant":463,"Ġegg":464,"Ġeggplant":465,"weet":466,"Ġsweet":467,"Ġcuc":468,"umber":469,"Ġcucumber":470,"aw":471,"oc":472,"ve":473,"Ġj":474,"ger":475,"ean":476,"Ġbean":477,"Ġgrap":478,"Ġgrape":479,"qu":480,"Ġsqu":481,"ash":482,"Ġsquash":483,"ale":484,"Ġkale":485,"age":486,"bage":487,"Ġcab":488,"Ġcabbage":489,"aked":490,"rot":491,"Ġbaked":492,"Ġcar":493,"Ġcarrot":494,"wed":495,"Ġstewed":496,"Ġbit":497,"Ġbitter":498,"lard":499,"Ġcollard":500,"mel":501,"ine":502,"Ġcau":503,"liflower":504,"Ġcauliflower":505,"oz":506,"Ġfroz":507,"Ġfrozen":508,"pp":509,"inach":510,"Ġpepp":511,"Ġspinach":512,"Ġradish":513,"Ġpepper":514,"ff":515,"lu":516,"Ġflu":517,"ffy":518,"Ġfluffy":519,"Ġpea":520,"Ġcooked":521,"tard":522,"Ġmustard":523,"Ġsun":524,"Ġsunflower":525,"sel":526,"Ġbrus":527,"Ġsogg":528,"Ġspro":529,"sels":530,"Ġbrussels":531,"Ġsoggy":532,"Ġsprout":533,"gu":534,"la":535,"ugu":536,"Ġspicy":537,"Ġarugu":538,"Ġarugula":539,"ory":540,"Ġburn":541,"icory":542,"Ġchicory":543,"Ġburnt":544,"ele":545,"hro":546,"Ġcele":547,"Ġtom":548,"Ġmushro":549,"Ġchard":550,"Ġcelery":551,"Ġtomato":552,"Ġmushroom":553,"lmon":554,"tt":555,"ybe":556,"Ġalmon":557,"Ġsoybe":558,"uce":559,"Ġlett":560,"Ġalmond":561,"Ġsoybean":562,"Ġlettuce":563,"ate":564,"nion":565,"Ġbeet":566,"Ġonion":567,"nt":568,"Ġlent":569,"Ġlentil":570,"kin":571,"pkin":572,"zuc":573,"Ġasp":574,"Ġzuc":575,"Ġpum":576,"ini":577,"arag":578,"Ġwal":579,"Ġsour":580,"chini":581,"Ġasparag":582,"Ġzucchini":583,"Ġpumpkin":584,"Ġwalnut":585,"Ġasparagus":586,"chio":587,"tac":588,"istac":589,"Ġpistac":590,"Ġraw":591,"icchio":592,"Ġhot":593,"Ġradicchio":594,"Ġpistachio":595,"io":596,"hew":597,"Ġcas":598,"elic":599,"Ġyam":600,"Ġdelic":601,"ious":602,"Ġcashew":603,"Ġdelicious":604,"ia":605,"acad":606,"asty":607,"Ġmacad":608,"Ġtasty":609,"amia":610,"Ġmacadamia":611,"co":612,"ft":613,"Ġbroc":614,"Ġsoft":615,"coli":616,"Ġbroccoli":617,"anut":618,"Ġpeanut":619,"eam":620,"Ġcream":621,"Ġcreamy":622,"az":623,"elnut":624,"Ġap":625,"Ġhaz":626,"Ġhazelnut":627,"ho":628,"ke":629,"tic":630,"ten":631,"Ġrot":632,"Ġartic":633,"hoke":634,"Ġrotten":635,"Ġartichoke":636,"my":637,"Ġsal":638,"Ġyum":639,"Ġsalty":640,"Ġyummy":641,"can":642,"Ġfried":643,"Ġpecan":644,"Ġrip":645,"Ġstick":646,"Ġripe":647,"Ġsticky":648,"me":649,"py":650,"ispy":651,"Ġcrispy":652,"elion":653,"Ġdand":654,"Ġhard":655,"Ġdandelion":656,"Ġrut":657,"abag":658,"Ġrutabag":659,"Ġrutabaga":660,"di":661,"Ġen":662,"dive":663,"Ġendive":664,"Ġcold":665,"Ġcrun":666,"chy":667,"Ġcrunchy":668,"es":669,"Ġfres":670,"Ġfresh":671,"oot":672,"Ġsmoot":673,"Ġsmooth":674,"lic":675,"arlic":676,"Ġgarlic":677,"eas":678,"Ġgreas":679,"Ġgreasy":680,"uicy":681,"Ġjuicy":682,"inger":683,"Ġginger":684,"Ġpeach":685,"tine":686,"Ġcle":687,"mentine":688,"Ġclementine":689,"ater":690,"Ġwater":691,"melon":692,"Ġwatermelon":693,"eap":694,"Ġpineap":695,"Ġgrapefruit":696,"Ġpineapple":697,"ig":698,"ou":699,"pe":700,"rry":701,"herry":702,"Ġcant":703,"Ġcherry":704,"Ġfig":705,"alou":706,"Ġcantalou":707,"Ġcantaloupe":708,"Ġlemon":709,"im":710,"sim":711,"Ġper":712,"simmon":713,"Ġpersimmon":714,"iw":715,"Ġkiw":716,"Ġkiwi":717,"ct":718,"de":719,"ey":720,"ect":721,"nect":722,"Ġnect":723,"Ġblackberry":724,"oney":725,"Ġhoney":726,"arine":727,"dew":728,"Ġnectarine":729,"Ġhoneydew":730,"che":731,"ly":732,"Ġly":733,"chee":734,"Ġlychee":735,"kr":736,"Ġokr":737,"Ġokra":738,"go":739,"lum":740,"Ġplum":741,"ango":742,"Ġmango":743,"Ġban":744,"ana":745,"Ġapple":746,"Ġbanana":747,"gran":748,"Ġrasp":749,"Ġpome":750,"granate":751,"Ġraspberry":752,"Ġpomegranate":753,"Ġfish":754,"Ġpomel":755,"Ġpomelo":756,"elon":757,"Ġmelon":758,"Ġdate":759,"ackfruit":760,"Ġpear":761,"Ġjackfruit":762,"uav":763,"anger":764,"Ġguav":765,"Ġtanger":766,"Ġguava":767,"Ġtangerine":768,"ian":769,"voc":770,"urian":771,"Ġdurian":772,"Ġavoc":773,"ado":774,"Ġavocado":775,"ay":776,"Ġpap":777,"arfruit":778,"Ġstarfruit":779,"aya":780,"Ġpapaya":781,"raw":782,"Ġstraw":783,"Ġstrawberry":784,"Ġoli":785,"Ġolive":786,"ric":787,"Ġblueberry":788,"Ġapric":789,"Ġapricot":790,"Ġli":791,"Ġtam":792,"arind":793,"Ġlime":794,"Ġtamarind":795}
--------------------------------------------------------------------------------
/data/simple/word-level-vocab.json:
--------------------------------------------------------------------------------
1 | {
2 | "version": "1.0",
3 | "truncation": {
4 | "direction": "Right",
5 | "max_length": 512,
6 | "strategy": "LongestFirst",
7 | "stride": 0
8 | },
9 | "padding": null,
10 | "added_tokens": [
11 | {
12 | "id": 0,
13 | "content": "[UNK]",
14 | "single_word": false,
15 | "lstrip": false,
16 | "rstrip": false,
17 | "normalized": false,
18 | "special": true
19 | },
20 | {
21 | "id": 1,
22 | "content": "[CLS]",
23 | "single_word": false,
24 | "lstrip": false,
25 | "rstrip": false,
26 | "normalized": false,
27 | "special": true
28 | },
29 | {
30 | "id": 2,
31 | "content": "[SEP]",
32 | "single_word": false,
33 | "lstrip": false,
34 | "rstrip": false,
35 | "normalized": false,
36 | "special": true
37 | },
38 | {
39 | "id": 3,
40 | "content": "[PAD]",
41 | "single_word": false,
42 | "lstrip": false,
43 | "rstrip": false,
44 | "normalized": false,
45 | "special": true
46 | },
47 | {
48 | "id": 4,
49 | "content": "[MASK]",
50 | "single_word": false,
51 | "lstrip": false,
52 | "rstrip": false,
53 | "normalized": false,
54 | "special": true
55 | }
56 | ],
57 | "normalizer": {
58 | "type": "Sequence",
59 | "normalizers": [
60 | {
61 | "type": "NFD"
62 | },
63 | {
64 | "type": "Lowercase"
65 | },
66 | {
67 | "type": "StripAccents"
68 | }
69 | ]
70 | },
71 | "pre_tokenizer": {
72 | "type": "Sequence",
73 | "pretokenizers": [
74 | {
75 | "type": "Digits",
76 | "individual_digits": true
77 | },
78 | {
79 | "type": "Whitespace"
80 | }
81 | ]
82 | },
83 | "post_processor": {
84 | "type": "TemplateProcessing",
85 | "single": [
86 | {
87 | "SpecialToken": {
88 | "id": "[CLS]",
89 | "type_id": 0
90 | }
91 | },
92 | {
93 | "Sequence": {
94 | "id": "A",
95 | "type_id": 0
96 | }
97 | },
98 | {
99 | "SpecialToken": {
100 | "id": "[SEP]",
101 | "type_id": 0
102 | }
103 | }
104 | ],
105 | "pair": [
106 | {
107 | "Sequence": {
108 | "id": "A",
109 | "type_id": 0
110 | }
111 | },
112 | {
113 | "Sequence": {
114 | "id": "B",
115 | "type_id": 1
116 | }
117 | }
118 | ],
119 | "special_tokens": {
120 | "[CLS]": {
121 | "id": "[CLS]",
122 | "ids": [
123 | 1
124 | ],
125 | "tokens": [
126 | "[CLS]"
127 | ]
128 | },
129 | "[SEP]": {
130 | "id": "[SEP]",
131 | "ids": [
132 | 2
133 | ],
134 | "tokens": [
135 | "[SEP]"
136 | ]
137 | }
138 | }
139 | },
140 | "decoder": null,
141 | "model": {
142 | "type": "WordLevel",
143 | "vocab": {
144 | "[UNK]": 0,
145 | "[CLS]": 1,
146 | "[SEP]": 2,
147 | "[PAD]": 3,
148 | "[MASK]": 4,
149 | ".": 5,
150 | "is": 6,
151 | "the": 7,
152 | "orange": 8,
153 | "black": 9,
154 | "red": 10,
155 | "purple": 11,
156 | "pink": 12,
157 | "blue": 13,
158 | "yellow": 14,
159 | "white": 15,
160 | "brown": 16,
161 | "green": 17,
162 | "fried": 18,
163 | "-": 19,
164 | "dried": 20,
165 | "candied": 21,
166 | "smoked": 22,
167 | "steamed": 23,
168 | "turnip": 24,
169 | "poached": 25,
170 | "deep": 26,
171 | "sauteed": 27,
172 | "stir": 28,
173 | "grilled": 29,
174 | "boiled": 30,
175 | "browned": 31,
176 | "brined": 32,
177 | "pickled": 33,
178 | "microwaved": 34,
179 | "toasted": 35,
180 | "barbecued": 36,
181 | "braised": 37,
182 | "marinated": 38,
183 | "fermented": 39,
184 | "aged": 40,
185 | "blanched": 41,
186 | "cured": 42,
187 | "roasted": 43,
188 | "potato": 44,
189 | "coconut": 45,
190 | "eggplant": 46,
191 | "sweet": 47,
192 | "cucumber": 48,
193 | "bean": 49,
194 | "squash": 50,
195 | "kale": 51,
196 | "cabbage": 52,
197 | "baked": 53,
198 | "carrot": 54,
199 | "stewed": 55,
200 | "bitter": 56,
201 | "collard": 57,
202 | "cauliflower": 58,
203 | "frozen": 59,
204 | "pepper": 60,
205 | "radish": 61,
206 | "spinach": 62,
207 | "fluffy": 63,
208 | "pea": 64,
209 | "cooked": 65,
210 | "mustard": 66,
211 | "sunflower": 67,
212 | "brussels": 68,
213 | "soggy": 69,
214 | "sprout": 70,
215 | "arugula": 71,
216 | "spicy": 72,
217 | "burnt": 73,
218 | "chicory": 74,
219 | "celery": 75,
220 | "chard": 76,
221 | "mushroom": 77,
222 | "tomato": 78,
223 | "almond": 79,
224 | "lettuce": 80,
225 | "soybean": 81,
226 | "beet": 82,
227 | "onion": 83,
228 | "lentil": 84,
229 | "asparagus": 85,
230 | "pumpkin": 86,
231 | "sour": 87,
232 | "walnut": 88,
233 | "zucchini": 89,
234 | "hot": 90,
235 | "pistachio": 91,
236 | "radicchio": 92,
237 | "raw": 93,
238 | "cashew": 94,
239 | "delicious": 95,
240 | "yam": 96,
241 | "macadamia": 97,
242 | "tasty": 98,
243 | "broccoli": 99,
244 | "soft": 100,
245 | "peanut": 101,
246 | "creamy": 102,
247 | "hazelnut": 103,
248 | "artichoke": 104,
249 | "rotten": 105,
250 | "salty": 106,
251 | "yummy": 107,
252 | "pecan": 108,
253 | "ripe": 109,
254 | "sticky": 110,
255 | "crispy": 111,
256 | "dandelion": 112,
257 | "hard": 113,
258 | "rutabaga": 114,
259 | "endive": 115,
260 | "cold": 116,
261 | "crunchy": 117,
262 | "fresh": 118,
263 | "smooth": 119,
264 | "garlic": 120,
265 | "greasy": 121,
266 | "juicy": 122,
267 | "ginger": 123,
268 | "peach": 124,
269 | "clementine": 125,
270 | "grape": 126,
271 | "watermelon": 127,
272 | "grapefruit": 128,
273 | "pineapple": 129,
274 | "cantaloupe": 130,
275 | "cherry": 131,
276 | "fig": 132,
277 | "lemon": 133,
278 | "persimmon": 134,
279 | "kiwi": 135,
280 | "blackberry": 136,
281 | "honeydew": 137,
282 | "nectarine": 138,
283 | "lychee": 139,
284 | "okra": 140,
285 | "mango": 141,
286 | "plum": 142,
287 | "apple": 143,
288 | "banana": 144,
289 | "pomegranate": 145,
290 | "raspberry": 146,
291 | "fish": 147,
292 | "pomelo": 148,
293 | "date": 149,
294 | "melon": 150,
295 | "jackfruit": 151,
296 | "pear": 152,
297 | "guava": 153,
298 | "tangerine": 154,
299 | "avocado": 155,
300 | "durian": 156,
301 | "papaya": 157,
302 | "starfruit": 158,
303 | "strawberry": 159,
304 | "olive": 160,
305 | "apricot": 161,
306 | "blueberry": 162,
307 | "lime": 163,
308 | "tamarind": 164
309 | },
310 | "unk_token": "[UNK]"
311 | }
312 | }
--------------------------------------------------------------------------------
/docs/CNAME:
--------------------------------------------------------------------------------
1 | diffusion.textgen.info
--------------------------------------------------------------------------------
/docs/controllable.md:
--------------------------------------------------------------------------------
1 | ## Classifier-guided Controllable Generation
2 |
3 | * The unconditional generation is used to sample some sentences from a distribution of interest. However, a more interesting task is to generate sentences that satisfy some constraints. For example, we may want to generate sentences containing a certain color.
4 |
5 | - For this walkthrough, let's say we want to generate sentences that contain {"red", "blue", "green", "white"}. We create such a labeled dataset in `data/simple/simple_labeled.tsv`.
6 |
7 | ```sh
8 | $ shuf data/simple/simple_labeled.tsv|head -2
9 | The purple pumpkin is juicy. 0
10 | The green pear is sweet. 1
11 | ```
12 | - The labeled file is required to train a classifier. In general, if you are working with a dataset name `dset`, a labeled file should be present at `data/{dset}/dset_labeled.tsv` with two columns (sentence and label).
13 |
14 | Let's start!
15 |
16 | ### Step 0: Train a diffusion model.
17 |
18 | * This is the backbone diffusion model whose generations we want to guide. We will use a diffusion model trained on the `simple` dataset introduced in the README.
19 | * Please download the new checkpoint (word-level vocab) from [here](https://drive.google.com/drive/folders/1zPiopN0MqhkYNlUza6zOChPqWyoDofLh?usp=sharing) and put it in the `ckpts/simplev2` folder.
20 |
21 | ### Step 1: Train a classifier
22 |
23 |
24 | * Train a classifier
25 |
26 | ```sh
27 | python -u src/controllable/classifier.py --model_name_or_path ckpts/simplev2/ema_0.9999_005001.pt
28 | ```
29 |
30 | * This trains a classifier on the latent/noisy samples ($$x_t$$).
31 |
32 | - It is sufficient to only specify the checkpoint! The name of the dataset and other hyperparameters are loaded from the diffusion model's config file (`ckpts/simplev2/ema_0.9999_005001.pt`). However, the classifier does require the labeled file to be present at `data/{dset}/dset_labeled.tsv`.
33 |
34 | - The classifier is saved at `ckpts/simplev2/classifier.pt`.
35 |
36 |
37 | ### Step 2: Run controllable generation
38 |
39 | ```sh
40 | bash scripts/ctrl_text_sample.sh ckpts/simplev2/ema_0.9999_005001.pt 300 50
41 | ```
42 |
43 | - Note that we use only 300 diffusion steps vs. 2000 for training. This works because the decoding is actually DDIM style: we approximate `x0` at each step, which is used for denoising.
44 |
45 | - The outputs are generated at: `ckpts/simplev2/ema_0.9999_005001.pt.samples_50.steps-300.clamp-no_clamp.txt.ctrl`.
46 |
47 | - Let's also generate 500 samples from the unguided model for comparison:
48 |
49 | ```sh
50 | CUDA_VISIBLE_DEVICES=8 && bash scripts/text_sample.sh ckpts/simplev2/ema_0.9999_005001.pt 300 500
51 | ```
52 |
53 | * Let's compare the outputs of the two models:
54 |
55 | ```sh
56 | # top 5 colors in the unguided output:
57 |
58 | (diffusion) amadaan@sa:~/home2/minimal-text-diffusion$ cut -f3 -d" " ckpts/simplev2/ema_0.9999_005001.pt.samples_500.steps-300.clamp-no_clamp.txt | sort | uniq -c | sed 's/^\s*//g' | sort -n|tail -5
59 | 30 purple
60 | 53 yellow
61 | 69 green
62 | 111 pink
63 | 166 white
64 | ```
65 |
66 | ```sh
67 | # top 5 colors in the guided output:
68 | (diffusion) amadaan@sa:~/home2/minimal-text-diffusion$ cut -f3 -d" " ckpts/simplev2/ema_0.9999_005001.pt.samples_500.steps-300.clamp-no_clamp.txt.ctrl.sample1 | sort | uniq -c | sed 's/^\s*//g' | sort -n|tail -5
69 | 15 pink
70 | 16 black
71 | 25 purple
72 | 124 yellow
73 | 269 green
74 | ```
75 |
76 | * 50% of the sentences in the guided output contain the color word "green" vs. 69/500 = 14% in the unguided output. It looks like it's working! (recall that green was one of the 4 colors we specified in the classifier for label 1).
77 |
78 |
79 |
80 | ## Implementation Details
81 |
82 | - The files relevant to controllable generation are in `src/controllable/`.
83 |
84 | - Listing
85 | src/controllable/
86 | ├── classifier.py
87 | ├── controllable_text_sample.py
88 | └── langevin.py
89 |
90 |
91 | * Here:
92 | - `classifier.py` trains a classifier on the latents of the diffusion model.
93 |
94 | - `controllable_text_sample.py` runs controllable generation.
95 |
96 | - `langevin.py` refines the embeddings with classifier guidance (using Langevin dynamics).
97 |
98 |
99 | - At a high level, the procedure is as follows:
100 | a) `p_sample_loop_langevin_progressive` in `src/modeling/diffusion/gaussian_diffusion.py` first creates an approximate `x_{t-1}` and then calls `langevin_binary_classifier` in `src/controllable/langevin.py`
101 |
102 | b) `langevin_binary_classifier` then refines the embeddings with classifier guidance. This is the Langevin dynamics step. $x_{t-1} = x_{t-1} + \epsilon \nabla_x \log p(y = 1 \mid x_{t-1})$ where $\log p(y \mid x_{t-1})$ is the probability of $y = 1$ given the noisy input $x_{t-1}$. The controllable generation is currently only done for labels = 1, but this can be changed by flipping the labels in `langevin_binary_classifier`. (TODO: add support for dynamic labels).
103 |
104 |
--------------------------------------------------------------------------------
/docs/imgs/greetings_training_finished.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/minimal-text-diffusion/9303ffd481a2f647da24c6053e4dec44fd086a8d/docs/imgs/greetings_training_finished.png
--------------------------------------------------------------------------------
/docs/imgs/greetings_training_loop.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/minimal-text-diffusion/9303ffd481a2f647da24c6053e4dec44fd086a8d/docs/imgs/greetings_training_loop.png
--------------------------------------------------------------------------------
/docs/old_experiments.md:
--------------------------------------------------------------------------------
1 | ## Experiments and Results
2 |
3 | * I've tried experiments with the following hyperparameters:
4 |
5 | 1. Embeddings/vocab: pre-trained `bert-base-uncased` vs. initialized randomly.
6 |
7 | 2. Model backbone: pre-trained `bert-base-uncased` vs. initialized from scratch.
8 |
9 | 3. Embeddings fine-tuning: fine-tuned vs. frozen.
10 |
11 | Out of the 8 possible combinations, the best results were obtained with the following hyperparameters:
12 |
13 | | File | Sample Sentences | Perplexity | % Unique Lines | % Unique Tokens |
14 | |----------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------|----------------|-----------------|
15 | | MODEL_PT-False_EMBEDS_PT-False-FREEZE_EMBEDS-False | The yellow lentil is stir-fried., The green lemon is stir-fried., The white kiwi is deep-fried., The orange turnip is stir-fried., The blue blackberry is deep-fried. | 212.70 | 80.0 | 3.76 |
16 | | MODEL_PT-False_EMBEDS_PT-False-FREEZE_EMBEDS-True | The green spinach is stir-fried., The pink pomelo is stir-fried., The brown onion is stir-fried., The yellow artichoke is stir-fried., The blue pomegranate is deep-fried. | 218.77 | 74.2 | 3.76 |
17 | | MODEL_PT-True_EMBEDS_PT-True-FREEZE_EMBEDS-False | the yellow poccoli isggy., the red pe is sauteed., the green spinach is candied., the green danmelli isy., the brown kale is candied. | 1424.21 | 78.0 | 6.1 |
18 |
19 |
20 | ---
21 |
22 | - The best setting in terms of diversity is using pre-trained bert, bert embeddings, and fine-tuning the embeddings. However, this setting has the lowest perplexity because it generates weird sentences (which could be a good thing depending on the application!).
23 |
24 | - Some random samples from this setting:
25 |
26 | - Update 10/24: The following samples were likely from the `MODEL_PT-False_EMBEDS_PT-False-FREEZE_EMBEDS-False` setting.
27 |
28 | ```
29 | the purple chard is braised.
30 | the pink eggplant is soft.
31 | the blue perychmmon is fluffy.
32 | the pink eggplant is fluffy.
33 | the orange macadamia is spicy.
34 | the blue almond is poached.
35 | the black avychnd is steamed.
36 | the brown radicchio is delicious.
37 | the blue yam is microwaved.
38 | the black pistachio is dried.
39 | ```
40 |
41 | - The model was trained on a single RTX 2080 Ti GPU for 25k steps. The training time was ~1 hours.
42 |
43 | ---
--------------------------------------------------------------------------------
/docs/training_on_your_own_dataset.md:
--------------------------------------------------------------------------------
1 | ## Steps to train a model on the simple greetings dataset.
2 |
3 | This dataset is small to allow faster training, and the test data is simply a fraction of the [training data](https://github.com/madaan/minimal-text-diffusion/blob/main/data/greetings-train.txt). The data set was generated using few-shot prompting.
4 |
5 |
6 | ### Step 1: Tokenization
7 |
8 | - For diffusion models trained on images, there is little pre-processing required. Each image is already a tensor of `height x width x channels.` However, when dealing with text-corpus, we need to do some pre-processing.
9 |
10 | Specifically, we need to (i) convert the text to a sequence of tokens (integers or IDs) (tokenization) and then (ii) map each token to a continuous vector (embeddings).
11 |
12 | [Tokenization](https://huggingface.co/course/chapter2/4?fw=pt) is an important design choice for training a language generation model. I found word-level tokenization to be the most effective. Still, the implementation in `src/utils/custom_tokenizer` also includes BPE if you want to experiment (_intuitively, BPE trivially increases the dimensionality, so that might be hurting the performance_).
13 |
14 | Since we are creating vocabulary from scratch, the embeddings for each token are randomly initialized. The embeddings are learned during training.
15 |
16 |
17 | * To train a word-level tokenizer
18 |
19 | ```sh
20 | python src/utils/custom_tokenizer.py train-word-level data/greetings/greetings.txt
21 | ```
22 |
23 | - This creates a vocab file in `data/greetings/word-level-vocab.json.`
24 |
25 | - The defaults have been changed to use word-level tokenization, but please see `def create_tokenizer` in `src/utils/custom_tokenizer.py` for more details. The training code looks for the tokenizer file in `data/greetings//word-level-vocab.json` (more generally, `data//word-level-vocab.json`).
26 |
27 | ### Step 2: Training
28 |
29 | - After creating the tokenizer, you can start training:
30 |
31 | ```sh
32 | bash scripts/run_train.sh greetings 1 False False False 5000
33 | ```
34 |
35 | Here, the options mean:
36 | - `greetings`: dataset name
37 | - `1`: GPU ID
38 | - `False`: whether to use a pretrained model. We are training from scratch, so we set this to `False.` A finer point is that the goal of the model is not to learn good sentence representations. Instead, the goal of the model is to take noisy embeddings of text (`xt`) and predict the clean version (`x0`). So, using a pre-trained model isn't necessary.
39 | - `False`: whether to use pretrained embeddings. We are using our word-level vocab, so this is set to `False.`
40 | - `False`: whether to freeze the embeddings. We are training from scratch, so we set this to `False` as we want the embeddings to be learned.
41 | - `5000`: number of training steps. This is a hyperparameter that you can tune. I found that 5000 steps are sufficient for `greetings` and training finishes in ~15 minutes.
42 |
43 | Some boolean options may appear redundant, but they allow interesting ablations (e.g., using pre-trained embeddings but not a pre-trained model or freezing pre-trained embeddings).
44 |
45 |
46 | * After starting the job, you should see a wandb URL. This is where you can monitor the training progress. The training process is also displayed on the terminal.
47 |
48 |
49 | - 
50 |
51 |
52 | - 
53 |
54 |
55 |
56 | * The checkpoint is saved in `ckpts/greetings/.`
57 |
58 |
59 |
60 | ### Step 3: Evaluation
61 |
62 | - After training, you can evaluate the model using the following command:
63 |
64 | ```sh
65 | CUDA_VISIBLE_DEVICES=9 && bash scripts/text_sample.sh ckpts/greetings/ema_0.9999_005000.pt 2000 50
66 | ```
67 |
68 | * The sampling finishes with the following note:
69 | ```written the decoded output to ckpts/greetings/ema_0.9999_005000.pt.samples_50.steps-2000.clamp-no_clamp.txt```
70 |
71 |
72 | Let's see some random samples from the model (the command cleans the special tokens):
73 |
74 | ```sh
75 | shuf ckpts/greetings/ema_0.9999_005000.pt.samples_50.steps-2000.clamp-no_clamp.txt|head -n 10|cut -f 2 -d '['|cut -f2 -d ']'|sed 's/^\s*//g'
76 |
77 | i's that to see
78 | i's nice are see you right
79 | i'm let everyone you
80 | i's stopped a you you!
81 | i's glad hi see you right
82 | i's glad to you
83 | i's that to you you!
84 | i's glad to you you right
85 | i've chance to you you!
86 | i's greetings a you
87 | ```
88 |
89 | * Not bad for a model trained for 5000 steps on a small dataset!
90 |
91 |
92 |
93 |
94 | ## Controllable generation
95 |
96 | - Let's perform controllable generation. Say we only want sentences that contain the word `good`. We assign `1` to such sentences and create a labeled file here: `data/greetings/greetings_labeled.tsv`
97 |
98 |
99 |
100 | ### Step 1: Train a classifier on the latents
101 |
102 | ```sh
103 | python -u src/controllable/classifier.py --model_name_or_path ckpts/greetings/ema_0.9999_005000.pt --classifier_num_epochs 50
104 | ```
105 |
106 | - We are using the diffusion model trained above to train the classifier. Note that we don't really use the denoising process during training the classifier. We are only using the diffusion model to get the latents (i.e., run the forward or generative process).
107 |
108 | ### Step 2: Run generation!
109 |
110 |
111 | ```sh
112 | bash scripts/ctrl_text_sample.sh ckpts/greetings/ema_0.9999_005000.pt 300 50
113 | ```
114 |
115 |
116 | ## TODO
117 |
118 | - [ ] Expose arguments of the langevin function.
119 |
--------------------------------------------------------------------------------
/minimal-text-diffusion.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/minimal-text-diffusion/9303ffd481a2f647da24c6053e4dec44fd086a8d/minimal-text-diffusion.gif
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # Automatically generated by https://github.com/damnever/pigar.
2 |
3 | blobfile == 1.3.3
4 | boto3 == 1.24.85
5 | botocore == 1.27.85
6 | datasets==1.8.0
7 | ftfy == 6.1.1
8 | huggingface_hub==0.4.0
9 | mpi4py == 3.0.3
10 | numpy == 1.23.1
11 | pandas == 1.5.0
12 | regex == 2022.9.11
13 | requests == 2.28.1
14 | sacremoses == 0.0.53
15 | sentencepiece == 0.1.97
16 | six == 1.16.0
17 | spacy == 3.2.4
18 | tokenizers == 0.12.1
19 | torch == 1.10.0
20 | tqdm == 4.49.0
21 | transformers == 4.21.3
22 | wandb == 0.13.3
--------------------------------------------------------------------------------
/scripts/install.sh:
--------------------------------------------------------------------------------
1 | conda install mpi4py
2 | conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
3 | pip install -e improved-diffusion/
4 | pip install -e transformers/
5 | pip install spacy==3.2.4
6 | pip install datasets==1.8.0
7 | pip install huggingface_hub==0.4.0
8 | pip install wandb
9 |
--------------------------------------------------------------------------------
/scripts/run_train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -u
3 |
4 | DSET=${1:-simple}
5 |
6 | GPU=${2:-0}
7 | INIT_PRETRAINED_MODEL=${3:-"True"}
8 | USE_PRETRAINED_EMBEDDINGS=${4:-"True"}
9 | FREEZE_EMBEDDINGS=${5:-"False"}
10 |
11 | LR_ANNEAL_STEPS=${6:-25001}
12 | LR=${7:-0.0001}
13 | DIFFUSION_STEPS=${8:-2000}
14 | NOISE_SCHEDULE=${9:-sqrt}
15 | BATCH_SIZE=${10:-64}
16 | SEQ_LEN=${11:-50}
17 |
18 | CHECKPOINT_PATH=${12:-"ckpts/${DSET}"}
19 | TRAIN_TXT_PATH=${13:-data/${DSET}-train.txt}
20 | VAL_TXT_PATH=${14:-data/${DSET}-test.txt}
21 | IN_CHANNELS=${15:-128}
22 | WEIGHT_DECAY=${16:-0.0}
23 | SEED=${17:-10708}
24 | DROPOUT=${18:-0.1}
25 | NUM_HEADS=${19:-4}
26 | CONFIG_NAME=${20:-"bert-base-uncased"}
27 |
28 |
29 | NOTES=${18:-"Pre-trained models, pre-trained embeddings, embeddings not frozen"}
30 |
31 | mkdir -p ${CHECKPOINT_PATH}
32 |
33 | # PLEASE NOTE THE CHECKPOINT PATH!
34 | # NOTE: You can use the following checkpoint path if you're sweeping over hyperparams
35 | # ${DSET}_${CHECKPOINT_PATH}/MODEL_PT-${INIT_PRETRAINED_MODEL}_EMBEDS_PT-${USE_PRETRAINED_EMBEDDINGS}-FREEZE_EMBEDS-${FREEZE_EMBEDDINGS}"
36 |
37 |
38 |
39 |
40 | ARGS=(--checkpoint_path ${CHECKPOINT_PATH}
41 | --save_interval 50000 --lr ${LR}
42 | --batch_size ${BATCH_SIZE}
43 | --diffusion_steps ${DIFFUSION_STEPS}
44 | --noise_schedule ${NOISE_SCHEDULE}
45 | --sequence_len ${SEQ_LEN} --seed ${SEED}
46 | --dropout ${DROPOUT} --in_channel ${IN_CHANNELS}
47 | --out_channel ${IN_CHANNELS}
48 | --weight_decay ${WEIGHT_DECAY}
49 | --predict_xstart True
50 | --train_txt_path ${TRAIN_TXT_PATH}
51 | --dataset ${DSET}
52 | --val_txt_path ${VAL_TXT_PATH}
53 | --num_heads ${NUM_HEADS}
54 | --config_name ${CONFIG_NAME}
55 | --init_pretrained ${INIT_PRETRAINED_MODEL}
56 | --freeze_embeddings ${FREEZE_EMBEDDINGS}
57 | --use_pretrained_embeddings ${USE_PRETRAINED_EMBEDDINGS}
58 | --notes \""${NOTES}"\")
59 |
60 |
61 | if [ ${LR_ANNEAL_STEPS} -eq 0 ]; then
62 | LR_ANNEAL_STEPS=100
63 | DEBUG=true
64 | else
65 | DEBUG=false
66 | fi
67 |
68 | ARGS+=(--lr_anneal_steps $LR_ANNEAL_STEPS)
69 |
70 |
71 |
72 | if [ $DEBUG = true ]; then
73 | ARGS+=(--debug)
74 | fi
75 |
76 |
77 |
78 |
79 |
80 | export CUDA_VISIBLE_DEVICES=$GPU && python -u src/train_infer/train.py "${ARGS[@]}"
81 |
82 |
83 |
--------------------------------------------------------------------------------
/scripts/text_sample.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | MODEL_NAME=$1
5 | # dir of MODEL_NAME
6 |
7 | DIFFUSION_STEPS=${2:-20}
8 |
9 | NUM_SAMPLES=${3:-10}
10 |
11 | OUT_DIR=${4}
12 |
13 | if [ -z "$OUT_DIR" ]; then
14 | OUT_DIR=${MODEL_NAME}
15 | fi
16 |
17 | BATCH_SIZE=${5:-50}
18 | TOP_P=${6:-0.9}
19 | CLAMP=${7:-no_clamp}
20 | SEQ_LEN=${8:-10}
21 | SEED=${9:-10708}
22 |
23 | python -u src/train_infer/text_sample.py --model_name_or_path ${MODEL_NAME} \
24 | --batch_size ${BATCH_SIZE} --num_samples ${NUM_SAMPLES} --top_p ${TOP_P} \
25 | --seed ${SEED} \
26 | --out_dir ${OUT_DIR} --diffusion_steps ${DIFFUSION_STEPS} --clamp ${CLAMP} --sequence_len ${SEQ_LEN}
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/minimal-text-diffusion/9303ffd481a2f647da24c6053e4dec44fd086a8d/src/__init__.py
--------------------------------------------------------------------------------
/src/controllable/controllable_text_sample.py:
--------------------------------------------------------------------------------
1 | """
2 | Generate a large batch of image samples from a model and save them as a large
3 | numpy array. This can be used to produce samples for FID evaluation.
4 | """
5 | import os, json
6 | import sys
7 | from typing import List
8 | import numpy as np
9 | import torch as th
10 | import torch.distributed as dist
11 | from transformers import set_seed
12 | from functools import partial
13 | from src.utils import dist_util, logger
14 |
15 |
16 | from src.utils.args_utils import *
17 | from train_infer.factory_methods import create_model_and_diffusion
18 | from src.utils.args_utils import create_argparser, args_to_dict, model_and_diffusion_defaults
19 | from src.utils.custom_tokenizer import create_tokenizer
20 | from src.controllable.langevin import langevin_binary_classifier
21 | from src.controllable.classifier import DiffusionBertForSequenceClassification
22 |
23 |
24 | def main():
25 |
26 | args = create_argparser().parse_args()
27 |
28 | set_seed(args.seed)
29 | dist_util.setup_dist()
30 | logger.configure()
31 |
32 | # load configurations.
33 | args.checkpoint_path = os.path.split(args.model_name_or_path)[0]
34 |
35 | config_path = os.path.join(args.checkpoint_path, "training_args.json")
36 | training_args = read_training_args(config_path)
37 | training_args["batch_size"] = args.batch_size
38 | # overwrite this because we want to allow generation for any diffusion step.
39 | training_args["diffusion_steps"] = args.diffusion_steps
40 | training_args["model_name_or_path"] = args.model_name_or_path
41 | training_args["clamp"] = args.clamp
42 | training_args["out_dir"] = args.out_dir
43 | training_args["num_samples"] = args.num_samples
44 |
45 | args.__dict__.update(training_args)
46 | args.sigma_small = True
47 |
48 | logger.info(f"Init pretrained = {args.init_pretrained}")
49 | logger.info(f"Freeze embeddings = {args.freeze_embeddings}")
50 | logger.info(f"Use pretrained embeddings = {args.use_pretrained_embeddings}")
51 |
52 | model, diffusion = create_model_and_diffusion(
53 | **args_to_dict(args, model_and_diffusion_defaults().keys())
54 | )
55 | model.load_state_dict(dist_util.load_state_dict(args.model_name_or_path, map_location="cpu"))
56 | model.eval()
57 |
58 | tokenizer = create_tokenizer(
59 | return_pretokenized=args.use_pretrained_embeddings, path=f"data/{args.dataset}/"
60 | )
61 |
62 | model.config.update({"embedding_dim": args.in_channel})
63 | model.config.update({"train_diffusion_steps": args.diffusion_steps})
64 | model.config.update({"vocab_size": tokenizer.vocab_size})
65 |
66 | classifier = DiffusionBertForSequenceClassification.load_from_checkpoint(
67 | checkpoint_path=args.checkpoint_path + "/classifier.pt",
68 | config=model.config,
69 | diffusion_model=diffusion,
70 | ).to("cuda")
71 |
72 | # freeze the classifier
73 | for param in classifier.parameters():
74 | param.requires_grad = False
75 |
76 | langevin_classifier_wrapper = partial(langevin_binary_classifier, classifier=classifier)
77 |
78 | pytorch_total_params = sum(p.numel() for p in model.parameters())
79 | logger.log(f"the parameter count is {pytorch_total_params}")
80 |
81 | diffusion.rescale_timesteps = True
82 |
83 | model.to(dist_util.dev())
84 | model.eval() # DEBUG
85 |
86 | logger.log(f"Generating {args.num_samples} samples")
87 | logger.log(f"Clamping is set to {args.clamp}")
88 | all_samples = []
89 | while len(all_samples) * args.batch_size < args.num_samples:
90 | model_kwargs = {}
91 | sample_shape = (args.batch_size, args.sequence_len, model.word_embedding.weight.shape[1])
92 | sample = diffusion.p_sample_loop(
93 | model,
94 | sample_shape,
95 | clip_denoised=args.clip_denoised,
96 | denoised_fn=None,
97 | model_kwargs=model_kwargs,
98 | top_p=args.top_p,
99 | progress=True,
100 | tokenizer=tokenizer,
101 | log_verbose=True,
102 | langevin_fn=langevin_classifier_wrapper,
103 | )
104 |
105 | gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
106 | dist.all_gather(gathered_samples, sample) # gather not supported with NCCL
107 | all_samples.extend([sample.cpu().numpy() for sample in gathered_samples])
108 |
109 | logger.log(f"created {len(all_samples)} samples")
110 |
111 | arr = np.concatenate(all_samples, axis=0)
112 | arr = arr[: args.num_samples * args.mbr_sample]
113 |
114 | x_t = th.tensor(arr).cuda()
115 |
116 | logits = model.get_logits(x_t) # bsz, seqlen, vocab
117 | cands = th.topk(logits, k=1, dim=-1)
118 |
119 | decoded_sentences = []
120 |
121 | for seq in cands.indices:
122 | decoded_sentence = tokenizer.decode(seq.squeeze(1).tolist())
123 | decoded_sentences.append(decoded_sentence)
124 |
125 | dist.barrier()
126 | logger.log("sampling complete")
127 |
128 | write_outputs(args=args, sentences=decoded_sentences)
129 |
130 |
131 | def load_embeddings(checkpoint_path, tokenizer, emb_dim):
132 | embeddings = th.nn.Embedding(tokenizer.vocab_size, emb_dim)
133 | embeddings.load_state_dict(th.load(f"{checkpoint_path}/random_emb.torch"))
134 | return embeddings
135 |
136 |
137 | def read_training_args(config_path):
138 | with open(config_path, "r") as f:
139 | return json.load(f)
140 |
141 |
142 | def write_outputs(args: dict, sentences: List[str]) -> None:
143 |
144 | model_dir = os.path.split(args.model_name_or_path)[0]
145 | model_base_name = os.path.split(args.model_name_or_path)[1]
146 |
147 | num_samples = len(sentences)
148 | output_file_basepath = (
149 | os.path.join(
150 | model_dir,
151 | f"{model_base_name}.samples_{num_samples}.steps-{args.diffusion_steps}.clamp-{args.clamp}",
152 | )
153 | + ".txt.ctrl"
154 | )
155 |
156 | with open(output_file_basepath, "w") as text_fout:
157 | for generated_sentence in sentences:
158 | text_fout.write(generated_sentence + "\n")
159 |
160 | print(f"written the decoded output to {output_file_basepath}")
161 |
162 |
163 | if __name__ == "__main__":
164 | main()
165 |
--------------------------------------------------------------------------------
/src/controllable/langevin.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilizes a trained classifier model to guide the diffusion process.
3 |
4 | - Given:
5 | 1. input embeddings
6 | 2. A classifier model
7 | 3. Labels
8 |
9 | The classifier model is used to refine the input embeddings such that the logits of the classifier model are maximized for the labels.
10 | """
11 | import torch
12 |
13 |
14 | def langevin_binary_classifier(classifier, label_ids, x_t, t, num_langevin_steps: int = 3, step_size: float=1e-3): # current best.
15 |
16 | x_t_as_params = torch.nn.Parameter(x_t)
17 |
18 | with torch.enable_grad():
19 | for i in range(num_langevin_steps):
20 | optimizer = torch.optim.Adagrad([x_t_as_params], lr=step_size)
21 |
22 | optimizer.zero_grad()
23 | model_out = classifier.label_logp(inputs_with_added_noise=x_t_as_params,
24 | labels=label_ids,
25 | t=t)
26 | loss = -model_out.loss # logp
27 | loss.backward()
28 | # print(f"{i}> grad norm: {x_t_as_params.grad.data.norm(2)} | loss: {loss}")
29 |
30 | optimizer.step()
31 |
32 |
33 | x_t_as_params = torch.nn.Parameter(x_t_as_params.data.detach())
34 |
35 | return x_t_as_params.data.detach()
36 |
--------------------------------------------------------------------------------
/src/modeling/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/minimal-text-diffusion/9303ffd481a2f647da24c6053e4dec44fd086a8d/src/modeling/__init__.py
--------------------------------------------------------------------------------
/src/modeling/diffusion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/minimal-text-diffusion/9303ffd481a2f647da24c6053e4dec44fd086a8d/src/modeling/diffusion/__init__.py
--------------------------------------------------------------------------------
/src/modeling/diffusion/losses.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers for various likelihood-based losses. These are ported from the original
3 | Ho et al. diffusion models codebase:
4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py
5 | """
6 |
7 | import numpy as np
8 |
9 | import torch as th
10 |
11 |
12 | def normal_kl(mean1, logvar1, mean2, logvar2):
13 | """
14 | Compute the KL divergence between two gaussians.
15 |
16 | Shapes are automatically broadcasted, so batches can be compared to
17 | scalars, among other use cases.
18 | """
19 | tensor = None
20 | for obj in (mean1, logvar1, mean2, logvar2):
21 | if isinstance(obj, th.Tensor):
22 | tensor = obj
23 | break
24 | assert tensor is not None, "at least one argument must be a Tensor"
25 |
26 | # Force variances to be Tensors. Broadcasting helps convert scalars to
27 | # Tensors, but it does not work for th.exp().
28 | logvar1, logvar2 = [
29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
30 | for x in (logvar1, logvar2)
31 | ]
32 |
33 | # print(logvar2.shape)
34 | # temp1 = 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2))
35 | # print(f'const = {temp1.mean()}, coef={(th.exp(-logvar2) * 0.5).mean()}, mse={((mean1 - mean2) ** 2).mean().item()}')
36 |
37 | return 0.5 * (
38 | -1.0
39 | + logvar2
40 | - logvar1
41 | + th.exp(logvar1 - logvar2)
42 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
43 | )
44 |
45 |
46 | def approx_standard_normal_cdf(x):
47 | """
48 | A fast approximation of the cumulative distribution function of the
49 | standard normal.
50 | """
51 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
52 |
53 |
54 | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
55 | """
56 | Compute the log-likelihood of a Gaussian distribution discretizing to a
57 | given image.
58 |
59 | :param x: the target images. It is assumed that this was uint8 values,
60 | rescaled to the range [-1, 1].
61 | :param means: the Gaussian mean Tensor.
62 | :param log_scales: the Gaussian log stddev Tensor.
63 | :return: a tensor like x of log probabilities (in nats).
64 | """
65 | assert x.shape == means.shape == log_scales.shape
66 | centered_x = x - means
67 | inv_stdv = th.exp(-log_scales)
68 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
69 | cdf_plus = approx_standard_normal_cdf(plus_in)
70 | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
71 | cdf_min = approx_standard_normal_cdf(min_in)
72 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
73 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
74 | cdf_delta = cdf_plus - cdf_min
75 | log_probs = th.where(
76 | x < -0.999,
77 | log_cdf_plus,
78 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
79 | )
80 | assert log_probs.shape == x.shape
81 | return log_probs
82 |
83 | def gaussian_density(x, *, means, log_scales):
84 | from torch.distributions import Normal
85 | normal_dist = Normal(means, log_scales.exp())
86 | logp = normal_dist.log_prob(x)
87 | return logp
88 |
--------------------------------------------------------------------------------
/src/modeling/diffusion/nn.py:
--------------------------------------------------------------------------------
1 | """
2 | Various utilities for neural networks.
3 | """
4 |
5 | import math
6 |
7 | import torch as th
8 | import torch.nn as nn
9 |
10 |
11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
12 | class SiLU(nn.Module):
13 | def forward(self, x):
14 | return x * th.sigmoid(x)
15 |
16 |
17 | class GroupNorm32(nn.GroupNorm):
18 | def forward(self, x):
19 | return super().forward(x.float()).type(x.dtype)
20 |
21 |
22 | def conv_nd(dims, *args, **kwargs):
23 | """
24 | Create a 1D, 2D, or 3D convolution module.
25 | """
26 | if dims == 1:
27 | return nn.Conv1d(*args, **kwargs)
28 | elif dims == 2:
29 | return nn.Conv2d(*args, **kwargs)
30 | elif dims == 3:
31 | return nn.Conv3d(*args, **kwargs)
32 | raise ValueError(f"unsupported dimensions: {dims}")
33 |
34 |
35 | def linear(*args, **kwargs):
36 | """
37 | Create a linear module.
38 | """
39 | return nn.Linear(*args, **kwargs)
40 |
41 |
42 | def avg_pool_nd(dims, *args, **kwargs):
43 | """
44 | Create a 1D, 2D, or 3D average pooling module.
45 | """
46 | if dims == 1:
47 | return nn.AvgPool1d(*args, **kwargs)
48 | elif dims == 2:
49 | return nn.AvgPool2d(*args, **kwargs)
50 | elif dims == 3:
51 | return nn.AvgPool3d(*args, **kwargs)
52 | raise ValueError(f"unsupported dimensions: {dims}")
53 |
54 |
55 | def update_ema(target_params, source_params, rate=0.99):
56 | """
57 | Update target parameters to be closer to those of source parameters using
58 | an exponential moving average.
59 |
60 | :param target_params: the target parameter sequence.
61 | :param source_params: the source parameter sequence.
62 | :param rate: the EMA rate (closer to 1 means slower).
63 | """
64 | for targ, src in zip(target_params, source_params):
65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate)
66 |
67 |
68 | def zero_module(module):
69 | """
70 | Zero out the parameters of a module and return it.
71 | """
72 | for p in module.parameters():
73 | p.detach().zero_()
74 | return module
75 |
76 |
77 | def scale_module(module, scale):
78 | """
79 | Scale the parameters of a module and return it.
80 | """
81 | for p in module.parameters():
82 | p.detach().mul_(scale)
83 | return module
84 |
85 |
86 | def mean_flat(tensor):
87 | """
88 | Take the mean over all non-batch dimensions.
89 | """
90 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
91 |
92 |
93 | def normalization(channels):
94 | """
95 | Make a standard normalization layer.
96 |
97 | :param channels: number of input channels.
98 | :return: an nn.Module for normalization.
99 | """
100 | return GroupNorm32(32, channels)
101 |
102 |
103 | def timestep_embedding(timesteps, dim, max_period=10000):
104 | """
105 | Create sinusoidal timestep embeddings.
106 |
107 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
108 | These may be fractional.
109 | :param dim: the dimension of the output.
110 | :param max_period: controls the minimum frequency of the embeddings.
111 | :return: an [N x dim] Tensor of positional embeddings.
112 | """
113 | half = dim // 2
114 | freqs = th.exp(
115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
116 | ).to(device=timesteps.device)
117 | args = timesteps[:, None].float() * freqs[None]
118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
119 | if dim % 2:
120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
121 | return embedding
122 |
123 |
124 | def checkpoint(func, inputs, params, flag):
125 | """
126 | Evaluate a function without caching intermediate activations, allowing for
127 | reduced memory at the expense of extra compute in the backward pass.
128 |
129 | :param func: the function to evaluate.
130 | :param inputs: the argument sequence to pass to `func`.
131 | :param params: a sequence of parameters `func` depends on but does not
132 | explicitly take as arguments.
133 | :param flag: if False, disable gradient checkpointing.
134 | """
135 | if flag:
136 | args = tuple(inputs) + tuple(params)
137 | return CheckpointFunction.apply(func, len(inputs), *args)
138 | else:
139 | return func(*inputs)
140 |
141 |
142 | class CheckpointFunction(th.autograd.Function):
143 | @staticmethod
144 | def forward(ctx, run_function, length, *args):
145 | ctx.run_function = run_function
146 | ctx.input_tensors = list(args[:length])
147 | ctx.input_params = list(args[length:])
148 | with th.no_grad():
149 | output_tensors = ctx.run_function(*ctx.input_tensors)
150 | return output_tensors
151 |
152 | @staticmethod
153 | def backward(ctx, *output_grads):
154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
155 | with th.enable_grad():
156 | # Fixes a bug where the first op in run_function modifies the
157 | # Tensor storage in place, which is not allowed for detach()'d
158 | # Tensors.
159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
160 | output_tensors = ctx.run_function(*shallow_copies)
161 | input_grads = th.autograd.grad(
162 | output_tensors,
163 | ctx.input_tensors + ctx.input_params,
164 | output_grads,
165 | allow_unused=True,
166 | )
167 | del ctx.input_tensors
168 | del ctx.input_params
169 | del output_tensors
170 | return (None, None) + input_grads
171 |
--------------------------------------------------------------------------------
/src/modeling/diffusion/resample.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | import numpy as np
4 | import torch as th
5 | import torch.distributed as dist
6 |
7 |
8 | def create_named_schedule_sampler(name, diffusion):
9 | """
10 | Create a ScheduleSampler from a library of pre-defined samplers.
11 |
12 | :param name: the name of the sampler.
13 | :param diffusion: the diffusion object to sample for.
14 | """
15 | if name == "uniform":
16 | return UniformSampler(diffusion)
17 | elif name == "loss-second-moment":
18 | return LossSecondMomentResampler(diffusion)
19 | else:
20 | raise NotImplementedError(f"unknown schedule sampler: {name}")
21 |
22 |
23 | class ScheduleSampler(ABC):
24 | """
25 | A distribution over timesteps in the diffusion process, intended to reduce
26 | variance of the objective.
27 |
28 | By default, samplers perform unbiased importance sampling, in which the
29 | objective's mean is unchanged.
30 | However, subclasses may override sample() to change how the resampled
31 | terms are reweighted, allowing for actual changes in the objective.
32 | """
33 |
34 | @abstractmethod
35 | def weights(self):
36 | """
37 | Get a numpy array of weights, one per diffusion step.
38 |
39 | The weights needn't be normalized, but must be positive.
40 | """
41 |
42 | def sample(self, batch_size, device):
43 | """
44 | Importance-sample timesteps for a batch.
45 |
46 | :param batch_size: the number of timesteps.
47 | :param device: the torch device to save to.
48 | :return: a tuple (timesteps, weights):
49 | - timesteps: a tensor of timestep indices.
50 | - weights: a tensor of weights to scale the resulting losses.
51 | """
52 | w = self.weights()
53 | p = w / np.sum(w)
54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
55 | indices = th.from_numpy(indices_np).long().to(device)
56 | weights_np = 1 / (len(p) * p[indices_np])
57 | weights = th.from_numpy(weights_np).float().to(device)
58 | return indices, weights
59 |
60 |
61 | class UniformSampler(ScheduleSampler):
62 | def __init__(self, diffusion):
63 | self.diffusion = diffusion
64 | self._weights = np.ones([diffusion.num_timesteps])
65 |
66 | def weights(self):
67 | return self._weights
68 |
69 |
70 | class LossAwareSampler(ScheduleSampler):
71 | def update_with_local_losses(self, local_ts, local_losses):
72 | """
73 | Update the reweighting using losses from a model.
74 |
75 | Call this method from each rank with a batch of timesteps and the
76 | corresponding losses for each of those timesteps.
77 | This method will perform synchronization to make sure all of the ranks
78 | maintain the exact same reweighting.
79 |
80 | :param local_ts: an integer Tensor of timesteps.
81 | :param local_losses: a 1D Tensor of losses.
82 | """
83 | batch_sizes = [
84 | th.tensor([0], dtype=th.int32, device=local_ts.device)
85 | for _ in range(dist.get_world_size())
86 | ]
87 | dist.all_gather(
88 | batch_sizes,
89 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
90 | )
91 |
92 | # Pad all_gather batches to be the maximum batch size.
93 | batch_sizes = [x.item() for x in batch_sizes]
94 | max_bs = max(batch_sizes)
95 |
96 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
97 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
98 | dist.all_gather(timestep_batches, local_ts)
99 | dist.all_gather(loss_batches, local_losses)
100 | timesteps = [
101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
102 | ]
103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
104 | self.update_with_all_losses(timesteps, losses)
105 |
106 | @abstractmethod
107 | def update_with_all_losses(self, ts, losses):
108 | """
109 | Update the reweighting using losses from a model.
110 |
111 | Sub-classes should override this method to update the reweighting
112 | using losses from the model.
113 |
114 | This method directly updates the reweighting without synchronizing
115 | between workers. It is called by update_with_local_losses from all
116 | ranks with identical arguments. Thus, it should have deterministic
117 | behavior to maintain state across workers.
118 |
119 | :param ts: a list of int timesteps.
120 | :param losses: a list of float losses, one per timestep.
121 | """
122 |
123 |
124 | class LossSecondMomentResampler(LossAwareSampler):
125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
126 | self.diffusion = diffusion
127 | self.history_per_term = history_per_term
128 | self.uniform_prob = uniform_prob
129 | self._loss_history = np.zeros(
130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64
131 | )
132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
133 |
134 | def weights(self):
135 | if not self._warmed_up():
136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
138 | weights /= np.sum(weights)
139 | weights *= 1 - self.uniform_prob
140 | weights += self.uniform_prob / len(weights)
141 | return weights
142 |
143 | def update_with_all_losses(self, ts, losses):
144 | for t, loss in zip(ts, losses):
145 | if self._loss_counts[t] == self.history_per_term:
146 | # Shift out the oldest loss term.
147 | self._loss_history[t, :-1] = self._loss_history[t, 1:]
148 | self._loss_history[t, -1] = loss
149 | else:
150 | self._loss_history[t, self._loss_counts[t]] = loss
151 | self._loss_counts[t] += 1
152 |
153 | def _warmed_up(self):
154 | return (self._loss_counts == self.history_per_term).all()
155 |
--------------------------------------------------------------------------------
/src/modeling/diffusion/respace.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch as th
3 |
4 | from src.modeling.diffusion.gaussian_diffusion import GaussianDiffusion
5 |
6 |
7 | def space_timesteps(num_timesteps, section_counts):
8 | """
9 | Create a list of timesteps to use from an original diffusion process,
10 | given the number of timesteps we want to take from equally-sized portions
11 | of the original process.
12 |
13 | For example, if there's 300 timesteps and the section counts are [10,15,20]
14 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
15 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
16 |
17 | If the stride is a string starting with "ddim", then the fixed striding
18 | from the DDIM paper is used, and only one section is allowed.
19 |
20 | :param num_timesteps: the number of diffusion steps in the original
21 | process to divide up.
22 | :param section_counts: either a list of numbers, or a string containing
23 | comma-separated numbers, indicating the step count
24 | per section. As a special case, use "ddimN" where N
25 | is a number of steps to use the striding from the
26 | DDIM paper.
27 | :return: a set of diffusion steps from the original process to use.
28 | """
29 | if isinstance(section_counts, str):
30 | if section_counts.startswith("ddim"):
31 | desired_count = int(section_counts[len("ddim") :])
32 | for i in range(1, num_timesteps):
33 | if len(range(0, num_timesteps, i)) == desired_count:
34 | return set(range(0, num_timesteps, i))
35 | raise ValueError(
36 | f"cannot create exactly {num_timesteps} steps with an integer stride"
37 | )
38 | section_counts = [int(x) for x in section_counts.split(",")]
39 | size_per = num_timesteps // len(section_counts)
40 | extra = num_timesteps % len(section_counts)
41 | start_idx = 0
42 | all_steps = []
43 | for i, section_count in enumerate(section_counts):
44 | size = size_per + (1 if i < extra else 0)
45 | if size < section_count:
46 | raise ValueError(
47 | f"cannot divide section of {size} steps into {section_count}"
48 | )
49 | if section_count <= 1:
50 | frac_stride = 1
51 | else:
52 | frac_stride = (size - 1) / (section_count - 1)
53 | cur_idx = 0.0
54 | taken_steps = []
55 | for _ in range(section_count):
56 | taken_steps.append(start_idx + round(cur_idx))
57 | cur_idx += frac_stride
58 | all_steps += taken_steps
59 | start_idx += size
60 | return set(all_steps)
61 |
62 |
63 | class SpacedDiffusion(GaussianDiffusion):
64 | """
65 | A diffusion process which can skip steps in a base diffusion process.
66 |
67 | :param use_timesteps: a collection (sequence or set) of timesteps from the
68 | original diffusion process to retain.
69 | :param kwargs: the kwargs to create the base diffusion process.
70 | """
71 |
72 | def __init__(self, use_timesteps, **kwargs):
73 | self.use_timesteps = set(use_timesteps)
74 | self.timestep_map = []
75 | self.original_num_steps = len(kwargs["betas"])
76 |
77 | # print(kwargs.keys())
78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
79 | last_alpha_cumprod = 1.0
80 | new_betas = []
81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
82 | if i in self.use_timesteps:
83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
84 | last_alpha_cumprod = alpha_cumprod
85 | self.timestep_map.append(i)
86 | kwargs["betas"] = np.array(new_betas)
87 | super().__init__(**kwargs)
88 |
89 | def p_mean_variance(
90 | self, model, *args, **kwargs
91 | ): # pylint: disable=signature-differs
92 | # print('called p_mean_var')
93 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
94 |
95 | def training_losses(
96 | self, model, *args, **kwargs
97 | ): # pylint: disable=signature-differs
98 | # print('called training_losses')
99 | return super().training_losses(self._wrap_model(model), *args, **kwargs)
100 |
101 | def _wrap_model(self, model):
102 | if isinstance(model, _WrappedModel):
103 | return model
104 | return _WrappedModel(
105 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
106 | )
107 |
108 | def _scale_timesteps(self, t):
109 | # Scaling is done by the wrapped model.
110 | return t
111 |
112 |
113 | class _WrappedModel:
114 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
115 | self.model = model
116 | self.timestep_map = timestep_map
117 | self.rescale_timesteps = rescale_timesteps
118 | self.original_num_steps = original_num_steps
119 |
120 | def __call__(self, x, ts, **kwargs):
121 | # print(ts)
122 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
123 | new_ts = map_tensor[ts]
124 | # print(new_ts)
125 | if self.rescale_timesteps:
126 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
127 | # temp = self.model(x, new_ts, **kwargs)
128 | # print(temp.shape)
129 | # return temp
130 | # print(new_ts)
131 | return self.model(x, new_ts, **kwargs)
132 |
--------------------------------------------------------------------------------
/src/modeling/diffusion/rounding.py:
--------------------------------------------------------------------------------
1 | import torch
2 | # bert results
3 |
4 | # print( os.path.join(sys.path[0], '../../transformers/examples/pytorch/language-modeling'))
5 | # sys.path.insert(0, 'diffusion_lm/transformers/examples/pytorch/language-modeling')
6 | # sys.path.insert(0, os.path.join(sys.path[0], '../../transformers/examples/pytorch/language-modeling'))
7 | # from custom_trainer import GPT2LMHeadModelCompress, BERTModelCompress, AutoEncoderWithNoise
8 | import json
9 |
10 | def load_embeddings_and_tokenizer(modality=None, mode=None, model_name_or_path=None, emb_dim=None, checkpoint_path=None, extra_args=None):
11 |
12 | path_save_tokenizer = '{}/vocab.json'.format(checkpoint_path)
13 | print(f'loading from {path_save_tokenizer}')
14 | with open(path_save_tokenizer, 'r') as f:
15 | vocab = json.load(f)
16 | print(len(vocab))
17 | tokenizer = {v: k for k, v in vocab.items()}
18 | model = torch.nn.Embedding(tokenizer.vocab_size, emb_dim)
19 | path_save = '{}/random_emb.torch'.format(checkpoint_path)
20 | model.load_state_dict(torch.load(path_save))
21 |
22 | return model, tokenizer
23 |
24 |
25 | def load_tokenizer(modality, mode, model_name_or_path):
26 | import json
27 | path_save_tokenizer = '{}/vocab.json'.format(model_name_or_path)
28 | with open(path_save_tokenizer, 'r') as f:
29 | vocab = json.load(f)
30 | tokenizer = {v: k for k, v in vocab.items()}
31 |
32 | return tokenizer
33 |
34 | def rounding_func(mode, text_emb_lst, model, tokenizer, emb_scale_factor=1.0):
35 | decoded_out_lst = []
36 | if mode in ['random', 'random_up_proj', 'glove']:
37 | down_proj_emb = model.weight # input_embs
38 | down_proj_emb2 = None
39 |
40 |
41 | def get_knn(down_proj_emb, text_emb, dist='cos'):
42 |
43 | if dist == 'cos':
44 | adjacency = down_proj_emb @ text_emb.transpose(1, 0).to(down_proj_emb.device)
45 | elif dist == 'l2':
46 | adjacency = down_proj_emb.unsqueeze(1).expand(-1, text_emb.size(0), -1) - text_emb.unsqueeze(0).expand(
47 | down_proj_emb.size(0), -1, -1)
48 | adjacency = -torch.norm(adjacency, dim=-1)
49 | topk_out = torch.topk(adjacency, k=6, dim=0)
50 | return topk_out.values, topk_out.indices
51 |
52 | dist = 'l2'
53 | # print(npzfile['arr_0'].shape)
54 | for text_emb in text_emb_lst:
55 | import torch
56 | text_emb = torch.tensor(text_emb)
57 | # print(text_emb.shape)
58 | if len(text_emb.shape) > 2:
59 | text_emb = text_emb.view(-1, text_emb.size(-1))
60 | else:
61 | text_emb = text_emb
62 | val, indices = get_knn((down_proj_emb2 if dist == 'cos' else down_proj_emb),
63 | text_emb.to(down_proj_emb.device), dist=dist)
64 | # generated_lst.append(tuple(indices[0].tolist()))
65 |
66 | # print(indices[0].tolist())
67 | # for i in range(64):
68 | # print([tokenizer[x.item()] for x in indices[:,i]])
69 | decoded_out = " ".join([tokenizer[i] for i in indices[0].tolist()])
70 | decoded_out_lst.append(decoded_out)
71 |
72 | return decoded_out_lst
73 |
74 |
--------------------------------------------------------------------------------
/src/modeling/predictor/transformer_model.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoConfig
2 |
3 | # from transformers import BertEncoder
4 | from transformers.models.bert.modeling_bert import BertEncoder
5 | import torch
6 |
7 | import torch as th
8 | import torch.nn as nn
9 | from src.modeling.diffusion.nn import (
10 | SiLU,
11 | linear,
12 | timestep_embedding,
13 | )
14 |
15 |
16 | class TransformerNetModel(nn.Module):
17 | """
18 | A transformer model to be used in Diffusion Model Training.
19 |
20 | :param in_channels: channels in the input Tensor.
21 | :param model_channels: base channel count for the model.
22 | :param out_channels: channels in the output Tensor.
23 | :param dropout: the dropout probability.
24 | :param channel_mult: channel multiplier for each level of the UNet.
25 | :param dims: determines if the signal is 1D, 2D, or 3D.
26 | :param num_classes: if specified (as an int), then this model will be
27 | class-conditional with `num_classes` classes. TODO for the next version
28 | :param use_checkpoint: use gradient checkpointing to reduce memory usage.
29 | :param num_heads: the number of attention heads in each attention layer.
30 | """
31 |
32 | def __init__(
33 | self,
34 | in_channels,
35 | model_channels,
36 | out_channels,
37 | init_pretrained,
38 | freeze_embeddings,
39 | use_pretrained_embeddings,
40 | dropout=0,
41 | use_checkpoint=False,
42 | num_heads=1,
43 | config=None,
44 | config_name="bert-base-uncased",
45 | vocab_size=None,
46 | logits_mode=1,
47 | ):
48 | super().__init__()
49 |
50 | if config is None:
51 | config = AutoConfig.from_pretrained(config_name)
52 | config.hidden_dropout_prob = dropout
53 | # config.hidden_size = 512
54 |
55 | self.in_channels = in_channels
56 | self.model_channels = model_channels
57 | self.out_channels = out_channels
58 | self.dropout = dropout
59 | self.use_checkpoint = use_checkpoint
60 | self.num_heads = num_heads
61 | self.logits_mode = logits_mode
62 | self.vocab_size = vocab_size
63 | self.init_pretrained = init_pretrained
64 | self.freeze_embeddings = freeze_embeddings
65 | self.use_pretrained_embeddings = use_pretrained_embeddings
66 | self.config = config
67 | self.config_name = config_name
68 |
69 | time_embed_dim = model_channels * 4
70 | self.time_embed = nn.Sequential(
71 | linear(model_channels, time_embed_dim),
72 | SiLU(),
73 | linear(time_embed_dim, config.hidden_size),
74 | )
75 |
76 | self.build_xstart_predictor()
77 | self.build_input_output_projections()
78 | self.build_embeddings()
79 |
80 | self.register_buffer(
81 | "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
82 | )
83 |
84 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
85 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
86 |
87 | def build_xstart_predictor(self):
88 | if self.init_pretrained:
89 | from transformers.models.bert.modeling_bert import BertModel
90 |
91 | temp_bert = BertModel.from_pretrained(self.config_name, config=self.config)
92 | del temp_bert.pooler
93 | self.input_transformers = temp_bert.encoder
94 | else:
95 | self.input_transformers = BertEncoder(self.config)
96 |
97 | def build_input_output_projections(self):
98 | if self.use_pretrained_embeddings:
99 | self.input_up_proj = nn.Identity()
100 | self.output_down_proj = nn.Identity()
101 | else: # need to adapt the model to the embedding size
102 | self.input_up_proj = nn.Sequential(
103 | nn.Linear(self.in_channels, self.config.hidden_size),
104 | nn.Tanh(),
105 | nn.Linear(self.config.hidden_size, self.config.hidden_size),
106 | )
107 |
108 | self.output_down_proj = nn.Sequential(
109 | nn.Linear(self.config.hidden_size, self.config.hidden_size),
110 | nn.Tanh(),
111 | nn.Linear(self.config.hidden_size, self.out_channels),
112 | )
113 |
114 | def build_embeddings(self):
115 | if self.use_pretrained_embeddings:
116 | from transformers.models.bert.modeling_bert import BertModel
117 |
118 | temp_bert = BertModel.from_pretrained(self.config_name, config=self.config)
119 | self.word_embedding = temp_bert.embeddings.word_embeddings
120 | self.position_embeddings = temp_bert.embeddings.position_embeddings
121 | else:
122 | self.word_embedding = nn.Embedding(self.vocab_size, self.in_channels)
123 | self.position_embeddings = nn.Embedding(
124 | self.config.max_position_embeddings, self.config.hidden_size
125 | )
126 |
127 | self.lm_head = nn.Linear(self.in_channels, self.word_embedding.weight.shape[0])
128 |
129 | if self.freeze_embeddings:
130 | self.word_embedding.weight.requires_grad = False
131 | self.position_embeddings.weight.requires_grad = False
132 |
133 | with th.no_grad():
134 | self.lm_head.weight = self.word_embedding.weight
135 |
136 | def get_embeds(self, input_ids):
137 | return self.word_embedding(input_ids)
138 |
139 | def get_logits(self, hidden_repr):
140 | return self.lm_head(hidden_repr)
141 |
142 | def forward(self, x, timesteps, y=None, src_ids=None, src_mask=None, attention_mask=None):
143 | """
144 | Apply the model to an input batch.
145 |
146 | :param x: an [N x C x ...] Tensor of inputs.
147 | :param timesteps: a 1-D batch of timesteps.
148 | :param y: an [N] Tensor of labels, if class-conditional.
149 | :return: an [N x C x ...] Tensor of outputs.
150 | """
151 |
152 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
153 |
154 | emb_x = self.input_up_proj(x)
155 | seq_length = x.size(1)
156 | position_ids = self.position_ids[:, :seq_length]
157 | emb_inputs = (
158 | self.position_embeddings(position_ids)
159 | + emb_x
160 | + emb.unsqueeze(1).expand(-1, seq_length, -1)
161 | )
162 | emb_inputs = self.dropout(self.LayerNorm(emb_inputs))
163 |
164 | # https://github.com/huggingface/transformers/blob/e95d433d77727a9babadf008dd621a2326d37303/src/transformers/modeling_utils.py#L700
165 | if attention_mask is not None:
166 | attention_mask = attention_mask[:, None, None, :]
167 |
168 | input_trans_hidden_states = self.input_transformers(
169 | emb_inputs, attention_mask=attention_mask
170 | ).last_hidden_state
171 |
172 | h = self.output_down_proj(input_trans_hidden_states)
173 | h = h.type(x.dtype)
174 | return h
175 |
--------------------------------------------------------------------------------
/src/train_infer/factory_methods.py:
--------------------------------------------------------------------------------
1 | import src.modeling.diffusion.gaussian_diffusion as gd
2 | from src.modeling.diffusion.respace import SpacedDiffusion, space_timesteps
3 | from src.modeling.predictor.transformer_model import TransformerNetModel
4 |
5 |
6 | def create_model_and_diffusion(
7 | class_cond,
8 | learn_sigma,
9 | sigma_small,
10 | num_channels,
11 | num_heads,
12 | dropout,
13 | diffusion_steps,
14 | noise_schedule,
15 | timestep_respacing,
16 | use_kl,
17 | predict_xstart,
18 | rescale_timesteps,
19 | rescale_learned_sigmas,
20 | use_checkpoint,
21 | model_arch,
22 | in_channel,
23 | out_channel,
24 | training_mode,
25 | vocab_size,
26 | config_name,
27 | logits_mode,
28 | init_pretrained,
29 | freeze_embeddings,
30 | use_pretrained_embeddings,
31 | **kwargs,
32 | ):
33 | model = create_model(
34 | num_channels,
35 | learn_sigma=learn_sigma,
36 | class_cond=class_cond,
37 | use_checkpoint=use_checkpoint,
38 | num_heads=num_heads,
39 | dropout=dropout,
40 | in_channel=in_channel,
41 | out_channel=out_channel,
42 | training_mode=training_mode,
43 | vocab_size=vocab_size,
44 | config_name=config_name,
45 | logits_mode=logits_mode,
46 | init_pretrained=init_pretrained,
47 | freeze_embeddings=freeze_embeddings,
48 | use_pretrained_embeddings=use_pretrained_embeddings,
49 | )
50 | diffusion = create_gaussian_diffusion(
51 | steps=diffusion_steps,
52 | learn_sigma=learn_sigma,
53 | sigma_small=sigma_small,
54 | noise_schedule=noise_schedule,
55 | use_kl=use_kl,
56 | predict_xstart=predict_xstart,
57 | rescale_timesteps=rescale_timesteps,
58 | rescale_learned_sigmas=rescale_learned_sigmas,
59 | timestep_respacing=timestep_respacing,
60 | model_arch=model_arch,
61 | training_mode=training_mode,
62 | )
63 | return model, diffusion
64 |
65 |
66 | def create_model(
67 | num_channels,
68 | learn_sigma,
69 | use_checkpoint,
70 | class_cond, # TODO for the next version
71 | num_heads,
72 | dropout,
73 | init_pretrained,
74 | freeze_embeddings,
75 | use_pretrained_embeddings,
76 | in_channel,
77 | out_channel,
78 | training_mode,
79 | vocab_size,
80 | config_name,
81 | logits_mode,
82 | ):
83 |
84 | return TransformerNetModel(
85 | in_channels=in_channel,
86 | model_channels=num_channels,
87 | out_channels=(out_channel if not learn_sigma else out_channel * 2),
88 | dropout=dropout,
89 | use_checkpoint=use_checkpoint,
90 | num_heads=num_heads,
91 | config_name=config_name,
92 | vocab_size=vocab_size,
93 | logits_mode=logits_mode,
94 | init_pretrained=init_pretrained,
95 | use_pretrained_embeddings=use_pretrained_embeddings,
96 | freeze_embeddings=freeze_embeddings,
97 |
98 | )
99 |
100 |
101 | def create_gaussian_diffusion(
102 | *,
103 | steps=1000,
104 | learn_sigma=False,
105 | sigma_small=False,
106 | noise_schedule="linear",
107 | use_kl=False,
108 | predict_xstart=False,
109 | rescale_timesteps=False,
110 | rescale_learned_sigmas=False,
111 | timestep_respacing="",
112 | model_arch="transformer",
113 | training_mode="diffusion-lm",
114 | ):
115 | betas = gd.get_named_beta_schedule(noise_schedule, steps)
116 |
117 | if use_kl:
118 | loss_type = gd.LossType.E2E_KL
119 | else:
120 | loss_type = gd.LossType.E2E_MSE
121 |
122 | if not timestep_respacing:
123 | timestep_respacing = [steps]
124 |
125 | # Whether variance is learned or fixed
126 | model_var_type = None
127 | if not learn_sigma:
128 | if sigma_small:
129 | model_var_type = gd.ModelVarType.FIXED_SMALL
130 | else:
131 | model_var_type = gd.ModelVarType.FIXED_LARGE
132 | else:
133 | model_var_type = gd.ModelVarType.LEARNED_RANGE
134 |
135 | # what is the interpretation of the output generated by the model? Is it generating the noise or the mean directly?
136 |
137 | model_mean_type = None
138 | if not predict_xstart:
139 | model_mean_type = gd.ModelMeanType.EPSILON # predicts noise
140 | else: # predicts starting x (x0 estimate, possibly used by DDIM?)
141 | model_mean_type = gd.ModelMeanType.START_X
142 |
143 | return SpacedDiffusion(
144 | use_timesteps=space_timesteps(steps, timestep_respacing),
145 | betas=betas,
146 | model_var_type=model_var_type,
147 | model_mean_type=model_mean_type,
148 | loss_type=loss_type,
149 | rescale_timesteps=rescale_timesteps,
150 | model_arch=model_arch,
151 | training_mode=training_mode,
152 | )
153 |
--------------------------------------------------------------------------------
/src/train_infer/text_sample.py:
--------------------------------------------------------------------------------
1 | """
2 | Generate a large batch of image samples from a model and save them as a large
3 | numpy array. This can be used to produce samples for FID evaluation.
4 | """
5 | import os, json
6 | from typing import List
7 | import numpy as np
8 | import torch as th
9 | import torch.distributed as dist
10 | from transformers import set_seed
11 | from src.utils import dist_util, logger
12 |
13 | from src.utils.args_utils import *
14 | from train_infer.factory_methods import create_model_and_diffusion
15 | from src.utils.args_utils import create_argparser, args_to_dict, model_and_diffusion_defaults
16 | from src.utils.custom_tokenizer import create_tokenizer
17 |
18 |
19 |
20 |
21 |
22 | def main():
23 |
24 | args = create_argparser().parse_args()
25 |
26 | set_seed(args.seed)
27 | dist_util.setup_dist()
28 | logger.configure()
29 |
30 | # load configurations.
31 | args.checkpoint_path = os.path.split(args.model_name_or_path)[0]
32 |
33 | config_path = os.path.join(args.checkpoint_path, "training_args.json")
34 | training_args = read_training_args(config_path)
35 | training_args["batch_size"] = args.batch_size
36 | training_args["diffusion_steps"] = args.diffusion_steps
37 | training_args['model_name_or_path'] = args.model_name_or_path
38 | training_args["clamp"] = args.clamp
39 | training_args['out_dir'] = args.out_dir
40 | training_args['num_samples'] = args.num_samples
41 |
42 | args.__dict__.update(training_args)
43 | args.sigma_small = True
44 |
45 |
46 | logger.info(f"Init pretrained = {args.init_pretrained}")
47 | logger.info(f"Freeze embeddings = {args.freeze_embeddings}")
48 | logger.info(f"Use pretrained embeddings = {args.use_pretrained_embeddings}")
49 |
50 | model, diffusion = create_model_and_diffusion(
51 | **args_to_dict(args, model_and_diffusion_defaults().keys())
52 | )
53 | model.load_state_dict(dist_util.load_state_dict(args.model_name_or_path, map_location="cpu"))
54 | model.eval()
55 |
56 | tokenizer = create_tokenizer(return_pretokenized=args.use_pretrained_embeddings, path=f"data/{args.dataset}/")
57 |
58 | pytorch_total_params = sum(p.numel() for p in model.parameters())
59 | logger.log(f"the parameter count is {pytorch_total_params}")
60 |
61 | diffusion.rescale_timesteps = True
62 |
63 | model.to(dist_util.dev())
64 | model.eval() # DEBUG
65 |
66 |
67 | logger.log("sampling...")
68 | logger.log(f"Clamping is set to {args.clamp}")
69 | all_samples = []
70 | while len(all_samples) * args.batch_size < args.num_samples:
71 | model_kwargs = {}
72 | sample_shape = (args.batch_size, args.sequence_len, model.word_embedding.weight.shape[1])
73 | sample = diffusion.p_sample_loop(
74 | model,
75 | sample_shape,
76 | clip_denoised=args.clip_denoised,
77 | denoised_fn=None,
78 | model_kwargs=model_kwargs,
79 | top_p=args.top_p,
80 | progress=True,
81 | tokenizer=tokenizer,
82 | log_verbose=True
83 | )
84 |
85 |
86 | gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
87 | dist.all_gather(gathered_samples, sample) # gather not supported with NCCL
88 | all_samples.extend([sample.cpu().numpy() for sample in gathered_samples])
89 |
90 | logger.log(f"created {len(all_samples) * args.batch_size} samples")
91 |
92 | arr = np.concatenate(all_samples, axis=0)
93 | arr = arr[: args.num_samples * args.mbr_sample]
94 |
95 | x_t = th.tensor(arr).cuda()
96 |
97 | logits = model.get_logits(x_t) # bsz, seqlen, vocab
98 | cands = th.topk(logits, k=1, dim=-1)
99 |
100 | decoded_sentences = []
101 |
102 | for seq in cands.indices:
103 | decoded_sentence = tokenizer.decode(seq.squeeze(1).tolist())
104 | decoded_sentences.append(decoded_sentence)
105 |
106 | dist.barrier()
107 | logger.log("sampling complete")
108 |
109 | write_outputs(args=args, sentences=decoded_sentences)
110 |
111 |
112 | def load_embeddings(checkpoint_path, tokenizer, emb_dim):
113 | embeddings = th.nn.Embedding(tokenizer.vocab_size, emb_dim)
114 | embeddings.load_state_dict(th.load(f'{checkpoint_path}/random_emb.torch'))
115 | return embeddings
116 |
117 |
118 | def read_training_args(config_path):
119 | with open(config_path, "r") as f:
120 | return json.load(f)
121 |
122 |
123 | def write_outputs(args: dict, sentences: List[str]) -> None:
124 |
125 | model_dir = os.path.split(args.model_name_or_path)[0]
126 | model_base_name = os.path.split(args.model_name_or_path)[1]
127 |
128 | num_samples = len(sentences)
129 | output_file_basepath = os.path.join(
130 | model_dir,
131 | f"{model_base_name}.samples_{num_samples}.steps-{args.diffusion_steps}.clamp-{args.clamp}",
132 | ) + ".txt"
133 | with open(output_file_basepath, "w") as text_fout:
134 | for generated_sentence in sentences:
135 | text_fout.write(generated_sentence + "\n")
136 |
137 | print(f"written the decoded output to {output_file_basepath}")
138 |
139 |
140 | if __name__ == "__main__":
141 | main()
142 |
--------------------------------------------------------------------------------
/src/train_infer/train.py:
--------------------------------------------------------------------------------
1 | """
2 | Train a diffusion model on images.
3 | """
4 |
5 | import json, os
6 | import pathlib
7 | import pprint
8 | import sys
9 | import wandb
10 | from transformers import set_seed
11 | import os
12 |
13 | from src.utils import dist_util, logger
14 | from src.modeling.diffusion.resample import create_named_schedule_sampler
15 | from train_infer.factory_methods import create_model_and_diffusion
16 | from train_loop import TrainLoop
17 | from src.utils import data_utils_sentencepiece
18 | from src.utils.args_utils import create_argparser, args_to_dict, model_and_diffusion_defaults
19 | from src.utils.custom_tokenizer import create_tokenizer
20 |
21 |
22 | def main():
23 | args = create_argparser().parse_args()
24 | set_seed(args.seed)
25 | dist_util.setup_dist() # DEBUG **
26 | logger.configure()
27 |
28 |
29 | logger.log("creating data loader")
30 |
31 | pathlib.Path(args.checkpoint_path).mkdir(parents=True, exist_ok=True)
32 |
33 | tokenizer = create_tokenizer(return_pretokenized=args.use_pretrained_embeddings, path=f"data/{args.dataset}/")
34 |
35 | train_dataloader = data_utils_sentencepiece.get_dataloader(
36 | tokenizer=tokenizer,
37 | data_path=args.train_txt_path,
38 | batch_size=args.batch_size,
39 | max_seq_len=args.sequence_len
40 | )
41 |
42 | val_dataloader = data_utils_sentencepiece.get_dataloader(
43 | tokenizer=tokenizer,
44 | data_path=args.val_txt_path,
45 | batch_size=args.batch_size,
46 | max_seq_len=args.sequence_len
47 | )
48 |
49 |
50 | args.vocab_size = tokenizer.vocab_size
51 |
52 | logger.log("creating model and diffusion...")
53 |
54 |
55 |
56 | model, diffusion = create_model_and_diffusion(
57 | **args_to_dict(args, model_and_diffusion_defaults().keys())
58 | )
59 | model.to(dist_util.dev()) # DEBUG **
60 | # model.cuda() # DEBUG **
61 |
62 | print(model)
63 |
64 | pytorch_total_params = sum(p.numel() for p in model.parameters())
65 |
66 | logger.log(f"the parameter count is {pytorch_total_params}")
67 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)
68 |
69 | logger.log(f"saving the hyperparameters to {args.checkpoint_path}/training_args.json")
70 | with open(f"{args.checkpoint_path}/training_args.json", "w") as f:
71 | json.dump(args.__dict__, f, indent=2)
72 |
73 | if args.debug:
74 | wandb.init(mode="disabled")
75 | else:
76 | wandb.init(
77 | project=os.getenv("WANDB_PROJECT", "minimial-text-diffusion"),
78 | name=args.checkpoint_path + make_wandb_name_from_args(args),
79 | notes=args.notes,
80 | )
81 | wandb.config.update(args.__dict__, allow_val_change=True)
82 |
83 | logger.log("training...")
84 | TrainLoop(
85 | model=model,
86 | diffusion=diffusion,
87 | data=train_dataloader,
88 | batch_size=args.batch_size,
89 | microbatch=args.microbatch,
90 | lr=args.lr,
91 | ema_rate=args.ema_rate,
92 | log_interval=args.log_interval,
93 | save_interval=args.save_interval,
94 | resume_checkpoint=args.resume_checkpoint,
95 | use_fp16=args.use_fp16,
96 | fp16_scale_growth=args.fp16_scale_growth,
97 | schedule_sampler=schedule_sampler,
98 | weight_decay=args.weight_decay,
99 | lr_anneal_steps=args.lr_anneal_steps,
100 | checkpoint_path=args.checkpoint_path,
101 | gradient_clipping=args.gradient_clipping,
102 | eval_data=val_dataloader,
103 | eval_interval=args.eval_interval,
104 | ).run_loop()
105 |
106 |
107 | def make_wandb_name_from_args(args):
108 | keys_to_add = ["batch_size", "lr", "num_heads", "lr_anneal_steps", "config_name", "seed", "in_channel"]
109 | name = ""
110 | for key in keys_to_add:
111 | name += f"{key}={getattr(args, key)}_"
112 | return name
113 |
114 | if __name__ == "__main__":
115 | main()
116 |
--------------------------------------------------------------------------------
/src/utils/args_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for command line arguments.
3 | """
4 |
5 | import argparse
6 |
7 |
8 |
9 | def create_argparser():
10 | defaults = dict(
11 | data_dir="",
12 | schedule_sampler="uniform",
13 | lr=1e-4,
14 | weight_decay=0.0,
15 | lr_anneal_steps=30000,
16 | batch_size=1,
17 | microbatch=-1, # -1 disables microbatches
18 | ema_rate="0.9999", # comma-separated list of EMA values
19 | log_interval=50,
20 | save_interval=25000,
21 | resume_checkpoint="",
22 | use_fp16=False,
23 | fp16_scale_growth=1e-3,
24 | seed=101,
25 | gradient_clipping=-1.0,
26 | eval_interval=2000,
27 | checkpoint_path="diff_models",
28 | train_txt_path="data/quotes_train.txt",
29 | val_txt_path="data/quotes_valid.txt",
30 | dataset="",
31 | notes="",
32 | )
33 | text_defaults = dict(
34 | modality="text",
35 | emb_scale_factor=1.0,
36 | in_channel=16,
37 | out_channel=16,
38 | noise_level=0.0,
39 | cache_mode="no",
40 | use_bert_tokenizer="no",
41 | padding_mode="block",
42 | preprocessing_num_workers=1,
43 | tok_thresh=150
44 | )
45 |
46 | guided_generation_defaults = dict(
47 | classifier_num_epochs=15
48 | )
49 |
50 | defaults.update(model_and_diffusion_defaults())
51 | defaults.update(text_defaults)
52 | defaults.update(guided_generation_defaults)
53 | defaults.update(decoding_defaults())
54 | parser = argparse.ArgumentParser()
55 | parser.add_argument("--debug", action="store_true")
56 |
57 | add_dict_to_argparser(parser, defaults)
58 | return parser
59 |
60 |
61 | def model_and_diffusion_defaults():
62 | """
63 | Defaults for text-diffusion model training.
64 | """
65 | return dict(
66 | sequence_len=64,
67 | num_channels=16,
68 | num_heads=4,
69 | dropout=0.0,
70 | learn_sigma=False,
71 | sigma_small=False,
72 | class_cond=False,
73 | diffusion_steps=10000,
74 | noise_schedule="linear",
75 | timestep_respacing="",
76 | use_kl=False,
77 | predict_xstart=False,
78 | rescale_timesteps=True,
79 | rescale_learned_sigmas=True,
80 | use_checkpoint=False,
81 | model_arch="transformer",
82 | in_channel=16,
83 | out_channel=16,
84 | vocab_size=66,
85 | config_name="bert-base-uncased",
86 | logits_mode=1,
87 | training_mode="diffusion-lm",
88 | init_pretrained=False,
89 | freeze_embeddings=False,
90 | use_pretrained_embeddings=True,
91 | )
92 |
93 |
94 | def decoding_defaults():
95 | return dict(
96 | num_samples=50,
97 | top_p=0.9,
98 | out_dir="",
99 | model_name_or_path="",
100 | checkpoint_path="",
101 | use_ddim=False,
102 | clip_denoised=False,
103 | batch_size=64,
104 | mbr_sample=1,
105 | verbose="yes",
106 | clamp="clamp",
107 | preprocessing_num_workers=1,
108 | emb_scale_factor=1.0,
109 | classifier_path="",
110 | )
111 |
112 |
113 | def add_dict_to_argparser(parser, default_dict):
114 | for k, v in default_dict.items():
115 | v_type = type(v)
116 | if v is None:
117 | v_type = str
118 | elif isinstance(v, bool):
119 | v_type = str2bool
120 | parser.add_argument(f"--{k}", default=v, type=v_type)
121 |
122 |
123 | def args_to_dict(args, keys):
124 | return {k: getattr(args, k) for k in keys}
125 |
126 |
127 | def str2bool(v):
128 | """
129 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
130 | """
131 | if isinstance(v, bool):
132 | return v
133 | if v.lower() in ("yes", "true", "t", "y", "1"):
134 | return True
135 | elif v.lower() in ("no", "false", "f", "n", "0"):
136 | return False
137 | else:
138 | raise argparse.ArgumentTypeError("boolean value expected")
139 |
--------------------------------------------------------------------------------
/src/utils/custom_tokenizer.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import pathlib
4 | import torch
5 | from transformers import AutoTokenizer
6 |
7 | from tokenizers.processors import BertProcessing
8 | from tokenizers import ByteLevelBPETokenizer, decoders
9 |
10 | logging.basicConfig(level=logging.INFO)
11 |
12 | def create_tokenizer(return_pretokenized, path, tokenizer_type: str = "word-level"):
13 | if return_pretokenized:
14 | tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
15 | return tokenizer
16 |
17 | if tokenizer_type == "byte-level":
18 | return read_byte_level(path)
19 | elif tokenizer_type == "word-level":
20 | return read_word_level(path)
21 | else:
22 | raise ValueError(f"Invalid tokenizer type: {tokenizer_type}")
23 |
24 | def train_bytelevel(
25 | path,
26 | vocab_size=10000,
27 | min_frequency=1,
28 | special_tokens=["", "", "", "", ""],
29 | ):
30 |
31 | tokenizer = ByteLevelBPETokenizer()
32 |
33 | # Customize training
34 | tokenizer.train(
35 | files=[path],
36 | vocab_size=vocab_size,
37 | min_frequency=min_frequency,
38 | special_tokens=special_tokens,
39 | )
40 |
41 | tokenizer.save_model(str(pathlib.Path(path).parent))
42 |
43 |
44 |
45 | def read_byte_level(path: str):
46 | tokenizer = ByteLevelBPETokenizer(
47 | f"{path}/vocab.json",
48 | f"{path}/merges.txt",
49 | )
50 |
51 | tokenizer._tokenizer.post_processor = BertProcessing(
52 | ("", tokenizer.token_to_id("")),
53 | ("", tokenizer.token_to_id("")),
54 | )
55 |
56 | tokenizer.enable_truncation(max_length=512)
57 |
58 | print(
59 | tokenizer.encode(
60 | "Bores can be divided into two classes; those who have their own particular subject, and those who do not need a subject."
61 | ).tokens
62 | )
63 |
64 | with open(f"{path}/vocab.json", "r") as fin:
65 | vocab = json.load(fin)
66 |
67 | # add length method to tokenizer object
68 | tokenizer.vocab_size = len(vocab)
69 |
70 | # add length property to tokenizer object
71 | tokenizer.__len__ = property(lambda self: self.vocab_size)
72 |
73 | tokenizer.decoder = decoders.ByteLevel()
74 | print(tokenizer.vocab_size)
75 |
76 | print(
77 | tokenizer.encode(
78 | "Bores can be divided into two classes; those who have their own particular subject, and those who do not need a subject."
79 | ).ids
80 | )
81 |
82 | print(
83 | tokenizer.decode(
84 | tokenizer.encode(
85 | "Bores can be divided into two classes; those who have their own particular subject, and those who do not need a subject."
86 | ).ids,
87 | skip_special_tokens=True,
88 | )
89 | )
90 |
91 | ids = tokenizer.encode(
92 | "Bores can be divided into two classes; those who have their own particular subject, and those who do not need a subject."
93 | ).ids
94 | tensor = torch.tensor(ids)
95 | print(tokenizer.decode(tensor.tolist(), skip_special_tokens=True))
96 | print(f"Vocab size: {tokenizer.vocab_size}")
97 |
98 | return tokenizer
99 |
100 |
101 | def read_word_level(path: str):
102 |
103 | from transformers import PreTrainedTokenizerFast
104 |
105 | logging.info(f"Loading tokenizer from {path}/word-level-vocab.json")
106 | tokenizer = PreTrainedTokenizerFast(
107 | tokenizer_file=f"{str(pathlib.Path(path))}/word-level-vocab.json",
108 | bos_token="[CLS]",
109 | eos_token="[SEP]",
110 | unk_token="[UNK]",
111 | sep_token="[SEP]",
112 | pad_token="[PAD]",
113 | cls_token="[CLS]",
114 | mask_token="[MASK]",
115 | padding_side="right",
116 | )
117 |
118 | # add length property to tokenizer object
119 | tokenizer.__len__ = property(lambda self: self.vocab_size)
120 |
121 | return tokenizer
122 |
123 |
124 | def train_word_level_tokenizer(
125 | path: str,
126 | vocab_size: int = 10000,
127 | special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
128 | ):
129 |
130 | from tokenizers import Tokenizer, normalizers, pre_tokenizers
131 | from tokenizers.models import WordLevel
132 | from tokenizers.normalizers import NFD, Lowercase, StripAccents
133 | from tokenizers.pre_tokenizers import Digits, Whitespace
134 | from tokenizers.processors import TemplateProcessing
135 | from tokenizers.trainers import WordLevelTrainer
136 |
137 | tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
138 | tokenizer.normalizer = normalizers.Sequence([NFD(), Lowercase(), StripAccents()])
139 | tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
140 | [Digits(individual_digits=True), Whitespace()]
141 | )
142 | tokenizer.post_processor = TemplateProcessing(
143 | single="[CLS] $A [SEP]", special_tokens=[("[CLS]", 1), ("[SEP]", 2)]
144 | )
145 |
146 | trainer = WordLevelTrainer(vocab_size=vocab_size, special_tokens=special_tokens)
147 | tokenizer.train(files=[path], trainer=trainer)
148 |
149 | tokenizer.__len__ = property(lambda self: self.vocab_size)
150 |
151 | tokenizer.enable_truncation(max_length=512)
152 |
153 | print(tokenizer.encode("the red.").ids)
154 |
155 | print(tokenizer.encode("the red."))
156 |
157 | tokenizer.save(f"{str(pathlib.Path(path).parent)}/word-level-vocab.json")
158 |
159 |
160 | if __name__ == "__main__":
161 | import sys
162 |
163 | if sys.argv[1] == "train-word-level":
164 | train_word_level_tokenizer(path=sys.argv[2])
165 | elif sys.argv[1] == "train-byte-level":
166 | train_bytelevel(path=sys.argv[2])
167 | elif sys.argv[1] == "create":
168 | create_tokenizer(path=sys.argv[2])
169 |
--------------------------------------------------------------------------------
/src/utils/data_utils_sentencepiece.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch
3 | import pandas as pd
4 | from torch.utils.data import DataLoader, Dataset
5 | import torch
6 | from functools import partial
7 |
8 | logging.basicConfig(level=logging.INFO)
9 |
10 | # BAD: this should not be global
11 | # tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
12 |
13 |
14 |
15 |
16 | def get_dataloader(tokenizer, data_path, batch_size, max_seq_len):
17 | dataset = TextDataset(tokenizer=tokenizer, data_path=data_path)
18 |
19 | dataloader = DataLoader(
20 | dataset,
21 | batch_size=batch_size, # 20,
22 | drop_last=True,
23 | shuffle=True,
24 | num_workers=1,
25 | collate_fn=partial(TextDataset.collate_pad, cutoff=max_seq_len),
26 | )
27 |
28 | while True:
29 | for batch in dataloader:
30 | yield batch
31 |
32 |
33 | class TextDataset(Dataset):
34 | def __init__(
35 | self,
36 | tokenizer,
37 | data_path: str,
38 | has_labels: bool = False
39 | ) -> None:
40 | super().__init__()
41 | self.data_path = data_path
42 | self.tokenizer = tokenizer
43 | self.read_data()
44 | if has_labels:
45 | self.read_labels()
46 |
47 | def read_data(self):
48 | logging.info("Reading data from {}".format(self.data_path))
49 | data = pd.read_csv(self.data_path, sep="\t", header=None) # read text file
50 | logging.info(f"Tokenizing {len(data)} sentences")
51 |
52 | self.text = data[0].apply(lambda x: x.strip()).tolist()
53 | # encoded_input = self.tokenizer(self.questions, self.paragraphs)
54 |
55 | # check if tokenizer has a method 'encode_batch'
56 | if hasattr(self.tokenizer, 'encode_batch'):
57 |
58 | encoded_input = self.tokenizer.encode_batch(self.text)
59 | self.input_ids = [x.ids for x in encoded_input]
60 |
61 | else:
62 | encoded_input = self.tokenizer(self.text)
63 | self.input_ids = encoded_input["input_ids"]
64 |
65 | def read_labels(self):
66 | self.labels = pd.read_csv(self.data_path, sep="\t", header=None)[1].tolist()
67 | # check if labels are already numerical
68 | self.labels = [str(x) for x in self.labels]
69 | if isinstance(self.labels[0], int):
70 | return
71 | # if not, convert to numerical
72 | all_labels = sorted(list(set(self.labels)))
73 | self.label_to_idx = {label: i for i, label in enumerate(all_labels)}
74 | self.idx_to_label = {i: label for i, label in self.label_to_idx.items()}
75 | self.labels = [self.label_to_idx[label] for label in self.labels]
76 |
77 |
78 |
79 | def __len__(self) -> int:
80 | return len(self.text)
81 |
82 | def __getitem__(self, i):
83 | out_dict = {
84 | "input_ids": self.input_ids[i],
85 | # "attention_mask": [1] * len(self.input_ids[i]),
86 | }
87 | if hasattr(self, "labels"):
88 | out_dict["label"] = self.labels[i]
89 | return out_dict
90 |
91 | @staticmethod
92 | def collate_pad(batch, cutoff: int):
93 | max_token_len = 0
94 | num_elems = len(batch)
95 | # batch[0] -> __getitem__[0] --> returns a tuple (embeddings, out_dict)
96 |
97 | for i in range(num_elems):
98 | max_token_len = max(max_token_len, len(batch[i]["input_ids"]))
99 |
100 | max_token_len = min(cutoff, max_token_len)
101 |
102 | tokens = torch.zeros(num_elems, max_token_len).long()
103 | tokens_mask = torch.zeros(num_elems, max_token_len).long()
104 |
105 | has_labels = False
106 | if "label" in batch[0]:
107 | labels = torch.zeros(num_elems).long()
108 | has_labels = True
109 |
110 | for i in range(num_elems):
111 | toks = batch[i]["input_ids"]
112 | length = len(toks)
113 | tokens[i, :length] = torch.LongTensor(toks)
114 | tokens_mask[i, :length] = 1
115 | if has_labels:
116 | labels[i] = batch[i]["label"]
117 |
118 | # TODO: the first return None is just for backward compatibility -- can be removed
119 | if has_labels:
120 | return None, {"input_ids": tokens, "attention_mask": tokens_mask, "labels": labels}
121 | else:
122 | return None, {"input_ids": tokens, "attention_mask": tokens_mask}
123 |
--------------------------------------------------------------------------------
/src/utils/dist_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers for distributed training.
3 | """
4 |
5 | import io
6 | import os
7 | import socket
8 |
9 | import blobfile as bf
10 | from mpi4py import MPI
11 | import torch as th
12 | import torch.distributed as dist
13 |
14 | # Change this to reflect your cluster layout.
15 | # The GPU for a given rank is (rank % GPUS_PER_NODE).
16 | GPUS_PER_NODE = 10 #8
17 |
18 | SETUP_RETRY_COUNT = 3
19 |
20 |
21 | def setup_dist():
22 | """
23 | Setup a distributed process group.
24 | """
25 | if dist.is_initialized():
26 | return
27 |
28 | comm = MPI.COMM_WORLD
29 | backend = "gloo" if not th.cuda.is_available() else "nccl"
30 |
31 | if backend == "gloo":
32 | hostname = "localhost"
33 | else:
34 | hostname = socket.gethostbyname(socket.getfqdn())
35 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0)
36 | os.environ["RANK"] = str(comm.rank)
37 | os.environ["WORLD_SIZE"] = str(comm.size)
38 |
39 |
40 | port = comm.bcast(_find_free_port(), root=0)
41 | os.environ["MASTER_PORT"] = str(port)
42 | dist.init_process_group(backend=backend, init_method="env://")
43 |
44 |
45 | def dev():
46 | """
47 | Get the device to use for torch.distributed.
48 | """
49 | if th.cuda.is_available():
50 | return th.device(f"cuda:{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}")
51 | return th.device("cpu")
52 |
53 |
54 | def load_state_dict(path, **kwargs):
55 | """
56 | Load a PyTorch file without redundant fetches across MPI ranks.
57 | """
58 | if MPI.COMM_WORLD.Get_rank() == 0:
59 | with bf.BlobFile(path, "rb") as f:
60 | data = f.read()
61 | else:
62 | data = None
63 | data = MPI.COMM_WORLD.bcast(data)
64 | return th.load(io.BytesIO(data), **kwargs)
65 |
66 |
67 | def sync_params(params):
68 | """
69 | Synchronize a sequence of Tensors across ranks from rank 0.
70 | """
71 | for p in params:
72 | with th.no_grad():
73 | dist.broadcast(p, 0)
74 |
75 |
76 | def _find_free_port():
77 | try:
78 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
79 | s.bind(("", 0))
80 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
81 | return s.getsockname()[1]
82 | finally:
83 | s.close()
84 |
--------------------------------------------------------------------------------
/src/utils/eval_ppl.py:
--------------------------------------------------------------------------------
1 | """
2 | Evaluates perplexity of a model on a dataset.
3 | Directly taken from https://huggingface.co/spaces/evaluate-measurement/perplexity/blob/main/perplexity.py
4 | """
5 |
6 | from itertools import chain
7 | from typing import List
8 | import torch
9 |
10 | import numpy as np
11 | import torch
12 | from torch.nn import CrossEntropyLoss
13 | from transformers import AutoModelForCausalLM, AutoTokenizer
14 |
15 | device = "cuda" if torch.cuda.is_available() else "cpu"
16 |
17 | model_id = "distilgpt2"
18 | model = AutoModelForCausalLM.from_pretrained(model_id)
19 | model = model.to(device)
20 |
21 | tokenizer = AutoTokenizer.from_pretrained(model_id)
22 |
23 |
24 | def compute_perplexity(data, batch_size: int = 16, add_start_token: bool = True):
25 |
26 |
27 |
28 | # if batch_size > 1 (which generally leads to padding being required), and
29 | # if there is not an already assigned pad_token, assign an existing
30 | # special token to also be the padding token
31 | if tokenizer.pad_token is None and batch_size > 1:
32 | existing_special_tokens = list(tokenizer.special_tokens_map_extended.values())
33 | # check that the model already has at least one special token defined
34 | assert (
35 | len(existing_special_tokens) > 0
36 | ), "If batch_size > 1, model must have at least one special token to use for padding. Please use a different model or set batch_size=1."
37 | # assign one of the special tokens to also be the pad token
38 | tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]})
39 |
40 | if add_start_token:
41 | # leave room for token to be added:
42 | assert (
43 | tokenizer.bos_token is not None
44 | ), "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False"
45 | max_tokenized_len = model.config.max_length - 1
46 | else:
47 | max_tokenized_len = model.config.max_length
48 |
49 | encodings = tokenizer(
50 | data,
51 | add_special_tokens=False,
52 | padding=True,
53 | truncation=True,
54 | max_length=max_tokenized_len,
55 | return_tensors="pt",
56 | return_attention_mask=True,
57 | ).to(device)
58 |
59 | encoded_texts = encodings["input_ids"]
60 | attn_masks = encodings["attention_mask"]
61 |
62 | # check that each input is long enough:
63 | if add_start_token:
64 | assert torch.all(torch.ge(attn_masks.sum(1), 1)), "Each input text must be at least one token long."
65 | else:
66 | assert torch.all(
67 | torch.ge(attn_masks.sum(1), 2)
68 | ), "When add_start_token=False, each input text must be at least two tokens long. Run with add_start_token=True if inputting strings of only one token, and remove all empty input strings."
69 |
70 | ppls = []
71 | loss_fct = CrossEntropyLoss(reduction="none")
72 |
73 | for start_index in tqdm(range(0, len(encoded_texts), batch_size)):
74 | end_index = min(start_index + batch_size, len(encoded_texts))
75 | encoded_batch = encoded_texts[start_index:end_index]
76 | attn_mask = attn_masks[start_index:end_index]
77 |
78 | if add_start_token:
79 | bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * encoded_batch.size(dim=0)).to(device)
80 | encoded_batch = torch.cat([bos_tokens_tensor, encoded_batch], dim=1)
81 | attn_mask = torch.cat(
82 | [torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(device), attn_mask], dim=1
83 | )
84 |
85 | labels = encoded_batch
86 |
87 | with torch.no_grad():
88 | out_logits = model(encoded_batch, attention_mask=attn_mask).logits
89 |
90 | shift_logits = out_logits[..., :-1, :].contiguous()
91 | shift_labels = labels[..., 1:].contiguous()
92 | shift_attention_mask_batch = attn_mask[..., 1:].contiguous()
93 |
94 | perplexity_batch = torch.exp(
95 | (loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).sum(1)
96 | / shift_attention_mask_batch.sum(1)
97 | )
98 |
99 | ppls += perplexity_batch.tolist()
100 |
101 | return {"perplexities": ppls, "mean_perplexity": np.mean(ppls)}
102 |
103 |
104 | def calculate_perplexity_for_file(path: str):
105 | # read lines
106 | special_tokens = ["[CLS]", "[SEP]", "[PAD]", "[MASK]", "", ""]
107 | with open(path, "r") as f:
108 | lines = f.readlines()
109 |
110 | lines = [remove_all(line.strip(), special_tokens) for line in lines if len(line.strip()) > 0]
111 | try:
112 | num_unique_lines = len(set(lines))
113 | perc_unique_lines = round(num_unique_lines * 100 / len(lines), 2)
114 | all_tokens = list(chain(*[line.split() for line in lines]))
115 | perc_unique_tokens = round(len(set(all_tokens)) * 100 / len(all_tokens), 2)
116 |
117 | return {"data": lines, "ppl": compute_perplexity(lines)['mean_perplexity'], "perc_unique_lines": perc_unique_lines, "perc_unique_tokens": perc_unique_tokens}
118 | except Exception as e:
119 | return {"data": [], "ppl": 1e6, "perc_unique_lines": 0, "perc_unique_tokens": 0}
120 |
121 |
122 | def remove_all(line: str, special_toks: List[str]) -> str:
123 | for tok in special_toks:
124 | line = line.replace(tok, "").strip()
125 | return line
126 |
127 |
128 | if __name__ == '__main__':
129 | import sys
130 | import glob
131 | from tqdm import tqdm
132 | from pprint import pprint
133 | import json
134 | files = glob.glob(sys.argv[1])
135 | res = dict()
136 | for file in tqdm(files):
137 | res[file] = calculate_perplexity_for_file(file)
138 |
139 |
140 | # sort by perplexity
141 | res = {k: v for k, v in sorted(res.items(), key=lambda item: item[1]['ppl'])}
142 |
143 | for file in res:
144 | # show a few lines
145 | print(f"File: {file}")
146 | pprint(res[file]['data'][:5])
147 |
148 | # show the perplexity
149 | print(f"Perplexity: {res[file]['ppl']}")
150 | print(f"Percentage of unique lines: {res[file]['perc_unique_lines']}")
151 | print(f"Percentage of unique tokens: {res[file]['perc_unique_tokens']}")
152 | print("-" * 100)
153 |
154 | # Create a nice MARKDOWN report with: i) sample sentences, ii) perplexity, iii) percentage of unique lines, iv) percentage of unique tokens
155 |
156 | import random
157 | print("| File | Sample Sentences | Perplexity | % Unique Lines | % Unique Tokens |")
158 | for file in res:
159 | sentences = set(res[file]['data'])
160 | # pick 5 random sentences
161 | sentences = random.sample(sentences, 5) if len(sentences) > 5 else sentences
162 |
163 | filename = "#".join(file.split("/")[:-1])
164 | # print row
165 | print('-' * 80)
166 | if res[file]['perc_unique_tokens'] > 0:
167 | print(f"| {filename} | {', '.join(sentences)} | {res[file]['ppl']} | {res[file]['perc_unique_lines']} | {res[file]['perc_unique_tokens']} |")
168 |
169 |
170 | with open("perplexity.json", "w") as f:
171 | json.dump(res, f)
172 |
173 |
--------------------------------------------------------------------------------
/src/utils/fp16_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers to train with 16-bit precision.
3 | """
4 |
5 | import torch.nn as nn
6 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
7 |
8 |
9 | def convert_module_to_f16(l):
10 | """
11 | Convert primitive modules to float16.
12 | """
13 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
14 | l.weight.data = l.weight.data.half()
15 | l.bias.data = l.bias.data.half()
16 |
17 |
18 | def convert_module_to_f32(l):
19 | """
20 | Convert primitive modules to float32, undoing convert_module_to_f16().
21 | """
22 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
23 | l.weight.data = l.weight.data.float()
24 | l.bias.data = l.bias.data.float()
25 |
26 |
27 | def make_master_params(model_params):
28 | """
29 | Copy model parameters into a (differently-shaped) list of full-precision
30 | parameters.
31 | """
32 | master_params = _flatten_dense_tensors(
33 | [param.detach().float() for param in model_params]
34 | )
35 | master_params = nn.Parameter(master_params)
36 | master_params.requires_grad = True
37 | return [master_params]
38 |
39 |
40 | def model_grads_to_master_grads(model_params, master_params):
41 | """
42 | Copy the gradients from the model parameters into the master parameters
43 | from make_master_params().
44 | """
45 | master_params[0].grad = _flatten_dense_tensors(
46 | [param.grad.data.detach().float() for param in model_params]
47 | )
48 |
49 |
50 | def master_params_to_model_params(model_params, master_params):
51 | """
52 | Copy the master parameter data back into the model parameters.
53 | """
54 | # Without copying to a list, if a generator is passed, this will
55 | # silently not copy any parameters.
56 | model_params = list(model_params)
57 |
58 | for param, master_param in zip(
59 | model_params, unflatten_master_params(model_params, master_params)
60 | ):
61 | param.detach().copy_(master_param)
62 |
63 |
64 | def unflatten_master_params(model_params, master_params):
65 | """
66 | Unflatten the master parameters to look like model_params.
67 | """
68 | return _unflatten_dense_tensors(master_params[0].detach(), model_params)
69 |
70 |
71 | def zero_grad(model_params):
72 | for param in model_params:
73 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
74 | if param.grad is not None:
75 | param.grad.detach_()
76 | param.grad.zero_()
77 |
--------------------------------------------------------------------------------
/src/utils/logger.py:
--------------------------------------------------------------------------------
1 | """
2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies:
3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py
4 | """
5 |
6 | import os
7 | import sys
8 | import shutil
9 | import os.path as osp
10 | import json
11 | import time
12 | import datetime
13 | import tempfile
14 | import warnings
15 | from collections import defaultdict
16 | from contextlib import contextmanager
17 | import wandb
18 |
19 | DEBUG = 10
20 | INFO = 20
21 | WARN = 30
22 | ERROR = 40
23 |
24 | DISABLED = 50
25 |
26 |
27 | class KVWriter(object):
28 | def writekvs(self, kvs):
29 | raise NotImplementedError
30 |
31 |
32 | class SeqWriter(object):
33 | def writeseq(self, seq):
34 | raise NotImplementedError
35 |
36 |
37 | class HumanOutputFormat(KVWriter, SeqWriter):
38 | def __init__(self, filename_or_file):
39 | if isinstance(filename_or_file, str):
40 | self.file = open(filename_or_file, "wt")
41 | self.own_file = True
42 | else:
43 | assert hasattr(filename_or_file, "read"), (
44 | "expected file or str, got %s" % filename_or_file
45 | )
46 | self.file = filename_or_file
47 | self.own_file = False
48 |
49 | def writekvs(self, kvs):
50 | # Create strings for printing
51 | key2str = {}
52 | for (key, val) in sorted(kvs.items()):
53 | if hasattr(val, "__float__"):
54 | valstr = "%-8.3g" % val
55 | else:
56 | valstr = str(val)
57 | key2str[self._truncate(key)] = self._truncate(valstr)
58 |
59 | # Find max widths
60 | if len(key2str) == 0:
61 | print("WARNING: tried to write empty key-value dict")
62 | return
63 | else:
64 | keywidth = max(map(len, key2str.keys()))
65 | valwidth = max(map(len, key2str.values()))
66 |
67 | # Write out the data
68 | dashes = "-" * (keywidth + valwidth + 7)
69 | lines = [dashes]
70 | for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()):
71 | lines.append(
72 | "| %s%s | %s%s |"
73 | % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val)))
74 | )
75 | lines.append(dashes)
76 | self.file.write("\n".join(lines) + "\n")
77 |
78 | # Flush the output to the file
79 | self.file.flush()
80 |
81 | def _truncate(self, s):
82 | maxlen = 30
83 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s
84 |
85 | def writeseq(self, seq):
86 | seq = list(seq)
87 | for (i, elem) in enumerate(seq):
88 | self.file.write(elem)
89 | if i < len(seq) - 1: # add space unless this is the last one
90 | self.file.write(" ")
91 | self.file.write("\n")
92 | self.file.flush()
93 |
94 | def close(self):
95 | if self.own_file:
96 | self.file.close()
97 |
98 |
99 | class JSONOutputFormat(KVWriter):
100 | def __init__(self, filename):
101 | self.file = open(filename, "wt")
102 |
103 | def writekvs(self, kvs):
104 | for k, v in sorted(kvs.items()):
105 | if hasattr(v, "dtype"):
106 | kvs[k] = float(v)
107 | self.file.write(json.dumps(kvs) + "\n")
108 | self.file.flush()
109 |
110 | def close(self):
111 | self.file.close()
112 |
113 |
114 | class CSVOutputFormat(KVWriter):
115 | def __init__(self, filename):
116 | self.file = open(filename, "w+t")
117 | self.keys = []
118 | self.sep = ","
119 |
120 | def writekvs(self, kvs):
121 | # Add our current row to the history
122 | extra_keys = list(kvs.keys() - self.keys)
123 | extra_keys.sort()
124 | if extra_keys:
125 | self.keys.extend(extra_keys)
126 | self.file.seek(0)
127 | lines = self.file.readlines()
128 | self.file.seek(0)
129 | for (i, k) in enumerate(self.keys):
130 | if i > 0:
131 | self.file.write(",")
132 | self.file.write(k)
133 | self.file.write("\n")
134 | for line in lines[1:]:
135 | self.file.write(line[:-1])
136 | self.file.write(self.sep * len(extra_keys))
137 | self.file.write("\n")
138 | for (i, k) in enumerate(self.keys):
139 | if i > 0:
140 | self.file.write(",")
141 | v = kvs.get(k)
142 | if v is not None:
143 | self.file.write(str(v))
144 | self.file.write("\n")
145 | self.file.flush()
146 |
147 | def close(self):
148 | self.file.close()
149 |
150 |
151 | class TensorBoardOutputFormat(KVWriter):
152 | """
153 | Dumps key/value pairs into TensorBoard's numeric format.
154 | """
155 |
156 | def __init__(self, dir):
157 | os.makedirs(dir, exist_ok=True)
158 | self.dir = dir
159 | self.step = 1
160 | prefix = "events"
161 | path = osp.join(osp.abspath(dir), prefix)
162 | import tensorflow as tf
163 | from tensorflow.python import pywrap_tensorflow
164 | from tensorflow.core.util import event_pb2
165 | from tensorflow.python.util import compat
166 |
167 | self.tf = tf
168 | self.event_pb2 = event_pb2
169 | self.pywrap_tensorflow = pywrap_tensorflow
170 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
171 |
172 | def writekvs(self, kvs):
173 | def summary_val(k, v):
174 | kwargs = {"tag": k, "simple_value": float(v)}
175 | return self.tf.Summary.Value(**kwargs)
176 |
177 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
178 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
179 | event.step = (
180 | self.step
181 | ) # is there any reason why you'd want to specify the step?
182 | self.writer.WriteEvent(event)
183 | self.writer.Flush()
184 | self.step += 1
185 |
186 | def close(self):
187 | if self.writer:
188 | self.writer.Close()
189 | self.writer = None
190 |
191 |
192 | def make_output_format(format, ev_dir, log_suffix=""):
193 | os.makedirs(ev_dir, exist_ok=True)
194 | if format == "stdout":
195 | return HumanOutputFormat(sys.stdout)
196 | elif format == "log":
197 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix))
198 | elif format == "json":
199 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix))
200 | elif format == "csv":
201 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix))
202 | elif format == "tensorboard":
203 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix))
204 | else:
205 | raise ValueError("Unknown format specified: %s" % (format,))
206 |
207 |
208 | # ================================================================
209 | # API
210 | # ================================================================
211 |
212 |
213 | def logkv(key, val):
214 | """
215 | Log a value of some diagnostic
216 | Call this once for each diagnostic quantity, each iteration
217 | If called many times, last value will be used.
218 | """
219 | get_current().logkv(key, val)
220 |
221 |
222 | def logkv_mean(key, val):
223 | """
224 | The same as logkv(), but if called many times, values averaged.
225 | """
226 | get_current().logkv_mean(key, val)
227 |
228 |
229 | def logkvs(d):
230 | """
231 | Log a dictionary of key-value pairs
232 | """
233 | for (k, v) in d.items():
234 | logkv(k, v)
235 |
236 |
237 | def dumpkvs():
238 | """
239 | Write all of the diagnostics from the current iteration
240 | """
241 | return get_current().dumpkvs()
242 |
243 |
244 | def getkvs():
245 | return get_current().name2val
246 |
247 |
248 | def log(*args, level=INFO):
249 | """
250 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
251 | """
252 | get_current().log(*args, level=level)
253 |
254 |
255 | def debug(*args):
256 | log(*args, level=DEBUG)
257 |
258 |
259 | def info(*args):
260 | log(*args, level=INFO)
261 |
262 |
263 | def warn(*args):
264 | log(*args, level=WARN)
265 |
266 |
267 | def error(*args):
268 | log(*args, level=ERROR)
269 |
270 |
271 | def set_level(level):
272 | """
273 | Set logging threshold on current logger.
274 | """
275 | get_current().set_level(level)
276 |
277 |
278 | def set_comm(comm):
279 | get_current().set_comm(comm)
280 |
281 |
282 | def get_dir():
283 | """
284 | Get directory that log files are being written to.
285 | will be None if there is no output directory (i.e., if you didn't call start)
286 | """
287 | return get_current().get_dir()
288 |
289 |
290 | record_tabular = logkv
291 | dump_tabular = dumpkvs
292 |
293 |
294 | @contextmanager
295 | def profile_kv(scopename):
296 | logkey = "wait_" + scopename
297 | tstart = time.time()
298 | try:
299 | yield
300 | finally:
301 | get_current().name2val[logkey] += time.time() - tstart
302 |
303 |
304 | def profile(n):
305 | """
306 | Usage:
307 | @profile("my_func")
308 | def my_func(): code
309 | """
310 |
311 | def decorator_with_name(func):
312 | def func_wrapper(*args, **kwargs):
313 | with profile_kv(n):
314 | return func(*args, **kwargs)
315 |
316 | return func_wrapper
317 |
318 | return decorator_with_name
319 |
320 |
321 | # ================================================================
322 | # Backend
323 | # ================================================================
324 |
325 |
326 | def get_current():
327 | if Logger.CURRENT is None:
328 | _configure_default_logger()
329 |
330 | return Logger.CURRENT
331 |
332 |
333 | class Logger(object):
334 | DEFAULT = None # A logger with no output files. (See right below class definition)
335 | # So that you can still log to the terminal without setting up any output files
336 | CURRENT = None # Current logger being used by the free functions above
337 |
338 | def __init__(self, dir, output_formats, comm=None):
339 | self.name2val = defaultdict(float) # values this iteration
340 | self.name2cnt = defaultdict(int)
341 | self.level = INFO
342 | self.dir = dir
343 | self.output_formats = output_formats
344 | self.comm = comm
345 |
346 | # Logging API, forwarded
347 | # ----------------------------------------
348 | def logkv(self, key, val):
349 | self.name2val[key] = val
350 |
351 | def logkv_mean(self, key, val):
352 | oldval, cnt = self.name2val[key], self.name2cnt[key]
353 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
354 | self.name2cnt[key] = cnt + 1
355 |
356 | def dumpkvs(self, prefix=None):
357 | if self.comm is None:
358 | d = self.name2val
359 | else:
360 | d = mpi_weighted_mean(
361 | self.comm,
362 | {
363 | name: (val, self.name2cnt.get(name, 1))
364 | for (name, val) in self.name2val.items()
365 | },
366 | )
367 | if self.comm.rank != 0:
368 | d["dummy"] = 1 # so we don't get a warning about empty dict
369 | # LISA
370 | wandb.log({**d})
371 | out = d.copy() # Return the dict for unit testing purposes
372 | for fmt in self.output_formats:
373 | if isinstance(fmt, KVWriter):
374 | fmt.writekvs(d)
375 | self.name2val.clear()
376 | self.name2cnt.clear()
377 | return out
378 |
379 | def log(self, *args, level=INFO):
380 | if self.level <= level:
381 | self._do_log(args)
382 |
383 | # Configuration
384 | # ----------------------------------------
385 | def set_level(self, level):
386 | self.level = level
387 |
388 | def set_comm(self, comm):
389 | self.comm = comm
390 |
391 | def get_dir(self):
392 | return self.dir
393 |
394 | def close(self):
395 | for fmt in self.output_formats:
396 | fmt.close()
397 |
398 | # Misc
399 | # ----------------------------------------
400 | def _do_log(self, args):
401 | for fmt in self.output_formats:
402 | if isinstance(fmt, SeqWriter):
403 | fmt.writeseq(map(str, args))
404 |
405 |
406 | def get_rank_without_mpi_import():
407 | # check environment variables here instead of importing mpi4py
408 | # to avoid calling MPI_Init() when this module is imported
409 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]:
410 | if varname in os.environ:
411 | return int(os.environ[varname])
412 | return 0
413 |
414 |
415 | def mpi_weighted_mean(comm, local_name2valcount):
416 | """
417 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110
418 | Perform a weighted average over dicts that are each on a different node
419 | Input: local_name2valcount: dict mapping key -> (value, count)
420 | Returns: key -> mean
421 | """
422 | all_name2valcount = comm.gather(local_name2valcount)
423 | if comm.rank == 0:
424 | name2sum = defaultdict(float)
425 | name2count = defaultdict(float)
426 | for n2vc in all_name2valcount:
427 | for (name, (val, count)) in n2vc.items():
428 | try:
429 | val = float(val)
430 | except ValueError:
431 | if comm.rank == 0:
432 | warnings.warn(
433 | "WARNING: tried to compute mean on non-float {}={}".format(
434 | name, val
435 | )
436 | )
437 | else:
438 | name2sum[name] += val * count
439 | name2count[name] += count
440 | return {name: name2sum[name] / name2count[name] for name in name2sum}
441 | else:
442 | return {}
443 |
444 |
445 | def configure(dir=None, format_strs=None, comm=None, log_suffix=""):
446 | """
447 | If comm is provided, average all numerical stats across that comm
448 | """
449 | if dir is None:
450 | dir = os.getenv("OPENAI_LOGDIR")
451 | if dir is None:
452 | dir = osp.join(
453 | tempfile.gettempdir(),
454 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"),
455 | )
456 | assert isinstance(dir, str)
457 | dir = os.path.expanduser(dir)
458 | os.makedirs(os.path.expanduser(dir), exist_ok=True)
459 |
460 | rank = get_rank_without_mpi_import()
461 | if rank > 0:
462 | log_suffix = log_suffix + "-rank%03i" % rank
463 |
464 | if format_strs is None:
465 | if rank == 0:
466 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",")
467 | else:
468 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",")
469 | format_strs = filter(None, format_strs)
470 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
471 |
472 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
473 | if output_formats:
474 | log("Logging to %s" % dir)
475 |
476 |
477 | def _configure_default_logger():
478 | configure()
479 | Logger.DEFAULT = Logger.CURRENT
480 |
481 |
482 | def reset():
483 | if Logger.CURRENT is not Logger.DEFAULT:
484 | Logger.CURRENT.close()
485 | Logger.CURRENT = Logger.DEFAULT
486 | log("Reset logger")
487 |
488 |
489 | @contextmanager
490 | def scoped_configure(dir=None, format_strs=None, comm=None):
491 | prevlogger = Logger.CURRENT
492 | configure(dir=dir, format_strs=format_strs, comm=comm)
493 | try:
494 | yield
495 | finally:
496 | Logger.CURRENT.close()
497 | Logger.CURRENT = prevlogger
498 |
499 |
--------------------------------------------------------------------------------
/src/utils/show_sampling_progress.py:
--------------------------------------------------------------------------------
1 |
2 | from typing import List
3 | list_of_colors_from_red_to_blue = [f"\033[38;2;{r};0;{b}m" for r, b in zip(range(255, 0, -10), range(0, 255, 10))]
4 |
5 | def pprint_sentences(sentences: List[str], banner: str = "", sep: str = ""):
6 | """
7 | Given a list of sentences, prints them with a gradient of colors from red to blue
8 | """
9 | print()
10 | print(f"\033[1m{'=' * 20} {banner} {'=' * 20}\033[0m")
11 | for i, sentence in enumerate(sentences):
12 | sentence_color = list_of_colors_from_red_to_blue[i]
13 | if i == len(sentences) - 1:
14 | print(f"\033[38;5;{sentence_color}{sentence}\033[0m")
15 | else:
16 | print(f"\033[38;5;{sentence_color}{sentence}\033[0m", end=sep)
17 | print()
18 |
19 |
20 | if __name__ == '__main__':
21 | sentences = [
22 | "This is a sentence",
23 | "This is another sentence",
24 | "This is a third sentence",
25 | "This is a fourth sentence",
26 | "This is a fifth sentence",
27 | "This is a sixth sentence",
28 | "This is a seventh sentence",
29 | "This is an eighth sentence",
30 | "This is a ninth sentence",
31 | "This is a tenth sentence",
32 | "This is an eleventh sentence",
33 | "This is a twelfth sentence",
34 | "This is a thirteenth sentence",
35 | "This is a fourteenth sentence",
36 | "This is a fifteenth sentence",
37 | "This is a sixteenth sentence",
38 | "This is a seventeenth sentence",
39 | "This is an eighteenth sentence",
40 | "This is a nineteenth sentence",
41 | "This is a twentieth sentence",
42 | ]
43 | for i in range(1, len(sentences) + 1):
44 | pprint_sentences(sentences[:i], sep= " -> ")
45 | print("---")
46 |
47 |
--------------------------------------------------------------------------------
/src/utils/test_util.py:
--------------------------------------------------------------------------------
1 | import torch as th
2 | import numpy as np
3 |
4 | def compute_logp(args, model, x, input_ids):
5 | word_emb = model.weight
6 | sigma = 0.1
7 | if args.model_arch == '1d-unet':
8 | x = x.permute(0, 2, 1)
9 |
10 | bsz, seqlen, dim = x.shape
11 |
12 | x_flat = x.reshape(-1, x.size(-1)).unsqueeze(0) # 1, bsz*sample*seqlen, dim
13 | word_emb_flat = word_emb.unsqueeze(1) # vocab, 1, dim
14 | diff = (x_flat - word_emb_flat) ** 2 # vocab, seqlen, dim
15 |
16 | logp_expanded = -diff.sum(dim=-1) / (2 * sigma ** 2) # vocab, seqlen
17 | logp_expanded = logp_expanded.permute((1, 0))
18 | # print(th.topk(logp_expanded.view(bsz, seqlen, -1), k=5, dim=-1)[0])
19 | # print(input_ids[0])
20 | ce = th.nn.CrossEntropyLoss(reduction='none')
21 | loss = ce(logp_expanded, input_ids.view(-1)).view(bsz, seqlen)
22 | # print(loss[0])
23 |
24 | # print(loss.shape)
25 | return loss
26 |
27 | def get_weights(model, args):
28 | if hasattr(model, 'transformer'):
29 | input_embs = model.transformer.wte # input_embs
30 | down_proj = model.down_proj
31 | down_proj_emb = down_proj(input_embs.weight)
32 | print(down_proj_emb.shape)
33 | # model = th.nn.Embedding(down_proj_emb.shape[1], down_proj_emb.shape[0])
34 | model = th.nn.Embedding(down_proj_emb.size(0), down_proj_emb.size(1))
35 | print(args.emb_scale_factor)
36 | model.weight.data = down_proj_emb * args.emb_scale_factor
37 |
38 | elif hasattr(model, 'weight'):
39 | pass
40 | else:
41 | assert NotImplementedError
42 |
43 | model.weight.requires_grad = False
44 | return model
45 |
46 | def denoised_fn_round(args, model, text_emb, t):
47 |
48 | down_proj_emb = model.weight # input_embs
49 | # print(t)
50 | old_shape = text_emb.shape
51 | old_device = text_emb.device
52 |
53 | def get_efficient_knn(down_proj_emb, text_emb, dist='l2'):
54 | if dist == 'l2':
55 | emb_norm = (down_proj_emb**2).sum(-1).view(-1, 1) #vocab
56 | text_emb_t = th.transpose(text_emb.view(-1, text_emb.size(-1)), 0, 1) #d, bsz*seqlen
57 | arr_norm = (text_emb ** 2).sum(-1).view(-1, 1) #bsz*seqlen, 1
58 | # print(emb_norm.shape, arr_norm.shape)
59 | dist = emb_norm + arr_norm.transpose(0, 1) - 2.0 * th.mm(down_proj_emb, text_emb_t) #(vocab, d) x (d, bsz*seqlen)
60 | dist = th.clamp(dist, 0.0, np.inf)
61 | # print(dist.shape)
62 | topk_out = th.topk(-dist, k=1, dim=0)
63 | # adjacency = down_proj_emb.unsqueeze(1).expand(-1, text_emb.size(0), -1) - text_emb.unsqueeze(0).expand(
64 | # down_proj_emb.size(0), -1, -1)
65 | # adjacency = -th.norm(adjacency, dim=-1)
66 | # topk_out = th.topk(adjacency, k=1, dim=0)
67 | # print(topk_out1.indices == topk_out.indices)
68 | # assert th.all(topk_out1.indices == topk_out.indices)
69 | return topk_out.values, topk_out.indices
70 |
71 | def get_knn(down_proj_emb, text_emb, dist='l2'):
72 | if dist == 'l2':
73 | adjacency = down_proj_emb.unsqueeze(1).expand(-1, text_emb.size(0), -1) - text_emb.unsqueeze(0).expand(
74 | down_proj_emb.size(0), -1, -1)
75 | adjacency = -th.norm(adjacency, dim=-1)
76 | topk_out = th.topk(adjacency, k=1, dim=0)
77 | return topk_out.values, topk_out.indices
78 |
79 | dist = 'l2'
80 | if len(text_emb.shape) > 2:
81 | text_emb = text_emb.reshape(-1, text_emb.size(-1))
82 | else:
83 | text_emb = text_emb
84 | # val, indices = get_knn(down_proj_emb,
85 | # text_emb.to(down_proj_emb.device), dist=dist)
86 | val, indices = get_efficient_knn(down_proj_emb,
87 | text_emb.to(down_proj_emb.device), dist=dist)
88 | rounded_tokens = indices[0]
89 | # print(rounded_tokens.shape)
90 | new_embeds = model(rounded_tokens).view(old_shape).to(old_device)
91 | if args.model_arch == '1d-unet':
92 | new_embeds = new_embeds.permute(0, 2, 1)
93 | return new_embeds
94 |
95 | def load_results(json_path, load_dict):
96 | import json
97 | with open(json_path, 'w') as f:
98 | json.dump(load_dict, f, indent=2)
99 |
--------------------------------------------------------------------------------