├── .gitattributes ├── .gitignore ├── Int2Int.pdf ├── LICENSE ├── README.md ├── data ├── elliptic_rank.test └── elliptic_rank.train ├── src ├── __init__.py ├── dataset.py ├── envs │ ├── __init__.py │ ├── arithmetic.py │ ├── encoders.py │ └── generators.py ├── evaluator.py ├── logger.py ├── model │ ├── __init__.py │ ├── lstm.py │ └── transformer.py ├── optim.py ├── slurm.py ├── trainer.py └── utils.py ├── tools └── ReadXP.ipynb └── train.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.train filter=lfs diff=lfs merge=lfs -text 2 | *.test filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /Int2Int.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-charton/Int2Int/7a62494006c1c2f98b14818dc050a1e53854d2c9/Int2Int.pdf -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 François Charton 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Int2Int 2 | 3 | ## Integer sequence to integer sequence translator 4 | 5 | This is complete code for a sequence to sequence transformer, the goal is to train it (supervisedly) to translate short sequences of integers into short sequences of integers. The manual can be found in Int2Int.pdf, and at https://arxiv.org/abs/2502.17513. 6 | Please cite the Arxiv paper (instead of this repository). 7 | 8 | ### My first experiment 9 | 10 | One way to run it out of the box is to run (on an environment where you have pytorch and numpy installed): 11 | 12 | ``` 13 | python train.py --dump_path /some_path_on_your_computer/ --exp_name my_first_experiment --exp_id 1 --operation "gcd" 14 | ``` 15 | 16 | This should train a transformer to compute the GCD of integers, encoded in base 1000, generated on the fly (between 1 and a million). 17 | 18 | You can try other operations with the parameter --operation : modular_add, modular_mul, fraction_compare, fraction_add, fraction_simplify, matrix_rank... 19 | 20 | Modular operations use modulo 67 by default, you can change it with --modulo 21 | 22 | Integers are encoded in base 1000 by default, you can change it with --base 23 | 24 | These problem specific parameters are to be found in src/envs/arithmetic.py 25 | 26 | If you don't have a NVIDIA GPU, you can train on CPU by stating --cpu true (it will be slow, you may want to set --epoch_size to something small, like 10000...) 27 | 28 | 29 | ### Running my first experiment 30 | 31 | If everything goes fine, the log should appear on the screen. It will also be saved in 32 | 33 | > /some_path_on_your_computer/my_first_experiment/1/train.log 34 | 35 | (if you don't provide an exp_id, the program will create one for you) 36 | 37 | After indicating the hyperparameters you use, and the model you train (and its number of parameters), the logs will describe the training, which is a series of **training epochs** (300,000 examples by default, can be changed via --epoch_size), separated by evaluations over a test set (generated on the fly in the case of the GCD). During epochs, the model will report the training loss (this should go down, but it can be bumpy, if it bumps too much your learning rate is probably too large). 38 | 39 | ``` 40 | INFO - 09/18/24 20:41:05 - 0:00:16 - 400 - 977.02 equations/s - 9768.07 words/s - ARITHMETIC: 0.5479 - LR: 1.0000e-04 41 | INFO - 09/18/24 20:41:11 - 0:00:23 - 600 - 940.88 equations/s - 9408.46 words/s - ARITHMETIC: 0.4181 - LR: 1.0000e-04 42 | INFO - 09/18/24 20:41:18 - 0:00:30 - 800 - 955.94 equations/s - 9558.79 words/s - ARITHMETIC: 0.3715 - LR: 1.0000e-04 43 | ``` 44 | 45 | Training losses are logged every 200 optimization steps (this is configurable, in trainer.py), here, you see it going down 0.55 to 0.42 to 0.37. Life is good! 46 | The eq/s and words/s give an idea of the learning speed, eqs are examples: here with 950 examples / s you expect to complete a 300k example epoch in a little more than 6 minutes (we are on a GPU). 47 | 48 | At the end of each epoch, the model runs test on a sample of size --eval_size (10k by default, you can make this smaller), examples are evaluated in batches of --batch_size_eval. During evaluation, the lines 49 | 50 | ``` 51 | INFO - 09/18/24 20:57:28 - 0:16:39 - (7168/10000) Found 102/128 valid top-1 predictions. Generating solutions ... 52 | INFO - 09/18/24 20:57:28 - 0:16:39 - Found 102/128 solutions in beam hypotheses. 53 | ``` 54 | 55 | indicate how many solutions were correct in each eval batch (here 102 our of 128). At the end, you should have a small report saying : 56 | 57 | ``` 58 | INFO - 09/18/24 20:57:29 - 0:16:40 - 8459/10000 (84.59%) equations were evaluated correctly. 59 | INFO - 09/18/24 20:57:29 - 0:16:40 - 1: 6104 / 6107 (99.95%) 60 | INFO - 09/18/24 20:57:29 - 0:16:40 - 2: 1581 / 1581 (100.00%) 61 | INFO - 09/18/24 20:57:29 - 0:16:40 - 4: 356 / 356 (100.00%) 62 | INFO - 09/18/24 20:57:29 - 0:16:40 - 5: 241 / 245 (98.37%) 63 | INFO - 09/18/24 20:57:29 - 0:16:40 - 8: 86 / 86 (100.00%) 64 | INFO - 09/18/24 20:57:29 - 0:16:40 - 10: 56 / 57 (98.25%) 65 | INFO - 09/18/24 20:57:29 - 0:16:40 - 20: 22 / 22 (100.00%) 66 | INFO - 09/18/24 20:57:29 - 0:16:40 - 25: 8 / 8 (100.00%) 67 | INFO - 09/18/24 20:57:29 - 0:16:40 - 40: 2 / 2 (100.00%) 68 | INFO - 09/18/24 20:57:29 - 0:16:40 - 50: 3 / 3 (100.00%) 69 | ``` 70 | 71 | 84.6% of the test GCD were correctly calculated. Correct model predictions were GCD 1, 2, 4, 5, 8 .... products of powers of divisors of the base... 72 | 73 | At the end of the peoch, the model exports a python dictionary containing detailed results. This is what you want to load (in a notebook) to draw learning curves, etc. 74 | 75 | ### Training from a data file 76 | 77 | Training and test files can be provided with the parameters: `--train_data` and `--eval_data` (setting `--eval_size` to `-1` will cause the model to evaluate on all the eval data). 78 | 79 | Training and test examples are written, one per line, as sequence of tokens, separated by whitespaces, the input and output being separated by a tab. 80 | 81 | 82 | One specify the data type of the input and output, e.g.: `--operation "data" --data_types '"int[5]:int"` 83 | 84 | The supported data types at the moment are: 85 | - `int` -- an integer 86 | 87 | encoded as `p ad ... a0` where `p` in `{+, -}` and `ai` are the digits of `a` in base `1000` (by default), e.g., `-3500` is represented as `- 3 500` 88 | 89 | 90 | - `int[n]` -- an integer array of length 91 | 92 | represented as `Vn z1 ... zn` where `zi` are encoded as above 93 | 94 | - `range(a, b)` -- an integer in the range `{a,...,b-1}` 95 | 96 | encoded as a string in base 10, e.g., via `1101`. 97 | 98 | For example, here are some python functions to encode the above data types, respectively: 99 | 100 | ```python3 101 | def encode_integer(val, base=1000, digit_sep=" "): 102 | if val == 0: 103 | return '+ 0' 104 | sgn = '+' if val >= 0 else '-' 105 | val = abs(val) 106 | r = [] 107 | while val > 0: 108 | r.append(str(val % base)) 109 | val = val//base 110 | r.append(sgn) 111 | r.reverse() 112 | return digit_sep.join(r) 113 | 114 | def encode_integer_array(x, base=1000): 115 | return f'V{len(x)} ' + " ".join(encode_integer(int(z), base) for z in x) 116 | 117 | def encode_range(x): 118 | return str(int(x)) 119 | ``` 120 | 121 | For example, for `GCD` we would use `int[2]:int` where 122 | 123 | ``` 124 | V2 + 1 24 + 16\t+ 16\n 125 | ``` 126 | 127 | represents `GCD (1024, 16) = 16`, in base `1000`. Note that here `V2`, `1`, `24`, `16` are words/tokens. 128 | 129 | For an elliptic curve and if it has nontrivial rank, I would have something like `int[5]:range(2)` 130 | 131 | ``` 132 | V5 + 0 - 1 + 0 - 84 375 258 - 298 283 918 238\t1 133 | ``` 134 | 135 | The code is organised as follows: 136 | 137 | train.py is the main file, you run python train.py with some parameters, you can train from generated data (using envs/generators), generate and export data (setting --export_data to true), or train and test from external data (using train_data and test_data) 138 | 139 | src/envs contain the math-specific code 140 | 141 | -------------------------------------------------------------------------------- /data/elliptic_rank.test: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:f47efea8bbabfff142064ee2b5c14f7c0781b3a42bfe76a697e0f1f17bba64ae 3 | size 329086 4 | -------------------------------------------------------------------------------- /data/elliptic_rank.train: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:07e7a7dd1b0708a20a635242448ed9ba52c4930624479ee0c55925305dac19eb 3 | size 32959926 4 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-charton/Int2Int/7a62494006c1c2f98b14818dc050a1e53854d2c9/src/__init__.py -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from logging import getLogger 9 | import os 10 | import io 11 | import sys 12 | import numpy as np 13 | 14 | 15 | import torch 16 | from torch.utils.data.dataset import Dataset 17 | 18 | logger = getLogger() 19 | 20 | 21 | class EnvDataset(Dataset): 22 | def __init__(self, env, task, train, params, path, size=None, type=None): 23 | super(EnvDataset).__init__() 24 | self.env = env 25 | self.train = train 26 | self.batch_size = params.batch_size 27 | self.env_base_seed = params.env_base_seed 28 | self.path = path 29 | self.global_rank = params.global_rank 30 | self.count = 0 31 | self.type = type 32 | self.two_classes = params.two_classes 33 | self.first_class_prob = params.first_class_prob 34 | self.first_class_size = params.first_class_size 35 | self.batch_type = 0 36 | 37 | self.decoder_only = (params.architecture == 'decoder_only') 38 | self.encoder_only = (params.architecture == 'encoder_only') 39 | assert size is None or not self.train 40 | assert not params.batch_load or params.reload_size > 0 41 | 42 | # batching 43 | self.num_workers = params.num_workers 44 | self.batch_size = params.batch_size 45 | 46 | self.batch_load = params.batch_load 47 | self.reload_size = params.reload_size 48 | self.local_rank = params.local_rank 49 | self.n_gpu_per_node = params.n_gpu_per_node 50 | 51 | self.basepos = 0 52 | self.nextpos = 0 53 | self.seekpos = 0 54 | #self.once=True 55 | 56 | # generation, or reloading from file 57 | if path is not None: 58 | assert os.path.isfile(path) 59 | if params.batch_load and self.train: 60 | self.load_chunk() 61 | else: 62 | logger.info(f"Loading data from {path} ...") 63 | with io.open(path, mode="r", encoding="utf-8") as f: 64 | # either reload the entire file, or the first N lines 65 | # (for the training set) 66 | if not train: 67 | lines = [line.rstrip() for line in f] 68 | else: 69 | lines = [] 70 | for i, line in enumerate(f): 71 | if i == params.reload_size: 72 | break 73 | if i % params.n_gpu_per_node == params.local_rank: 74 | lines.append(line.rstrip()) 75 | self.data = [xy.split("\t") for xy in lines] 76 | self.data = [xy for xy in self.data if len(xy) == 2] 77 | logger.info(f"Loaded {len(self.data)} equations from the disk.") 78 | #logger.info(f"{self.data[0]}\n{self.data[1]}") 79 | #logger.info(f"{self.data[0][0].split()} {self.data[0][1].split()}" ) 80 | 81 | 82 | # dataset size: infinite iterator for train, finite for valid / test 83 | # (default of 10000 if no file provided) 84 | if self.train: 85 | self.size = 1 << 60 86 | elif size is None: 87 | self.size = 10000 if path is None else len(self.data) 88 | else: 89 | assert size > 0 90 | self.size = size 91 | 92 | def load_chunk(self): 93 | self.basepos = self.nextpos 94 | logger.info( 95 | f"Loading data from {self.path} ... seekpos {self.seekpos}, " 96 | f"basepos {self.basepos}" 97 | ) 98 | endfile = False 99 | with io.open(self.path, mode="r", encoding="utf-8") as f: 100 | f.seek(self.seekpos, 0) 101 | lines = [] 102 | for i in range(self.reload_size): 103 | line = f.readline() 104 | if not line: 105 | endfile = True 106 | break 107 | if i % self.n_gpu_per_node == self.local_rank: 108 | lines.append(line.rstrip()) 109 | self.seekpos = 0 if endfile else f.tell() 110 | 111 | self.data = [xy.split("\t") for xy in lines] 112 | self.data = [xy for xy in self.data if len(xy) == 2] 113 | self.nextpos = self.basepos + len(self.data) 114 | logger.info( 115 | f"Loaded {len(self.data)} equations from the disk. seekpos {self.seekpos}, " 116 | f"nextpos {self.nextpos}" 117 | ) 118 | if len(self.data) == 0: 119 | self.load_chunk() 120 | 121 | def batch_sequences(self, sequences, pad_index, bos_index, eos_index, no_bos = False): 122 | """ 123 | Take as input a list of n sequences (torch.LongTensor vectors) and return 124 | a tensor of size (slen, n) where slen is the length of the longest 125 | sentence, and a vector lengths containing the length of each sentence. 126 | """ 127 | initial_offset = 0 if no_bos else 1 128 | lengths = torch.LongTensor([len(s) + 1 + initial_offset for s in sequences]) 129 | sent = torch.LongTensor(lengths.max().item(), lengths.size(0)).fill_( 130 | pad_index 131 | ) 132 | assert lengths.min().item() > 1 + initial_offset 133 | 134 | if not no_bos: 135 | sent[0] = bos_index 136 | 137 | for i, s in enumerate(sequences): 138 | sent[initial_offset : lengths[i] - 1, i].copy_(s) 139 | sent[lengths[i] - 1, i] = eos_index 140 | 141 | return sent, lengths 142 | 143 | def collate_fn(self, elements): 144 | """ 145 | Collate samples into a batch. 146 | """ 147 | x, y = zip(*elements) 148 | nb_eqs = [self.env.code_class(xi, yi) for xi, yi in zip(x, y)] 149 | #if self.once: 150 | # for xi,yi in zip(x,y): 151 | # logger.info(f"{xi},{yi}") 152 | # self.once=False 153 | 154 | if self.decoder_only: 155 | xy = [ xi + [''] + yi for xi, yi in zip(x,y)] 156 | xy = [torch.LongTensor([self.env.word2id[w] for w in seq]) for seq in xy] 157 | xy, xy_len = self.batch_sequences(xy, self.env.pad_index, self.env.eos_index, self.env.eos_index) 158 | if self.train: 159 | return (xy, xy_len), torch.LongTensor(nb_eqs) 160 | else: 161 | x = [torch.LongTensor([self.env.word2id[w] for w in seq]) for seq in x] 162 | x, x_len = self.batch_sequences(x, self.env.pad_index, self.env.eos_index, self.env.sep_index) 163 | return (x, x_len), (xy, xy_len), torch.LongTensor(nb_eqs) 164 | else: 165 | x = [torch.LongTensor([self.env.word2id[w] for w in seq]) for seq in x] 166 | y = [torch.LongTensor([self.env.word2id[w] for w in seq]) for seq in y] 167 | x, x_len = self.batch_sequences(x, self.env.pad_index, self.env.eos_index, self.env.eos_index, True) 168 | y, y_len = self.batch_sequences(y, self.env.pad_index, self.env.eos_index, self.env.eos_index, True if self.encoder_only else False) 169 | return (x, x_len), (y, y_len), torch.LongTensor(nb_eqs) 170 | 171 | 172 | def init_rng(self): 173 | """ 174 | Initialize random generator for training. 175 | """ 176 | if hasattr(self.env, "rng"): 177 | return 178 | if self.train: 179 | worker_id = self.get_worker_id() 180 | self.env.worker_id = worker_id 181 | self.env.rng = np.random.RandomState( 182 | [worker_id, self.global_rank, self.env_base_seed] 183 | ) 184 | logger.info( 185 | f"Initialized random generator for worker {worker_id}, with seed " 186 | f"{[worker_id, self.global_rank, self.env_base_seed]} " 187 | f"(base seed={self.env_base_seed})." 188 | ) 189 | else: 190 | self.env.rng = np.random.RandomState(None) 191 | 192 | def get_worker_id(self): 193 | """ 194 | Get worker ID. 195 | """ 196 | if not self.train: 197 | return 0 198 | worker_info = torch.utils.data.get_worker_info() 199 | assert (worker_info is None) == (self.num_workers == 0) 200 | return 0 if worker_info is None else worker_info.id 201 | 202 | def __len__(self): 203 | """ 204 | Return dataset size. 205 | """ 206 | return self.size 207 | 208 | def __getitem__(self, index): 209 | """ 210 | Return a training sample. 211 | Either generate it, or read it from file. 212 | """ 213 | self.init_rng() 214 | if self.path is None: 215 | return self.generate_sample() 216 | else: 217 | return self.read_sample(index) 218 | 219 | def read_sample(self, index): 220 | """ 221 | Read a sample. 222 | """ 223 | idx = index 224 | if self.train: 225 | if self.batch_load: 226 | if index >= self.nextpos: 227 | self.load_chunk() 228 | idx = index - self.basepos 229 | else: 230 | if self.two_classes: 231 | if self.env.rng.rand()< self.first_class_prob: 232 | idx = (self.env.rng.randint(self.first_class_size)) % len(self.data) 233 | else: 234 | idx = (self.first_class_size + self.env.rng.randint(len(self.data) - self.first_class_size)) % len(self.data) 235 | else: 236 | index = self.env.rng.randint(len(self.data)) 237 | idx = index 238 | 239 | 240 | x, y = self.data[idx] 241 | x = x.split() 242 | y = y.split() 243 | assert len(x) >= 1 and len(y) >= 1 244 | return x, y 245 | 246 | def generate_sample(self): 247 | """ 248 | Generate a sample. 249 | """ 250 | while True: 251 | try: 252 | xy = self.env.gen_expr(self.type) 253 | if xy is None: 254 | continue 255 | x, y = xy 256 | break 257 | except Exception as e: 258 | logger.error( 259 | 'An unknown exception of type {0} occurred for worker {4} in line {1} for expression "{2}". Arguments:{3!r}.'.format( 260 | type(e).__name__, 261 | sys.exc_info()[-1].tb_lineno, 262 | "F", 263 | e.args, 264 | self.get_worker_id(), 265 | ) 266 | ) 267 | continue 268 | self.count += 1 269 | 270 | return x, y 271 | -------------------------------------------------------------------------------- /src/envs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from logging import getLogger 9 | 10 | from .arithmetic import ArithmeticEnvironment 11 | 12 | ENVS = { 13 | 'arithmetic': ArithmeticEnvironment, 14 | } 15 | 16 | logger = getLogger() 17 | 18 | 19 | def build_env(params): 20 | """ 21 | Build environment. 22 | """ 23 | env = ENVS[params.env_name](params) 24 | 25 | # tasks 26 | tasks = [x for x in params.tasks.split(',') if len(x) > 0] 27 | assert len(tasks) == len(set(tasks)) > 0 28 | assert all(task in env.TRAINING_TASKS for task in tasks) 29 | params.tasks = tasks 30 | logger.info(f'Training tasks: {", ".join(tasks)}') 31 | 32 | return env 33 | -------------------------------------------------------------------------------- /src/envs/arithmetic.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from logging import getLogger 9 | 10 | import numpy as np 11 | import src.envs.encoders as encoders 12 | import src.envs.generators as generators 13 | 14 | 15 | from torch.utils.data import DataLoader 16 | from src.dataset import EnvDataset 17 | 18 | from ..utils import bool_flag 19 | 20 | 21 | SPECIAL_WORDS = ["", "", "", "(", ")"] 22 | SPECIAL_WORDS = SPECIAL_WORDS + [f"" for i in range(10)] 23 | 24 | logger = getLogger() 25 | 26 | 27 | class InvalidPrefixExpression(Exception): 28 | def __init__(self, data): 29 | self.data = data 30 | 31 | def __str__(self): 32 | return repr(self.data) 33 | 34 | def data_type_to_encoder(params, typ): 35 | tensor_dim = typ.count('[') 36 | assert tensor_dim == typ.count(']') 37 | if typ.startswith('int'): 38 | ext = typ[3:] 39 | if tensor_dim > 0: 40 | dims = [int(elt.strip('[')) for elt in ext.strip(']').split(']')] 41 | max_dim = max(dims) 42 | assert ext == ''.join(f'[{d}]' for d in dims) 43 | return encoders.NumberArray(params, max_dim, 'V', tensor_dim) 44 | else: 45 | assert typ == 'int' 46 | return encoders.PositionalInts(params.base) 47 | elif typ.startswith('range'): 48 | assert tensor_dim == 0, "at the moment we don't support arrays of ranges, use int arrays" 49 | if ',' in typ: 50 | low, high = map(int, typ[6:-1].split(',')) 51 | assert typ == f'range({low},{high})' 52 | else: 53 | low = 0 54 | high = int(typ[6:-1]) 55 | assert typ == f'range({high})' 56 | return encoders.SymbolicInts(low, high - 1) # SymbolicInts is inclusive on the high end 57 | else: 58 | assert False, "type not supported" 59 | 60 | 61 | 62 | 63 | 64 | 65 | class ArithmeticEnvironment(object): 66 | 67 | TRAINING_TASKS = {"arithmetic"} 68 | 69 | def __init__(self, params): 70 | self.max_len = params.max_len 71 | self.operation = params.operation 72 | 73 | self.base = params.base 74 | self.max_class = params.max_class 75 | 76 | #if self.operation == 'data': 77 | #assert params.data_types, "argument --data_types is required" 78 | #i, o = params.data_types.split(':') 79 | #self.input_encoder = data_type_to_encoder(params, i) 80 | #self.output_encoder = data_type_to_encoder(params, o) 81 | # self.input_encoder = encoders.NumberArray(params, 5, 'V', 1) 82 | # self.output_encoder = encoders.SymbolicInts(0, 10) 83 | # dims=[] 84 | # self.generator = generators.Sequence(params, dims) 85 | self.export_pred = params.export_pred 86 | self.n_eval_metrics = params.n_eval_metrics 87 | self.n_error_metrics = params.n_error_metrics 88 | 89 | 90 | if self.operation == 'matrix_rank': 91 | dims = [params.dim1, params.dim2] 92 | max_dim = 100 93 | tensor_dim = 2 94 | self.output_encoder = encoders.SymbolicInts(1, max_dim) 95 | else: 96 | dims = [] 97 | max_dim = 4 if self.operation in ["fraction_compare", "fraction_determinant", "fraction_add", "fraction_product"] else 2 98 | tensor_dim = 1 99 | if self.operation in ["fraction_add", "fraction_product", "fraction_simplify"]: 100 | self.output_encoder = encoders.NumberArray(params, 2, 'V', tensor_dim ) 101 | elif self.operation in ["fraction_round", "gcd", "fraction_determinant","modular_add","modular_mul","elliptic"]: 102 | self.output_encoder = encoders.PositionalInts(params.base) 103 | else: 104 | self.output_encoder = encoders.SymbolicInts(0, 1) 105 | self.input_encoder = encoders.NumberArray(params, max_dim, 'V', tensor_dim) 106 | assert not self.export_pred or isinstance(self.output_encoder, (encoders.SymbolicInts, encoders.PositionalInts)) 107 | 108 | self.generator = generators.Sequence(params, dims) 109 | 110 | # vocabulary 111 | self.words = SPECIAL_WORDS + sorted(list( 112 | set(self.input_encoder.symbols+self.output_encoder.symbols) 113 | )) 114 | self.id2word = {i: s for i, s in enumerate(self.words)} 115 | self.word2id = {s: i for i, s in self.id2word.items()} 116 | assert len(self.words) == len(set(self.words)) 117 | 118 | # number of words / indices 119 | self.n_words = params.n_words = len(self.words) 120 | self.eos_index = params.eos_index = 0 121 | self.pad_index = params.pad_index = 1 122 | self.sep_index = params.sep_index = 2 123 | 124 | logger.info(f"words: {self.word2id}") 125 | 126 | def input_to_infix(self, lst): 127 | return ' '.join(lst) 128 | 129 | def output_to_infix(self, lst): 130 | return ' '.join(lst) 131 | 132 | def gen_expr(self, data_type=None): 133 | """ 134 | Generate pairs of problems and solutions. 135 | Encode this as a prefix sentence 136 | """ 137 | gen = self.generator.generate(self.rng, data_type) 138 | if gen is None: 139 | return None 140 | x_data, y_data = gen 141 | # encode input 142 | x = self.input_encoder.encode(x_data) 143 | # encode output 144 | y = self.output_encoder.encode(y_data) 145 | if self.max_len > 0 and (len(x) >= self.max_len or len(y) >= self.max_len): 146 | return None 147 | return x, y 148 | 149 | def decode_class(self, i): 150 | """ 151 | The code class splits the test data in to subgroups by code_class 152 | """ 153 | if i>=1000: 154 | return str(i//1000)+"-"+str(i%1000) 155 | return str(i) 156 | 157 | def code_class(self, xi, yi): 158 | """ 159 | The code class splits the test data in to subgroups by code_class 160 | This is passed to the evaluator, so it needs to be an integer 161 | """ 162 | if self.export_pred: 163 | v = self.output_encoder.decode(yi) 164 | assert v is not None 165 | if v >= self.max_class: 166 | v = self.max_class 167 | return v 168 | 169 | if self.operation in ["fraction_add", "fraction_product", "fraction_simplify", "fraction_round", "fraction_determinant"]: 170 | return 0 171 | elif self.operation in ["gcd", "modular_add", "modular_mul"]: 172 | v = self.output_encoder.decode(yi) 173 | assert v is not None 174 | if v >= self.max_class: 175 | v = self.max_class 176 | return v 177 | else: 178 | v = self.output_encoder.decode(yi) 179 | assert v is not None 180 | if isinstance(self.output_encoder, encoders.NumberArray): 181 | v = v[0] 182 | return v % self.max_class 183 | 184 | def check_prediction(self, src, tgt, hyp): 185 | w = self.output_encoder.decode(hyp) 186 | if w is None: 187 | return -1,[],[], None if self.export_pred else -1,[],[] 188 | if len(hyp) == 0 or len(tgt) == 0: 189 | return -1,[],[], None if self.export_pred else -1,[],[] 190 | if hyp == tgt: 191 | return 2,[],[], w if self.export_pred else 2,[],[] 192 | 193 | a, b, c = self.generator.evaluate(self.input_encoder.decode(src), self.input_encoder.decode(tgt), w) 194 | return a, b, c, w if self.export_pred else a, b, c 195 | 196 | 197 | def create_train_iterator(self, task, data_path, params): 198 | """ 199 | Create a dataset for this environment. 200 | """ 201 | logger.info(f"Creating train iterator for {task} ...") 202 | 203 | dataset = EnvDataset( 204 | self, 205 | task, 206 | train=True, 207 | params=params, 208 | path=data_path, 209 | type = "train", 210 | ) 211 | return DataLoader( 212 | dataset, 213 | timeout=(0 if params.num_workers == 0 else 1800), 214 | batch_size=params.batch_size, 215 | num_workers=( 216 | params.num_workers 217 | if data_path is None or params.num_workers == 0 218 | else 1 219 | ), 220 | shuffle=False, 221 | collate_fn=dataset.collate_fn, 222 | ) 223 | 224 | def create_test_iterator( 225 | self, data_type, task, data_path, batch_size, params, size 226 | ): 227 | """ 228 | Create a dataset for this environment. 229 | """ 230 | #assert data_type in ["valid", "test"] or data_type[:4] == "test" 231 | logger.info(f"Creating {data_type} iterator for {task} ...") 232 | if data_path is None: 233 | path_iter = None 234 | elif data_type == "valid": 235 | path_iter = data_path[0] 236 | elif data_type == "test": 237 | path_iter = data_path[1] 238 | else: 239 | path_iter = data_path[int(data_type[4:])] 240 | dataset = EnvDataset( 241 | self, 242 | task, 243 | train=False, 244 | params=params, 245 | path=path_iter, 246 | size=size, 247 | type=data_type, 248 | ) 249 | return DataLoader( 250 | dataset, 251 | timeout=0, 252 | batch_size=batch_size, 253 | num_workers=1, 254 | shuffle=False, 255 | collate_fn=dataset.collate_fn, 256 | ) 257 | 258 | @staticmethod 259 | def register_args(parser): 260 | """ 261 | Register environment parameters. 262 | """ 263 | parser.add_argument( 264 | "--operation", type=str, default="data", help="Operation to perform" 265 | ) 266 | parser.add_argument( 267 | "--data_types", type=str, default="", help="Data type for input and out output separated by :, e.g. \"int[5]:range(2)\"" 268 | ) 269 | parser.add_argument( 270 | "--dim1", type=int, default=10, help="Lines of matrix" 271 | ) 272 | parser.add_argument( 273 | "--dim2", type=int, default=10, help="Columns of matrix" 274 | ) 275 | 276 | 277 | parser.add_argument( 278 | "--maxint", type=int, default=1000000, help="Maximum value of integers" 279 | ) 280 | parser.add_argument( 281 | "--minint", type=int, default=1, help="Minimum value of integers (uniform generation only)" 282 | ) 283 | 284 | 285 | 286 | parser.add_argument( 287 | "--two_classes", type=bool_flag, default=False, help="Two classes in train set" 288 | ) 289 | parser.add_argument( 290 | "--first_class_size", type=int, default=1000000, help="Standard deviation, in examples" 291 | ) 292 | parser.add_argument( 293 | "--first_class_prob", type=float, default=0.25, help="Proportion of repeated fixed examples in train set" 294 | ) 295 | 296 | parser.add_argument( 297 | "--base", type=int, default=1000, help="Encoding base" 298 | ) 299 | parser.add_argument( 300 | "--modulus", type=int, default=67, help="Modulus for modular operations" 301 | ) 302 | 303 | parser.add_argument( 304 | "--n_eval_metrics", type=int, default=0, help="number of eval metrics, returned by generator.evaluate()") 305 | 306 | parser.add_argument( 307 | "--n_error_metrics", type=int, default=0, help="number of error metrics, returned by generator.evaluate()") 308 | 309 | parser.add_argument( 310 | "--export_pred", type=bool_flag, default=False, help="export model predictions, returned by check_predictions()") 311 | 312 | parser.add_argument( 313 | "--max_class", type=int, default=101, help="Maximum class for reporting with error predictions" 314 | ) 315 | 316 | -------------------------------------------------------------------------------- /src/envs/encoders.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import numpy as np 3 | 4 | class Encoder(ABC): 5 | """ 6 | Base class for encoders, encodes and decodes matrices 7 | abstract methods for encoding/decoding numbers 8 | """ 9 | def __init__(self): 10 | pass 11 | 12 | @abstractmethod 13 | def encode(self, val): 14 | pass 15 | 16 | def decode(self, lst): 17 | v, p = self.parse(lst) 18 | if p == 0: 19 | return None 20 | return v 21 | 22 | class SymbolicInts(Encoder): 23 | """ 24 | one token per int from min to max (0 to 1 for binary, -10 to 10 for bounded ints, 0 to Q-1 for modular) 25 | optionally: add a prefix, e.g. E-100 E100 for exponents, N1 N5 for dimensions 26 | """ 27 | def __init__(self, min, max, prefix = ''): 28 | super().__init__() 29 | self.prefix = prefix 30 | self.symbols = [self.prefix + str(i) for i in range(min, max+1)] 31 | 32 | def encode(self, value): 33 | return [self.prefix+str(value)] 34 | 35 | def parse(self, lst): 36 | if len(lst) == 0 or (not lst[0] in self.symbols): 37 | return None, 0 38 | return int(lst[0][len(self.prefix):]), 1 39 | 40 | 41 | class PositionalInts(Encoder): 42 | """ 43 | Single integers, in base params.base (positive base), with the sign 44 | """ 45 | def __init__(self, base=10): 46 | super().__init__() 47 | self.base = base 48 | self.symbols = ['+', '-'] + [str(i) for i in range(self.base)] 49 | 50 | def encode(self, value): 51 | if value != 0: 52 | prefix = [] 53 | w = abs(value) 54 | while w > 0: 55 | prefix.append(str(w % self.base)) 56 | w = w // self.base 57 | prefix = prefix[::-1] 58 | else: 59 | prefix =['0'] 60 | prefix = (['+'] if value >= 0 else ['-']) + prefix 61 | return prefix 62 | 63 | def parse(self,lst): 64 | if len(lst) <= 1 or (lst[0] != '+' and lst[0] != '-'): 65 | return None, 0 66 | res = 0 67 | pos = 1 68 | for x in lst[1:]: 69 | if not (x.isdigit()): 70 | break 71 | res = res * self.base + int(x) 72 | pos += 1 73 | if pos < 2: return None, pos 74 | return -res if lst[0] == '-' else res, pos 75 | 76 | class NumberArray(Encoder): 77 | """ 78 | Array of integers, in base params.base (any shape) 79 | TODO modify to support float, complex (rationals), different subencoders 80 | """ 81 | def __init__(self, params, max_dim, dim_prefix, tensor_dim, code='pos_int'): 82 | super().__init__() 83 | self.tensor_dim = tensor_dim 84 | self.symbols = [] 85 | self.dimencoder = SymbolicInts(1, max_dim, dim_prefix) 86 | self.symbols.extend(self.dimencoder.symbols) 87 | if code == 'pos_int': 88 | self.subencoder = PositionalInts(params.base) 89 | else: 90 | self.subencoder = SymbolicInts(params.min_int, params.max_int) 91 | self.symbols.extend(self.subencoder.symbols) 92 | 93 | def encode(self, vector): 94 | lst = [] 95 | assert len(np.shape(vector)) == self.tensor_dim 96 | for d in np.shape(vector): 97 | lst.extend(self.dimencoder.encode(d)) 98 | for val in np.nditer(np.array(vector)): 99 | lst.extend(self.subencoder.encode(val)) 100 | return lst 101 | 102 | def decode(self, lst): 103 | shap = [] 104 | h = lst 105 | for _ in range(self.tensor_dim): 106 | v, _ = self.dimencoder.parse(h) 107 | if v is None: 108 | return None 109 | shap.append(v) 110 | h = h[1:] 111 | m = np.zeros(tuple(shap), dtype=int) 112 | for val in np.nditer(m, op_flags=['readwrite']): 113 | v, pos = self.subencoder.parse(h) 114 | if v is None: 115 | return None 116 | h = h[pos:] 117 | val[...] = v 118 | return m 119 | 120 | -------------------------------------------------------------------------------- /src/envs/generators.py: -------------------------------------------------------------------------------- 1 | 2 | from abc import ABC, abstractmethod 3 | import numpy as np 4 | import math 5 | from logging import getLogger 6 | 7 | logger = getLogger() 8 | 9 | 10 | class Generator(ABC): 11 | def __init__(self, params): 12 | super().__init__() 13 | 14 | @abstractmethod 15 | def generate(self, rng): 16 | pass 17 | 18 | @abstractmethod 19 | def evaluate(self, src, tgt, hyp): 20 | pass 21 | 22 | # empty for now 23 | class Sequence(Generator): 24 | def __init__(self, params, dims): 25 | super().__init__(params) 26 | 27 | self.operation = params.operation 28 | self.maxint = params.maxint 29 | self.minint = params.minint 30 | self.dims = dims 31 | self.modulus = params.modulus 32 | 33 | # integers from 1 to maxint, log uniform distribution 34 | def integer_loguniform_sequence(self, len, rng, type=None, max=None): 35 | maxint = self.maxint if max is None else max 36 | lgs = math.log10(maxint)*rng.rand(len) 37 | return np.int64(10 ** lgs) 38 | 39 | # integers from minint to maxint, uniform distribution 40 | def integer_sequence(self, len, rng, type=None, max=None): 41 | maxint = self.maxint if max is None else max 42 | return rng.randint(self. minint, maxint + 1, len) 43 | 44 | # integer (n,p) matrix, uniformly distributed coefficients between -maxint and maxint 45 | def integer_matrix(self, n, p, rng): 46 | maxint = (int)(self.maxint + 0.5) 47 | return rng.randint(- maxint, maxint + 1, (n, p)) 48 | 49 | def generate(self, rng, type=None): 50 | if self.operation in ["fraction_simplify","fraction_round"]: 51 | integers = self.integer_sequence(3, rng) 52 | if self.operation == "fraction_simplify": 53 | g = math.gcd(integers[1],integers[2]) 54 | if integers[0] == 1: 55 | integers[0] = rng.randint(2, self.maxint + 1) 56 | inp = [integers[0] * integers[1] // g, integers[0] * integers[2] // g ] 57 | out = [integers[1] // g , integers[2] // g] 58 | else: 59 | m1 = min(integers[1],integers[2]) 60 | m2 = max(integers[1],integers[2]) 61 | if m2 == m1: 62 | m1 = m2 - 1 63 | inp = [integers[0] * m2 + m1, m2] 64 | out = integers[0] 65 | return inp, out 66 | 67 | if self.operation in ["fraction_add", "fraction_compare", "fraction_determinant", "fraction_product"]: 68 | inp = self.integer_sequence(4, rng) 69 | if self.operation == "fraction_add": 70 | num = inp[0] * inp[3] + inp[1] * inp[2] 71 | den = inp[1] * inp[3] 72 | g = math.gcd(num, den) 73 | out = [int(num // g), int(den // g)] 74 | elif self.operation == "fraction_product": 75 | num = inp[0] * inp[2] 76 | den = inp[1] * inp[3] 77 | g = math.gcd(num, den) 78 | out = [int(num // g), int(den // g)] 79 | elif self.operation == "fraction_determinant": 80 | out = inp[0] * inp[3] - inp[1] * inp[2] 81 | else: 82 | cmp = inp[0] * inp[3] - inp[1] * inp[2] 83 | out = 1 if cmp > 0 else 0 84 | return inp, out 85 | if self.operation in ["modular_add","modular_mul"]: 86 | inp = self.integer_sequence(2, rng, type) 87 | out = (inp[0] + inp[1]) % self.modulus if self.operation =="modular_add" else (inp[0] * inp[1]) % self.modulus 88 | return inp, out 89 | if self.operation in ["gcd"]: 90 | inp = self.integer_sequence(2, rng, type) 91 | out = math.gcd(inp[0], inp[1]) 92 | return inp, out 93 | if self.operation == "matrix_rank": 94 | maxrank = min(self.dims[0], self.dims[1]) 95 | rank = rng.randint(1, maxrank + 1) 96 | 97 | P = self.integer_matrix(self.dims[0], rank, rng) 98 | Q = self.integer_matrix(rank, self.dims[1], rng) 99 | input = P @ Q 100 | check_rank = np.linalg.matrix_rank(input) 101 | if check_rank != rank: 102 | return None 103 | return input, rank 104 | 105 | return None 106 | 107 | def evaluate(self, src, tgt, hyp): 108 | 109 | return 0, [],[] 110 | 111 | -------------------------------------------------------------------------------- /src/evaluator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from logging import getLogger 9 | from collections import OrderedDict 10 | import os 11 | import torch 12 | 13 | from .utils import to_cuda 14 | 15 | 16 | logger = getLogger() 17 | 18 | 19 | def idx_to_infix(env, idx, input=True): 20 | """ 21 | Convert an indexed prefix expression to SymPy. 22 | """ 23 | prefix = [env.id2word[wid] for wid in idx] 24 | infix = env.input_to_infix(prefix) if input else env.output_to_infix(prefix) 25 | return infix 26 | 27 | 28 | def check_hypothesis(eq): 29 | """ 30 | Check a hypothesis for a given equation and its solution. 31 | """ 32 | env = Evaluator.ENV 33 | 34 | src = [env.id2word[wid] for wid in eq["src"]] 35 | tgt = [env.id2word[wid] for wid in eq["tgt"]] 36 | hyp = [env.id2word[wid] for wid in eq["hyp"]] 37 | 38 | # update hypothesis 39 | eq["src"] = env.input_to_infix(src) 40 | eq["tgt"] = tgt 41 | eq["hyp"] = hyp 42 | if env.export_pred: 43 | try: 44 | m, s1, s2, nb = env.check_prediction(src, tgt, hyp) 45 | except Exception: 46 | m = -1 47 | s1 = [] 48 | s2 = [] 49 | nb = None 50 | else: 51 | try: 52 | m, s1, s2 = env.check_prediction(src, tgt, hyp) 53 | except Exception: 54 | m = -1 55 | s1 = [] 56 | s2 = [] 57 | eq["is_valid"] = m 58 | for i in range(env.n_eval_metrics): 59 | if m == 2: 60 | eq[f"is_valid{i+1}"] = 1 61 | elif m < 0: 62 | eq[f"is_valid{i+1}"] = 0 63 | else: 64 | eq[f"is_valid{i+1}"] = s1[i] 65 | for i in range(env.n_error_metrics): 66 | eq[f"is_error{i+1}"] = s2[i] if m in [0,1] else 0 67 | 68 | if env.export_pred: 69 | eq["pred"] = nb 70 | return eq 71 | 72 | 73 | class Evaluator(object): 74 | 75 | ENV = None 76 | 77 | def __init__(self, trainer): 78 | """ 79 | Initialize evaluator. 80 | """ 81 | self.trainer = trainer 82 | self.modules = trainer.modules 83 | self.params = trainer.params 84 | self.env = trainer.env 85 | Evaluator.ENV = trainer.env 86 | 87 | def run_all_evals(self): 88 | """ 89 | Run all evaluations. 90 | 91 | """ 92 | params = self.params 93 | scores = OrderedDict({"epoch": self.trainer.epoch}) 94 | 95 | # save statistics about generated data 96 | if params.export_data: 97 | scores["total"] = self.trainer.total_samples 98 | return scores 99 | 100 | data_type_list = ["valid"] 101 | if params.eval_data != '': 102 | l = len(params.eval_data.split(',')) 103 | for i in range(l): 104 | data_type_list.append("test"+(str(i+1) if i>0 else "")) 105 | 106 | with torch.no_grad(): 107 | for data_type in data_type_list: 108 | for task in params.tasks: 109 | if params.beam_eval: 110 | self.enc_dec_step_beam(data_type, task, scores) 111 | else: 112 | self.enc_dec_step(data_type, task, scores) 113 | return scores 114 | 115 | def enc_dec_step(self, data_type, task, scores): 116 | """ 117 | Encoding / decoding step. 118 | """ 119 | params = self.params 120 | env = self.env 121 | max_beam_length = params.max_output_len + 2 122 | if params.architecture != "decoder_only": 123 | encoder = ( 124 | self.modules["encoder"].module 125 | if params.multi_gpu 126 | else self.modules["encoder"] 127 | ) 128 | encoder.eval() 129 | if params.architecture != "encoder_only": 130 | decoder = ( 131 | self.modules["decoder"].module 132 | if params.multi_gpu 133 | else self.modules["decoder"] 134 | ) 135 | decoder.eval() 136 | assert params.eval_verbose in [0, 1,2] 137 | assert params.eval_verbose_print is False or params.eval_verbose > 0 138 | assert task in env.TRAINING_TASKS 139 | 140 | # evaluation details 141 | if params.eval_verbose: 142 | eval_path = os.path.join( 143 | params.dump_path, f"eval.{data_type}.{task}.{scores['epoch']}" 144 | ) 145 | f_export = open(eval_path, "w") 146 | logger.info(f"Writing evaluation results in {eval_path} ...") 147 | 148 | def display_logs(logs, offset): # FC A revoir 149 | """ 150 | Display detailed results about success / fails. 151 | """ 152 | if params.eval_verbose == 0: 153 | return 154 | for i, res in sorted(logs.items()): 155 | n_valid = sum([int(v) for _, _, v in res["hyps"]]) 156 | s = f"Equation {offset + i} ({n_valid}/{len(res['hyps'])})\n" 157 | s += f"src={res['src']}\ntgt={res['tgt']}\n" 158 | for hyp, score, valid in res["hyps"]: 159 | if score is None: 160 | s += f"{int(valid)} {hyp}\n" 161 | else: 162 | s += f"{int(valid)} {score :.3e} {hyp}\n" 163 | if params.eval_verbose_print: 164 | logger.info(s) 165 | f_export.write(s + "\n") 166 | f_export.flush() 167 | 168 | # stats 169 | xe_loss = 0 170 | n_valid = torch.zeros(10000, dtype=torch.long) 171 | n_total = torch.zeros(10000, dtype=torch.long) 172 | n_perfect_match = 0 173 | n_correct = 0 174 | n_perfect = 0 175 | if env.n_eval_metrics > 0: 176 | eval_metrics = torch.zeros(self.env.n_eval_metrics) 177 | if env.n_error_metrics > 0: 178 | error_metrics = torch.zeros(self.env.n_error_metrics) 179 | 180 | if env.export_pred: 181 | n_pairs = torch.zeros((env.max_class+1,env.max_class+1), dtype=torch.long) 182 | 183 | # iterator 184 | iterator = self.env.create_test_iterator( 185 | data_type, 186 | task, 187 | data_path=params.eval_data.split(',') if params.eval_data != "" else None, 188 | batch_size=params.batch_size_eval, 189 | params=params, 190 | size=params.eval_size, 191 | ) 192 | eval_size = len(iterator.dataset) 193 | 194 | for (x1, len1), (x2, len2), nb_ops in iterator: 195 | 196 | # cuda 197 | x1_, len1_, x2_, len2_ = to_cuda(x1, len1, x2, len2) 198 | # target words to predict 199 | if params.architecture != "encoder_only": 200 | alen = torch.arange(len2_.max(), dtype=torch.long, device=len2_.device) 201 | pred_mask = ( 202 | alen[:, None] < len2_[None] - 1 203 | ) # do not predict anything given the last target word 204 | y = x2_[1:].masked_select(pred_mask[:-1]) 205 | assert len(y) == (len2_ - 1).sum().item() 206 | else: 207 | alen = torch.arange(len1_.max(), dtype=torch.long, device=len2_.device) 208 | pred_mask = ( 209 | (alen[:, None] < len2_[None]) # & (alen[:, None] > torch.zeros_like(len2)[None]) 210 | ) 211 | y= torch.cat((x2_,torch.full((len1_.max()-len2_.max(),len2_.size(0)),self.env.eos_index,device=len2_.device)),0) 212 | y = y.masked_select(pred_mask) 213 | 214 | bs = len(len1_) 215 | 216 | # forward / loss 217 | if params.architecture == "encoder_decoder": 218 | if params.lstm: 219 | _, hidden = encoder("fwd", x=x1_, lengths=len1_, causal=False) 220 | decoded, _ = decoder( 221 | "fwd", 222 | x=x2_, 223 | lengths=len2_, 224 | causal=True, 225 | src_enc=hidden, 226 | ) 227 | word_scores, loss = decoder( 228 | "predict", tensor=decoded, pred_mask=pred_mask, y=y, get_scores=True 229 | ) 230 | else: 231 | encoded = encoder("fwd", x=x1_, lengths=len1_, causal=False) 232 | decoded = decoder( 233 | "fwd", 234 | x=x2_, 235 | lengths=len2_, 236 | causal=True, 237 | src_enc=encoded.transpose(0, 1), 238 | src_len=len1_, 239 | ) 240 | word_scores, loss = decoder( 241 | "predict", tensor=decoded, pred_mask=pred_mask, y=y, get_scores=True 242 | ) 243 | elif params.architecture == "encoder_only": 244 | encoded = encoder("fwd", x=x1_, lengths=len1_, causal=False) 245 | word_scores, loss = encoder( 246 | "predict", tensor=encoded, pred_mask=pred_mask, y=y, get_scores=True 247 | ) 248 | else: 249 | decoded = decoder("fwd", x=x2_, lengths=len2_, causal=True, src_enc=None, src_len=None) 250 | word_scores, loss = decoder( 251 | "predict", tensor=decoded, pred_mask=pred_mask, y=y, get_scores=True 252 | ) 253 | 254 | # correct outputs per sequence / valid top-1 predictions 255 | t = torch.zeros_like(pred_mask, device=y.device) 256 | t[pred_mask] += word_scores.max(1)[1] == y 257 | if params.architecture == "encoder_only": 258 | valid = (t.sum(0) == len2_).cpu().long() 259 | else: 260 | valid = (t.sum(0) == len2_ - 1).cpu().long() 261 | n_perfect_match += valid.sum().item() 262 | 263 | # export evaluation details 264 | beam_log = {} 265 | for i in range(len(len1)): 266 | if params.architecture == "decoder_only": 267 | src = idx_to_infix(env, x2[1 : len1[i] - 1, i].tolist(), True) 268 | tgt = idx_to_infix(env, x2[len1[i] : len2[i] - 1, i].tolist(), False) 269 | else: 270 | out_offset = 0 if params.architecture == "encoder_only" else 1 271 | src = idx_to_infix(env, x1[0 : len1[i] - 1, i].tolist(), True) 272 | tgt = idx_to_infix(env, x2[out_offset : len2[i] - 1, i].tolist(), False) 273 | if valid[i]: 274 | beam_log[i] = {"src": src, "tgt": tgt, "hyps": [(tgt, None, True)]} 275 | if env.export_pred: 276 | result = nb_ops[i] if nb_ops[i] < env.max_class else env.max_class 277 | n_pairs[result][result] += 1 278 | 279 | # stats 280 | xe_loss += loss.item() * len(y) 281 | n_valid.index_add_(-1, nb_ops, valid) 282 | n_total.index_add_(-1, nb_ops, torch.ones_like(nb_ops)) 283 | 284 | # continue if everything is correct. if eval_verbose, perform 285 | # a full beam search, even on correct greedy generations 286 | if valid.sum() == len(valid) and params.eval_verbose < 2: 287 | display_logs(beam_log, offset=n_total.sum().item() - bs) 288 | continue 289 | 290 | # invalid top-1 predictions - check if there is a solution in the beam 291 | invalid_idx = (1 - valid).nonzero().view(-1) 292 | logger.info( 293 | f"({n_total.sum().item()}/{eval_size}) Found " 294 | f"{bs - len(invalid_idx)}/{bs} valid top-1 predictions. " 295 | f"Generating solutions ..." 296 | ) 297 | 298 | # generate 299 | if params.architecture == "encoder_decoder": 300 | if params.lstm: 301 | generated, _ = decoder.generate( 302 | hidden, 303 | len1_, 304 | max_len=max_beam_length, 305 | ) 306 | else: 307 | generated, _ = decoder.generate( 308 | encoded.transpose(0, 1), 309 | len1_, 310 | max_len=max_beam_length, 311 | ) 312 | generated=generated.transpose(0, 1) 313 | elif params.architecture == "encoder_only": 314 | generated = encoder.decode(x1_, len1_, max_beam_length).cpu() 315 | else: 316 | generated, _ = decoder.generate(x1_, len1_, max_beam_length) 317 | generated=generated.transpose(0, 1) 318 | # prepare inputs / hypotheses to check 319 | # if eval_verbose < 2, no beam search on equations solved greedily 320 | inputs = [] 321 | for i in range(len(generated)): 322 | if valid[i] and params.eval_verbose < 2: 323 | continue 324 | if params.architecture == "decoder_only": 325 | inputs.append( 326 | { 327 | "i": i, 328 | "src": x2[1 : len1[i] - 1, i].tolist(), 329 | "tgt": x2[len1[i] : len2[i] - 1, i].tolist(), 330 | "hyp": generated[i][len1[i]:].tolist(), 331 | "task": task, 332 | } 333 | ) 334 | else: 335 | out_offset = 0 if params.architecture == "encoder_only" else 1 336 | inputs.append( 337 | { 338 | "i": i, 339 | "src": x1[0 : len1[i] - 1, i].tolist(), 340 | "tgt": x2[out_offset : len2[i] - 1, i].tolist(), 341 | "hyp": generated[i][out_offset:].tolist(), 342 | "task": task, 343 | } 344 | ) 345 | 346 | # check hypotheses with multiprocessing 347 | outputs = [] 348 | #if params.windows is True: 349 | for inp in inputs: 350 | outputs.append(check_hypothesis(inp)) 351 | #else: 352 | # with ProcessPoolExecutor(max_workers=20) as executor: 353 | # for output in executor.map(check_hypothesis, inputs, chunksize=1): 354 | # outputs.append(output) 355 | 356 | # read results 357 | for i in range(bs): 358 | # select hypotheses associated to current equation 359 | gens = [o for o in outputs if o["i"] == i] 360 | assert (len(gens) == 0) == (valid[i] and params.eval_verbose < 2) 361 | assert (i in beam_log) == valid[i] 362 | if len(gens) == 0: 363 | continue 364 | 365 | assert len(gens) == 1 366 | # source / target 367 | gen = gens[0] 368 | src = gen["src"] 369 | tgt = gen["tgt"] 370 | beam_log[i] = {"src": src, "tgt": tgt, "hyps": []} 371 | 372 | # sanity check 373 | assert ( 374 | gen["src"] == src 375 | and gen["tgt"] == tgt 376 | and gen["i"] == i 377 | ) 378 | 379 | # if hypothesis is correct, and we did not find a correct one before 380 | is_valid = gen["is_valid"] 381 | is_b_valid = is_valid > 0 382 | if not valid[i]: 383 | if is_valid ==2: 384 | n_perfect += 1 385 | if is_valid >= 0: 386 | n_correct += 1 387 | if is_valid > 0: 388 | n_valid[nb_ops[i]] += 1 389 | valid[i] = 1 390 | for em in range(env.n_eval_metrics > 0): 391 | eval_metrics[em] += gen[f"is_valid{em+1}"] 392 | for em in range(env.n_error_metrics > 0): 393 | error_metrics[em] += gen[f"is_error{em+1}"] 394 | 395 | if env.export_pred: 396 | is_valid4 = gen["pred"] 397 | result = nb_ops[i] if nb_ops[i] < env.max_class else env.max_class 398 | prediction = env.max_class if (is_valid4 is None or is_valid4 > env.max_class) else is_valid4 399 | n_pairs[result][prediction] += 1 400 | 401 | 402 | # update beam log 403 | beam_log[i]["hyps"].append((gen["hyp"], None, is_b_valid)) # gen["score"], is_b_valid)) 404 | 405 | # valid solutions found with beam search 406 | logger.info( 407 | f" Found {valid.sum().item()}/{bs} solutions in beam hypotheses." 408 | ) 409 | 410 | # export evaluation details 411 | if params.eval_verbose: 412 | assert len(beam_log) == bs 413 | display_logs(beam_log, offset=n_total.sum().item() - bs) 414 | 415 | # evaluation details 416 | if params.eval_verbose: 417 | f_export.close() 418 | logger.info(f"Evaluation results written in {eval_path}") 419 | 420 | # log 421 | _n_valid = n_valid.sum().item() 422 | _n_total = n_total.sum().item() 423 | logger.info( 424 | f"{_n_valid}/{_n_total} ({100. * _n_valid / _n_total}%) " 425 | f"examples were evaluated correctly." 426 | ) 427 | 428 | # compute perplexity and prediction accuracy 429 | assert _n_total == eval_size 430 | scores[f"{data_type}_{task}_xe_loss"] = xe_loss / _n_total 431 | scores[f"{data_type}_{task}_acc"] = 100.0 * _n_valid / _n_total 432 | scores[f"{data_type}_{task}_perfect"] = 100.0 * (n_perfect_match + n_perfect) / _n_total 433 | scores[f"{data_type}_{task}_correct"] = ( 434 | 100.0 * (n_perfect_match + n_correct) / _n_total 435 | ) 436 | 437 | for em in range(env.n_eval_metrics > 0): 438 | scores[f"{data_type}_{task}_acc_eval{em+1}"] = 100.0*(n_perfect_match + eval_metrics[em]) / _n_total 439 | for em in range(env.n_error_metrics > 0): 440 | scores[f"{data_type}_{task}_acc_error{em+1}"] = 100.0*error_metrics[em] / _n_total 441 | 442 | 443 | # per class perplexity and prediction accuracy 444 | for i in range(len(n_total)): 445 | if n_total[i].item() == 0: 446 | continue 447 | e = env.decode_class(i) 448 | scores[f"{data_type}_{task}_acc_{e}"] = ( 449 | 100.0 * n_valid[i].item() / max(n_total[i].item(), 1) 450 | ) 451 | if n_valid[i].item() > 0: 452 | logger.info( 453 | f"{e}: {n_valid[i].item()} / {n_total[i].item()} " 454 | f"({100. * n_valid[i].item() / max(n_total[i].item(), 1):.2f}%)" 455 | ) 456 | if env.export_pred: 457 | logger.info(f"{data_type} predicted pairs") 458 | for i in range(env.max_class+1): 459 | for j in range(env.max_class+1): 460 | if n_pairs[i][j].item() >= 10: 461 | logger.info(f"{i}-{j}: {n_pairs[i][j].item()} ({100. * n_pairs[i][j].item() / n_pairs[i].sum().item():2f}%)") 462 | 463 | def enc_dec_step_beam(self, data_type, task, scores, size=None): 464 | """ 465 | Encoding / decoding step with beam generation and SymPy check. 466 | """ 467 | params = self.params 468 | env = self.env 469 | max_beam_length = params.max_output_len + 2 470 | encoder = ( 471 | self.modules["encoder"].module 472 | if params.multi_gpu 473 | else self.modules["encoder"] 474 | ) 475 | decoder = ( 476 | self.modules["decoder"].module 477 | if params.multi_gpu 478 | else self.modules["decoder"] 479 | ) 480 | encoder.eval() 481 | decoder.eval() 482 | assert params.eval_verbose in [0, 1, 2] 483 | assert params.eval_verbose_print is False or params.eval_verbose > 0 484 | assert task in env.TRAINING_TASKS 485 | 486 | # evaluation details 487 | if params.eval_verbose: 488 | eval_path = os.path.join( 489 | params.dump_path, f"eval.beam.{data_type}.{task}.{scores['epoch']}" 490 | ) 491 | f_export = open(eval_path, "w") 492 | logger.info(f"Writing evaluation results in {eval_path} ...") 493 | 494 | def display_logs(logs, offset): 495 | """ 496 | Display detailed results about success / fails. 497 | """ 498 | if params.eval_verbose == 0: 499 | return 500 | for i, res in sorted(logs.items()): 501 | n_valid = sum([int(v) for _, _, v in res["hyps"]]) 502 | s = f"Equation {offset + i} ({n_valid}/{len(res['hyps'])})\n" 503 | s += f"src={res['src']}\ntgt={res['tgt']}\n" 504 | for hyp, score, valid in res["hyps"]: 505 | if score is None: 506 | s += f"{int(valid)} {hyp}\n" 507 | else: 508 | s += f"{int(valid)} {score :.3e} {hyp}\n" 509 | if params.eval_verbose_print: 510 | logger.info(s) 511 | f_export.write(s + "\n") 512 | f_export.flush() 513 | 514 | # stats 515 | xe_loss = 0 516 | n_valid = torch.zeros(10000, dtype=torch.long) 517 | n_total = torch.zeros(10000, dtype=torch.long) 518 | n_perfect_match = 0 519 | n_perfect = 0 520 | n_correct = 0 521 | if env.n_eval_metrics > 0: 522 | eval_metrics = torch.zeros(self.env.n_eval_metrics) 523 | if env.n_error_metrics > 0: 524 | error_metrics = torch.zeros(self.env.n_error_metrics) 525 | 526 | # iterator 527 | iterator = env.create_test_iterator( 528 | data_type, 529 | task, 530 | data_path=params.eval_data.split(',') if params.eval_data != "" else None, 531 | batch_size=params.batch_size_eval, 532 | params=params, 533 | size=params.eval_size, 534 | ) 535 | eval_size = len(iterator.dataset) 536 | 537 | for (x1, len1), (x2, len2), nb_ops in iterator: 538 | 539 | # target words to predict 540 | alen = torch.arange(len2.max(), dtype=torch.long, device=len2.device) 541 | pred_mask = ( 542 | alen[:, None] < len2[None] - 1 543 | ) # do not predict anything given the last target word 544 | y = x2[1:].masked_select(pred_mask[:-1]) 545 | assert len(y) == (len2 - 1).sum().item() 546 | 547 | # cuda 548 | x1_, len1_, x2, len2, y = to_cuda(x1, len1, x2, len2, y) 549 | bs = len(len1) 550 | 551 | # forward 552 | if params.lstm: 553 | encoded, hidden = encoder("fwd", x=x1_, lengths=len1_, causal=False) 554 | decoded, _ = decoder( 555 | "fwd", 556 | x=x2, 557 | lengths=len2, 558 | causal=True, 559 | src_enc=hidden, 560 | ) 561 | else: 562 | encoded = encoder("fwd", x=x1_, lengths=len1_, causal=False) 563 | decoded = decoder( 564 | "fwd", 565 | x=x2, 566 | lengths=len2, 567 | causal=True, 568 | src_enc=encoded.transpose(0, 1), 569 | src_len=len1_, 570 | ) 571 | word_scores, loss = decoder( 572 | "predict", tensor=decoded, pred_mask=pred_mask, y=y, get_scores=True 573 | ) 574 | 575 | # correct outputs per sequence / valid top-1 predictions 576 | t = torch.zeros_like(pred_mask, device=y.device) 577 | t[pred_mask] += word_scores.max(1)[1] == y 578 | valid = (t.sum(0) == len2 - 1).cpu().long() 579 | n_perfect_match += valid.sum().item() 580 | 581 | # save evaluation details 582 | beam_log = {} 583 | for i in range(len(len1)): 584 | src = idx_to_infix(env, x1[1 : len1[i] - 1, i].tolist(), True) 585 | tgt = idx_to_infix(env, x2[1 : len2[i] - 1, i].tolist(), False) 586 | if valid[i]: 587 | beam_log[i] = {"src": src, "tgt": tgt, "hyps": [(tgt, None, True)]} 588 | 589 | # stats 590 | xe_loss += loss.item() * len(y) 591 | n_valid.index_add_(-1, nb_ops, valid) 592 | n_total.index_add_(-1, nb_ops, torch.ones_like(nb_ops)) 593 | 594 | # continue if everything is correct. if eval_verbose, perform 595 | # a full beam search, even on correct greedy generations 596 | if valid.sum() == len(valid) and params.eval_verbose < 2: 597 | display_logs(beam_log, offset=n_total.sum().item() - bs) 598 | continue 599 | 600 | # invalid top-1 predictions - check if there is a solution in the beam 601 | invalid_idx = (1 - valid).nonzero().view(-1) 602 | logger.info( 603 | f"({n_total.sum().item()}/{eval_size}) Found " 604 | f"{bs - len(invalid_idx)}/{bs} valid top-1 predictions. " 605 | f"Generating solutions ..." 606 | ) 607 | 608 | # generate 609 | if params.lstm: 610 | _, _, generations = decoder.generate_beam( 611 | hidden, 612 | len1_, 613 | beam_size=params.beam_size, 614 | length_penalty=params.beam_length_penalty, 615 | early_stopping=params.beam_early_stopping, 616 | max_len=max_beam_length, 617 | ) 618 | else: 619 | _, _, generations = decoder.generate_beam( 620 | encoded.transpose(0, 1), 621 | len1_, 622 | beam_size=params.beam_size, 623 | length_penalty=params.beam_length_penalty, 624 | early_stopping=params.beam_early_stopping, 625 | max_len=max_beam_length, 626 | ) 627 | 628 | # prepare inputs / hypotheses to check 629 | # if eval_verbose < 2, no beam search on equations solved greedily 630 | inputs = [] 631 | for i in range(len(generations)): 632 | if valid[i] and params.eval_verbose < 2: 633 | continue 634 | for j, (score, hyp) in enumerate( 635 | sorted(generations[i].hyp, key=lambda x: x[0], reverse=True) 636 | ): 637 | inputs.append( 638 | { 639 | "i": i, 640 | "j": j, 641 | "score": score, 642 | "src": x1[1 : len1[i] - 1, i].tolist(), 643 | "tgt": x2[1 : len2[i] - 1, i].tolist(), 644 | "hyp": hyp[1:].tolist(), 645 | "task": task, 646 | } 647 | ) 648 | 649 | # check hypotheses with multiprocessing 650 | outputs = [] 651 | #if params.windows is True: 652 | for inp in inputs: 653 | outputs.append(check_hypothesis(inp)) 654 | #else: 655 | # with ProcessPoolExecutor(max_workers=20) as executor: 656 | # for output in executor.map(check_hypothesis, inputs, chunksize=1): 657 | # outputs.append(output) 658 | 659 | # read results 660 | for i in range(bs): 661 | 662 | # select hypotheses associated to current equation 663 | gens = sorted([o for o in outputs if o["i"] == i], key=lambda x: x["j"]) 664 | assert (len(gens) == 0) == (valid[i] and params.eval_verbose < 2) and ( 665 | i in beam_log 666 | ) == valid[i] 667 | if len(gens) == 0: 668 | continue 669 | 670 | # source / target 671 | src = gens[0]["src"] 672 | tgt = gens[0]["tgt"] 673 | beam_log[i] = {"src": src, "tgt": tgt, "hyps": []} 674 | 675 | curr_correct = 0 676 | curr_perfect = 0 677 | if env.n_eval_metrics > 0: 678 | curr_eval_metrics = torch.zeros(self.env.n_eval_metrics) 679 | if env.n_error_metrics > 0: 680 | curr_error_metrics = torch.zeros(self.env.n_error_metrics) 681 | curr_valid = 0 682 | 683 | # for each hypothesis 684 | for j, gen in enumerate(gens): 685 | 686 | # sanity check 687 | assert ( 688 | gen["src"] == src 689 | and gen["tgt"] == tgt 690 | and gen["i"] == i 691 | and gen["j"] == j 692 | ) 693 | 694 | # if hypothesis is correct, and we did not find a correct one before 695 | is_valid = gen["is_valid"] 696 | is_b_valid = is_valid > 0 697 | if not valid[i]: 698 | if is_valid ==2: 699 | curr_perfect = 1 700 | if is_valid >= 0: 701 | curr_correct = 1 702 | if is_valid > 0: 703 | curr_valid = 1 704 | 705 | for em in range(env.n_eval_metrics > 0): 706 | if gen[f"is_valid{em+1}"] == 1: 707 | curr_eval_metrics[em]=1 708 | 709 | # update beam log 710 | beam_log[i]["hyps"].append((gen["hyp"], gen["score"], is_b_valid)) 711 | 712 | if not valid[i]: 713 | n_correct += curr_correct 714 | n_perfect += curr_perfect 715 | for em in range(env.n_eval_metrics > 0): 716 | eval_metrics[em] += curr_eval_metrics[em] 717 | valid[i] = curr_valid 718 | n_valid[nb_ops[i]] += curr_valid 719 | 720 | # valid solutions found with beam search 721 | logger.info( 722 | f" Found {valid.sum().item()}/{bs} solutions in beam hypotheses." 723 | ) 724 | 725 | # export evaluation details 726 | if params.eval_verbose: 727 | assert len(beam_log) == bs 728 | display_logs(beam_log, offset=n_total.sum().item() - bs) 729 | 730 | # evaluation details 731 | if params.eval_verbose: 732 | f_export.close() 733 | logger.info(f"Evaluation results written in {eval_path}") 734 | 735 | # log 736 | _n_valid = n_valid.sum().item() 737 | _n_total = n_total.sum().item() 738 | logger.info( 739 | f"{_n_valid}/{_n_total} ({100. * _n_valid / _n_total}%) " 740 | f"equations were evaluated correctly." 741 | ) 742 | 743 | # compute perplexity and prediction accuracy 744 | assert _n_total == eval_size 745 | scores[f"{data_type}_{task}_xe_loss"] = xe_loss / _n_total 746 | scores[f"{data_type}_{task}_acc"] = 100.0 * _n_valid / _n_total 747 | scores[f"{data_type}_{task}_perfect"] = 100.0 * (n_perfect_match + n_perfect) / _n_total 748 | scores[f"{data_type}_{task}_correct"] = ( 749 | 100.0 * (n_perfect_match + n_correct) / _n_total 750 | ) 751 | for em in range(env.n_eval_metrics > 0): 752 | scores[f"{data_type}_{task}_acc_eval{em+1}"] = 100.0*(n_perfect_match + eval_metrics[em]) / _n_total 753 | 754 | # per class perplexity and prediction accuracy 755 | for i in range(len(n_total)): 756 | if n_total[i].item() == 0: 757 | continue 758 | e = env.decode_class(i) 759 | logger.info( 760 | f"{e}: {n_valid[i].item()} / {n_total[i].item()} " 761 | f"({100. * n_valid[i].item() / max(n_total[i].item(), 1)}%)" 762 | ) 763 | scores[f"{data_type}_{task}_acc_{e}"] = ( 764 | 100.0 * n_valid[i].item() / max(n_total[i].item(), 1) 765 | ) 766 | -------------------------------------------------------------------------------- /src/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import logging 9 | import time 10 | from datetime import timedelta 11 | 12 | 13 | class LogFormatter(): 14 | 15 | def __init__(self): 16 | self.start_time = time.time() 17 | 18 | def format(self, record): 19 | elapsed_seconds = round(record.created - self.start_time) 20 | 21 | prefix = "%s - %s - %s" % ( 22 | record.levelname, 23 | time.strftime('%x %X'), 24 | timedelta(seconds=elapsed_seconds) 25 | ) 26 | message = record.getMessage() 27 | message = message.replace('\n', '\n' + ' ' * (len(prefix) + 3)) 28 | return "%s - %s" % (prefix, message) if message else '' 29 | 30 | 31 | def create_logger(filepath, rank): 32 | """ 33 | Create a logger. 34 | Use a different log file for each process. 35 | """ 36 | # create log formatter 37 | log_formatter = LogFormatter() 38 | 39 | # create file handler and set level to debug 40 | if filepath is not None: 41 | if rank > 0: 42 | filepath = '%s-%i' % (filepath, rank) 43 | file_handler = logging.FileHandler(filepath, "a") 44 | file_handler.setLevel(logging.DEBUG) 45 | file_handler.setFormatter(log_formatter) 46 | 47 | # create console handler and set level to info 48 | console_handler = logging.StreamHandler() 49 | console_handler.setLevel(logging.INFO) 50 | console_handler.setFormatter(log_formatter) 51 | 52 | # create logger and set level to debug 53 | logger = logging.getLogger() 54 | logger.handlers = [] 55 | logger.setLevel(logging.DEBUG) 56 | logger.propagate = False 57 | if filepath is not None: 58 | logger.addHandler(file_handler) 59 | logger.addHandler(console_handler) 60 | 61 | # reset logger elapsed time 62 | def reset_time(): 63 | log_formatter.start_time = time.time() 64 | logger.reset_time = reset_time 65 | 66 | return logger 67 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from logging import getLogger 9 | import os 10 | import torch 11 | 12 | from .transformer import TransformerModel 13 | from .lstm import LSTMModel 14 | 15 | 16 | logger = getLogger() 17 | 18 | 19 | def check_model_params(params): 20 | """ 21 | Check models parameters. 22 | """ 23 | # model dimensions 24 | assert params.enc_emb_dim % params.n_enc_heads == 0 25 | assert params.dec_emb_dim % params.n_dec_heads == 0 26 | 27 | # reload a pretrained model 28 | if params.reload_model != "": 29 | assert os.path.isfile(params.reload_model) 30 | 31 | 32 | def build_modules(env, params): 33 | """ 34 | Build modules. 35 | """ 36 | assert params.architecture in ["encoder_decoder","encoder_only","decoder_only"] 37 | assert not params.lstm or params.architecture == "encoder_decoder" 38 | modules = {} 39 | if params.architecture == "encoder_decoder": 40 | if params.lstm: 41 | modules["encoder"] = LSTMModel( 42 | params, env.id2word, is_encoder=True, with_output=False 43 | ) 44 | modules["decoder"] = LSTMModel( 45 | params, env.id2word, is_encoder=False, with_output=True 46 | ) 47 | else: 48 | modules["encoder"] = TransformerModel( 49 | params, env.id2word, is_encoder=True, with_output=False 50 | ) 51 | modules["decoder"] = TransformerModel( 52 | params, env.id2word, is_encoder=False, with_output=True 53 | ) 54 | elif params.architecture == "encoder_only": 55 | modules["encoder"] = TransformerModel( 56 | params, env.id2word, is_encoder=True, with_output=True 57 | ) 58 | else: 59 | modules["decoder"] = TransformerModel( 60 | params, env.id2word, is_encoder=False, with_output=True 61 | ) 62 | 63 | 64 | # reload pretrained modules 65 | if params.reload_model != "": 66 | logger.info(f"Reloading modules from {params.reload_model} ...") 67 | reloaded = torch.load(params.reload_model) 68 | for k, v in modules.items(): 69 | assert k in reloaded 70 | if all([k2.startswith("module.") for k2 in reloaded[k].keys()]): 71 | reloaded[k] = { 72 | k2[len("module.") :]: v2 for k2, v2 in reloaded[k].items() 73 | } 74 | v.load_state_dict(reloaded[k]) 75 | 76 | # log 77 | logger.debug(f"{len(modules)}") 78 | 79 | for k, v in modules.items(): 80 | logger.debug(f"{v}: {v}") 81 | for k, v in modules.items(): 82 | logger.info( 83 | f"Number of parameters ({k}): {sum([p.numel() for p in v.parameters() if p.requires_grad])}" 84 | ) 85 | 86 | # cuda 87 | if not params.cpu: 88 | for v in modules.values(): 89 | v.cuda() 90 | 91 | return modules 92 | -------------------------------------------------------------------------------- /src/model/lstm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from logging import getLogger 9 | # import math 10 | # import itertools 11 | # import numpy as np 12 | import torch 13 | # from torch._C import _set_backcompat_keepdim_warn 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | from .transformer import Embedding 17 | from .transformer import BeamHypotheses 18 | 19 | 20 | logger = getLogger() 21 | 22 | 23 | def get_masks(slen, lengths): 24 | """ 25 | Generate hidden states mask, and optionally an attention mask. 26 | """ 27 | assert lengths.max().item() <= slen 28 | bs = lengths.size(0) 29 | alen = torch.arange(slen, dtype=torch.long, device=lengths.device) 30 | mask = alen < lengths[:, None] 31 | 32 | # sanity check 33 | assert mask.size() == (bs, slen) 34 | 35 | return mask 36 | 37 | 38 | class LSTMModel(nn.Module): 39 | def __init__(self, params, id2word, is_encoder, with_output): 40 | """ 41 | Transformer model (encoder or decoder). 42 | """ 43 | super().__init__() 44 | 45 | # encoder / decoder, output layer 46 | self.dtype = torch.half if params.fp16 else torch.float 47 | self.is_encoder = is_encoder 48 | self.is_decoder = not is_encoder 49 | self.with_output = with_output 50 | 51 | self.lstm = params.lstm 52 | self.GRU = params.GRU 53 | 54 | # dictionary 55 | self.n_words = params.n_words 56 | self.eos_index = params.eos_index 57 | self.pad_index = params.pad_index 58 | self.id2word = id2word 59 | assert len(self.id2word) == self.n_words 60 | 61 | # model parameters 62 | assert params.enc_emb_dim == params.dec_emb_dim 63 | self.dim = params.enc_emb_dim if is_encoder else params.dec_emb_dim # 512 by default 64 | self.src_dim = params.enc_emb_dim 65 | self.bidirectional = params.bidirectional 66 | self.hidden_dim = params.lstm_hidden_dim 67 | self.n_layers = params.n_enc_layers if is_encoder else params.n_dec_layers 68 | self.dropout = params.dropout 69 | 70 | # embeddings 71 | self.embeddings = Embedding(self.n_words, self.dim, padding_idx=self.pad_index) 72 | self.layer_norm_emb = nn.LayerNorm(self.dim, eps=1e-12) 73 | 74 | # transformer layers 75 | if self.GRU: 76 | self.layers = nn.GRU(self.dim, self.hidden_dim, self.n_layers, bidirectional=self.bidirectional, dropout=self.dropout, batch_first=True) 77 | else: 78 | self.layers = nn.LSTM(self.dim, self.hidden_dim, self.n_layers, bidirectional=self.bidirectional, dropout=self.dropout, batch_first=True) 79 | self.gru_out = nn.Linear(self.hidden_dim*(2 if self.bidirectional else 1) , self.dim, bias=True) 80 | 81 | # output layer 82 | if self.with_output: 83 | self.proj = nn.Linear(self.dim, params.n_words, bias=True) 84 | if params.share_inout_emb: 85 | self.proj.weight = self.embeddings.weight 86 | 87 | def forward(self, mode, **kwargs): 88 | """ 89 | Forward function with different forward modes. 90 | ### Small hack to handle PyTorch distributed. 91 | """ 92 | if mode == "fwd": 93 | return self.fwd(**kwargs) 94 | elif mode == "predict": 95 | return self.predict(**kwargs) 96 | else: 97 | raise Exception("Unknown mode: %s" % mode) 98 | 99 | def fwd( 100 | self, 101 | x, 102 | lengths, 103 | causal, 104 | src_enc=None, 105 | src_len=None, 106 | use_cache=False, 107 | ): 108 | """ 109 | Inputs: 110 | `x` LongTensor(slen, bs), containing word indices 111 | `lengths` LongTensor(bs), containing the length of each sentence 112 | `causal` Boolean, if True, the attention is only done over previous hidden states 113 | """ 114 | # lengths = (x != self.pad_index).float().sum(dim=1) 115 | # mask = x != self.pad_index 116 | 117 | # check inputs 118 | slen, bs = x.size() 119 | assert lengths.size(0) == bs 120 | assert lengths.max().item() <= slen 121 | x = x.transpose(0, 1) # batch size as dimension 0 122 | if src_enc is not None: 123 | assert self.is_decoder 124 | #print(np.shape(src_enc)) 125 | #assert src_enc.size(1) == bs 126 | 127 | # generate masks 128 | mask = get_masks(slen, lengths) 129 | 130 | # embeddings 131 | tensor = self.embeddings(x) 132 | tensor = self.layer_norm_emb(tensor) 133 | 134 | tensor = F.dropout(tensor, p=self.dropout, training=self.training) 135 | tensor *= mask.unsqueeze(-1).to(tensor.dtype) 136 | 137 | # transformer layers 138 | tensor, hidden = self.layers.forward(tensor, src_enc) 139 | tensor = self.gru_out(tensor) 140 | 141 | # move back sequence length to dimension 0 142 | tensor = tensor.transpose(0, 1) 143 | 144 | return tensor, hidden 145 | 146 | 147 | def predict(self, tensor, pred_mask, y, get_scores): 148 | """ 149 | Given the last hidden state, compute word scores and/or the loss. 150 | `pred_mask` is a ByteTensor of shape (slen, bs), filled with 1 when 151 | we need to predict a word 152 | `y` is a LongTensor of shape (pred_mask.sum(),) 153 | `get_scores` is a boolean specifying whether we need to return scores 154 | """ 155 | x = tensor[pred_mask.unsqueeze(-1).expand_as(tensor)].view(-1, self.dim) 156 | assert (y == self.pad_index).sum().item() == 0 157 | # print(np.shape(x)) 158 | scores = self.proj(x).view(-1, self.n_words) 159 | # print(np.shape(scores)) 160 | loss = F.cross_entropy(scores.float(), y, reduction="mean") 161 | return scores, loss 162 | 163 | def generate(self, src_enc, src_len, max_len=200, sample_temperature=None): 164 | """ 165 | Decode a sentence given initial start. 166 | `x`: 167 | - LongTensor(bs, slen) 168 | W1 W2 W3 169 | W1 W2 W3 W4 170 | `lengths`: 171 | - LongTensor(bs) [5, 6] 172 | """ 173 | 174 | # input batch 175 | if self.GRU: 176 | bs = src_enc.size(1) 177 | else: 178 | bs = src_enc[0].size(1) 179 | 180 | 181 | # generated sentences 182 | generated = src_len.new(max_len, bs) # upcoming output 183 | generated.fill_(self.pad_index) # fill upcoming ouput with 184 | generated[0].fill_(self.eos_index) # we use for everywhere 185 | 186 | # current position / max lengths / length of generated sentences / unfinished sentences 187 | cur_len = 1 188 | gen_len = src_len.clone().fill_(1) 189 | unfinished_sents = src_len.clone().fill_(1) 190 | 191 | # cache compute states 192 | self.cache = {"slen": 0} 193 | 194 | while cur_len < max_len: 195 | 196 | # compute word scores 197 | tensor, _ = self.forward( 198 | "fwd", 199 | x=generated[:cur_len], 200 | lengths=gen_len.new(bs).fill_(cur_len), 201 | causal=True, 202 | src_enc=src_enc, 203 | ) 204 | tensor = tensor[-1:,:,:] 205 | assert tensor.size() == (1, bs, self.dim), tensor.size() 206 | tensor = tensor.data[-1, :, :] #.to(self.dtype) # (bs, dim) 207 | scores = self.proj(tensor) # (bs, n_words) 208 | 209 | # select next words: sample or greedy 210 | if sample_temperature is None: 211 | next_words = torch.topk(scores, 1)[1].squeeze(1) 212 | else: 213 | next_words = torch.multinomial( 214 | F.softmax(scores.float() / sample_temperature, dim=1), 1 215 | ).squeeze(1) 216 | assert next_words.size() == (bs,) 217 | 218 | # update generations / lengths / finished sentences / current length 219 | generated[cur_len] = next_words * unfinished_sents + self.pad_index * ( 220 | 1 - unfinished_sents 221 | ) 222 | gen_len.add_(unfinished_sents) 223 | unfinished_sents.mul_(next_words.ne(self.eos_index).long()) 224 | cur_len = cur_len + 1 225 | 226 | # stop when there is a in each sentence, or if we exceed the maximul length 227 | if unfinished_sents.max() == 0: 228 | break 229 | 230 | # add to unfinished sentences 231 | if cur_len == max_len: 232 | generated[-1].masked_fill_(unfinished_sents.byte(), self.eos_index) 233 | 234 | # sanity check 235 | assert (generated == self.eos_index).sum() == 2 * bs 236 | 237 | return generated[:cur_len].cpu(), gen_len.cpu() 238 | 239 | def generate_beam( 240 | self, src_enc, src_len, beam_size, length_penalty, early_stopping, max_len=200 241 | ): 242 | """ 243 | Decode a sentence given initial start. 244 | `x`: 245 | - LongTensor(bs, slen) 246 | W1 W2 W3 247 | W1 W2 W3 W4 248 | `lengths`: 249 | - LongTensor(bs) [5, 6] 250 | """ 251 | 252 | # check inputs 253 | #assert src_enc.size(0) == src_len.size(0) 254 | assert beam_size == 1 255 | 256 | # batch size / number of words 257 | n_words = self.n_words 258 | if self.GRU: 259 | bs = src_enc.size(1) 260 | #src_enc = ( 261 | # src_enc.unsqueeze(1) 262 | # .expand((bs, beam_size) + src_enc.shape[1:]) 263 | # .contiguous() 264 | # .view((bs * beam_size,) + src_enc.shape[1:]) 265 | #) 266 | else: 267 | bs = src_enc[0].size(1) 268 | 269 | # expand to beam size the source latent representations / source lengths 270 | #src_enc = ( 271 | # src_enc[0].unsqueeze(1) 272 | # .expand((bs, beam_size) + src_enc[0].shape[1:]) 273 | # .contiguous() 274 | # .view((bs * beam_size,) + src_enc[0].shape[1:]), 275 | # src_enc[1].unsqueeze(1) 276 | # .expand((bs, beam_size) + src_enc[1].shape[1:]) 277 | # .contiguous() 278 | # .view((bs * beam_size,) + src_enc[1].shape[1:]) 279 | #) 280 | src_len = src_len.unsqueeze(1).expand(bs, beam_size).contiguous().view(-1) 281 | 282 | # generated sentences (batch with beam current hypotheses) 283 | generated = src_len.new(max_len, bs * beam_size) # upcoming output 284 | generated.fill_(self.pad_index) # fill upcoming ouput with 285 | generated[0].fill_(self.eos_index) # we use for everywhere 286 | 287 | # generated hypotheses 288 | generated_hyps = [ 289 | BeamHypotheses(beam_size, max_len, length_penalty, early_stopping) 290 | for _ in range(bs) 291 | ] 292 | 293 | # scores for each sentence in the beam 294 | beam_scores = src_len.new(bs, beam_size).float().fill_(0) 295 | beam_scores[:, 1:] = -1e9 296 | beam_scores = beam_scores.view(-1) 297 | 298 | # current position 299 | cur_len = 1 300 | 301 | # cache compute states 302 | self.cache = {"slen": 0} 303 | 304 | # done sentences 305 | done = [False for _ in range(bs)] 306 | 307 | while cur_len < max_len: 308 | 309 | # compute word scores 310 | tensor, _ = self.forward( 311 | "fwd", 312 | x=generated[:cur_len], 313 | lengths=src_len.new(bs * beam_size).fill_(cur_len), 314 | causal=True, 315 | src_enc=src_enc, 316 | ) 317 | tensor = tensor[-1:,:,:] 318 | assert tensor.size() == (1, bs * beam_size, self.dim), tensor.size() 319 | tensor = tensor.data[-1, :, :] # .to(self.dtype) # (bs * beam_size, dim) 320 | scores = self.proj(tensor) # (bs * beam_size, n_words) 321 | scores = F.log_softmax(scores.float(), dim=-1) # (bs * beam_size, n_words) 322 | assert scores.size() == (bs * beam_size, n_words) 323 | 324 | # select next words with scores 325 | _scores = scores + beam_scores[:, None].expand_as( 326 | scores 327 | ) # (bs * beam_size, n_words) 328 | _scores = _scores.view(bs, beam_size * n_words) # (bs, beam_size * n_words) 329 | 330 | next_scores, next_words = torch.topk( 331 | _scores, 2 * beam_size, dim=1, largest=True, sorted=True 332 | ) 333 | assert next_scores.size() == next_words.size() == (bs, 2 * beam_size) 334 | 335 | # next batch beam content 336 | # list of (bs * beam_size) tuple(next hypothesis score, next word, current position in the batch) 337 | next_batch_beam = [] 338 | 339 | # for each sentence 340 | for sent_id in range(bs): 341 | 342 | # if we are done with this sentence 343 | done[sent_id] = done[sent_id] or generated_hyps[sent_id].is_done( 344 | next_scores[sent_id].max().item() 345 | ) 346 | if done[sent_id]: 347 | next_batch_beam.extend( 348 | [(0, self.pad_index, 0)] * beam_size 349 | ) # pad the batch 350 | continue 351 | 352 | # next sentence beam content 353 | next_sent_beam = [] 354 | 355 | # next words for this sentence 356 | for idx, value in zip(next_words[sent_id], next_scores[sent_id]): 357 | 358 | # get beam and word IDs 359 | beam_id = idx // n_words 360 | word_id = idx % n_words 361 | 362 | # end of sentence, or next word 363 | if word_id == self.eos_index or cur_len + 1 == max_len: 364 | generated_hyps[sent_id].add( 365 | generated[:cur_len, sent_id * beam_size + beam_id] 366 | .clone() 367 | .cpu(), 368 | value.item(), 369 | ) 370 | else: 371 | next_sent_beam.append( 372 | (value, word_id, sent_id * beam_size + beam_id) 373 | ) 374 | 375 | # the beam for next step is full 376 | if len(next_sent_beam) == beam_size: 377 | break 378 | 379 | # update next beam content 380 | assert len(next_sent_beam) == 0 if cur_len + 1 == max_len else beam_size 381 | if len(next_sent_beam) == 0: 382 | next_sent_beam = [ 383 | (0, self.pad_index, 0) 384 | ] * beam_size # pad the batch 385 | next_batch_beam.extend(next_sent_beam) 386 | assert len(next_batch_beam) == beam_size * (sent_id + 1) 387 | 388 | # sanity check / prepare next batch 389 | assert len(next_batch_beam) == bs * beam_size 390 | beam_scores = beam_scores.new([x[0] for x in next_batch_beam]) 391 | beam_words = generated.new([x[1] for x in next_batch_beam]) 392 | beam_idx = src_len.new([x[2] for x in next_batch_beam]) 393 | 394 | # re-order batch and internal states 395 | generated = generated[:, beam_idx] 396 | generated[cur_len] = beam_words 397 | for k in self.cache.keys(): 398 | if k != "slen": 399 | self.cache[k] = ( 400 | self.cache[k][0][beam_idx], 401 | self.cache[k][1][beam_idx], 402 | ) 403 | 404 | # update current length 405 | cur_len = cur_len + 1 406 | 407 | # stop when we are done with each sentence 408 | if all(done): 409 | break 410 | 411 | # def get_coeffs(s): 412 | # roots = [int(s[i + 2]) for i, c in enumerate(s) if c == 'x'] 413 | # poly = np.poly1d(roots, r=True) 414 | # coeffs = list(poly.coefficients.astype(np.int64)) 415 | # return [c % 10 for c in coeffs], coeffs 416 | 417 | # visualize hypotheses 418 | # print([len(x) for x in generated_hyps], cur_len) 419 | # globals().update( locals() ); 420 | # !import code; code.interact(local=vars()) 421 | # for ii in range(bs): 422 | # for ss, ww in sorted(generated_hyps[ii].hyp, key=lambda x: x[0], reverse=True): 423 | # hh = " ".join(self.id2word[x] for x in ww.tolist()) 424 | # print(f"{ss:+.4f} {hh}") 425 | # # cc = get_coeffs(hh[4:]) 426 | # # print(f"{ss:+.4f} {hh} || {cc[0]} || {cc[1]}") 427 | # print("") 428 | 429 | # select the best hypotheses 430 | tgt_len = src_len.new(bs) 431 | best = [] 432 | 433 | for i, hypotheses in enumerate(generated_hyps): 434 | best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1] 435 | tgt_len[i] = len(best_hyp) + 1 # +1 for the symbol 436 | best.append(best_hyp) 437 | 438 | # generate target batch 439 | decoded = src_len.new(tgt_len.max().item(), bs).fill_(self.pad_index) 440 | for i, hypo in enumerate(best): 441 | decoded[: tgt_len[i] - 1, i] = hypo 442 | decoded[tgt_len[i] - 1, i] = self.eos_index 443 | 444 | # sanity check 445 | assert (decoded == self.eos_index).sum() == 2 * bs 446 | 447 | return decoded, tgt_len, generated_hyps 448 | -------------------------------------------------------------------------------- /src/model/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from logging import getLogger 9 | import math 10 | import itertools 11 | import numpy as np 12 | import torch 13 | # from torch._C import _set_backcompat_keepdim_warn 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | 17 | 18 | N_MAX_POSITIONS = 4096 # maximum input sequence length 19 | 20 | 21 | logger = getLogger() 22 | 23 | 24 | def Embedding(num_embeddings, embedding_dim, padding_idx=None): 25 | m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) 26 | nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) 27 | if padding_idx is not None: 28 | nn.init.constant_(m.weight[padding_idx], 0) 29 | return m 30 | 31 | 32 | def create_sinusoidal_embeddings(n_pos, dim, out): 33 | position_enc = np.array( 34 | [ 35 | [pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] 36 | for pos in range(n_pos) 37 | ] 38 | ) 39 | out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) 40 | out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) 41 | out.detach_() 42 | out.requires_grad = False 43 | 44 | 45 | def gelu(x): 46 | """ 47 | GELU activation 48 | https://arxiv.org/abs/1606.08415 49 | """ 50 | # return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 51 | return 0.5 * x * (1.0 + torch.erf(x / math.sqrt(2.0))) 52 | 53 | 54 | def get_masks(slen, lengths, causal): 55 | """ 56 | Generate hidden states mask, and optionally an attention mask. 57 | """ 58 | assert lengths.max().item() <= slen 59 | bs = lengths.size(0) 60 | alen = torch.arange(slen, dtype=torch.long, device=lengths.device) 61 | mask = alen < lengths[:, None] 62 | 63 | # attention mask is the same as mask, or triangular inferior attention (causal) 64 | if causal: 65 | attn_mask = alen[None, None, :].repeat(bs, slen, 1) <= alen[None, :, None] 66 | else: 67 | attn_mask = mask 68 | 69 | # sanity check 70 | assert mask.size() == (bs, slen) 71 | assert causal is False or attn_mask.size() == (bs, slen, slen) 72 | 73 | return mask, attn_mask 74 | 75 | 76 | class MultiHeadAttention(nn.Module): 77 | 78 | NEW_ID = itertools.count() 79 | 80 | def __init__(self, n_heads, dim, src_dim, dropout, normalized_attention, xav_init=False): 81 | super().__init__() 82 | self.layer_id = next(MultiHeadAttention.NEW_ID) 83 | self.dim = dim 84 | self.src_dim = src_dim 85 | self.n_heads = n_heads 86 | self.dropout = dropout 87 | self.normalized_attention = normalized_attention 88 | assert self.dim % self.n_heads == 0 89 | 90 | self.q_lin = nn.Linear(dim, dim) 91 | self.k_lin = nn.Linear(src_dim, dim) 92 | self.v_lin = nn.Linear(src_dim, dim) 93 | self.out_lin = nn.Linear(dim, dim) 94 | if self.normalized_attention: 95 | self.attention_scale = nn.Parameter( 96 | torch.tensor(1.0 / math.sqrt(dim // n_heads)) 97 | ) 98 | if xav_init: 99 | gain = (1 / math.sqrt(2)) if self.src_dim == self.dim else 1.0 100 | nn.init.xavier_uniform_(self.q_lin.weight, gain=gain) 101 | nn.init.xavier_uniform_(self.k_lin.weight, gain=gain) 102 | nn.init.xavier_uniform_(self.v_lin.weight, gain=gain) 103 | nn.init.xavier_uniform_(self.out_lin.weight) 104 | nn.init.constant_(self.out_lin.bias, 0.0) 105 | 106 | def forward(self, input, mask, kv=None, use_cache=False, first_loop=True): 107 | """ 108 | Self-attention (if kv is None) 109 | or attention over source sentence (provided by kv). 110 | Input is (bs, qlen, dim) 111 | Mask is (bs, klen) (non-causal) or (bs, klen, klen) 112 | """ 113 | assert not (use_cache and self.cache is None) 114 | bs, qlen, dim = input.size() 115 | if kv is None: 116 | klen = qlen if not use_cache else self.cache["slen"] + qlen 117 | else: 118 | klen = kv.size(1) 119 | assert dim == self.dim, "Dimensions do not match: %s input vs %s configured" % ( 120 | dim, 121 | self.dim, 122 | ) 123 | n_heads = self.n_heads 124 | dim_per_head = dim // n_heads 125 | mask_reshape = (bs, 1, qlen, klen) if mask.dim() == 3 else (bs, 1, 1, klen) 126 | 127 | def shape(x): 128 | """ projection """ 129 | return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2) 130 | 131 | def unshape(x): 132 | """ compute context """ 133 | return ( 134 | x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head) 135 | ) 136 | 137 | q = shape(self.q_lin(input)) # (bs, n_heads, qlen, dim_per_head) 138 | if kv is None: 139 | k = shape(self.k_lin(input)) # (bs, n_heads, qlen, dim_per_head) 140 | v = shape(self.v_lin(input)) # (bs, n_heads, qlen, dim_per_head) 141 | elif not use_cache or self.layer_id not in self.cache: 142 | k = v = kv 143 | k = shape(self.k_lin(k)) # (bs, n_heads, qlen, dim_per_head) 144 | v = shape(self.v_lin(v)) # (bs, n_heads, qlen, dim_per_head) 145 | 146 | if use_cache: 147 | if self.layer_id in self.cache: 148 | if kv is None and first_loop: 149 | k_, v_ = self.cache[self.layer_id] 150 | k = torch.cat([k_, k], dim=2) # (bs, n_heads, klen, dim_per_head) 151 | v = torch.cat([v_, v], dim=2) # (bs, n_heads, klen, dim_per_head) 152 | else: 153 | k, v = self.cache[self.layer_id] 154 | self.cache[self.layer_id] = (k, v) 155 | if self.normalized_attention: 156 | q = F.normalize(q, p=2, dim=-1) 157 | k = F.normalize(k, p=2, dim=-1) 158 | q = q * self.attention_scale 159 | else: 160 | q = q / math.sqrt(dim_per_head) # (bs, n_heads, qlen, dim_per_head) 161 | scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen) 162 | mask = ( 163 | (mask == 0).view(mask_reshape).expand_as(scores) 164 | ) # (bs, n_heads, qlen, klen) 165 | scores.masked_fill_(mask, -float("inf")) # (bs, n_heads, qlen, klen) 166 | 167 | weights = F.softmax(scores.float(), dim=-1).type_as( 168 | scores 169 | ) # (bs, n_heads, qlen, klen) 170 | weights = F.dropout( 171 | weights, p=self.dropout, training=self.training 172 | ) # (bs, n_heads, qlen, klen) 173 | context = torch.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head) 174 | context = unshape(context) # (bs, qlen, dim) 175 | 176 | if TransformerModel.STORE_OUTPUTS and not self.training: 177 | self.outputs = weights.detach().cpu() 178 | 179 | return self.out_lin(context) 180 | 181 | 182 | class TransformerFFN(nn.Module): 183 | def __init__(self, in_dim, dim_hidden, out_dim, hidden_layers, dropout, gelu_activation=False, xav_init=False): 184 | super().__init__() 185 | self.dropout = dropout 186 | self.hidden_layers = hidden_layers 187 | self.act = gelu if gelu_activation else F.relu 188 | self.midlin = nn.ModuleList() 189 | self.lin1 = nn.Linear(in_dim, dim_hidden) 190 | for i in range(1, self.hidden_layers): 191 | self.midlin.append(nn.Linear(dim_hidden, dim_hidden)) 192 | self.lin2 = nn.Linear(dim_hidden, out_dim) 193 | if xav_init: 194 | nn.init.xavier_uniform_(self.lin1.weight) 195 | nn.init.constant_(self.lin1.bias, 0.0) 196 | for mlin in self.midlin: 197 | nn.init.xavier_uniform_(mlin.weight) 198 | nn.init.constant_(mlin.bias, 0.0) 199 | nn.init.xavier_uniform_(self.lin2.weight) 200 | nn.init.constant_(self.lin2.bias, 0.0) 201 | 202 | def forward(self, input): 203 | x = self.lin1(input) 204 | x = self.act(x) 205 | x = F.dropout(x, p=self.dropout, training=self.training) 206 | for mlin in self.midlin: 207 | x = mlin(x) 208 | x = self.act(x) 209 | x = F.dropout(x, p=self.dropout, training=self.training) 210 | x = self.lin2(x) 211 | return x 212 | 213 | 214 | class Gate(nn.Module): 215 | def __init__(self, dimension, scalar, dropout, biased_gates, gate_bias): 216 | super().__init__() 217 | self.dropout = dropout 218 | self.gate1 = nn.Linear(dimension, 4 * dimension) 219 | self.gate2 = nn.Linear(4 * dimension, 1 if scalar else dimension) 220 | if biased_gates: 221 | self.gate2.bias.data.fill_(gate_bias) 222 | 223 | def forward(self, x): 224 | outp = self.gate1(x) 225 | outp = F.relu(outp) 226 | outp = F.dropout(outp, p=self.dropout, training=self.training) 227 | outp = self.gate2(outp) 228 | return torch.sigmoid(outp) 229 | 230 | 231 | class TransformerLayer(nn.Module): 232 | def __init__(self, params, is_encoder, gated=False): 233 | """ 234 | Transformer model (encoder or decoder). 235 | """ 236 | super().__init__() 237 | 238 | self.is_encoder = is_encoder 239 | self.is_decoder = not is_encoder 240 | 241 | # model parameters 242 | self.dim = params.enc_emb_dim if is_encoder else params.dec_emb_dim # 512 by default 243 | self.src_dim = params.enc_emb_dim 244 | self.hidden_dim = self.dim * 4 # 2048 by default 245 | self.n_hidden_layers = params.n_enc_hidden_layers if is_encoder else params.n_dec_hidden_layers 246 | self.n_heads = params.n_enc_heads if is_encoder else params.n_dec_heads # 8 by default 247 | self.n_layers = params.n_enc_layers if is_encoder else params.n_dec_layers 248 | self.dropout = params.dropout 249 | self.attention_dropout = params.attention_dropout 250 | self.gated = gated 251 | self.scalar_gate = params.scalar_gate 252 | 253 | assert ( 254 | self.dim % self.n_heads == 0 255 | ), "transformer dim must be a multiple of n_heads" 256 | self.self_attention = MultiHeadAttention( 257 | self.n_heads, 258 | self.dim, 259 | self.dim, 260 | dropout=self.attention_dropout, 261 | normalized_attention=params.norm_attention, 262 | ) 263 | self.layer_norm1 = nn.LayerNorm(self.dim, eps=1e-12) 264 | if self.is_decoder: 265 | self.layer_norm15 = nn.LayerNorm(self.dim, eps=1e-12) 266 | self.cross_attention = MultiHeadAttention( 267 | self.n_heads, 268 | self.dim, 269 | self.src_dim, 270 | dropout=self.attention_dropout, 271 | normalized_attention=params.norm_attention, 272 | ) 273 | self.ffn = TransformerFFN( 274 | self.dim, 275 | self.hidden_dim, 276 | self.dim, 277 | self.n_hidden_layers, 278 | dropout=self.dropout, 279 | gelu_activation = params.gelu_activation 280 | ) 281 | self.layer_norm2 = nn.LayerNorm(self.dim, eps=1e-12) 282 | if self.gated: 283 | self.gate = Gate(self.dim, self.scalar_gate, self.dropout, params.biased_gates, params.gate_bias) 284 | 285 | def forward(self, x, attn_mask, src_mask, src_enc, use_cache=False, cache=None, loop_count=1): 286 | tensor = x 287 | for i in range(loop_count): 288 | # self attention 289 | self.self_attention.cache = cache 290 | attn = self.self_attention(tensor, attn_mask, use_cache=use_cache, first_loop=(i==0)) 291 | attn = F.dropout(attn, p=self.dropout, training=self.training) 292 | output = tensor + attn 293 | output = self.layer_norm1(output) 294 | 295 | if self.gated: 296 | gate = self.gate(output) 297 | 298 | # encoder attention (for decoder only) 299 | if self.is_decoder and src_enc is not None: 300 | self.cross_attention.cache = cache 301 | attn = self.cross_attention( 302 | tensor, src_mask, kv=src_enc, use_cache=use_cache, first_loop=(i==0) 303 | ) 304 | attn = F.dropout(attn, p=self.dropout, training=self.training) 305 | output = output + attn 306 | output = self.layer_norm15(output) 307 | 308 | # FFN 309 | output = output + self.ffn(output) 310 | output = self.layer_norm2(output) 311 | if self.gated: 312 | tensor = gate * output + (1 - gate) * tensor 313 | else: 314 | tensor = output 315 | return tensor 316 | 317 | 318 | class AdaptiveHalt(nn.Module): 319 | def __init__(self, params, is_encoder, gated): 320 | super().__init__() 321 | self.dim = params.enc_emb_dim if is_encoder else params.dec_emb_dim # 512 by default 322 | self.max_loops = params.enc_loops if is_encoder else params.dec_loops 323 | assert params.act_threshold >= 0 324 | self.threshold = 1.0 - params.act_threshold 325 | self.halt_prob = nn.Linear(self.dim, 1) 326 | if params.act_biased: 327 | self.halt_prob.bias.data.fill_(params.act_bias) 328 | ponder = params.act_ponder_coupling 329 | self.ponder_coupling = ponder 330 | self.ponder_penalty = 0 331 | self.layer = TransformerLayer(params, is_encoder, gated) 332 | 333 | def forward(self, input, attn_mask, src_mask, src_enc, use_cache, cache, loop_count): 334 | bs = input.size(0) 335 | slen = input.size(1) 336 | shape = (bs, slen) 337 | halting_probability = torch.zeros(shape, device=input.device) 338 | remainders = torch.zeros_like(halting_probability) 339 | # n_updates = torch.zeros_like(halting_probability) 340 | acc_state = 0 341 | 342 | for i in range(self.max_loops): 343 | # stop probability for current state 344 | p = torch.squeeze(torch.sigmoid(self.halt_prob(input)), -1) 345 | # running tokens at step start 346 | still_running = torch.less(halting_probability, 1.0).float() 347 | # stopping this step 348 | new_halted = torch.greater(halting_probability + p * still_running, self.threshold).float() * still_running 349 | # running at step end 350 | still_running = torch.less_equal(halting_probability + p * still_running, self.threshold).float() * still_running 351 | 352 | halting_probability += p * still_running 353 | # R(t) in ACT paper 354 | remainders += new_halted * (1 - halting_probability) 355 | halting_probability += new_halted * remainders 356 | # N(t) in ACT paper (unused) 357 | # n_updates += still_running + new_halted 358 | # update state 359 | input = self.layer.forward(input, attn_mask, src_mask, src_enc, use_cache, cache, loop_count) 360 | # weighted final state 361 | update_weights = torch.unsqueeze(p * still_running + new_halted * remainders, -1) 362 | acc_state = (input * update_weights) + (acc_state * (1 - update_weights)) 363 | if still_running.sum() == 0: 364 | break 365 | 366 | remainders += torch.less(halting_probability, 1.0).float() * (1 - halting_probability) 367 | self.ponder_penalty = self.ponder_coupling * torch.mean(remainders) 368 | return acc_state 369 | 370 | 371 | 372 | class TransformerModel(nn.Module): 373 | 374 | STORE_OUTPUTS = False 375 | 376 | def __init__(self, params, id2word, is_encoder, with_output): 377 | """ 378 | Transformer model (encoder or decoder). 379 | """ 380 | super().__init__() 381 | 382 | # encoder / decoder, output layer 383 | self.dtype = torch.half if params.fp16 else torch.float 384 | self.is_encoder = is_encoder 385 | self.is_decoder = not is_encoder 386 | self.with_output = with_output 387 | self.decoder_only = params.architecture == "decoder_only" 388 | 389 | self.xav_init = params.xav_init 390 | 391 | # dictionary 392 | self.n_words = params.n_words 393 | self.eos_index = params.eos_index 394 | self.pad_index = params.pad_index 395 | self.sep_index = params.sep_index 396 | 397 | self.id2word = id2word 398 | assert len(self.id2word) == self.n_words 399 | 400 | # model parameters 401 | self.dim = params.enc_emb_dim if is_encoder else params.dec_emb_dim # 512 by default 402 | self.src_dim = params.enc_emb_dim 403 | self.max_src_len = params.max_src_len 404 | self.hidden_dim = self.dim * 4 # 2048 by default 405 | self.n_hidden_layers = params.n_enc_hidden_layers if is_encoder else params.n_dec_hidden_layers 406 | self.n_heads = params.n_enc_heads if is_encoder else params.n_dec_heads # 8 by default 407 | self.n_layers = params.n_enc_layers if is_encoder else params.n_dec_layers 408 | self.has_pos_emb = params.enc_has_pos_emb if is_encoder else params.dec_has_pos_emb 409 | self.dropout = params.dropout 410 | self.attention_dropout = params.attention_dropout 411 | self.norm_attention = params.norm_attention 412 | assert ( 413 | self.dim % self.n_heads == 0 414 | ), "transformer dim must be a multiple of n_heads" 415 | 416 | # iteration 417 | self.loop_idx = params.enc_loop_idx if is_encoder else params.dec_loop_idx 418 | assert self.loop_idx < self.n_layers, "loop idx must be lower than nr of layers" 419 | self.loops = params.enc_loops if is_encoder else params.dec_loops 420 | 421 | self.act = params.enc_act if is_encoder else params.dec_act 422 | assert (not self.act) or (self.loop_idx >= 0) 423 | 424 | # embeddings 425 | if self.has_pos_emb: 426 | self.position_embeddings = Embedding(N_MAX_POSITIONS, self.dim) 427 | if params.sinusoidal_embeddings: 428 | create_sinusoidal_embeddings( 429 | N_MAX_POSITIONS, self.dim, out=self.position_embeddings.weight 430 | ) 431 | self.embeddings = Embedding(self.n_words, self.dim, padding_idx=self.pad_index) 432 | self.layer_norm_emb = nn.LayerNorm(self.dim, eps=1e-12) 433 | 434 | # transformer layers 435 | self.layers = nn.ModuleList() 436 | for layer_id in range(self.n_layers): 437 | if params.enc_gated and self.is_encoder: 438 | gated = True 439 | elif params.dec_gated and self.is_decoder: 440 | gated = True 441 | elif params.gated and layer_id == self.loop_idx: 442 | gated = True 443 | else: 444 | gated = False 445 | if self.act and layer_id == self.loop_idx: 446 | self.layers.append(AdaptiveHalt(params, self.is_encoder, gated)) 447 | else: 448 | self.layers.append(TransformerLayer(params, self.is_encoder, gated)) 449 | 450 | self.cache = None 451 | 452 | # output layer 453 | if self.with_output: 454 | self.proj = nn.Linear(self.dim, params.n_words, bias=True) 455 | if self.xav_init: 456 | nn.init.xavier_uniform_(self.proj.weight) 457 | nn.init.constant_(self.proj.bias, 0.0) 458 | if params.share_inout_emb: 459 | self.proj.weight = self.embeddings.weight 460 | 461 | def forward(self, mode, **kwargs): 462 | """ 463 | Forward function with different forward modes. 464 | ### Small hack to handle PyTorch distributed. 465 | """ 466 | if mode == "fwd": 467 | return self.fwd(**kwargs) 468 | elif mode == "predict": 469 | return self.predict(**kwargs) 470 | else: 471 | raise Exception("Unknown mode: %s" % mode) 472 | 473 | def fwd( 474 | self, 475 | x, 476 | lengths, 477 | causal, 478 | src_enc=None, 479 | src_len=None, 480 | positions=None, 481 | use_cache=False, 482 | ): 483 | """ 484 | Inputs: 485 | `x` LongTensor(slen, bs), containing word indices 486 | `lengths` LongTensor(bs), containing the length of each sentence 487 | `causal` Boolean, if True, the attention is only done over previous hidden states 488 | `positions` LongTensor(slen, bs), containing word positions 489 | """ 490 | # lengths = (x != self.pad_index).float().sum(dim=1) 491 | # mask = x != self.pad_index 492 | 493 | # check inputs 494 | slen, bs = x.size() 495 | assert lengths.size(0) == bs 496 | assert lengths.max().item() <= slen 497 | x = x.transpose(0, 1) # batch size as dimension 0 498 | assert (src_enc is None) == (src_len is None) 499 | if src_enc is not None: 500 | assert self.is_decoder 501 | assert src_enc.size(0) == bs 502 | assert not (use_cache and self.cache is None) 503 | 504 | # generate masks 505 | mask, attn_mask = get_masks(slen, lengths, causal) 506 | src_mask = None 507 | if self.is_decoder and (src_enc is not None): 508 | if self.max_src_len > 0: 509 | src_mask = ( 510 | torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) 511 | < torch.clamp(src_len[:, None], max=self.max_src_len) 512 | ) 513 | else: 514 | src_mask = ( 515 | torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) 516 | < src_len[:, None] 517 | ) 518 | 519 | # positions 520 | if positions is None: 521 | positions = x.new(slen).long() 522 | positions = torch.arange(slen, out=positions).unsqueeze(0) 523 | else: 524 | assert positions.size() == (slen, bs) 525 | positions = positions.transpose(0, 1) 526 | 527 | # do not recompute cached elements 528 | if use_cache: 529 | _slen = slen - self.cache["slen"] 530 | x = x[:, -_slen:] 531 | positions = positions[:, -_slen:] 532 | mask = mask[:, -_slen:] 533 | attn_mask = attn_mask[:, -_slen:] 534 | 535 | # all layer outputs 536 | if TransformerModel.STORE_OUTPUTS and not self.training: 537 | self.outputs = [] 538 | 539 | # embeddings 540 | tensor = self.embeddings(x) 541 | if self.has_pos_emb: 542 | tensor = tensor + self.position_embeddings(positions).expand_as(tensor) 543 | tensor = self.layer_norm_emb(tensor) 544 | tensor = F.dropout(tensor, p=self.dropout, training=self.training) 545 | tensor *= mask.unsqueeze(-1).to(tensor.dtype) 546 | if TransformerModel.STORE_OUTPUTS and not self.training: 547 | self.outputs.append(tensor.detach().cpu()) 548 | 549 | # transformer layers 550 | for i in range(self.n_layers): 551 | loops = 1 552 | if self.loop_idx == -2 or self.loop_idx == i: 553 | loops = self.loops 554 | tensor = self.layers[i].forward(tensor, attn_mask, src_mask, src_enc, use_cache=use_cache, cache=self.cache, loop_count=loops) 555 | 556 | tensor *= mask.unsqueeze(-1).to(tensor.dtype) 557 | if TransformerModel.STORE_OUTPUTS and not self.training: 558 | self.outputs.append(tensor.detach().cpu()) 559 | 560 | # update cache length 561 | if use_cache: 562 | self.cache["slen"] += tensor.size(1) 563 | 564 | # move back sequence length to dimension 0 565 | tensor = tensor.transpose(0, 1) 566 | 567 | return tensor 568 | 569 | def predict(self, tensor, pred_mask, y, get_scores): 570 | """ 571 | Given the last hidden state, compute word scores and/or the loss. 572 | `pred_mask` is a ByteTensor of shape (slen, bs), filled with 1 when 573 | we need to predict a word 574 | `y` is a LongTensor of shape (pred_mask.sum(),) 575 | `get_scores` is a boolean specifying whether we need to return scores 576 | """ 577 | x = tensor[pred_mask.unsqueeze(-1).expand_as(tensor)].view(-1, self.dim) 578 | assert (y == self.pad_index).sum().item() == 0 579 | scores = self.proj(x).view(-1, self.n_words) 580 | loss = F.cross_entropy(scores.float(), y, reduction="mean") 581 | return scores, loss 582 | 583 | def decode(self, src_enc, src_len, exp_len): 584 | # input batch 585 | bs = len(src_len) 586 | assert src_enc.size(1) == bs 587 | 588 | # generated sentences 589 | max_len = src_enc.size(0) 590 | 591 | tensor = self.fwd(x=src_enc, lengths=src_len, causal=False) 592 | assert tensor.size() == (max_len, bs, self.dim) 593 | scores = self.proj(tensor) # (len, bs, n_words) 594 | # select next words: sample or greedy 595 | sample_temperature = None 596 | if sample_temperature is None: 597 | next_words = torch.topk(scores, 1)[1].squeeze(2) 598 | else: 599 | next_words = torch.multinomial( 600 | F.softmax(scores.float() / sample_temperature, dim=1), 1 601 | ).squeeze(2) 602 | assert next_words.size() == (max_len, bs,) 603 | next_words = next_words[:exp_len] 604 | return next_words.transpose(0,1) 605 | 606 | 607 | def generate(self, src_enc, src_len, max_len=200, sample_temperature=None): 608 | """ 609 | Decode a sentence given initial start. 610 | `x`: 611 | - LongTensor(bs, slen) 612 | W1 W2 W3 613 | W1 W2 W3 W4 614 | `lengths`: 615 | - LongTensor(bs) [5, 6] 616 | `positions`: 617 | - False, for regular "arange" positions (LM) 618 | - True, to reset positions from the new generation (MT) 619 | """ 620 | 621 | # input batch 622 | bs = len(src_len) 623 | #assert src_enc.size(0) == bs 624 | 625 | # generated sentences 626 | generated = src_len.new(max_len, bs) # upcoming output 627 | generated.fill_(self.pad_index) # fill upcoming ouput with 628 | # current position / max lengths / length of generated sentences 629 | if self.decoder_only: 630 | max_src = src_len.max() 631 | generated[:max_src] = src_enc 632 | generated[max_src-1] = self.sep_index 633 | # cur_len = src_len.min() 634 | # gen_len = src_len.clone().fill_(cur_len) 635 | else: 636 | generated[0].fill_(self.eos_index) # we use for everywhere 637 | cur_len = 1 638 | gen_len = src_len.clone().fill_(1) 639 | 640 | # positions 641 | positions = src_len.new(max_len).long() 642 | positions = ( 643 | torch.arange(max_len, out=positions).unsqueeze(1).expand_as(generated) 644 | ) 645 | 646 | unfinished_sents = src_len.clone().fill_(1) 647 | 648 | # cache compute states 649 | self.cache = {"slen": 0} 650 | 651 | while cur_len < max_len: 652 | 653 | if self.decoder_only: 654 | # compute word scores 655 | tensor = self.forward( 656 | "fwd", 657 | x=generated[:cur_len], 658 | lengths=gen_len, 659 | positions=positions[:cur_len], 660 | causal=True, 661 | src_enc=None, 662 | src_len=None, 663 | use_cache=True, 664 | ) 665 | else: 666 | # compute word scores 667 | tensor = self.forward( 668 | "fwd", 669 | x=generated[:cur_len], 670 | lengths=gen_len, 671 | positions=positions[:cur_len], 672 | causal=True, 673 | src_enc=src_enc, 674 | src_len=src_len, 675 | use_cache=True, 676 | ) 677 | assert tensor.size() == (1, bs, self.dim) 678 | tensor = tensor.data[-1, :, :] # .to(self.dtype) # (bs, dim) 679 | scores = self.proj(tensor) # (bs, n_words) 680 | 681 | # select next words: sample or greedy 682 | if sample_temperature is None: 683 | next_words = torch.topk(scores, 1)[1].squeeze(1) 684 | else: 685 | next_words = torch.multinomial( 686 | F.softmax(scores.float() / sample_temperature, dim=1), 1 687 | ).squeeze(1) 688 | assert next_words.size() == (bs,) 689 | 690 | # update generations / lengths / finished sentences / current length 691 | if self.decoder_only: 692 | up_mask = (src_len <= cur_len) 693 | to_update = unfinished_sents * up_mask 694 | generated[cur_len] = next_words * to_update + (1 - to_update) * ( 695 | generated[cur_len] * ~up_mask + self.pad_index * up_mask 696 | ) 697 | gen_len.add_(unfinished_sents) 698 | unfinished_sents.mul_((~up_mask).long() + up_mask * next_words.ne(self.eos_index).long()) 699 | else: 700 | generated[cur_len] = next_words * unfinished_sents + self.pad_index * ( 701 | 1 - unfinished_sents 702 | ) 703 | gen_len.add_(unfinished_sents) 704 | unfinished_sents.mul_(next_words.ne(self.eos_index).long()) 705 | 706 | cur_len = cur_len + 1 707 | 708 | # stop when there is a in each sentence, or if we exceed the maximal length 709 | if unfinished_sents.max() == 0: 710 | break 711 | 712 | # add to unfinished sentences 713 | if cur_len == max_len: 714 | generated[-1].masked_fill_(unfinished_sents.byte(), self.eos_index) 715 | 716 | # sanity check 717 | assert (generated == self.eos_index).sum() == 2 * bs 718 | 719 | return generated[:cur_len].cpu(), gen_len.cpu() 720 | 721 | def generate_beam( 722 | self, src_enc, src_len, beam_size, length_penalty, early_stopping, max_len=200 723 | ): 724 | """ 725 | Decode a sentence given initial start. 726 | `x`: 727 | - LongTensor(bs, slen) 728 | W1 W2 W3 729 | W1 W2 W3 W4 730 | `lengths`: 731 | - LongTensor(bs) [5, 6] 732 | `positions`: 733 | - False, for regular "arange" positions (LM) 734 | - True, to reset positions from the new generation (MT) 735 | """ 736 | 737 | # check inputs 738 | assert src_enc.size(0) == src_len.size(0) 739 | assert beam_size >= 1 740 | 741 | # batch size / number of words 742 | bs = len(src_len) 743 | n_words = self.n_words 744 | 745 | # expand to beam size the source latent representations / source lengths 746 | src_enc = ( 747 | src_enc.unsqueeze(1) 748 | .expand((bs, beam_size) + src_enc.shape[1:]) 749 | .contiguous() 750 | .view((bs * beam_size,) + src_enc.shape[1:]) 751 | ) 752 | src_len = src_len.unsqueeze(1).expand(bs, beam_size).contiguous().view(-1) 753 | 754 | # generated sentences (batch with beam current hypotheses) 755 | generated = src_len.new(max_len, bs * beam_size) # upcoming output 756 | generated.fill_(self.pad_index) # fill upcoming ouput with 757 | generated[0].fill_(self.eos_index) # we use for everywhere 758 | 759 | # generated hypotheses 760 | generated_hyps = [ 761 | BeamHypotheses(beam_size, max_len, length_penalty, early_stopping) 762 | for _ in range(bs) 763 | ] 764 | 765 | # positions 766 | positions = src_len.new(max_len).long() 767 | positions = ( 768 | torch.arange(max_len, out=positions).unsqueeze(1).expand_as(generated) 769 | ) 770 | 771 | # scores for each sentence in the beam 772 | beam_scores = src_enc.new(bs, beam_size).float().fill_(0) 773 | beam_scores[:, 1:] = -1e9 774 | beam_scores = beam_scores.view(-1) 775 | 776 | # current position 777 | cur_len = 1 778 | 779 | # cache compute states 780 | self.cache = {"slen": 0} 781 | 782 | # done sentences 783 | done = [False for _ in range(bs)] 784 | 785 | while cur_len < max_len: 786 | 787 | # compute word scores 788 | tensor = self.forward( 789 | "fwd", 790 | x=generated[:cur_len], 791 | lengths=src_len.new(bs * beam_size).fill_(cur_len), 792 | positions=positions[:cur_len], 793 | causal=True, 794 | src_enc=src_enc, 795 | src_len=src_len, 796 | use_cache=True, 797 | ) 798 | 799 | assert tensor.size() == (1, bs * beam_size, self.dim) 800 | tensor = tensor.data[-1, :, :] # .to(self.dtype) # (bs * beam_size, dim) 801 | scores = self.proj(tensor) # (bs * beam_size, n_words) 802 | scores = F.log_softmax(scores.float(), dim=-1) # (bs * beam_size, n_words) 803 | assert scores.size() == (bs * beam_size, n_words) 804 | 805 | # select next words with scores 806 | _scores = scores + beam_scores[:, None].expand_as( 807 | scores 808 | ) # (bs * beam_size, n_words) 809 | _scores = _scores.view(bs, beam_size * n_words) # (bs, beam_size * n_words) 810 | 811 | next_scores, next_words = torch.topk( 812 | _scores, 2 * beam_size, dim=1, largest=True, sorted=True 813 | ) 814 | assert next_scores.size() == next_words.size() == (bs, 2 * beam_size) 815 | 816 | # next batch beam content 817 | # list of (bs * beam_size) tuple(next hypothesis score, next word, current position in the batch) 818 | next_batch_beam = [] 819 | 820 | # for each sentence 821 | for sent_id in range(bs): 822 | 823 | # if we are done with this sentence 824 | done[sent_id] = done[sent_id] or generated_hyps[sent_id].is_done( 825 | next_scores[sent_id].max().item() 826 | ) 827 | if done[sent_id]: 828 | next_batch_beam.extend( 829 | [(0, self.pad_index, 0)] * beam_size 830 | ) # pad the batch 831 | continue 832 | 833 | # next sentence beam content 834 | next_sent_beam = [] 835 | 836 | # next words for this sentence 837 | for idx, value in zip(next_words[sent_id], next_scores[sent_id]): 838 | 839 | # get beam and word IDs 840 | beam_id = idx // n_words 841 | word_id = idx % n_words 842 | 843 | # end of sentence, or next word 844 | if word_id == self.eos_index or cur_len + 1 == max_len: 845 | generated_hyps[sent_id].add( 846 | generated[:cur_len, sent_id * beam_size + beam_id] 847 | .clone() 848 | .cpu(), 849 | value.item(), 850 | ) 851 | else: 852 | next_sent_beam.append( 853 | (value, word_id, sent_id * beam_size + beam_id) 854 | ) 855 | 856 | # the beam for next step is full 857 | if len(next_sent_beam) == beam_size: 858 | break 859 | 860 | # update next beam content 861 | assert len(next_sent_beam) == 0 if cur_len + 1 == max_len else beam_size 862 | if len(next_sent_beam) == 0: 863 | next_sent_beam = [ 864 | (0, self.pad_index, 0) 865 | ] * beam_size # pad the batch 866 | next_batch_beam.extend(next_sent_beam) 867 | assert len(next_batch_beam) == beam_size * (sent_id + 1) 868 | 869 | # sanity check / prepare next batch 870 | assert len(next_batch_beam) == bs * beam_size 871 | beam_scores = beam_scores.new([x[0] for x in next_batch_beam]) 872 | beam_words = generated.new([x[1] for x in next_batch_beam]) 873 | beam_idx = src_len.new([x[2] for x in next_batch_beam]) 874 | 875 | # re-order batch and internal states 876 | generated = generated[:, beam_idx] 877 | generated[cur_len] = beam_words 878 | for k in self.cache.keys(): 879 | if k != "slen": 880 | self.cache[k] = ( 881 | self.cache[k][0][beam_idx], 882 | self.cache[k][1][beam_idx], 883 | ) 884 | 885 | # update current length 886 | cur_len = cur_len + 1 887 | 888 | # stop when we are done with each sentence 889 | if all(done): 890 | break 891 | 892 | # def get_coeffs(s): 893 | # roots = [int(s[i + 2]) for i, c in enumerate(s) if c == 'x'] 894 | # poly = np.poly1d(roots, r=True) 895 | # coeffs = list(poly.coefficients.astype(np.int64)) 896 | # return [c % 10 for c in coeffs], coeffs 897 | 898 | # visualize hypotheses 899 | # print([len(x) for x in generated_hyps], cur_len) 900 | # globals().update( locals() ); 901 | # !import code; code.interact(local=vars()) 902 | # for ii in range(bs): 903 | # for ss, ww in sorted(generated_hyps[ii].hyp, key=lambda x: x[0], reverse=True): 904 | # hh = " ".join(self.id2word[x] for x in ww.tolist()) 905 | # print(f"{ss:+.4f} {hh}") 906 | # # cc = get_coeffs(hh[4:]) 907 | # # print(f"{ss:+.4f} {hh} || {cc[0]} || {cc[1]}") 908 | # print("") 909 | 910 | # select the best hypotheses 911 | tgt_len = src_len.new(bs) 912 | best = [] 913 | 914 | for i, hypotheses in enumerate(generated_hyps): 915 | best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1] 916 | tgt_len[i] = len(best_hyp) + 1 # +1 for the symbol 917 | best.append(best_hyp) 918 | 919 | # generate target batch 920 | decoded = src_len.new(tgt_len.max().item(), bs).fill_(self.pad_index) 921 | for i, hypo in enumerate(best): 922 | decoded[: tgt_len[i] - 1, i] = hypo 923 | decoded[tgt_len[i] - 1, i] = self.eos_index 924 | 925 | # sanity check 926 | assert (decoded == self.eos_index).sum() == 2 * bs 927 | 928 | return decoded, tgt_len, generated_hyps 929 | 930 | 931 | class BeamHypotheses(object): 932 | def __init__(self, n_hyp, max_len, length_penalty, early_stopping): 933 | """ 934 | Initialize n-best list of hypotheses. 935 | """ 936 | self.max_len = max_len - 1 # ignoring 937 | self.length_penalty = length_penalty 938 | self.early_stopping = early_stopping 939 | self.n_hyp = n_hyp 940 | self.hyp = [] 941 | self.worst_score = 1e9 942 | 943 | def __len__(self): 944 | """ 945 | Number of hypotheses in the list. 946 | """ 947 | return len(self.hyp) 948 | 949 | def add(self, hyp, sum_logprobs): 950 | """ 951 | Add a new hypothesis to the list. 952 | """ 953 | score = sum_logprobs / len(hyp) ** self.length_penalty 954 | if len(self) < self.n_hyp or score > self.worst_score: 955 | self.hyp.append((score, hyp)) 956 | if len(self) > self.n_hyp: 957 | sorted_scores = sorted( 958 | [(s, idx) for idx, (s, _) in enumerate(self.hyp)] 959 | ) 960 | del self.hyp[sorted_scores[0][1]] 961 | self.worst_score = sorted_scores[1][0] 962 | else: 963 | self.worst_score = min(score, self.worst_score) 964 | 965 | def is_done(self, best_sum_logprobs): 966 | """ 967 | If there are enough hypotheses and that none of the hypotheses being generated 968 | can become better than the worst one in the heap, 969 | then we are done with this sentence. 970 | """ 971 | if len(self) < self.n_hyp: 972 | return False 973 | elif self.early_stopping: 974 | return True 975 | else: 976 | return ( 977 | self.worst_score 978 | >= best_sum_logprobs / self.max_len ** self.length_penalty 979 | ) 980 | -------------------------------------------------------------------------------- /src/optim.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import re 9 | import math 10 | import inspect 11 | 12 | import torch 13 | from torch import optim 14 | 15 | if not hasattr(inspect, 'getargspec'): 16 | inspect.getargspec = inspect.getfullargspec 17 | 18 | 19 | 20 | 21 | class Adam(optim.Optimizer): 22 | """ 23 | Same as https://github.com/pytorch/pytorch/blob/master/torch/optim/adam.py, 24 | without amsgrad, with step in a tensor, and states initialization in __init__. 25 | It was important to add `.item()` in `state['step'].item()`. 26 | """ 27 | 28 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 29 | if not 0.0 <= lr: 30 | raise ValueError("Invalid learning rate: {}".format(lr)) 31 | if not 0.0 <= eps: 32 | raise ValueError("Invalid epsilon value: {}".format(eps)) 33 | if not 0.0 <= betas[0] < 1.0: 34 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 35 | if not 0.0 <= betas[1] < 1.0: 36 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 37 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 38 | super().__init__(params, defaults) 39 | 40 | for group in self.param_groups: 41 | for p in group['params']: 42 | state = self.state[p] 43 | state['step'] = 0 # torch.zeros(1) 44 | state['exp_avg'] = torch.zeros_like(p.data) 45 | state['exp_avg_sq'] = torch.zeros_like(p.data) 46 | 47 | def __setstate__(self, state): 48 | super().__setstate__(state) 49 | 50 | def step(self, closure=None): 51 | """ 52 | Step. 53 | """ 54 | loss = None 55 | if closure is not None: 56 | loss = closure() 57 | 58 | for group in self.param_groups: 59 | for p in group['params']: 60 | if p.grad is None: 61 | continue 62 | grad = p.grad.data 63 | if grad.is_sparse: 64 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 65 | 66 | state = self.state[p] 67 | 68 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 69 | beta1, beta2 = group['betas'] 70 | 71 | state['step'] += 1 72 | 73 | # if group['weight_decay'] != 0: 74 | # grad.add_(group['weight_decay'], p.data) 75 | 76 | # Decay the first and second moment running average coefficient 77 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 78 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 79 | denom = exp_avg_sq.sqrt().add_(group['eps']) 80 | # denom = exp_avg_sq.sqrt().clamp_(min=group['eps']) 81 | 82 | bias_correction1 = 1 - beta1 ** state['step'] # .item() 83 | bias_correction2 = 1 - beta2 ** state['step'] # .item() 84 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 85 | 86 | if group['weight_decay'] != 0: 87 | p.data.add_(-group['weight_decay'] * group['lr'], p.data) 88 | 89 | p.data.addcdiv_(-step_size, exp_avg, denom) 90 | 91 | return loss 92 | 93 | 94 | class AdamWithWarmup(Adam): 95 | """ 96 | Adam with a warmup phase where we linearly increase the learning rate 97 | from some initial learning rate (`warmup-init-lr`) until the configured 98 | learning rate (`lr`). 99 | During warmup: 100 | lrs = torch.linspace(warmup_init_lr, lr, warmup_updates) 101 | lr = lrs[update_num] 102 | After warmup: 103 | lr = lr 104 | """ 105 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 106 | weight_decay=0, warmup_updates=4000, warmup_init_lr=1e-7): 107 | super().__init__( 108 | params, 109 | lr=warmup_init_lr, 110 | betas=betas, 111 | eps=eps, 112 | weight_decay=weight_decay, 113 | ) 114 | 115 | # linearly warmup for the first warmup_updates 116 | self.warmup_updates = warmup_updates 117 | self.warmup_init_lr = warmup_init_lr 118 | self.warmup_end_lr = lr 119 | self.lr_step = (lr - warmup_init_lr) / warmup_updates 120 | 121 | # total number of updates 122 | for param_group in self.param_groups: 123 | param_group['num_updates'] = 0 124 | 125 | def get_lr_for_step(self, num_updates): 126 | if num_updates < self.warmup_updates: 127 | return self.warmup_init_lr + num_updates * self.lr_step 128 | else: 129 | return self.warmup_end_lr 130 | 131 | def step(self, closure=None): 132 | super().step(closure) 133 | for param_group in self.param_groups: 134 | param_group['num_updates'] += 1 135 | param_group['lr'] = self.get_lr_for_step(param_group['num_updates']) 136 | 137 | 138 | 139 | class AdamInverseSqrtWithWarmup(Adam): 140 | """ 141 | Decay the LR based on the inverse square root of the update number. 142 | We also support a warmup phase where we linearly increase the learning rate 143 | from some initial learning rate (`warmup-init-lr`) until the configured 144 | learning rate (`lr`). Thereafter we decay proportional to the number of 145 | updates, with a decay factor set to align with the configured learning rate. 146 | During warmup: 147 | lrs = torch.linspace(warmup_init_lr, lr, warmup_updates) 148 | lr = lrs[update_num] 149 | After warmup: 150 | lr = decay_factor / sqrt(update_num) 151 | where 152 | decay_factor = lr * sqrt(warmup_updates) 153 | """ 154 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 155 | weight_decay=0, warmup_updates=4000, warmup_init_lr=1e-7, 156 | exp_factor=0.5): 157 | super().__init__( 158 | params, 159 | lr=warmup_init_lr, 160 | betas=betas, 161 | eps=eps, 162 | weight_decay=weight_decay, 163 | ) 164 | 165 | # linearly warmup for the first warmup_updates 166 | self.warmup_updates = warmup_updates 167 | self.warmup_init_lr = warmup_init_lr 168 | warmup_end_lr = lr 169 | self.lr_step = (warmup_end_lr - warmup_init_lr) / warmup_updates 170 | 171 | # then, decay prop. to the inverse square root of the update number 172 | self.exp_factor = exp_factor 173 | self.decay_factor = warmup_end_lr * warmup_updates ** self.exp_factor 174 | 175 | # total number of updates 176 | for param_group in self.param_groups: 177 | param_group['num_updates'] = 0 178 | 179 | def get_lr_for_step(self, num_updates): 180 | if num_updates < self.warmup_updates: 181 | return self.warmup_init_lr + num_updates * self.lr_step 182 | else: 183 | return self.decay_factor * (num_updates ** -self.exp_factor) 184 | 185 | def step(self, closure=None): 186 | super().step(closure) 187 | for param_group in self.param_groups: 188 | param_group['num_updates'] += 1 189 | param_group['lr'] = self.get_lr_for_step(param_group['num_updates']) 190 | 191 | 192 | class AdamCosineWithWarmup(Adam): 193 | """ 194 | Assign LR based on a cyclical schedule that follows the cosine function. 195 | See https://arxiv.org/pdf/1608.03983.pdf for details. 196 | We also support a warmup phase where we linearly increase the learning rate 197 | from some initial learning rate (``--warmup-init-lr``) until the configured 198 | learning rate (``--lr``). 199 | During warmup:: 200 | lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates) 201 | lr = lrs[update_num] 202 | After warmup:: 203 | lr = lr_min + 0.5*(lr_max - lr_min)*(1 + cos(t_curr / t_i)) 204 | where ``t_curr`` is current percentage of updates within the current period 205 | range and ``t_i`` is the current period range, which is scaled by ``t_mul`` 206 | after every iteration. 207 | """ 208 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 209 | weight_decay=0, warmup_updates=4000, warmup_init_lr=1e-7, 210 | min_lr=1e-9, init_period=1000000, period_mult=1, lr_shrink=0.75, 211 | lr_shrink_min=0.75, smooth=False): 212 | super().__init__( 213 | params, 214 | lr=warmup_init_lr, 215 | betas=betas, 216 | eps=eps, 217 | weight_decay=weight_decay, 218 | ) 219 | 220 | # linearly warmup for the first warmup_updates 221 | self.warmup_updates = warmup_updates 222 | self.warmup_init_lr = warmup_init_lr 223 | self.smooth = smooth 224 | warmup_end_lr = lr 225 | self.lr_step = (warmup_end_lr - warmup_init_lr) / warmup_updates 226 | 227 | # then, apply cosine scheduler 228 | self.min_lr = min_lr 229 | self.max_lr = lr 230 | self.period = init_period 231 | self.period_mult = period_mult 232 | self.lr_shrink = lr_shrink 233 | self.lr_shrink_min = lr_shrink_min 234 | 235 | assert not self.smooth or self.period_mult == 1 236 | 237 | # total number of updates 238 | for param_group in self.param_groups: 239 | param_group['num_updates'] = 0 240 | 241 | def get_lr_for_step(self, num_updates): 242 | if num_updates < self.warmup_updates: 243 | return self.warmup_init_lr + num_updates * self.lr_step 244 | else: 245 | t = num_updates - self.warmup_updates 246 | if self.period_mult == 1: 247 | if self.smooth: 248 | pid = math.floor(t / self.period - 1 / 2) 249 | else: 250 | pid = math.floor(t / self.period) 251 | t_i = self.period 252 | t_curr = t - (self.period * pid) 253 | else: 254 | pid = math.floor(math.log(1 - t / self.period * (1 - self.period_mult), self.period_mult)) 255 | t_i = self.period * (self.period_mult ** pid) 256 | t_curr = t - (1 - self.period_mult ** pid) / (1 - self.period_mult) * self.period 257 | lr_shrink = self.lr_shrink ** pid 258 | lr_shrink_min = self.lr_shrink_min ** pid 259 | min_lr = self.min_lr * lr_shrink_min 260 | max_lr = self.max_lr * lr_shrink 261 | if max_lr < min_lr: 262 | max_lr = min_lr 263 | if self.smooth: 264 | return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(2 * math.pi * t_curr / t_i)) 265 | else: 266 | return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * t_curr / t_i)) 267 | 268 | 269 | def step(self, closure=None): 270 | super().step(closure) 271 | for param_group in self.param_groups: 272 | param_group['num_updates'] += 1 273 | param_group['lr'] = self.get_lr_for_step(param_group['num_updates']) 274 | 275 | 276 | def get_optimizer(parameters, s): 277 | """ 278 | Parse optimizer parameters. 279 | Input should be of the form: 280 | - "sgd,lr=0.01" 281 | - "adagrad,lr=0.1,lr_decay=0.05" 282 | """ 283 | if "," in s: 284 | method = s[:s.find(',')] 285 | optim_params = {} 286 | for x in s[s.find(',') + 1:].split(','): 287 | split = x.split('=') 288 | assert len(split) == 2 289 | assert re.match("^[+-]?(\d+(\.\d*)?|\.\d+)$", split[1]) is not None 290 | optim_params[split[0]] = float(split[1]) 291 | else: 292 | method = s 293 | optim_params = {} 294 | 295 | 296 | if method == 'adadelta': 297 | optim_fn = optim.Adadelta 298 | elif method == 'adagrad': 299 | optim_fn = optim.Adagrad 300 | elif method == 'adam': 301 | optim_fn = Adam 302 | elif method == 'adam_warmup': 303 | optim_fn = AdamWithWarmup 304 | elif method == 'adam_inverse_sqrt': 305 | optim_fn = AdamInverseSqrtWithWarmup 306 | elif method == 'adam_cosine': 307 | optim_fn = AdamCosineWithWarmup 308 | optim_params['smooth'] = False 309 | elif method == 'adam_smooth_cosine': 310 | optim_fn = AdamCosineWithWarmup 311 | optim_params['smooth'] = True 312 | elif method == "adamw": 313 | optim_fn = optim.AdamW 314 | elif method == 'adam_torch': 315 | optim_fn = optim.Adam 316 | elif method == 'adamax': 317 | optim_fn = optim.Adamax 318 | elif method == 'asgd': 319 | optim_fn = optim.ASGD 320 | elif method == 'rmsprop': 321 | optim_fn = optim.RMSprop 322 | elif method == 'rprop': 323 | optim_fn = optim.Rprop 324 | elif method == 'sgd': 325 | optim_fn = optim.SGD 326 | assert 'lr' in optim_params 327 | else: 328 | raise Exception('Unknown optimization method: "%s"' % method) 329 | 330 | # check that we give good parameters to the optimizer 331 | expected_args = inspect.getargspec(optim_fn.__init__)[0] 332 | assert expected_args[:2] == ['self', 'params'] 333 | if "betas" in expected_args[2:]: 334 | optim_params['betas'] = (optim_params.get('beta1', 0.9), optim_params.get('beta2', 0.999)) 335 | optim_params.pop('beta1', None) 336 | optim_params.pop('beta2', None) 337 | 338 | if not all(k in expected_args[2:] for k in optim_params.keys()): 339 | raise Exception('Unexpected parameters: expected "%s", got "%s"' % ( 340 | str(expected_args[2:]), str(optim_params.keys()))) 341 | 342 | return optim_fn(parameters, **optim_params) 343 | -------------------------------------------------------------------------------- /src/slurm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from logging import getLogger 9 | import os 10 | import sys 11 | import torch 12 | import socket 13 | import signal 14 | import subprocess 15 | 16 | 17 | logger = getLogger() 18 | 19 | 20 | def sig_handler(signum, frame): 21 | logger.warning("Signal handler called with signal " + str(signum)) 22 | prod_id = int(os.environ['SLURM_PROCID']) 23 | logger.warning("Host: %s - Global rank: %i" % (socket.gethostname(), prod_id)) 24 | if prod_id == 0: 25 | logger.warning("Requeuing job " + os.environ['SLURM_JOB_ID']) 26 | os.system('scontrol requeue ' + os.environ['SLURM_JOB_ID']) 27 | else: 28 | logger.warning("Not the master process, no need to requeue.") 29 | sys.exit(-1) 30 | 31 | 32 | def term_handler(signum, frame): 33 | logger.warning("Signal handler called with signal " + str(signum)) 34 | logger.warning("Bypassing SIGTERM.") 35 | 36 | 37 | def init_signal_handler(): 38 | """ 39 | Handle signals sent by SLURM for time limit / pre-emption. 40 | """ 41 | signal.signal(signal.SIGUSR1, sig_handler) 42 | signal.signal(signal.SIGTERM, term_handler) 43 | logger.warning("Signal handler installed.") 44 | 45 | 46 | def init_distributed_mode(params): 47 | """ 48 | Handle single and multi-GPU / multi-node / SLURM jobs. 49 | Initialize the following variables: 50 | - n_nodes 51 | - node_id 52 | - local_rank 53 | - global_rank 54 | - world_size 55 | """ 56 | params.is_slurm_job = 'SLURM_JOB_ID' in os.environ and not params.debug_slurm 57 | print("SLURM job: %s" % str(params.is_slurm_job)) 58 | 59 | # SLURM job 60 | if params.is_slurm_job: 61 | 62 | assert params.local_rank == -1 # on the cluster, this is handled by SLURM 63 | 64 | SLURM_VARIABLES = [ 65 | 'SLURM_JOB_ID', 66 | 'SLURM_JOB_NODELIST', 'SLURM_JOB_NUM_NODES', 'SLURM_NTASKS', 'SLURM_TASKS_PER_NODE', 67 | 'SLURM_MEM_PER_NODE', 'SLURM_MEM_PER_CPU', 68 | 'SLURM_NODEID', 'SLURM_PROCID', 'SLURM_LOCALID', 'SLURM_TASK_PID' 69 | ] 70 | 71 | PREFIX = "%i - " % int(os.environ['SLURM_PROCID']) 72 | for name in SLURM_VARIABLES: 73 | value = os.environ.get(name, None) 74 | print(PREFIX + "%s: %s" % (name, str(value))) 75 | 76 | # # job ID 77 | # params.job_id = os.environ['SLURM_JOB_ID'] 78 | 79 | # number of nodes / node ID 80 | params.n_nodes = int(os.environ['SLURM_JOB_NUM_NODES']) 81 | params.node_id = int(os.environ['SLURM_NODEID']) 82 | 83 | # local rank on the current node / global rank 84 | params.local_rank = int(os.environ['SLURM_LOCALID']) 85 | params.global_rank = int(os.environ['SLURM_PROCID']) 86 | 87 | # number of processes / GPUs per node 88 | params.world_size = int(os.environ['SLURM_NTASKS']) 89 | params.n_gpu_per_node = params.world_size // params.n_nodes 90 | 91 | # define master address and master port 92 | hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', os.environ['SLURM_JOB_NODELIST']]) 93 | params.master_addr = hostnames.split()[0].decode('utf-8') 94 | assert 10001 <= params.master_port <= 20000 or params.world_size == 1 95 | print(PREFIX + "Master address: %s" % params.master_addr) 96 | print(PREFIX + "Master port : %i" % params.master_port) 97 | 98 | # set environment variables for 'env://' 99 | os.environ['MASTER_ADDR'] = params.master_addr 100 | os.environ['MASTER_PORT'] = str(params.master_port) 101 | os.environ['WORLD_SIZE'] = str(params.world_size) 102 | os.environ['RANK'] = str(params.global_rank) 103 | 104 | # multi-GPU job (local or multi-node) - jobs started with torch.distributed.launch 105 | elif params.local_rank != -1: 106 | 107 | assert params.master_port == -1 108 | 109 | # read environment variables 110 | params.global_rank = int(os.environ['RANK']) 111 | params.world_size = int(os.environ['WORLD_SIZE']) 112 | params.n_gpu_per_node = int(os.environ['NGPU']) 113 | 114 | # number of nodes / node ID 115 | params.n_nodes = params.world_size // params.n_gpu_per_node 116 | params.node_id = params.global_rank // params.n_gpu_per_node 117 | 118 | # local job (single GPU) 119 | else: 120 | assert params.local_rank == -1 121 | assert params.master_port == -1 122 | params.n_nodes = 1 123 | params.node_id = 0 124 | params.local_rank = 0 125 | params.global_rank = 0 126 | params.world_size = 1 127 | params.n_gpu_per_node = 1 128 | 129 | # sanity checks 130 | assert params.n_nodes >= 1 131 | assert 0 <= params.node_id < params.n_nodes 132 | assert 0 <= params.local_rank <= params.global_rank < params.world_size 133 | assert params.world_size == params.n_nodes * params.n_gpu_per_node 134 | 135 | # define whether this is the master process / if we are in distributed mode 136 | params.is_master = params.node_id == 0 and params.local_rank == 0 137 | params.multi_node = params.n_nodes > 1 138 | params.multi_gpu = params.world_size > 1 139 | 140 | # summary 141 | PREFIX = "%i - " % params.global_rank 142 | print(PREFIX + "Number of nodes: %i" % params.n_nodes) 143 | print(PREFIX + "Node ID : %i" % params.node_id) 144 | print(PREFIX + "Local rank : %i" % params.local_rank) 145 | print(PREFIX + "Global rank : %i" % params.global_rank) 146 | print(PREFIX + "World size : %i" % params.world_size) 147 | print(PREFIX + "GPUs per node : %i" % params.n_gpu_per_node) 148 | print(PREFIX + "Master : %s" % str(params.is_master)) 149 | print(PREFIX + "Multi-node : %s" % str(params.multi_node)) 150 | print(PREFIX + "Multi-GPU : %s" % str(params.multi_gpu)) 151 | print(PREFIX + "Hostname : %s" % socket.gethostname()) 152 | 153 | # set GPU device 154 | if not params.cpu: 155 | if params.local_gpu != -1: 156 | torch.cuda.set_device(params.local_gpu) 157 | else: 158 | torch.cuda.set_device(params.local_rank) 159 | 160 | # initialize multi-GPU 161 | if params.multi_gpu: 162 | 163 | # http://pytorch.apachecn.org/en/0.3.0/distributed.html#environment-variable-initialization 164 | # 'env://' will read these environment variables: 165 | # MASTER_PORT - required; has to be a free port on machine with rank 0 166 | # MASTER_ADDR - required (except for rank 0); address of rank 0 node 167 | # WORLD_SIZE - required; can be set either here, or in a call to init function 168 | # RANK - required; can be set either here, or in a call to init function 169 | 170 | print("Initializing PyTorch distributed ...") 171 | torch.distributed.init_process_group( 172 | init_method='env://', 173 | backend='nccl', 174 | ) 175 | -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import os 9 | import io 10 | import sys 11 | import time 12 | from logging import getLogger 13 | from collections import OrderedDict 14 | import numpy as np 15 | import torch 16 | from torch import nn 17 | from torch.nn.utils import clip_grad_norm_ 18 | 19 | from .optim import get_optimizer 20 | from .utils import to_cuda 21 | 22 | logger = getLogger() 23 | 24 | 25 | class Trainer(object): 26 | def __init__(self, modules, env, params): 27 | """ 28 | Initialize trainer. 29 | """ 30 | # modules / params 31 | self.modules = modules 32 | self.params = params 33 | self.env = env 34 | 35 | assert self.params.report_loss_every > 0 36 | 37 | # epoch / iteration size 38 | self.epoch_size = params.epoch_size 39 | self.total_samples = 0 40 | if self.epoch_size == -1: 41 | self.epoch_size = self.data 42 | assert self.epoch_size > 0 43 | 44 | # data iterators 45 | self.iterators = {} 46 | 47 | # set parameters 48 | self.set_parameters() 49 | 50 | # float16 / distributed (no AMP) 51 | assert params.amp >= 1 or not params.fp16 52 | assert params.amp >= 0 or params.accumulate_gradients == 1 53 | if params.multi_gpu: 54 | logger.info("Using nn.parallel.DistributedDataParallel ...") 55 | for k in self.modules.keys(): 56 | self.modules[k] = nn.parallel.DistributedDataParallel( 57 | self.modules[k], 58 | device_ids=[params.local_rank], 59 | output_device=params.local_rank, 60 | broadcast_buffers=True, 61 | ) 62 | 63 | # set optimizer 64 | self.set_optimizer() 65 | 66 | # float16 / distributed (AMP) 67 | self.scaler = None 68 | if params.amp >= 0: 69 | self.init_amp() 70 | # stopping criterion used for early stopping 71 | if params.stopping_criterion != "": 72 | split = params.stopping_criterion.split(",") 73 | assert len(split) == 2 and split[1].isdigit() 74 | self.decrease_counts_max = int(split[1]) 75 | self.decrease_counts = 0 76 | if split[0][0] == "_": 77 | self.stopping_criterion = (split[0][1:], False) 78 | else: 79 | self.stopping_criterion = (split[0], True) 80 | self.best_stopping_criterion = -1e12 if self.stopping_criterion[1] else 1e12 81 | else: 82 | self.stopping_criterion = None 83 | self.best_stopping_criterion = None 84 | 85 | # validation metrics 86 | self.metrics = [] 87 | metrics = [m for m in params.validation_metrics.split(",") if m != ""] 88 | for m in metrics: 89 | m = (m[1:], False) if m[0] == "_" else (m, True) 90 | self.metrics.append(m) 91 | self.best_metrics = { 92 | metric: (-1e12 if biggest else 1e12) for (metric, biggest) in self.metrics 93 | } 94 | 95 | # training statistics 96 | self.epoch = 0 97 | self.n_iter = 0 98 | self.n_total_iter = 0 99 | self.stats = OrderedDict( 100 | [("processed_e", 0)] 101 | + [("processed_w", 0)] 102 | + [("encoder_act_penalty", [])] 103 | + [("decoder_act_penalty", [])] 104 | + sum( 105 | [[(x, []), (f"{x}-AVG-STOP-PROBS", [])] for x in env.TRAINING_TASKS], [] 106 | ) 107 | ) 108 | self.last_time = time.time() 109 | 110 | # reload potential checkpoints 111 | self.reload_checkpoint() 112 | 113 | # file handler to export data 114 | if params.export_data: 115 | assert params.train_data == "" 116 | params.export_path_prefix = os.path.join(params.dump_path, "data.prefix") 117 | self.file_handler_prefix = io.open( 118 | params.export_path_prefix, mode="a", encoding="utf-8" 119 | ) 120 | logger.info( 121 | f"Data will be stored in prefix in: {params.export_path_prefix} ..." 122 | ) 123 | 124 | # reload exported data 125 | if params.train_data != "": 126 | assert params.num_workers in [0, 1] 127 | assert params.export_data is False 128 | self.data_path = params.train_data 129 | else: 130 | self.data_path = None 131 | 132 | # create data loaders 133 | if not params.eval_only: 134 | if params.env_base_seed < 0: 135 | params.env_base_seed = np.random.randint(1_000_000_000) 136 | self.dataloader = { 137 | task: iter(self.env.create_train_iterator(task, self.data_path, params)) 138 | for task in params.tasks 139 | } 140 | 141 | def set_parameters(self): 142 | """ 143 | Set parameters. 144 | """ 145 | self.parameters = {} 146 | named_params = [] 147 | for v in self.modules.values(): 148 | named_params.extend( 149 | [(k, p) for k, p in v.named_parameters() if p.requires_grad] 150 | ) 151 | self.parameters["model"] = [p for k, p in named_params] 152 | for k, v in self.parameters.items(): 153 | logger.info("Found %i parameters in %s." % (len(v), k)) 154 | assert len(v) >= 1 155 | 156 | def set_optimizer(self): 157 | """ 158 | Set optimizer. 159 | """ 160 | params = self.params 161 | self.optimizer = get_optimizer( 162 | self.parameters["model"], params.optimizer 163 | ) 164 | logger.info("Optimizer: %s" % type(self.optimizer)) 165 | 166 | def init_amp(self): 167 | """ 168 | Initialize AMP optimizer. 169 | """ 170 | params = self.params 171 | assert ( 172 | params.amp == 0 173 | and params.fp16 is False 174 | or params.amp in [1, 2, 3] 175 | and params.fp16 is True 176 | ) 177 | # mod_names = sorted(self.modules.keys()) # unused 178 | self.scaler = torch.cuda.amp.GradScaler() 179 | 180 | def optimize(self, loss): 181 | """ 182 | Optimize. 183 | """ 184 | # check NaN 185 | if (loss != loss).data.any(): 186 | logger.warning("NaN detected") 187 | # exit() 188 | 189 | params = self.params 190 | 191 | # optimizer 192 | optimizer = self.optimizer 193 | 194 | # regular optimization 195 | if params.amp == -1: 196 | optimizer.zero_grad() 197 | loss.backward() 198 | if params.clip_grad_norm > 0: 199 | clip_grad_norm_(self.parameters["model"], params.clip_grad_norm) 200 | optimizer.step() 201 | 202 | else: 203 | if params.accumulate_gradients > 1: 204 | loss = loss / params.accumulate_gradients 205 | self.scaler.scale(loss).backward() 206 | 207 | if (self.n_iter + 1) % params.accumulate_gradients == 0: 208 | if params.clip_grad_norm > 0: 209 | self.scaler.unscale_(optimizer) 210 | clip_grad_norm_(self.parameters["model"], params.clip_grad_norm) 211 | self.scaler.step(optimizer) 212 | self.scaler.update() 213 | optimizer.zero_grad() 214 | 215 | def iter(self): 216 | """ 217 | End of iteration. 218 | """ 219 | self.n_iter += 1 220 | self.n_total_iter += 1 221 | if self.n_total_iter % self.params.report_loss_every == 0: 222 | self.print_stats() 223 | 224 | def print_stats(self): 225 | """ 226 | Print statistics about the training. 227 | """ 228 | s_iter = "%7i - " % self.n_total_iter 229 | s_stat = " || ".join( 230 | [ 231 | "{}: {:7.4f}".format(k.upper().replace("_", "-"), np.mean(v)) 232 | for k, v in self.stats.items() 233 | if type(v) is list and len(v) > 0 234 | ] 235 | ) 236 | for k in self.stats.keys(): 237 | if type(self.stats[k]) is list: 238 | del self.stats[k][:] 239 | 240 | # learning rates 241 | s_lr = ( 242 | (" - LR: ") 243 | + " / ".join("{:.4e}".format(group["lr"]) for group in self.optimizer.param_groups) 244 | ) 245 | 246 | # processing speed 247 | new_time = time.time() 248 | diff = new_time - self.last_time 249 | s_speed = "{:7.2f} examples/s - {:8.2f} words/s - ".format( 250 | self.stats["processed_e"] * 1.0 / diff, 251 | self.stats["processed_w"] * 1.0 / diff, 252 | ) 253 | self.stats["processed_e"] = 0 254 | self.stats["processed_w"] = 0 255 | self.last_time = new_time 256 | 257 | # log speed + stats + learning rate 258 | logger.info(s_iter + s_speed + s_stat + s_lr) 259 | 260 | def save_checkpoint(self, name, include_optimizer=True): 261 | """ 262 | Save the model / checkpoints. 263 | """ 264 | if not self.params.is_master: 265 | return 266 | 267 | path = os.path.join(self.params.dump_path, "%s.pth" % name) 268 | logger.info("Saving %s to %s ..." % (name, path)) 269 | 270 | data = { 271 | "epoch": self.epoch, 272 | "n_total_iter": self.n_total_iter, 273 | "best_metrics": self.best_metrics, 274 | "best_stopping_criterion": self.best_stopping_criterion, 275 | "params": {k: v for k, v in self.params.__dict__.items()}, 276 | } 277 | 278 | for k, v in self.modules.items(): 279 | logger.warning(f"Saving {k} parameters ...") 280 | data[k] = v.state_dict() 281 | 282 | if include_optimizer: 283 | logger.warning("Saving optimizer ...") 284 | data["optimizer"] = self.optimizer.state_dict() 285 | if self.scaler is not None: 286 | data["scaler"] = self.scaler.state_dict() 287 | 288 | torch.save(data, path) 289 | 290 | def reload_checkpoint(self): 291 | """ 292 | Reload a checkpoint if we find one. 293 | """ 294 | checkpoint_path = os.path.join(self.params.dump_path, "checkpoint.pth") 295 | if not os.path.isfile(checkpoint_path): 296 | if self.params.reload_checkpoint == "": 297 | return 298 | else: 299 | checkpoint_path = self.params.reload_checkpoint 300 | assert os.path.isfile(checkpoint_path) 301 | 302 | logger.warning(f"Reloading checkpoint from {checkpoint_path} ...") 303 | data = torch.load(checkpoint_path, map_location="cpu") 304 | 305 | # reload model parameters 306 | for k, v in self.modules.items(): 307 | v.load_state_dict(data[k]) 308 | 309 | # reload optimizer 310 | # AMP checkpoint reloading is buggy, we cannot reload optimizer 311 | # instead, we only reload current iterations / learning rates 312 | logger.warning("Reloading checkpoint optimizer ...") 313 | self.optimizer.load_state_dict(data["optimizer"]) 314 | 315 | if self.params.fp16: 316 | logger.warning("Reloading gradient scaler ...") 317 | self.scaler.load_state_dict(data["scaler"]) 318 | else: 319 | assert self.scaler is None and "scaler" not in data 320 | 321 | # reload main metrics 322 | self.epoch = data["epoch"] + 1 323 | self.n_total_iter = data["n_total_iter"] 324 | self.best_metrics = data["best_metrics"] 325 | self.best_stopping_criterion = data["best_stopping_criterion"] 326 | logger.warning( 327 | f"Checkpoint reloaded. Resuming at epoch {self.epoch} / iteration {self.n_total_iter} ..." 328 | ) 329 | 330 | def save_periodic(self): 331 | """ 332 | Save the models periodically. 333 | """ 334 | if not self.params.is_master: 335 | return 336 | if ( 337 | self.params.save_periodic > 0 338 | and self.epoch % self.params.save_periodic == 0 339 | ): 340 | self.save_checkpoint("periodic-%i" % self.epoch) 341 | 342 | def save_best_model(self, scores): 343 | """ 344 | Save best models according to given validation metrics. 345 | """ 346 | if not self.params.is_master: 347 | return 348 | for metric, biggest in self.metrics: 349 | if metric not in scores: 350 | logger.warning('Metric "%s" not found in scores!' % metric) 351 | continue 352 | factor = 1 if biggest else -1 353 | if factor * scores[metric] > factor * self.best_metrics[metric]: 354 | self.best_metrics[metric] = scores[metric] 355 | logger.info("New best score for %s: %.6f" % (metric, scores[metric])) 356 | self.save_checkpoint("best-%s" % metric) 357 | 358 | def end_epoch(self, scores): 359 | """ 360 | End the epoch. 361 | """ 362 | # stop if the stopping criterion has not improved after a certain number of epochs 363 | if self.stopping_criterion is not None and ( 364 | self.params.is_master or not self.stopping_criterion[0].endswith("_mt_bleu") 365 | ): 366 | metric, biggest = self.stopping_criterion 367 | assert metric in scores, metric 368 | factor = 1 if biggest else -1 369 | if factor * scores[metric] > factor * self.best_stopping_criterion: 370 | self.best_stopping_criterion = scores[metric] 371 | logger.info( 372 | "New best validation score: %f" % self.best_stopping_criterion 373 | ) 374 | self.decrease_counts = 0 375 | else: 376 | logger.info( 377 | "Not a better validation score (%i / %i)." 378 | % (self.decrease_counts, self.decrease_counts_max) 379 | ) 380 | self.decrease_counts += 1 381 | if self.decrease_counts > self.decrease_counts_max: 382 | logger.info( 383 | "Stopping criterion has been below its best value for more " 384 | "than %i epochs. Ending the experiment..." 385 | % self.decrease_counts_max 386 | ) 387 | if self.params.multi_gpu and "SLURM_JOB_ID" in os.environ: 388 | os.system("scancel " + os.environ["SLURM_JOB_ID"]) 389 | exit() 390 | self.save_checkpoint("checkpoint") 391 | self.epoch += 1 392 | 393 | def get_batch(self, task): 394 | """ 395 | Return a training batch for a specific task. 396 | """ 397 | try: 398 | batch = next(self.dataloader[task]) 399 | except Exception as e: 400 | logger.error( 401 | "An unknown exception of type {0} occurred in line {1} when fetching batch. " 402 | "Arguments:{2!r}. Restarting ...".format( 403 | type(e).__name__, sys.exc_info()[-1].tb_lineno, e.args 404 | ) 405 | ) 406 | if self.params.is_slurm_job: 407 | if int(os.environ["SLURM_PROCID"]) == 0: 408 | logger.warning("Requeuing job " + os.environ["SLURM_JOB_ID"]) 409 | os.system("scontrol requeue " + os.environ["SLURM_JOB_ID"]) 410 | else: 411 | logger.warning("Not the master process, no need to requeue.") 412 | raise 413 | 414 | return batch 415 | 416 | def export_data(self, task): 417 | """ 418 | Export data to the disk. 419 | """ 420 | env = self.env 421 | (x1, len1), (x2, len2), _ = self.get_batch(task) 422 | for i in range(len(len1)): 423 | # prefix assuming encoder decoder... 424 | # no bos 425 | prefix1 = [env.id2word[wid] for wid in x1[0: len1[i] - 1, i].tolist()] 426 | # bos 427 | prefix2 = [env.id2word[wid] for wid in x2[1: len2[i] - 1, i].tolist()] 428 | # save 429 | prefix1_str = " ".join(prefix1) 430 | prefix2_str = " ".join(prefix2) 431 | self.file_handler_prefix.write(f"{prefix1_str}\t{prefix2_str}\n") 432 | self.file_handler_prefix.flush() 433 | # self.EQUATIONS[(prefix1_str, prefix2_str)] = self.EQUATIONS.get((prefix1_str, prefix2_str), 0) + 1 434 | 435 | # number of processed sequences / words 436 | self.n_equations += self.params.batch_size 437 | self.total_samples += self.params.batch_size 438 | self.stats["processed_e"] += len1.size(0) 439 | self.stats["processed_w"] += (len1 + len2 - 2).sum().item() 440 | 441 | def enc_dec_step(self, task): 442 | """ 443 | Encoding / decoding step. 444 | """ 445 | params = self.params 446 | 447 | if params.architecture == "decoder_only": 448 | # batch 449 | (x2, len2), _ = self.get_batch(task) 450 | # cuda 451 | x2, len2 = to_cuda(x2, len2) 452 | else: 453 | # batch 454 | (x1, len1), (x2, len2), _ = self.get_batch(task) 455 | # cuda 456 | x1, len1, x2, len2 = to_cuda(x1, len1, x2, len2) 457 | 458 | # target words to predict 459 | if params.architecture != "encoder_only": 460 | alen = torch.arange(len2.max(), dtype=torch.long, device=len2.device) 461 | pred_mask = ( 462 | alen[:, None] < len2[None] - 1 463 | ) # do not predict anything given the last target word 464 | y = x2[1:].masked_select(pred_mask[:-1]) 465 | assert len(y) == (len2 - 1).sum().item() 466 | else: 467 | alen = torch.arange(len1.max(), dtype=torch.long, device=len2.device) 468 | pred_mask = ( 469 | (alen[:, None] < len2[None]) 470 | ) 471 | y= torch.cat((x2,torch.full((len1.max()-len2.max(),len2.size(0)),self.env.eos_index,device=len2.device)),0) 472 | y = y.masked_select(pred_mask) 473 | 474 | if params.architecture != "decoder_only": 475 | encoder = ( 476 | self.modules["encoder"].module 477 | if params.multi_gpu 478 | else self.modules["encoder"] 479 | ) 480 | encoder.train() 481 | 482 | if params.architecture != "encoder_only": 483 | decoder = ( 484 | self.modules["decoder"].module 485 | if params.multi_gpu 486 | else self.modules["decoder"] 487 | ) 488 | decoder.train() 489 | 490 | # forward / loss 491 | if params.architecture == "encoder_decoder": 492 | if params.lstm: 493 | if params.amp == -1: 494 | _, hidden = encoder("fwd", x=x1, lengths=len1, causal=False) 495 | decoded, _ = decoder( 496 | "fwd", 497 | x=x2, 498 | lengths=len2, 499 | causal=True, 500 | src_enc=hidden, 501 | ) 502 | _, loss = decoder( 503 | "predict", tensor=decoded, pred_mask=pred_mask, y=y, get_scores=False 504 | ) 505 | else: 506 | with torch.cuda.amp.autocast(): 507 | _, hidden = encoder("fwd", x=x1, lengths=len1, causal=False) 508 | decoded, _ = decoder( 509 | "fwd", 510 | x=x2, 511 | lengths=len2, 512 | causal=True, 513 | src_enc=hidden, 514 | ) 515 | _, loss = decoder( 516 | "predict", 517 | tensor=decoded, 518 | pred_mask=pred_mask, 519 | y=y, 520 | get_scores=False, 521 | ) 522 | else: 523 | if params.amp == -1: 524 | encoded = encoder("fwd", x=x1, lengths=len1, causal=False) 525 | decoded = decoder( 526 | "fwd", 527 | x=x2, 528 | lengths=len2, 529 | causal=True, 530 | src_enc=encoded.transpose(0, 1), 531 | src_len=len1, 532 | ) 533 | _, loss = decoder( 534 | "predict", tensor=decoded, pred_mask=pred_mask, y=y, get_scores=False 535 | ) 536 | else: 537 | with torch.cuda.amp.autocast(): 538 | encoded = encoder("fwd", x=x1, lengths=len1, causal=False) 539 | decoded = decoder( 540 | "fwd", 541 | x=x2, 542 | lengths=len2, 543 | causal=True, 544 | src_enc=encoded.transpose(0, 1), 545 | src_len=len1, 546 | ) 547 | _, loss = decoder( 548 | "predict", 549 | tensor=decoded, 550 | pred_mask=pred_mask, 551 | y=y, 552 | get_scores=False, 553 | ) 554 | if encoder.act: 555 | loss = loss + encoder.layers[encoder.loop_idx].ponder_penalty 556 | if decoder.act: 557 | loss = loss + decoder.layers[decoder.loop_idx].ponder_penalty 558 | elif params.architecture == "encoder_only": 559 | if params.amp == -1: 560 | encoded = encoder("fwd", x=x1, lengths=len1, causal=False) 561 | _, loss = encoder( 562 | "predict", tensor=encoded, pred_mask=pred_mask, y=y, get_scores=False 563 | ) 564 | else: 565 | with torch.cuda.amp.autocast(): 566 | encoded = encoder("fwd", x=x1, lengths=len1, causal=False) 567 | _, loss = encoder( 568 | "predict", 569 | tensor=encoded, 570 | pred_mask=pred_mask, 571 | y=y, 572 | get_scores=False, 573 | ) 574 | if encoder.act: 575 | loss = loss + encoder.layers[encoder.loop_idx].ponder_penalty 576 | else: 577 | if params.amp == -1: 578 | decoded = decoder( 579 | "fwd", 580 | x=x2, 581 | lengths=len2, 582 | causal=True, 583 | src_enc=None, 584 | src_len=None, 585 | ) 586 | _, loss = decoder( 587 | "predict", tensor=decoded, pred_mask=pred_mask, y=y, get_scores=False 588 | ) 589 | else: 590 | with torch.cuda.amp.autocast(): 591 | decoded = decoder( 592 | "fwd", 593 | x=x2, 594 | lengths=len2, 595 | causal=True, 596 | src_enc=None, 597 | src_len=None, 598 | ) 599 | _, loss = decoder( 600 | "predict", 601 | tensor=decoded, 602 | pred_mask=pred_mask, 603 | y=y, 604 | get_scores=False, 605 | ) 606 | if decoder.act: 607 | loss = loss + decoder.layers[decoder.loop_idx].ponder_penalty 608 | 609 | 610 | self.stats[task].append(loss.item()) 611 | 612 | # optimize 613 | self.optimize(loss) 614 | 615 | # number of processed sequences / words 616 | self.n_equations += params.batch_size 617 | if params.architecture == "decoder_only": 618 | self.stats["processed_e"] += len2.size(0) 619 | self.stats["processed_w"] += (len2 - 3).sum().item() 620 | else: 621 | self.stats["processed_e"] += len1.size(0) 622 | self.stats["processed_w"] += (len1 + len2 - 2).sum().item() 623 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import os 9 | import re 10 | import sys 11 | import math 12 | import time 13 | import pickle 14 | import random 15 | import getpass 16 | import argparse 17 | import subprocess 18 | 19 | import errno 20 | import signal 21 | from functools import wraps, partial 22 | 23 | from .logger import create_logger 24 | 25 | 26 | FALSY_STRINGS = {'off', 'false', '0'} 27 | TRUTHY_STRINGS = {'on', 'true', '1'} 28 | 29 | DUMP_PATH = '/checkpoint/%s/dumped' % getpass.getuser() 30 | CUDA = True 31 | 32 | 33 | class AttrDict(dict): 34 | def __init__(self, *args, **kwargs): 35 | super(AttrDict, self).__init__(*args, **kwargs) 36 | self.__dict__ = self 37 | 38 | 39 | def bool_flag(s): 40 | """ 41 | Parse boolean arguments from the command line. 42 | """ 43 | if s.lower() in FALSY_STRINGS: 44 | return False 45 | elif s.lower() in TRUTHY_STRINGS: 46 | return True 47 | else: 48 | raise argparse.ArgumentTypeError("Invalid value for a boolean flag!") 49 | 50 | 51 | def initialize_exp(params): 52 | """ 53 | Initialize the experience: 54 | - dump parameters 55 | - create a logger 56 | """ 57 | # dump parameters 58 | get_dump_path(params) 59 | pickle.dump(params, open(os.path.join(params.dump_path, 'params.pkl'), 'wb')) 60 | 61 | # get running command 62 | command = ["python", sys.argv[0]] 63 | for x in sys.argv[1:]: 64 | if x.startswith('--'): 65 | assert '"' not in x and "'" not in x 66 | command.append(x) 67 | else: 68 | assert "'" not in x 69 | if re.match('^[a-zA-Z0-9_]+$', x): 70 | command.append("%s" % x) 71 | else: 72 | command.append("'%s'" % x) 73 | command = ' '.join(command) 74 | params.command = command + ' --exp_id "%s"' % params.exp_id 75 | 76 | # check experiment name 77 | assert len(params.exp_name.strip()) > 0 78 | 79 | # create a logger 80 | logger = create_logger(os.path.join(params.dump_path, 'train.log'), rank=getattr(params, 'global_rank', 0)) 81 | logger.info("============ Initialized logger ============") 82 | logger.info("\n".join("%s: %s" % (k, str(v)) 83 | for k, v in sorted(dict(vars(params)).items()))) 84 | logger.info("The experiment will be stored in %s\n" % params.dump_path) 85 | logger.info("Running command: %s" % command) 86 | logger.info("") 87 | return logger 88 | 89 | 90 | def get_dump_path(params): 91 | """ 92 | Create a directory to store the experiment. 93 | """ 94 | params.dump_path = DUMP_PATH if params.dump_path == '' else params.dump_path 95 | assert len(params.exp_name) > 0 96 | 97 | # create the sweep path if it does not exist 98 | sweep_path = os.path.join(params.dump_path, params.exp_name) 99 | if not os.path.exists(sweep_path): 100 | subprocess.Popen("mkdir -p %s" % sweep_path, shell=True).wait() 101 | 102 | # create an ID for the job if it is not given in the parameters. 103 | # if we run on the cluster, the job ID is the one of Chronos. 104 | # otherwise, it is randomly generated 105 | if params.exp_id == '': 106 | chronos_job_id = os.environ.get('CHRONOS_JOB_ID') 107 | slurm_job_id = os.environ.get('SLURM_JOB_ID') 108 | assert chronos_job_id is None or slurm_job_id is None 109 | exp_id = chronos_job_id if chronos_job_id is not None else slurm_job_id 110 | if exp_id is None: 111 | chars = 'abcdefghijklmnopqrstuvwxyz0123456789' 112 | while True: 113 | exp_id = ''.join(random.choice(chars) for _ in range(10)) 114 | if not os.path.isdir(os.path.join(sweep_path, exp_id)): 115 | break 116 | else: 117 | assert exp_id.isdigit() 118 | params.exp_id = exp_id 119 | 120 | # create the dump folder / update parameters 121 | params.dump_path = os.path.join(sweep_path, params.exp_id) 122 | if not os.path.isdir(params.dump_path): 123 | subprocess.Popen("mkdir -p %s" % params.dump_path, shell=True).wait() 124 | 125 | 126 | def to_cuda(*args): 127 | """ 128 | Move tensors to CUDA. 129 | """ 130 | if not CUDA: 131 | return args 132 | return [None if x is None else x.cuda() for x in args] 133 | 134 | 135 | class TimeoutError(BaseException): 136 | pass 137 | 138 | 139 | def timeout(seconds=10, error_message=os.strerror(errno.ETIME)): 140 | 141 | def decorator(func): 142 | 143 | def _handle_timeout(repeat_id, signum, frame): 144 | # logger.warning(f"Catched the signal ({repeat_id}) Setting signal handler {repeat_id + 1}") 145 | signal.signal(signal.SIGALRM, partial(_handle_timeout, repeat_id + 1)) 146 | signal.alarm(seconds) 147 | raise TimeoutError(error_message) 148 | 149 | def wrapper(*args, **kwargs): 150 | old_signal = signal.signal(signal.SIGALRM, partial(_handle_timeout, 0)) 151 | old_time_left = signal.alarm(seconds) 152 | assert type(old_time_left) is int and old_time_left >= 0 153 | if 0 < old_time_left < seconds: # do not exceed previous timer 154 | signal.alarm(old_time_left) 155 | start_time = time.time() 156 | try: 157 | result = func(*args, **kwargs) 158 | finally: 159 | if old_time_left == 0: 160 | signal.alarm(0) 161 | else: 162 | sub = time.time() - start_time 163 | signal.signal(signal.SIGALRM, old_signal) 164 | signal.alarm(max(0, math.ceil(old_time_left - sub))) 165 | return result 166 | 167 | return wraps(func)(wrapper) 168 | 169 | return decorator 170 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import json 9 | import random 10 | import argparse 11 | import numpy as np 12 | import torch 13 | import os 14 | import pickle 15 | 16 | import src 17 | from src.slurm import init_signal_handler, init_distributed_mode 18 | from src.utils import bool_flag, initialize_exp 19 | from src.model import check_model_params, build_modules 20 | from src.envs import ENVS, build_env 21 | from src.trainer import Trainer 22 | from src.evaluator import Evaluator 23 | 24 | 25 | np.seterr(all='raise') 26 | 27 | 28 | def get_parser(): 29 | """ 30 | Generate a parameters parser. 31 | """ 32 | # parse parameters 33 | parser = argparse.ArgumentParser(description="Language transfer") 34 | 35 | # main parameters 36 | parser.add_argument("--dump_path", type=str, default="", 37 | help="Experiment dump path") 38 | parser.add_argument("--exp_name", type=str, default="debug", 39 | help="Experiment name") 40 | parser.add_argument("--save_periodic", type=int, default=0, 41 | help="Save the model periodically (0 to disable)") 42 | parser.add_argument("--exp_id", type=str, default="", 43 | help="Experiment ID") 44 | 45 | parser.add_argument("--report_loss_every", type=int, default=200, 46 | help="Log train loss every n optimisation steps") 47 | 48 | parser.add_argument("--epoch_size", type=int, default=300000, 49 | help="Epoch size / evaluation frequency") 50 | parser.add_argument("--max_epoch", type=int, default=100000, 51 | help="Maximum epoch size") 52 | parser.add_argument("--stopping_criterion", type=str, default="", 53 | help="Stopping criterion, and number of non-increase before stopping the experiment") 54 | parser.add_argument("--validation_metrics", type=str, default="", 55 | help="Validation metrics") 56 | 57 | # model parameters 58 | parser.add_argument("--enc_emb_dim", type=int, default=256, 59 | help="Encoder embedding layer size") 60 | parser.add_argument("--dec_emb_dim", type=int, default=256, 61 | help="Decoder embedding layer size") 62 | parser.add_argument("--n_enc_layers", type=int, default=4, 63 | help="Number of Transformer layers in the encoder") 64 | parser.add_argument("--n_dec_layers", type=int, default=4, 65 | help="Number of Transformer layers in the decoder") 66 | parser.add_argument("--n_enc_heads", type=int, default=8, 67 | help="Number of Transformer encoder heads") 68 | parser.add_argument("--n_dec_heads", type=int, default=8, 69 | help="Number of Transformer decoder heads") 70 | parser.add_argument("--xav_init", type=bool_flag, default=False, 71 | help="Xavier initialization for transformer parameters") 72 | 73 | 74 | parser.add_argument("--n_enc_hidden_layers", type=int, default=1, 75 | help="Number of FFN layers in Transformer encoder") 76 | parser.add_argument("--n_dec_hidden_layers", type=int, default=1, 77 | help="Number of FFN layers in Transformer decoder") 78 | parser.add_argument("--gelu_activation", type=bool_flag, default=False, 79 | help="GELU initialization in FFN layers (else RELU)") 80 | parser.add_argument("--dropout", type=float, default=0, 81 | help="Dropout") 82 | 83 | parser.add_argument("--max_src_len", type=int, default=0, 84 | help="Maximum number of tokens to consider in encoder output") 85 | 86 | parser.add_argument("--norm_attention", type=bool_flag, default=False, 87 | help="Normalize attention and train temperaturee in Transformer") 88 | parser.add_argument("--attention_dropout", type=float, default=0, 89 | help="Dropout in the attention layer") 90 | 91 | parser.add_argument("--architecture", type=str, default="encoder_decoder", 92 | help="encoder_decoder, encoder_only or decoder_only (last 2 transformer only)") 93 | 94 | # lstm/GRU 95 | parser.add_argument("--lstm", type=bool_flag, default=False, 96 | help="LSTM or GRU") 97 | parser.add_argument("--GRU", type=bool_flag, default=False, 98 | help="GRU model") 99 | parser.add_argument("--bidirectional", type=bool_flag, default=False, 100 | help="bidirectional lstm") 101 | parser.add_argument("--lstm_hidden_dim", type=int, default=2048, 102 | help="hidden dimension for lstm") 103 | 104 | # embedding 105 | parser.add_argument("--share_inout_emb", type=bool_flag, default=True, 106 | help="Share input and output embeddings") 107 | parser.add_argument("--sinusoidal_embeddings", type=bool_flag, default=False, 108 | help="Use sinusoidal embeddings") 109 | 110 | parser.add_argument("--enc_has_pos_emb", type=bool_flag, default=True, 111 | help="Positional embedding in the encoder") 112 | parser.add_argument("--dec_has_pos_emb", type=bool_flag, default=True, 113 | help="Positional embedding in the decoder") 114 | 115 | 116 | 117 | # Loop layers 118 | parser.add_argument("--enc_loop_idx", type=int, default=-1, 119 | help="Index of the encoder shared weight layers (-1 for none, -2 for all)") 120 | parser.add_argument("--dec_loop_idx", type=int, default=-1, 121 | help="Index of the decoder shared weight layers (-1 for none, -2 for all)") 122 | parser.add_argument("--enc_loops", type=int, default=1, 123 | help="Fixed/max nr of train passes through the encoder loop") 124 | parser.add_argument("--dec_loops", type=int, default=1, 125 | help="Fixed/max nr of train passes through the decoder loop") 126 | 127 | 128 | # gates 129 | parser.add_argument("--gated", type=bool_flag, default=False, 130 | help="Gated loop layers") 131 | parser.add_argument("--enc_gated", type=bool_flag, default=False, 132 | help="All encoder layers gated") 133 | parser.add_argument("--dec_gated", type=bool_flag, default=False, 134 | help="All decoder layers gated") 135 | parser.add_argument("--scalar_gate", type=bool_flag, default=False, 136 | help="Scalar gates") 137 | parser.add_argument("--biased_gates", type=bool_flag, default=False, 138 | help="Biased gates") 139 | parser.add_argument("--gate_bias", type=int, default=0, 140 | help="Gate_bias") 141 | # ACT 142 | parser.add_argument("--enc_act", type=bool_flag, default=False, 143 | help="Encoder looped layer ACT") 144 | parser.add_argument("--dec_act", type=bool_flag, default=False, 145 | help="Decoder looped layer ACT") 146 | parser.add_argument("--act_threshold", type=float, default=0.01, 147 | help="Prob threshold for ACT") 148 | parser.add_argument("--act_ponder_coupling", type=float, default=0.01, 149 | help="Ponder loss coupling for ACT") 150 | parser.add_argument("--act_biased", type=bool_flag, default=False, 151 | help="ACT bias initialised") 152 | parser.add_argument("--act_bias", type=int, default=0, 153 | help="act bias") 154 | 155 | 156 | 157 | # technical parameters float16 / AMP API 158 | parser.add_argument("--fp16", type=bool_flag, default=False, 159 | help="Run model with float16") 160 | parser.add_argument("--amp", type=int, default=-1, 161 | help="Use AMP wrapper for float16 / distributed / gradient accumulation. Level of optimization. -1 to disable.") 162 | parser.add_argument("--num_workers", type=int, default=1, 163 | help="Number of CPU workers for DataLoader") 164 | parser.add_argument("--env_base_seed", type=int, default=-1, 165 | help="Base seed for environments (-1 to use timestamp seed)") 166 | 167 | # CPU / multi-gpu / multi-node 168 | parser.add_argument("--cpu", type=bool_flag, default=False, 169 | help="Run on CPU") 170 | parser.add_argument("--local_gpu", type=int, default=-1, 171 | help="Multi-GPU - Local GPU") 172 | parser.add_argument("--local_rank", type=int, default=-1, 173 | help="Multi-GPU - Local rank for torch.distributed.launch") 174 | parser.add_argument("--master_port", type=int, default=-1, 175 | help="Master port (for multi-node SLURM jobs)") 176 | # parser.add_argument("--windows", type=bool_flag, default=False, 177 | # help="Windows version (no multiprocessing for eval)") 178 | 179 | 180 | 181 | # training parameters 182 | 183 | parser.add_argument("--max_len", type=int, default=512, 184 | help="Maximum sequences length") 185 | parser.add_argument("--max_output_len", type=int, default=512, 186 | help="max length of output, beam max size") 187 | 188 | parser.add_argument("--eval_size", type=int, default=10000, 189 | help="Size of valid and test samples") 190 | parser.add_argument("--batch_size_eval", type=int, default=128, 191 | help="Number of sentences per batch during evaluation") 192 | 193 | 194 | parser.add_argument("--batch_size", type=int, default=32, 195 | help="Number of sentences per batch") 196 | parser.add_argument("--accumulate_gradients", type=int, default=1, 197 | help="Accumulate model gradients over N iterations (N times larger batch sizes)") 198 | parser.add_argument("--optimizer", type=str, default="adam,lr=0.0001", 199 | help="Optimizer (SGD / RMSprop / Adam, etc.)") 200 | parser.add_argument("--clip_grad_norm", type=float, default=5, 201 | help="Clip gradients norm (0 to disable)") 202 | 203 | # export data / reload it 204 | parser.add_argument("--export_data", type=bool_flag, default=False, 205 | help="Export data and disable training.") 206 | parser.add_argument("--train_data", type=str, default="", 207 | help="Load dataset from the disk") 208 | 209 | parser.add_argument("--reload_size", type=int, default=-1, 210 | help="Reloaded training set size (-1 for everything)") 211 | parser.add_argument("--batch_load", type=bool_flag, default=False, 212 | help="Load training set by batches (of size reload_size).") 213 | 214 | # environment parameters 215 | parser.add_argument("--env_name", type=str, default="arithmetic", 216 | help="Environment name") 217 | ENVS[parser.parse_known_args()[0].env_name].register_args(parser) 218 | 219 | # tasks 220 | parser.add_argument("--tasks", type=str, default="arithmetic", 221 | help="Tasks") 222 | 223 | # beam search configuration 224 | parser.add_argument("--beam_eval", type=bool_flag, default=False, 225 | help="Evaluate with beam search decoding.") 226 | parser.add_argument("--beam_eval_train", type=int, default=0, 227 | help="At training time, number of validation equations to test the model on using beam search (-1 for everything, 0 to disable)") 228 | parser.add_argument("--beam_size", type=int, default=1, 229 | help="Beam size, default = 1 (greedy decoding)") 230 | parser.add_argument("--beam_length_penalty", type=float, default=1, 231 | help="Length penalty, values < 1.0 favor shorter sentences, while values > 1.0 favor longer ones.") 232 | parser.add_argument("--beam_early_stopping", type=bool_flag, default=True, 233 | help="Early stopping, stop as soon as we have `beam_size` hypotheses, although longer ones may have better scores.") 234 | 235 | # reload pretrained model / checkpoint 236 | parser.add_argument("--reload_model", type=str, default="", 237 | help="Reload a pretrained model") 238 | parser.add_argument("--reload_checkpoint", type=str, default="", 239 | help="Reload a checkpoint") 240 | 241 | # evaluation 242 | parser.add_argument("--eval_only", type=bool_flag, default=False, 243 | help="Only run evaluations") 244 | parser.add_argument("--eval_from_exp", type=str, default="", 245 | help="Path of experiment to use") 246 | parser.add_argument("--eval_data", type=str, default="", 247 | help="Path of data to eval") 248 | parser.add_argument("--eval_verbose", type=int, default=0, 249 | help="Export evaluation details") 250 | parser.add_argument("--eval_verbose_print", type=bool_flag, default=False, 251 | help="Print evaluation details") 252 | 253 | # debug 254 | parser.add_argument("--debug_slurm", type=bool_flag, default=False, 255 | help="Debug multi-GPU / multi-node within a SLURM job") 256 | parser.add_argument("--debug", help="Enable all debug flags", 257 | action="store_true") 258 | 259 | 260 | return parser 261 | 262 | 263 | def main(params): 264 | 265 | # initialize the multi-GPU / multi-node training 266 | # initialize experiment / SLURM signal handler for time limit / pre-emption 267 | init_distributed_mode(params) 268 | logger = initialize_exp(params) 269 | if params.is_slurm_job: 270 | init_signal_handler() 271 | 272 | # CPU / CUDA 273 | if params.cpu: 274 | assert not params.multi_gpu 275 | else: 276 | assert torch.cuda.is_available() 277 | src.utils.CUDA = not params.cpu 278 | 279 | # build environment / modules / trainer / evaluator 280 | env = build_env(params) 281 | modules = build_modules(env, params) 282 | trainer = Trainer(modules, env, params) 283 | evaluator = Evaluator(trainer) 284 | 285 | # evaluation 286 | if params.eval_only: 287 | scores = evaluator.run_all_evals() 288 | for k, v in scores.items(): 289 | logger.info("%s -> %.6f" % (k, v)) 290 | logger.info("__log__:%s" % json.dumps(scores)) 291 | exit() 292 | 293 | # training 294 | for _ in range(params.max_epoch): 295 | 296 | logger.info("============ Starting epoch %i ... ============" % trainer.epoch) 297 | 298 | trainer.n_equations = 0 299 | 300 | while trainer.n_equations < trainer.epoch_size: 301 | 302 | # training steps 303 | for task_id in np.random.permutation(len(params.tasks)): 304 | task = params.tasks[task_id] 305 | if params.export_data: 306 | trainer.export_data(task) 307 | else: 308 | trainer.enc_dec_step(task) 309 | trainer.iter() 310 | 311 | logger.info(f"Memory allocated: {torch.cuda.memory_allocated(0)/(1024*1024):.2f}MB, reserved: {torch.cuda.memory_reserved(0)/(1024*1024):.2f}MB") 312 | 313 | 314 | logger.info("============ End of epoch %i ============" % trainer.epoch) 315 | 316 | # evaluate perplexity 317 | scores = evaluator.run_all_evals() 318 | logger.info(f"Memory allocated: {torch.cuda.memory_allocated(0)/(1024*1024):.2f}MB, reserved: {torch.cuda.memory_reserved(0)/(1024*1024):.2f}MB") 319 | 320 | # print / JSON log 321 | # for k, v in scores.items(): 322 | # logger.info("%s -> %.6f" % (k, v)) 323 | if params.is_master: 324 | logger.info("__log__:%s" % json.dumps(scores)) 325 | 326 | # end of epoch 327 | trainer.save_best_model(scores) 328 | trainer.save_periodic() 329 | trainer.end_epoch(scores) 330 | 331 | 332 | if __name__ == '__main__': 333 | 334 | # generate parser / parse parameters 335 | parser = get_parser() 336 | params = parser.parse_args() 337 | if params.eval_only and params.eval_from_exp != "": 338 | # read params from pickle 339 | pickle_file = params.eval_from_exp + "/params.pkl" 340 | assert os.path.isfile(pickle_file) 341 | pk = pickle.load(open(pickle_file, 'rb')) 342 | pickled_args = pk.__dict__ 343 | del pickled_args['exp_id'] 344 | for p in params.__dict__: 345 | if p in pickled_args: 346 | params.__dict__[p] = pickled_args[p] 347 | 348 | params.eval_only = True 349 | params.reload_model = params.eval_from_exp + '/best-' + params.validation_metrics + '.pth' 350 | if not os.path.isfile(params.reload_model): 351 | params.reload_model = params.eval_from_exp + '/checkpoint.pth' 352 | params.eval_size = None 353 | params.train_data = "" 354 | params.is_slurm_job = False 355 | params.local_rank = -1 356 | 357 | # debug mode 358 | if params.debug: 359 | params.exp_name = 'debug' 360 | if params.exp_id == '': 361 | params.exp_id = 'debug_%08i' % random.randint(0, 100000000) 362 | params.debug_slurm = True 363 | 364 | # check parameters 365 | check_model_params(params) 366 | 367 | # run experiment 368 | main(params) 369 | --------------------------------------------------------------------------------