├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── PKM-layer.ipynb ├── README.md ├── generate-embeddings.ipynb ├── get-data-glue.sh ├── get-data-nmt.sh ├── get-data-para.sh ├── get-data-wiki.sh ├── get-data-xnli.sh ├── glue-xnli.py ├── install-tools.sh ├── prepare-glue.sh ├── prepare-xnli.sh ├── preprocess.py ├── setup.py ├── src ├── tools ├── README.md ├── lowercase_and_remove_accent.py ├── segment_th.py └── tokenize.sh ├── train.py ├── translate.py └── xlm ├── __init__.py ├── data ├── __init__.py ├── dataset.py ├── dictionary.py └── loader.py ├── evaluation ├── __init__.py ├── evaluator.py ├── glue.py ├── multi-bleu.perl └── xnli.py ├── logger.py ├── model ├── __init__.py ├── embedder.py ├── memory │ ├── __init__.py │ ├── memory.py │ ├── query.py │ └── utils.py ├── pretrain.py └── transformer.py ├── optim.py ├── slurm.py ├── trainer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.gitignore.io/api/python 2 | # Edit at https://www.gitignore.io/?templates=python 3 | 4 | ### Python ### 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # celery beat schedule file 89 | celerybeat-schedule 90 | 91 | # SageMath parsed files 92 | *.sage.py 93 | 94 | # Environments 95 | .env 96 | .venv 97 | env/ 98 | venv/ 99 | ENV/ 100 | env.bak/ 101 | venv.bak/ 102 | 103 | # Spyder project settings 104 | .spyderproject 105 | .spyproject 106 | 107 | # Rope project settings 108 | .ropeproject 109 | 110 | # mkdocs documentation 111 | /site 112 | 113 | # mypy 114 | .mypy_cache/ 115 | .dmypy.json 116 | dmypy.json 117 | 118 | # Pyre type checker 119 | .pyre/ 120 | 121 | ### Python Patch ### 122 | .venv/ 123 | 124 | # End of https://www.gitignore.io/api/python 125 | 126 | # tools, data, and dumped models 127 | # /tools 128 | /data 129 | /dumped 130 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. 4 | Please read the [full text](https://code.fb.com/codeofconduct/) 5 | so that you can understand what actions will and will not be tolerated. 6 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to this repo 2 | 3 | ## Pull Requests 4 | 5 | In order to accept your pull request, we need you to submit a CLA. You only need 6 | to do this once to work on any of Facebook's open source projects. 7 | 8 | Complete your CLA here: 9 | 10 | ## Issues 11 | We use GitHub issues to track public bugs. Please ensure your description is 12 | clear and has sufficient instructions to be able to reproduce the issue. 13 | 14 | ## License 15 | By contributing to this repo, you agree that your contributions will be licensed 16 | under the LICENSE file in the root directory of this source tree. 17 | -------------------------------------------------------------------------------- /PKM-layer.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Product-Key Memory (PKM)\n", 8 | "**Minimalist implementation of a Product-Key Memory layer** https://arxiv.org/abs/1907.05242\n", 9 | "\n", 10 | "This notebook contains a simple implementation of a PKM layer.\n", 11 | "
\n", 12 | "Overall, the PKM layer can be seen as a network with very high capacity that maps elements from $R^d$ to $R^n$, but very efficiently.\n", 13 | "
\n", 14 | "In particular, a 12-layer transformer model that leverages a PKM layer outperforms a 24-layer model without memory, and is almost twice faster at inference.\n", 15 | "\n", 16 | "A more detailed implementation can be found at https://github.com/facebookresearch/XLM/tree/master/xlm/model/memory,\n", 17 | "with options to make the query network more powerful, to shuffle the key indices, to compute the value scores differently\n", 18 | "than with a softmax, etc., but the code below is much simpler and implements a configuration that worked well in our experiments (and that we used to report the majority of our results).\n", 19 | "\n", 20 | "#### Note: at training time, we recommend to use a different optimizer for the values, as these are learned with sparse updates. In particular, we obtained our best performance with the Adam optimizer, and a constant learning rate of 1e-3 to learn the values, independently of the optimizer / learning rate used to learn the rest of the network." 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 1, 26 | "metadata": { 27 | "collapsed": true 28 | }, 29 | "outputs": [], 30 | "source": [ 31 | "import math\n", 32 | "import numpy as np\n", 33 | "import torch\n", 34 | "from torch import nn\n", 35 | "from torch.nn import functional as F" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 2, 41 | "metadata": { 42 | "collapsed": true 43 | }, 44 | "outputs": [], 45 | "source": [ 46 | "def get_uniform_keys(n_keys, dim, seed):\n", 47 | " \"\"\"\n", 48 | " Generate random uniform keys (same initialization as nn.Linear).\n", 49 | " \"\"\"\n", 50 | " rng = np.random.RandomState(seed)\n", 51 | " bound = 1 / math.sqrt(dim)\n", 52 | " keys = rng.uniform(-bound, bound, (n_keys, dim))\n", 53 | " return keys.astype(np.float32)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 3, 59 | "metadata": { 60 | "collapsed": true 61 | }, 62 | "outputs": [], 63 | "source": [ 64 | "class HashingMemory(nn.Module):\n", 65 | "\n", 66 | " def __init__(self, input_dim, output_dim, params):\n", 67 | "\n", 68 | " super().__init__()\n", 69 | "\n", 70 | " # global parameters\n", 71 | " self.input_dim = input_dim\n", 72 | " self.output_dim = output_dim\n", 73 | " self.k_dim = params.k_dim\n", 74 | " self.v_dim = output_dim\n", 75 | " self.n_keys = params.n_keys\n", 76 | " self.size = self.n_keys ** 2\n", 77 | " self.heads = params.heads\n", 78 | " self.knn = params.knn\n", 79 | " assert self.k_dim >= 2 and self.k_dim % 2 == 0\n", 80 | "\n", 81 | " # dropout\n", 82 | " self.input_dropout = params.input_dropout\n", 83 | " self.query_dropout = params.query_dropout\n", 84 | " self.value_dropout = params.value_dropout\n", 85 | "\n", 86 | " # initialize keys / values\n", 87 | " self.initialize_keys()\n", 88 | " self.values = nn.EmbeddingBag(self.size, self.v_dim, mode='sum', sparse=params.sparse)\n", 89 | " nn.init.normal_(self.values.weight, mean=0, std=self.v_dim ** -0.5)\n", 90 | "\n", 91 | " # query network\n", 92 | " self.query_proj = nn.Sequential(*filter(None, [\n", 93 | " nn.Linear(self.input_dim, self.heads * self.k_dim, bias=True),\n", 94 | " nn.BatchNorm1d(self.heads * self.k_dim) if params.query_batchnorm else None\n", 95 | " ]))\n", 96 | "\n", 97 | " if params.query_batchnorm:\n", 98 | " print(\"WARNING: Applying batch normalization to queries improves the performance \"\n", 99 | " \"and memory usage. But if you use it, be sure that you use batches of \"\n", 100 | " \"sentences with the same size at training time (i.e. without padding). \"\n", 101 | " \"Otherwise, the padding token will result in incorrect mean/variance \"\n", 102 | " \"estimations in the BatchNorm layer.\\n\")\n", 103 | "\n", 104 | " def initialize_keys(self):\n", 105 | " \"\"\"\n", 106 | " Create two subkey sets per head.\n", 107 | " `self.keys` is of shape (heads, 2, n_keys, k_dim // 2)\n", 108 | " \"\"\"\n", 109 | " half = self.k_dim // 2\n", 110 | " keys = nn.Parameter(torch.from_numpy(np.array([\n", 111 | " get_uniform_keys(self.n_keys, half, seed=(2 * i + j))\n", 112 | " for i in range(self.heads)\n", 113 | " for j in range(2)\n", 114 | " ])).view(self.heads, 2, self.n_keys, half))\n", 115 | " self.keys = nn.Parameter(keys)\n", 116 | "\n", 117 | " def _get_indices(self, query, subkeys):\n", 118 | " \"\"\"\n", 119 | " Generate scores and indices for a specific head.\n", 120 | " \"\"\"\n", 121 | " assert query.dim() == 2 and query.size(1) == self.k_dim\n", 122 | " bs = query.size(0)\n", 123 | " knn = self.knn\n", 124 | " half = self.k_dim // 2\n", 125 | " n_keys = len(subkeys[0])\n", 126 | "\n", 127 | " # split query for product quantization\n", 128 | " q1 = query[:, :half] # (bs,half)\n", 129 | " q2 = query[:, half:] # (bs,half)\n", 130 | "\n", 131 | " # compute indices with associated scores\n", 132 | " scores1 = F.linear(q1, subkeys[0], bias=None) # (bs,n_keys)\n", 133 | " scores2 = F.linear(q2, subkeys[1], bias=None) # (bs,n_keys)\n", 134 | " scores1, indices1 = scores1.topk(knn, dim=1) # (bs,knn)\n", 135 | " scores2, indices2 = scores2.topk(knn, dim=1) # (bs,knn)\n", 136 | "\n", 137 | " # cartesian product on best candidate keys\n", 138 | " all_scores = (\n", 139 | " scores1.view(bs, knn, 1).expand(bs, knn, knn) +\n", 140 | " scores2.view(bs, 1, knn).expand(bs, knn, knn)\n", 141 | " ).view(bs, -1) # (bs,knn**2)\n", 142 | " all_indices = (\n", 143 | " indices1.view(bs, knn, 1).expand(bs, knn, knn) * n_keys +\n", 144 | " indices2.view(bs, 1, knn).expand(bs, knn, knn)\n", 145 | " ).view(bs, -1) # (bs,knn**2)\n", 146 | "\n", 147 | " # select best scores with associated indices\n", 148 | " scores, best_indices = torch.topk(all_scores, k=knn, dim=1) # (bs,knn)\n", 149 | " indices = all_indices.gather(1, best_indices) # (bs,knn)\n", 150 | "\n", 151 | " assert scores.shape == indices.shape == (bs, knn)\n", 152 | " return scores, indices\n", 153 | "\n", 154 | " def get_indices(self, query):\n", 155 | " \"\"\"\n", 156 | " Generate scores and indices.\n", 157 | " \"\"\"\n", 158 | " assert query.dim() == 2 and query.size(1) == self.k_dim\n", 159 | " query = query.view(-1, self.heads, self.k_dim)\n", 160 | " bs = len(query)\n", 161 | " outputs = [self._get_indices(query[:, i], self.keys[i]) for i in range(self.heads)]\n", 162 | " s = torch.cat([s.view(bs, 1, self.knn) for s, _ in outputs], 1) # (bs,heads,knn)\n", 163 | " i = torch.cat([i.view(bs, 1, self.knn) for _, i in outputs], 1) # (bs,heads,knn)\n", 164 | " return s.view(-1, self.knn), i.view(-1, self.knn)\n", 165 | "\n", 166 | " def forward(self, input):\n", 167 | " \"\"\"\n", 168 | " Read from the memory.\n", 169 | " \"\"\"\n", 170 | " # input dimensions\n", 171 | " assert input.shape[-1] == self.input_dim\n", 172 | " prefix_shape = input.shape[:-1]\n", 173 | " bs = np.prod(prefix_shape)\n", 174 | "\n", 175 | " # compute query\n", 176 | " input = F.dropout(input, p=self.input_dropout, training=self.training) # (...,i_dim)\n", 177 | " query = self.query_proj(input.contiguous().view(-1, self.input_dim)) # (bs,heads*k_dim)\n", 178 | " query = query.view(bs * self.heads, self.k_dim) # (bs*heads,k_dim)\n", 179 | " query = F.dropout(query, p=self.query_dropout, training=self.training) # (bs*heads,k_dim)\n", 180 | " assert query.shape == (bs * self.heads, self.k_dim)\n", 181 | "\n", 182 | " # retrieve indices and scores\n", 183 | " scores, indices = self.get_indices(query) # (bs*heads,knn)\n", 184 | " scores = F.softmax(scores.float(), dim=-1).type_as(scores) # (bs*heads,knn)\n", 185 | "\n", 186 | " # merge heads / knn (since we sum heads)\n", 187 | " indices = indices.view(bs, self.heads * self.knn) # (bs,heads*knn)\n", 188 | " scores = scores.view(bs, self.heads * self.knn) # (bs,heads*knn)\n", 189 | "\n", 190 | " # weighted sum of values\n", 191 | " output = self.values(indices, per_sample_weights=scores) # (bs,v_dim)\n", 192 | " output = F.dropout(output, p=self.value_dropout, training=self.training)# (bs,v_dim)\n", 193 | "\n", 194 | " # reshape output\n", 195 | " if len(prefix_shape) >= 2:\n", 196 | " output = output.view(prefix_shape + (self.v_dim,)) # (...,v_dim)\n", 197 | "\n", 198 | " return output\n", 199 | "\n", 200 | " @staticmethod\n", 201 | " def register_args(parser):\n", 202 | " \"\"\"\n", 203 | " Register memory parameters.\n", 204 | " \"\"\"\n", 205 | " # memory parameters\n", 206 | " parser.add_argument(\"--sparse\", type=bool_flag, default=False,\n", 207 | " help=\"Perform sparse updates for the values\")\n", 208 | " parser.add_argument(\"--k_dim\", type=int, default=256,\n", 209 | " help=\"Memory keys dimension\")\n", 210 | " parser.add_argument(\"--heads\", type=int, default=4,\n", 211 | " help=\"Number of memory heads\")\n", 212 | " parser.add_argument(\"--knn\", type=int, default=32,\n", 213 | " help=\"Number of memory slots to read / update - k-NN to the query\")\n", 214 | " parser.add_argument(\"--n_keys\", type=int, default=512,\n", 215 | " help=\"Number of keys\")\n", 216 | " parser.add_argument(\"--query_batchnorm\", type=bool_flag, default=False,\n", 217 | " help=\"Query MLP batch norm\")\n", 218 | "\n", 219 | " # dropout\n", 220 | " parser.add_argument(\"--input_dropout\", type=float, default=0,\n", 221 | " help=\"Input dropout\")\n", 222 | " parser.add_argument(\"--query_dropout\", type=float, default=0,\n", 223 | " help=\"Query dropout\")\n", 224 | " parser.add_argument(\"--value_dropout\", type=float, default=0,\n", 225 | " help=\"Value dropout\")" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": 4, 231 | "metadata": { 232 | "collapsed": true 233 | }, 234 | "outputs": [], 235 | "source": [ 236 | "class AttrDict(dict):\n", 237 | " def __init__(self, *args, **kwargs):\n", 238 | " super(AttrDict, self).__init__(*args, **kwargs)\n", 239 | " self.__dict__ = self\n", 240 | "\n", 241 | "\n", 242 | "params = AttrDict({\n", 243 | " \"sparse\": False,\n", 244 | " \"k_dim\": 128,\n", 245 | " \"heads\": 4,\n", 246 | " \"knn\": 32,\n", 247 | " \"n_keys\": 512, # the memory will have (n_keys ** 2) values\n", 248 | " \"query_batchnorm\": True,\n", 249 | " \"input_dropout\": 0,\n", 250 | " \"query_dropout\": 0,\n", 251 | " \"value_dropout\": 0,\n", 252 | "})" 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": 5, 258 | "metadata": {}, 259 | "outputs": [ 260 | { 261 | "name": "stdout", 262 | "output_type": "stream", 263 | "text": [ 264 | "WARNING: Applying batch normalization to queries improves the performance and memory usage. But if you use it, be sure that you use batches of sentences with the same size at training time (i.e. without padding). Otherwise, the padding token will result in incorrect mean/variance estimations in the BatchNorm layer.\n", 265 | "\n", 266 | "HashingMemory(\n", 267 | " (values): EmbeddingBag(262144, 100, mode=sum)\n", 268 | " (query_proj): Sequential(\n", 269 | " (0): Linear(in_features=50, out_features=512, bias=True)\n", 270 | " (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 271 | " )\n", 272 | ")\n" 273 | ] 274 | } 275 | ], 276 | "source": [ 277 | "device = 'cuda' # cpu / cuda\n", 278 | "input_dim = 50\n", 279 | "output_dim = 100\n", 280 | "memory = HashingMemory(input_dim, output_dim, params).to(device=device)\n", 281 | "print(memory)" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": 6, 287 | "metadata": {}, 288 | "outputs": [ 289 | { 290 | "name": "stdout", 291 | "output_type": "stream", 292 | "text": [ 293 | "0.14277362823486328\n", 294 | "torch.Size([2, 3, 4, 100])\n" 295 | ] 296 | } 297 | ], 298 | "source": [ 299 | "x = torch.randn(2, 3, 4, input_dim).to(device=device)\n", 300 | "output = memory(x)\n", 301 | "print(output.sum().item())\n", 302 | "print(output.shape)" 303 | ] 304 | } 305 | ], 306 | "metadata": { 307 | "kernelspec": { 308 | "display_name": "Python 3", 309 | "language": "python", 310 | "name": "python3" 311 | }, 312 | "language_info": { 313 | "codemirror_mode": { 314 | "name": "ipython", 315 | "version": 3 316 | }, 317 | "file_extension": ".py", 318 | "mimetype": "text/x-python", 319 | "name": "python", 320 | "nbconvert_exporter": "python", 321 | "pygments_lexer": "ipython3", 322 | "version": "3.6.4" 323 | } 324 | }, 325 | "nbformat": 4, 326 | "nbformat_minor": 2 327 | } 328 | -------------------------------------------------------------------------------- /generate-embeddings.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# Copyright (c) 2019-present, Facebook, Inc.\n", 10 | "# All rights reserved.\n", 11 | "#\n", 12 | "# This source code is licensed under the license found in the\n", 13 | "# LICENSE file in the root directory of this source tree.\n", 14 | "#" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 2, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "#\n", 24 | "# Code to generate sentence representations from a pretrained model.\n", 25 | "# This can be used to initialize a cross-lingual classifier, for instance.\n", 26 | "#" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 3, 32 | "metadata": {}, 33 | "outputs": [ 34 | { 35 | "name": "stderr", 36 | "output_type": "stream", 37 | "text": [ 38 | "FAISS library was not found.\n", 39 | "FAISS not available. Switching to standard nearest neighbors search implementation.\n" 40 | ] 41 | } 42 | ], 43 | "source": [ 44 | "import os\n", 45 | "import torch\n", 46 | "\n", 47 | "from xlm.utils import AttrDict\n", 48 | "from xlm.data.dictionary import Dictionary, BOS_WORD, EOS_WORD, PAD_WORD, UNK_WORD, MASK_WORD\n", 49 | "from xlm.model.transformer import TransformerModel" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "## Reload a pretrained model" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 4, 62 | "metadata": {}, 63 | "outputs": [ 64 | { 65 | "name": "stdout", 66 | "output_type": "stream", 67 | "text": [ 68 | "Supported languages: af, als, am, an, ang, ar, arz, ast, az, bar, be, bg, bn, br, bs, ca, ceb, ckb, cs, cy, da, de, el, en, eo, es, et, eu, fa, fi, fr, fy, ga, gan, gl, gu, he, hi, hr, hu, hy, ia, id, is, it, ja, jv, ka, kk, kn, ko, ku, la, lb, lt, lv, mk, ml, mn, mr, ms, my, nds, ne, nl, nn, no, oc, pl, pt, ro, ru, scn, sco, sh, si, simple, sk, sl, sq, sr, sv, sw, ta, te, th, tl, tr, tt, uk, ur, uz, vi, war, wuu, yi, zh, zh_classical, zh_min_nan, zh_yue\n" 69 | ] 70 | } 71 | ], 72 | "source": [ 73 | "model_path = 'models/mlm_100_1280.pth'\n", 74 | "reloaded = torch.load(model_path)\n", 75 | "params = AttrDict(reloaded['params'])\n", 76 | "print(\"Supported languages: %s\" % \", \".join(params.lang2id.keys()))" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": {}, 82 | "source": [ 83 | "## Build dictionary / update parameters / build model" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 5, 89 | "metadata": {}, 90 | "outputs": [ 91 | { 92 | "data": { 93 | "text/plain": [ 94 | "" 95 | ] 96 | }, 97 | "execution_count": 5, 98 | "metadata": {}, 99 | "output_type": "execute_result" 100 | } 101 | ], 102 | "source": [ 103 | "# build dictionary / update parameters\n", 104 | "dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts'])\n", 105 | "params.n_words = len(dico)\n", 106 | "params.bos_index = dico.index(BOS_WORD)\n", 107 | "params.eos_index = dico.index(EOS_WORD)\n", 108 | "params.pad_index = dico.index(PAD_WORD)\n", 109 | "params.unk_index = dico.index(UNK_WORD)\n", 110 | "params.mask_index = dico.index(MASK_WORD)\n", 111 | "\n", 112 | "# build model / reload weights\n", 113 | "model = TransformerModel(params, dico, True, True)\n", 114 | "model.eval()\n", 115 | "model.load_state_dict(reloaded['model'])" 116 | ] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "metadata": {}, 121 | "source": [ 122 | "\n", 123 | "## Get sentence representations" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "metadata": {}, 129 | "source": [ 130 | "Sentences have to be in the BPE format, i.e. tokenized sentences on which you applied fastBPE.\n", 131 | "\n", 132 | "Below you can see an example for English, French, Spanish, German, Arabic and Chinese sentences." 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 6, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "# Below is one way to bpe-ize sentences\n", 142 | "codes = \"\" # path to the codes of the model\n", 143 | "fastbpe = os.path.join(os.getcwd(), 'tools/fastBPE/fast')\n", 144 | "\n", 145 | "def to_bpe(sentences):\n", 146 | " # write sentences to tmp file\n", 147 | " with open('/tmp/sentences.bpe', 'w') as fwrite:\n", 148 | " for sent in sentences:\n", 149 | " fwrite.write(sent + '\\n')\n", 150 | " \n", 151 | " # apply bpe to tmp file\n", 152 | " os.system('%s applybpe /tmp/sentences.bpe /tmp/sentences %s' % (fastbpe, codes))\n", 153 | " \n", 154 | " # load bpe-ized sentences\n", 155 | " sentences_bpe = []\n", 156 | " with open('/tmp/sentences.bpe') as f:\n", 157 | " for line in f:\n", 158 | " sentences_bpe.append(line.rstrip())\n", 159 | " \n", 160 | " return sentences_bpe\n" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 7, 166 | "metadata": {}, 167 | "outputs": [ 168 | { 169 | "name": "stdout", 170 | "output_type": "stream", 171 | "text": [ 172 | "once he had worn tr@@ end@@ y italian le@@ ather sho@@ es and je@@ ans from paris that had cost three hundred euros .\n", 173 | "\n", 174 | "Le français est la seule langue étrang@@ ère propo@@ sée dans le système é@@ duc@@ atif .\n", 175 | "\n", 176 | "El cad@@ mio produce efectos tó@@ x@@ icos en los organismos vivos , aun en concentra@@ ciones muy pequeñas .\n", 177 | "\n", 178 | "Nach dem Zweiten Weltkrieg verbre@@ it@@ ete sich Bon@@ sai als Hob@@ by in der ganzen Welt .\n", 179 | "\n", 180 | "وقد فاز في الانتخابات في الج@@ ولة الثانية من التص@@ ويت من قبل سيدي ولد الشيخ عبد الله ، مع أحمد ولد دا@@ دا@@ ه في المرتبة الثانية .\n", 181 | "\n", 182 | "羅@@ 伯特 · 皮@@ 爾 斯 生於 186@@ 3年 , 在 英國 曼@@ 徹@@ 斯特 學習 而 成為 一 位 工程@@ 師 . 193@@ 3年 , 皮@@ 爾@@ 斯 在 直@@ 布@@ 羅@@ 陀@@ 去世 .\n", 183 | "Number of out-of-vocab words: 0/144\n" 184 | ] 185 | } 186 | ], 187 | "source": [ 188 | "# Below are already BPE-ized sentences\n", 189 | "\n", 190 | "# list of (sentences, lang)\n", 191 | "sentences = [\n", 192 | " 'once he had worn trendy italian leather shoes and jeans from paris that had cost three hundred euros .', # en\n", 193 | " 'Le français est la seule langue étrangère proposée dans le système éducatif .', # fr\n", 194 | " 'El cadmio produce efectos tóxicos en los organismos vivos , aun en concentraciones muy pequeñas .', # es\n", 195 | " 'Nach dem Zweiten Weltkrieg verbreitete sich Bonsai als Hobby in der ganzen Welt .', # de\n", 196 | " 'وقد فاز في الانتخابات في الجولة الثانية من التصويت من قبل سيدي ولد الشيخ عبد الله ، مع أحمد ولد داداه في المرتبة الثانية .', # ar\n", 197 | " '羅伯特 · 皮爾 斯 生於 1863年 , 在 英國 曼徹斯特 學習 而 成為 一 位 工程師 . 1933年 , 皮爾斯 在 直布羅陀去世 .', # zh\n", 198 | "]\n", 199 | "\n", 200 | "# bpe-ize sentences\n", 201 | "sentences = to_bpe(sentences)\n", 202 | "print('\\n\\n'.join(sentences))\n", 203 | "\n", 204 | "# check how many tokens are OOV\n", 205 | "n_w = len([w for w in ' '.join(sentences).split()])\n", 206 | "n_oov = len([w for w in ' '.join(sentences).split() if w not in dico.word2id])\n", 207 | "print('Number of out-of-vocab words: %s/%s' % (n_oov, n_w))\n", 208 | "\n", 209 | "# add sentence delimiters\n", 210 | "sentences = [((' %s ' % sent.strip()).split()) for sent in sentences]" 211 | ] 212 | }, 213 | { 214 | "cell_type": "markdown", 215 | "metadata": {}, 216 | "source": [ 217 | "### Create batch" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": 8, 223 | "metadata": {}, 224 | "outputs": [], 225 | "source": [ 226 | "bs = len(sentences)\n", 227 | "slen = max([len(sent) for sent in sentences])\n", 228 | "\n", 229 | "word_ids = torch.LongTensor(slen, bs).fill_(params.pad_index)\n", 230 | "for i in range(len(sentences)):\n", 231 | " sent = torch.LongTensor([dico.index(w) for w in sentences[i]])\n", 232 | " word_ids[:len(sent), i] = sent\n", 233 | "\n", 234 | "lengths = torch.LongTensor([len(sent) for sent in sentences])\n", 235 | " \n", 236 | "# NOTE: No more language id (removed it in a later version)\n", 237 | "# langs = torch.LongTensor([params.lang2id[lang] for _, lang in sentences]).unsqueeze(0).expand(slen, bs) if params.n_langs > 1 else None\n", 238 | "langs = None\n" 239 | ] 240 | }, 241 | { 242 | "cell_type": "markdown", 243 | "metadata": {}, 244 | "source": [ 245 | "### Forward" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": 9, 251 | "metadata": {}, 252 | "outputs": [ 253 | { 254 | "name": "stdout", 255 | "output_type": "stream", 256 | "text": [ 257 | "torch.Size([38, 6, 1280])\n" 258 | ] 259 | } 260 | ], 261 | "source": [ 262 | "tensor = model('fwd', x=word_ids, lengths=lengths, langs=langs, causal=False).contiguous()\n", 263 | "print(tensor.size())" 264 | ] 265 | }, 266 | { 267 | "cell_type": "markdown", 268 | "metadata": {}, 269 | "source": [ 270 | "The variable `tensor` is of shape `(sequence_length, batch_size, model_dimension)`.\n", 271 | "\n", 272 | "`tensor[0]` is a tensor of shape `(batch_size, model_dimension)` that corresponds to the first hidden state of the last layer of each sentence.\n", 273 | "\n", 274 | "This is this vector that we use to finetune on the GLUE and XNLI tasks." 275 | ] 276 | } 277 | ], 278 | "metadata": { 279 | "kernelspec": { 280 | "display_name": "Python 3", 281 | "language": "python", 282 | "name": "python3" 283 | }, 284 | "language_info": { 285 | "codemirror_mode": { 286 | "name": "ipython", 287 | "version": 3 288 | }, 289 | "file_extension": ".py", 290 | "mimetype": "text/x-python", 291 | "name": "python", 292 | "nbconvert_exporter": "python", 293 | "pygments_lexer": "ipython3", 294 | "version": "3.7.3" 295 | } 296 | }, 297 | "nbformat": 4, 298 | "nbformat_minor": 2 299 | } 300 | -------------------------------------------------------------------------------- /get-data-glue.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | set -e 9 | 10 | # data paths 11 | MAIN_PATH=$PWD 12 | OUTPATH=$PWD/data/glue 13 | URLPATH=https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2F 14 | 15 | # tools paths 16 | TOOLS_PATH=$PWD/tools 17 | TOKENIZE=$TOOLS_PATH/tokenize.sh 18 | MOSES=$TOOLS_PATH/mosesdecoder 19 | REPLACE_UNICODE_PUNCT=$MOSES/scripts/tokenizer/replace-unicode-punctuation.perl 20 | NORM_PUNC=$MOSES/scripts/tokenizer/normalize-punctuation.perl 21 | REM_NON_PRINT_CHAR=$MOSES/scripts/tokenizer/remove-non-printing-char.perl 22 | LOWER_REMOVE_ACCENT=$TOOLS_PATH/lowercase_and_remove_accent.py 23 | 24 | 25 | # install tools 26 | ./install-tools.sh 27 | 28 | # create directories 29 | # rm -r $OUTPATH 30 | mkdir -p $OUTPATH 31 | 32 | 33 | # SST-2 34 | if [ ! -d $OUTPATH/SST-2 ]; then 35 | if [ ! -f $OUTPATH/SST-2zip ]; then 36 | wget -c "${URLPATH}SST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8" -P $OUTPATH 37 | fi 38 | unzip $OUTPATH/*SST-2* -d $OUTPATH 39 | for split in train dev 40 | do 41 | sed '1d' $OUTPATH/SST-2/${split}.tsv | cut -f1 | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l en | $REM_NON_PRINT_CHAR > $OUTPATH/SST-2/${split}.x 42 | sed '1d' $OUTPATH/SST-2/${split}.tsv | cut -f2 > $OUTPATH/SST-2/${split}.y 43 | paste $OUTPATH/SST-2/${split}.x $OUTPATH/SST-2/${split}.y > $OUTPATH/SST-2/${split}.xlm.tsv 44 | rm $OUTPATH/SST-2/${split}.x $OUTPATH/SST-2/${split}.y 45 | done 46 | sed '1d' $OUTPATH/SST-2/test.tsv | cut -f2 | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l en | $REM_NON_PRINT_CHAR > $OUTPATH/SST-2/test.xlm.tsv 47 | rm $OUTPATH/*SST-2.zip* 48 | 49 | fi 50 | 51 | # SST-B 52 | if [ ! -d $OUTPATH/STS-B ]; then 53 | if [ ! -f $OUTPATH/STS-B.zip ]; then 54 | wget -c "${URLPATH}STS-B.zip?alt=media&token=bddb94a7-8706-4e0d-a694-1109e12273b5" -P $OUTPATH 55 | fi 56 | unzip $OUTPATH/*STS-B* -d $OUTPATH 57 | for split in train dev test 58 | do 59 | sed '1d' $OUTPATH/STS-B/${split}.tsv | cut -f8 | $TOKENIZE en | python $LOWER_REMOVE_ACCENT > $OUTPATH/STS-B/${split}.x1 60 | sed '1d' $OUTPATH/STS-B/${split}.tsv | cut -f9 | $TOKENIZE en | python $LOWER_REMOVE_ACCENT > $OUTPATH/STS-B/${split}.x2 61 | if [ "$split" != "test" ]; then 62 | sed '1d' $OUTPATH/STS-B/${split}.tsv | cut -f10 > $OUTPATH/STS-B/${split}.y 63 | paste $OUTPATH/STS-B/${split}.x1 $OUTPATH/STS-B/${split}.x2 $OUTPATH/STS-B/${split}.y > $OUTPATH/STS-B/${split}.xlm.tsv 64 | rm $OUTPATH/STS-B/${split}.x1 $OUTPATH/STS-B/${split}.x2 $OUTPATH/STS-B/${split}.y 65 | else 66 | paste $OUTPATH/STS-B/${split}.x1 $OUTPATH/STS-B/${split}.x2 > $OUTPATH/STS-B/${split}.xlm.tsv 67 | rm $OUTPATH/STS-B/${split}.x1 $OUTPATH/STS-B/${split}.x2 68 | fi 69 | done 70 | rm $OUTPATH/*STS-B.zip* 71 | 72 | fi 73 | 74 | # MNLI 75 | if [ ! -d $OUTPATH/MNLI ]; then 76 | if [ ! -f $OUTPATH/MNLI.zip ]; then 77 | wget -c "${URLPATH}MNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce" -P $OUTPATH 78 | fi 79 | unzip $OUTPATH/*MNLI* -d $OUTPATH 80 | mv $OUTPATH/MNLI/dev_matched.tsv $OUTPATH/MNLI/dev.tsv 81 | mv $OUTPATH/MNLI/test_matched.tsv $OUTPATH/MNLI/test.tsv 82 | for split in train dev test 83 | do 84 | sed '1d' $OUTPATH/MNLI/${split}.tsv | cut -f9 | $TOKENIZE en | python $LOWER_REMOVE_ACCENT > $OUTPATH/MNLI/${split}.x1 85 | sed '1d' $OUTPATH/MNLI/${split}.tsv | cut -f10 | $TOKENIZE en | python $LOWER_REMOVE_ACCENT > $OUTPATH/MNLI/${split}.x2 86 | sed '1d' $OUTPATH/MNLI/${split}.tsv | cut -f12 > $OUTPATH/MNLI/${split}.y 87 | paste $OUTPATH/MNLI/${split}.x1 $OUTPATH/MNLI/${split}.x2 $OUTPATH/MNLI/${split}.y > $OUTPATH/MNLI/${split}.xlm.tsv 88 | rm $OUTPATH/MNLI/${split}.x1 $OUTPATH/MNLI/${split}.x2 $OUTPATH/MNLI/${split}.y 89 | done 90 | rm $OUTPATH/*MNLI.zip* 91 | mv $OUTPATH/MNLI $OUTPATH/MNLI-m 92 | 93 | fi 94 | 95 | # QNLI 96 | if [ ! -d $OUTPATH/QNLI ]; then 97 | if [ ! -f $OUTPATH/QNLIv2.zip ]; then 98 | wget -c "${URLPATH}QNLIv2.zip?alt=media&token=6fdcf570-0fc5-4631-8456-9505272d1601" -P $OUTPATH 99 | fi 100 | unzip $OUTPATH/*QNLIv2* -d $OUTPATH 101 | for split in train dev test 102 | do 103 | sed '1d' $OUTPATH/QNLI/${split}.tsv | cut -f2 | $TOKENIZE en | python $LOWER_REMOVE_ACCENT > $OUTPATH/QNLI/${split}.x1 104 | sed '1d' $OUTPATH/QNLI/${split}.tsv | cut -f3 | $TOKENIZE en | python $LOWER_REMOVE_ACCENT > $OUTPATH/QNLI/${split}.x2 105 | if [ "$split" != "test" ]; then 106 | sed '1d' $OUTPATH/QNLI/${split}.tsv | cut -f4 > $OUTPATH/QNLI/${split}.y 107 | paste $OUTPATH/QNLI/${split}.x1 $OUTPATH/QNLI/${split}.x2 $OUTPATH/QNLI/${split}.y > $OUTPATH/QNLI/${split}.xlm.tsv 108 | rm $OUTPATH/QNLI/${split}.x1 $OUTPATH/QNLI/${split}.x2 $OUTPATH/QNLI/${split}.y 109 | else 110 | paste $OUTPATH/QNLI/${split}.x1 $OUTPATH/QNLI/${split}.x2 > $OUTPATH/QNLI/${split}.xlm.tsv 111 | rm $OUTPATH/QNLI/${split}.x1 $OUTPATH/QNLI/${split}.x2 112 | fi 113 | done 114 | rm $OUTPATH/*QNLIv2.zip* 115 | 116 | fi 117 | 118 | # QQP 119 | if [ ! -d $OUTPATH/QQP ]; then 120 | if [ ! -f $OUTPATH/QQP.zip ]; then 121 | wget -c "${URLPATH}QQP.zip?alt=media&token=700c6acf-160d-4d89-81d1-de4191d02cb5" -P $OUTPATH 122 | fi 123 | unzip $OUTPATH/*QQP* -d $OUTPATH 124 | for split in train dev test 125 | do 126 | if [ "$split" != "test" ]; then 127 | sed '1d' $OUTPATH/QQP/${split}.tsv | cut -f4 | $TOKENIZE en | python $LOWER_REMOVE_ACCENT > $OUTPATH/QQP/${split}.x1 128 | sed '1d' $OUTPATH/QQP/${split}.tsv | cut -f5 | $TOKENIZE en | python $LOWER_REMOVE_ACCENT > $OUTPATH/QQP/${split}.x2 129 | sed '1d' $OUTPATH/QQP/${split}.tsv | cut -f6 > $OUTPATH/QQP/${split}.y 130 | paste $OUTPATH/QQP/${split}.x1 $OUTPATH/QQP/${split}.x2 $OUTPATH/QQP/${split}.y > $OUTPATH/QQP/${split}.xlm.tsv 131 | rm $OUTPATH/QQP/${split}.x1 $OUTPATH/QQP/${split}.x2 $OUTPATH/QQP/${split}.y 132 | else 133 | sed '1d' $OUTPATH/QQP/${split}.tsv | cut -f2 | $TOKENIZE en | python $LOWER_REMOVE_ACCENT > $OUTPATH/QQP/${split}.x1 134 | sed '1d' $OUTPATH/QQP/${split}.tsv | cut -f3 | $TOKENIZE en | python $LOWER_REMOVE_ACCENT > $OUTPATH/QQP/${split}.x2 135 | paste $OUTPATH/QQP/${split}.x1 $OUTPATH/QQP/${split}.x2 > $OUTPATH/QQP/${split}.xlm.tsv 136 | rm $OUTPATH/QQP/${split}.x1 $OUTPATH/QQP/${split}.x2 137 | fi 138 | done 139 | rm $OUTPATH/*QQP.zip* 140 | 141 | fi 142 | 143 | -------------------------------------------------------------------------------- /get-data-para.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | # 9 | # Usage: ./get-data-para.sh $lg_pair 10 | # 11 | 12 | set -e 13 | 14 | pair=$1 # input language pair 15 | 16 | # data paths 17 | MAIN_PATH=$PWD 18 | PARA_PATH=$PWD/data/para 19 | 20 | # tools paths 21 | TOOLS_PATH=$PWD/tools 22 | TOKENIZE=$TOOLS_PATH/tokenize.sh 23 | LOWER_REMOVE_ACCENT=$TOOLS_PATH/lowercase_and_remove_accent.py 24 | 25 | # install tools 26 | ./install-tools.sh 27 | 28 | # create directories 29 | mkdir -p $PARA_PATH 30 | 31 | 32 | # 33 | # Download and uncompress data 34 | # 35 | 36 | # ar-en 37 | if [ $pair == "ar-en" ]; then 38 | # OpenSubtitles 2018 39 | # wget -c http://opus.nlpl.eu/download.php?f=OpenSubtitles2018%2Far-en.txt.zip -P $PARA_PATH 40 | # MultiUN 41 | wget -c http://opus.nlpl.eu/download.php?f=MultiUN%2Far-en.txt.zip -P $PARA_PATH 42 | unzip -u $PARA_PATH/download.php?f=MultiUN%2Far-en.txt.zip -d $PARA_PATH 43 | fi 44 | 45 | # bg-en 46 | if [ $pair == "bg-en" ]; then 47 | # OpenSubtitles 2018 48 | # wget -c http://opus.nlpl.eu/download.php?f=OpenSubtitles2018%2Fbg-en.txt.zip -P $PARA_PATH 49 | # EU Bookshop 50 | wget -c http://opus.nlpl.eu/download.php?f=EUbookshop%2Fbg-en.txt.zip -P $PARA_PATH 51 | unzip -u $PARA_PATH/download.php?f=EUbookshop%2Fbg-en.txt.zip -d $PARA_PATH 52 | # Europarl 53 | wget -c http://opus.nlpl.eu/download.php?f=Europarl%2Fbg-en.txt.zip -P $PARA_PATH 54 | unzip -u $PARA_PATH/download.php?f=Europarl%2Fbg-en.txt.zip -d $PARA_PATH 55 | fi 56 | 57 | # de-en 58 | if [ $pair == "de-en" ]; then 59 | # OpenSubtitles 2018 60 | # wget -c http://opus.nlpl.eu/download.php?f=OpenSubtitles2018%2Fde-en.txt.zip -P $PARA_PATH 61 | # EU Bookshop 62 | wget -c http://opus.nlpl.eu/download.php?f=EUbookshop%2Fde-en.txt.zip -P $PARA_PATH 63 | unzip -u $PARA_PATH/download.php?f=EUbookshop%2Fde-en.txt.zip -d $PARA_PATH 64 | fi 65 | 66 | # el-en 67 | if [ $pair == "el-en" ]; then 68 | # OpenSubtitles 2018 69 | # wget -c http://opus.nlpl.eu/download.php?f=OpenSubtitles2018%2Fel-en.txt.zip -P $PARA_PATH 70 | # EU Bookshop 71 | wget -c http://opus.nlpl.eu/download.php?f=EUbookshop%2Fel-en.txt.zip -P $PARA_PATH 72 | unzip -u $PARA_PATH/download.php?f=EUbookshop%2Fel-en.txt.zip -d $PARA_PATH 73 | fi 74 | 75 | # en-es 76 | if [ $pair == "en-es" ]; then 77 | # OpenSubtitles 2018 78 | # wget -c http://opus.nlpl.eu/download.php?f=OpenSubtitles2018%2Fen-es.txt.zip -P $PARA_PATH 79 | # EU Bookshop 80 | # wget -c http://opus.nlpl.eu/download.php?f=EUbookshop%2Fen-es.txt.zip -P $PARA_PATH 81 | # MultiUN 82 | wget -c https://object.pouta.csc.fi/OPUS-MultiUN/v1/moses/en-es.txt.zip -P $PARA_PATH 83 | unzip -u $PARA_PATH/en-es.txt.zip -d $PARA_PATH 84 | fi 85 | 86 | # en-fr 87 | if [ $pair == "en-fr" ]; then 88 | echo "Download parallel data for English-Hindi" 89 | # OpenSubtitles 2018 90 | # wget -c http://opus.nlpl.eu/download.php?f=OpenSubtitles2018%2Fen-fr.txt.zip -P $PARA_PATH 91 | # EU Bookshop 92 | # wget -c http://opus.nlpl.eu/download.php?f=EUbookshop%2Fen-fr.txt.zip -P $PARA_PATH 93 | # MultiUN 94 | wget -c https://object.pouta.csc.fi/OPUS-MultiUN/v1/moses/en-fr.txt.zip -P $PARA_PATH 95 | unzip -u $PARA_PATH/en-fr.txt.zip -d $PARA_PATH 96 | fi 97 | 98 | # en-hi 99 | if [ $pair == "en-hi" ]; then 100 | echo "Download parallel data for English-Hindi" 101 | # IIT Bombay English-Hindi Parallel Corpus 102 | wget -c http://www.cfilt.iitb.ac.in/iitb_parallel/iitb_corpus_download/parallel.tgz -P $PARA_PATH 103 | tar -xvf $PARA_PATH/parallel.tgz -d $PARA_PATH 104 | fi 105 | 106 | # en-ru 107 | if [ $pair == "en-ru" ]; then 108 | echo "Download parallel data for English-Russian" 109 | # OpenSubtitles 2018 110 | # wget -c http://opus.nlpl.eu/download.php?f=OpenSubtitles2018%2Fen-ru.txt.zip -P $PARA_PATH 111 | # MultiUN 112 | wget -c http://opus.nlpl.eu/download.php?f=MultiUN%2Fen-ru.txt.zip -P $PARA_PATH 113 | unzip -u download.php?f=MultiUN%2Fen-ru.txt.zip -d $PARA_PATH 114 | fi 115 | 116 | # en-sw 117 | if [ $pair == "en-sw" ]; then 118 | echo "Download parallel data for English-Swahili" 119 | # Tanzil 120 | wget -c http://opus.nlpl.eu/download.php?f=Tanzil%2Fen-sw.txt.zip -P $PARA_PATH 121 | unzip -u download.php?f=Tanzil%2Fen-sw.txt.zip -d $PARA_PATH 122 | # GlobalVoices 123 | wget -c http://opus.nlpl.eu/download.php?f=GlobalVoices%2Fen-sw.txt.zip -P $PARA_PATH 124 | unzip -u download.php?f=GlobalVoices%2Fen-sw.txt.zip -d $PARA_PATH 125 | fi 126 | 127 | # en-th 128 | if [ $pair == "en-th" ]; then 129 | echo "Download parallel data for English-Thai" 130 | # OpenSubtitles 2018 131 | wget -c http://opus.nlpl.eu/download.php?f=OpenSubtitles2018%2Fen-th.txt.zip -P $PARA_PATH 132 | unzip -u $PARA_PATH/download.php?f=OpenSubtitles2018%2Fen-th.txt.zip -d $PARA_PATH 133 | fi 134 | 135 | # en-tr 136 | if [ $pair == "en-tr" ]; then 137 | echo "Download parallel data for English-Turkish" 138 | # OpenSubtitles 2018 139 | # wget -c http://opus.nlpl.eu/download.php?f=OpenSubtitles2018%2Fen-tr.txt.zip -P $PARA_PATH 140 | # SETIMES2 141 | wget -c http://opus.nlpl.eu/download.php?f=SETIMES2%2Fen-tr.txt.zip -P $PARA_PATH 142 | unzip -u $PARA_PATH/download.php?f=SETIMES2%2Fen-tr.txt.zip -d $PARA_PATH 143 | # Wikipedia 144 | wget -c http://opus.nlpl.eu/download.php?f=Wikipedia%2Fen-tr.txt.zip -P $PARA_PATH 145 | unzip -u $PARA_PATH/download.php?f=Wikipedia%2Fen-tr.txt.zip -d $PARA_PATH 146 | # TED 147 | wget -c https://object.pouta.csc.fi/OPUS-TED2013/v1.1/moses/en-tr.txt.zip -P $PARA_PATH 148 | unzip -u $PARA_PATH/en-tr.txt.zip -d $PARA_PATH 149 | fi 150 | 151 | # en-ur 152 | if [ $pair == "en-ur" ]; then 153 | echo "Download parallel data for English-Urdu" 154 | # OpenSubtitles 2018 155 | # wget -c http://opus.nlpl.eu/download.php?f=OpenSubtitles2018%2Fen-ur.txt.zip -P $PARA_PATH 156 | # Tanzil 157 | wget -c http://opus.nlpl.eu/download.php?f=Tanzil%2Fen-ur.txt.zip -P $PARA_PATH 158 | unzip -u $PARA_PATH/download.php?f=Tanzil%2Fen-ur.txt.zip -d $PARA_PATH 159 | fi 160 | 161 | # en-vi 162 | if [ $pair == "en-vi" ]; then 163 | echo "Download parallel data for English-Vietnamese" 164 | # OpenSubtitles 2018 165 | wget -c http://opus.nlpl.eu/download.php?f=OpenSubtitles2018%2Fen-vi.txt.zip -P $PARA_PATH 166 | unzip -u $PARA_PATH/download.php?f=OpenSubtitles2018%2Fen-vi.txt.zip -d $PARA_PATH 167 | fi 168 | 169 | # en-zh 170 | if [ $pair == "en-zh" ]; then 171 | echo "Download parallel data for English-Chinese" 172 | # OpenSubtitles 2016 173 | # wget -c http://opus.nlpl.eu/download.php?f=OpenSubtitles2016%2Fen-zh.txt.zip -P $PARA_PATH 174 | # MultiUN 175 | wget -c http://opus.nlpl.eu/download.php?f=MultiUN%2Fen-zh.txt.zip -P $PARA_PATH 176 | unzip -u $PARA_PATH/download.php?f=MultiUN%2Fen-zh.txt.zip -d $PARA_PATH 177 | fi 178 | 179 | 180 | # 181 | # Tokenize and preprocess data 182 | # 183 | 184 | # tokenize 185 | for lg in $(echo $pair | sed -e 's/\-/ /g'); do 186 | if [ ! -f $PARA_PATH/$pair.$lg.all ]; then 187 | cat $PARA_PATH/*.$pair.$lg | $TOKENIZE $lg | python $LOWER_REMOVE_ACCENT > $PARA_PATH/$pair.$lg.all 188 | fi 189 | done 190 | 191 | # split into train / valid / test 192 | split_data() { 193 | get_seeded_random() { 194 | seed="$1"; openssl enc -aes-256-ctr -pass pass:"$seed" -nosalt /dev/null 195 | }; 196 | NLINES=`wc -l $1 | awk -F " " '{print $1}'`; 197 | NTRAIN=$((NLINES - 10000)); 198 | NVAL=$((NTRAIN + 5000)); 199 | shuf --random-source=<(get_seeded_random 42) $1 | head -$NTRAIN > $2; 200 | shuf --random-source=<(get_seeded_random 42) $1 | head -$NVAL | tail -5000 > $3; 201 | shuf --random-source=<(get_seeded_random 42) $1 | tail -5000 > $4; 202 | } 203 | for lg in $(echo $pair | sed -e 's/\-/ /g'); do 204 | split_data $PARA_PATH/$pair.$lg.all $PARA_PATH/$pair.$lg.train $PARA_PATH/$pair.$lg.valid $PARA_PATH/$pair.$lg.test 205 | done 206 | 207 | -------------------------------------------------------------------------------- /get-data-wiki.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | # 9 | # Usage: ./get-data-wiki.sh $lg 10 | # 11 | 12 | set -e 13 | 14 | lg=$1 # input language 15 | 16 | # data path 17 | MAIN_PATH=$PWD 18 | WIKI_PATH=$PWD/data/wiki 19 | 20 | # tools paths 21 | TOOLS_PATH=$PWD/tools 22 | TOKENIZE=$TOOLS_PATH/tokenize.sh 23 | LOWER_REMOVE_ACCENT=$TOOLS_PATH/lowercase_and_remove_accent.py 24 | 25 | # Wiki data 26 | WIKI_DUMP_NAME=${lg}wiki-latest-pages-articles.xml.bz2 27 | WIKI_DUMP_LINK=https://dumps.wikimedia.org/${lg}wiki/latest/$WIKI_DUMP_NAME 28 | 29 | # install tools 30 | ./install-tools.sh 31 | 32 | # create Wiki paths 33 | mkdir -p $WIKI_PATH/bz2 34 | mkdir -p $WIKI_PATH/txt 35 | 36 | # download Wikipedia dump 37 | echo "Downloading $lg Wikipedia dump from $WIKI_DUMP_LINK ..." 38 | wget -c $WIKI_DUMP_LINK -P $WIKI_PATH/bz2/ 39 | echo "Downloaded $WIKI_DUMP_NAME in $WIKI_PATH/bz2/$WIKI_DUMP_NAME" 40 | 41 | # extract and tokenize Wiki data 42 | cd $MAIN_PATH 43 | echo "*** Cleaning and tokenizing $lg Wikipedia dump ... ***" 44 | if [ ! -f $WIKI_PATH/txt/$lg.all ]; then 45 | python $TOOLS_PATH/wikiextractor/WikiExtractor.py $WIKI_PATH/bz2/$WIKI_DUMP_NAME --processes 8 -q -o - \ 46 | | sed "/^\s*\$/d" \ 47 | | grep -v "^\$" \ 49 | | $TOKENIZE $lg \ 50 | | python $LOWER_REMOVE_ACCENT \ 51 | > $WIKI_PATH/txt/$lg.all 52 | fi 53 | echo "*** Tokenized (+ lowercase + accent-removal) $lg Wikipedia dump to $WIKI_PATH/txt/train.${lg} ***" 54 | 55 | # split into train / valid / test 56 | echo "*** Split into train / valid / test ***" 57 | split_data() { 58 | get_seeded_random() { 59 | seed="$1"; openssl enc -aes-256-ctr -pass pass:"$seed" -nosalt /dev/null 60 | }; 61 | NLINES=`wc -l $1 | awk -F " " '{print $1}'`; 62 | NTRAIN=$((NLINES - 10000)); 63 | NVAL=$((NTRAIN + 5000)); 64 | shuf --random-source=<(get_seeded_random 42) $1 | head -$NTRAIN > $2; 65 | shuf --random-source=<(get_seeded_random 42) $1 | head -$NVAL | tail -5000 > $3; 66 | shuf --random-source=<(get_seeded_random 42) $1 | tail -5000 > $4; 67 | } 68 | split_data $WIKI_PATH/txt/$lg.all $WIKI_PATH/txt/$lg.train $WIKI_PATH/txt/$lg.valid $WIKI_PATH/txt/$lg.test 69 | 70 | -------------------------------------------------------------------------------- /get-data-xnli.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | # 9 | # Usage: ./get-data-xnli.sh 10 | # 11 | 12 | set -e 13 | 14 | # data paths 15 | MAIN_PATH=$PWD 16 | OUTPATH=$PWD/data/xnli 17 | XNLI_PATH=$PWD/data/xnli/XNLI-1.0 18 | 19 | # tools paths 20 | TOOLS_PATH=$PWD/tools 21 | TOKENIZE=$TOOLS_PATH/tokenize.sh 22 | LOWER_REMOVE_ACCENT=$TOOLS_PATH/lowercase_and_remove_accent.py 23 | 24 | # install tools 25 | ./install-tools.sh 26 | 27 | # create directories 28 | mkdir -p $OUTPATH 29 | 30 | # download data 31 | if [ ! -d $OUTPATH/XNLI-MT-1.0 ]; then 32 | if [ ! -f $OUTPATH/XNLI-MT-1.0.zip ]; then 33 | wget -c https://dl.fbaipublicfiles.com/XNLI/XNLI-MT-1.0.zip -P $OUTPATH 34 | fi 35 | unzip $OUTPATH/XNLI-MT-1.0.zip -d $OUTPATH 36 | fi 37 | if [ ! -d $OUTPATH/XNLI-1.0 ]; then 38 | if [ ! -f $OUTPATH/XNLI-1.0.zip ]; then 39 | wget -c https://dl.fbaipublicfiles.com/XNLI/XNLI-1.0.zip -P $OUTPATH 40 | fi 41 | unzip $OUTPATH/XNLI-1.0.zip -d $OUTPATH 42 | fi 43 | 44 | # English train set 45 | echo "*** Preparing English train set ****" 46 | echo -e "premise\thypo\tlabel" > $XNLI_PATH/en.train 47 | sed '1d' $OUTPATH/XNLI-MT-1.0/multinli/multinli.train.en.tsv | cut -f1 | python $LOWER_REMOVE_ACCENT > $XNLI_PATH/train.f1 48 | sed '1d' $OUTPATH/XNLI-MT-1.0/multinli/multinli.train.en.tsv | cut -f2 | python $LOWER_REMOVE_ACCENT > $XNLI_PATH/train.f2 49 | sed '1d' $OUTPATH/XNLI-MT-1.0/multinli/multinli.train.en.tsv | cut -f3 | sed 's/contradictory/contradiction/g' > $XNLI_PATH/train.f3 50 | paste $XNLI_PATH/train.f1 $XNLI_PATH/train.f2 $XNLI_PATH/train.f3 >> $XNLI_PATH/en.train 51 | 52 | rm $XNLI_PATH/train.f1 $XNLI_PATH/train.f2 $XNLI_PATH/train.f3 53 | 54 | 55 | # validation and test sets 56 | for lg in ar bg de el en es fr hi ru sw th tr ur vi zh; do 57 | 58 | echo "*** Preparing $lg validation and test sets ***" 59 | echo -e "premise\thypo\tlabel" > $XNLI_PATH/$lg.valid 60 | echo -e "premise\thypo\tlabel" > $XNLI_PATH/$lg.test 61 | 62 | # label 63 | awk -v lg=$lg '$1==lg' $XNLI_PATH/xnli.dev.tsv | cut -f2 > $XNLI_PATH/dev.f2 64 | awk -v lg=$lg '$1==lg' $XNLI_PATH/xnli.test.tsv | cut -f2 > $XNLI_PATH/test.f2 65 | 66 | # premise/hypothesis 67 | awk -v lg=$lg '$1==lg' $XNLI_PATH/xnli.dev.tsv | cut -f7 | $TOKENIZE $lg | python $LOWER_REMOVE_ACCENT > $XNLI_PATH/dev.f7 68 | awk -v lg=$lg '$1==lg' $XNLI_PATH/xnli.dev.tsv | cut -f8 | $TOKENIZE $lg | python $LOWER_REMOVE_ACCENT > $XNLI_PATH/dev.f8 69 | awk -v lg=$lg '$1==lg' $XNLI_PATH/xnli.test.tsv | cut -f7 | $TOKENIZE $lg | python $LOWER_REMOVE_ACCENT > $XNLI_PATH/test.f7 70 | awk -v lg=$lg '$1==lg' $XNLI_PATH/xnli.test.tsv | cut -f8 | $TOKENIZE $lg | python $LOWER_REMOVE_ACCENT > $XNLI_PATH/test.f8 71 | 72 | paste $XNLI_PATH/dev.f7 $XNLI_PATH/dev.f8 $XNLI_PATH/dev.f2 >> $XNLI_PATH/$lg.valid 73 | paste $XNLI_PATH/test.f7 $XNLI_PATH/test.f8 $XNLI_PATH/test.f2 >> $XNLI_PATH/$lg.test 74 | 75 | rm $XNLI_PATH/*.f2 $XNLI_PATH/*.f7 $XNLI_PATH/*.f8 76 | done 77 | -------------------------------------------------------------------------------- /glue-xnli.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import os 9 | import argparse 10 | 11 | from xlm.utils import bool_flag, initialize_exp 12 | from xlm.evaluation.glue import GLUE 13 | from xlm.evaluation.xnli import XNLI 14 | from xlm.model.embedder import SentenceEmbedder 15 | 16 | 17 | GLUE_TASKS = ['MNLI-m', 'MNLI-mm', 'QQP', 'QNLI', 'SST-2', 'CoLA', 'MRPC', 'RTE', 'STS-B', 'WNLI', 'AX_MNLI-m'] 18 | XNLI_TASKS = ['XNLI'] 19 | TASKS = GLUE_TASKS + XNLI_TASKS 20 | 21 | 22 | # parse parameters 23 | parser = argparse.ArgumentParser(description='Train on GLUE or XNLI') 24 | 25 | # main parameters 26 | parser.add_argument("--exp_name", type=str, default="", 27 | help="Experiment name") 28 | parser.add_argument("--dump_path", type=str, default="", 29 | help="Experiment dump path") 30 | parser.add_argument("--exp_id", type=str, default="", 31 | help="Experiment ID") 32 | 33 | # evaluation task / pretrained model 34 | parser.add_argument("--transfer_tasks", type=str, default="", 35 | help="Transfer tasks, example: 'MNLI-m,RTE,XNLI' ") 36 | parser.add_argument("--model_path", type=str, default="", 37 | help="Model location") 38 | 39 | # data 40 | parser.add_argument("--data_path", type=str, default="", 41 | help="Data path") 42 | parser.add_argument("--max_vocab", type=int, default=-1, 43 | help="Maximum vocabulary size (-1 to disable)") 44 | parser.add_argument("--min_count", type=int, default=0, 45 | help="Minimum vocabulary count") 46 | 47 | # batch parameters 48 | parser.add_argument("--max_len", type=int, default=256, 49 | help="Maximum length of sentences (after BPE)") 50 | parser.add_argument("--group_by_size", type=bool_flag, default=False, 51 | help="Sort sentences by size during the training") 52 | parser.add_argument("--batch_size", type=int, default=32, 53 | help="Number of sentences per batch") 54 | parser.add_argument("--max_batch_size", type=int, default=0, 55 | help="Maximum number of sentences per batch (used in combination with tokens_per_batch, 0 to disable)") 56 | parser.add_argument("--tokens_per_batch", type=int, default=-1, 57 | help="Number of tokens per batch") 58 | 59 | # model / optimization 60 | parser.add_argument("--finetune_layers", type=str, default='0:_1', 61 | help="Layers to finetune. 0 = embeddings, _1 = last encoder layer") 62 | parser.add_argument("--weighted_training", type=bool_flag, default=False, 63 | help="Use a weighted loss during training") 64 | parser.add_argument("--dropout", type=float, default=0, 65 | help="Fine-tuning dropout") 66 | parser.add_argument("--optimizer_e", type=str, default="adam,lr=0.0001", 67 | help="Embedder (pretrained model) optimizer") 68 | parser.add_argument("--optimizer_p", type=str, default="adam,lr=0.0001", 69 | help="Projection (classifier) optimizer") 70 | parser.add_argument("--n_epochs", type=int, default=100, 71 | help="Maximum number of epochs") 72 | parser.add_argument("--epoch_size", type=int, default=-1, 73 | help="Epoch size (-1 for full pass over the dataset)") 74 | 75 | # debug 76 | parser.add_argument("--debug_train", type=bool_flag, default=False, 77 | help="Use valid sets for train sets (faster loading)") 78 | parser.add_argument("--debug_slurm", type=bool_flag, default=False, 79 | help="Debug multi-GPU / multi-node within a SLURM job") 80 | 81 | # parse parameters 82 | params = parser.parse_args() 83 | if params.tokens_per_batch > -1: 84 | params.group_by_size = True 85 | 86 | # check parameters 87 | assert os.path.isdir(params.data_path) 88 | assert os.path.isfile(params.model_path) 89 | 90 | # tasks 91 | params.transfer_tasks = params.transfer_tasks.split(',') 92 | assert len(params.transfer_tasks) > 0 93 | assert all([task in TASKS for task in params.transfer_tasks]) 94 | 95 | # reload pretrained model 96 | embedder = SentenceEmbedder.reload(params.model_path, params) 97 | 98 | # reload langs from pretrained model 99 | params.n_langs = embedder.pretrain_params['n_langs'] 100 | params.id2lang = embedder.pretrain_params['id2lang'] 101 | params.lang2id = embedder.pretrain_params['lang2id'] 102 | 103 | # initialize the experiment / build sentence embedder 104 | logger = initialize_exp(params) 105 | scores = {} 106 | 107 | # prepare trainers / evaluators 108 | glue = GLUE(embedder, scores, params) 109 | xnli = XNLI(embedder, scores, params) 110 | 111 | # run 112 | for task in params.transfer_tasks: 113 | if task in GLUE_TASKS: 114 | glue.run(task) 115 | if task in XNLI_TASKS: 116 | xnli.run() 117 | -------------------------------------------------------------------------------- /install-tools.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | set -e 9 | 10 | lg=$1 # input language 11 | 12 | # data path 13 | MAIN_PATH=$PWD 14 | TOOLS_PATH=$PWD/tools 15 | 16 | # tools 17 | MOSES_DIR=$TOOLS_PATH/mosesdecoder 18 | FASTBPE_DIR=$TOOLS_PATH/fastBPE 19 | FASTBPE=$FASTBPE_DIR/fast 20 | WMT16_SCRIPTS=$TOOLS_PATH/wmt16-scripts 21 | 22 | # tools path 23 | mkdir -p $TOOLS_PATH 24 | 25 | # 26 | # Download and install tools 27 | # 28 | 29 | cd $TOOLS_PATH 30 | 31 | # Download Moses 32 | if [ ! -d "$MOSES_DIR" ]; then 33 | echo "Cloning Moses from GitHub repository..." 34 | git clone https://github.com/moses-smt/mosesdecoder.git 35 | fi 36 | 37 | # Download fastBPE 38 | if [ ! -d "$FASTBPE_DIR" ]; then 39 | echo "Cloning fastBPE from GitHub repository..." 40 | git clone https://github.com/glample/fastBPE 41 | fi 42 | 43 | # Compile fastBPE 44 | if [ ! -f "$FASTBPE" ]; then 45 | echo "Compiling fastBPE..." 46 | cd fastBPE 47 | g++ -std=c++11 -pthread -O3 fastBPE/main.cc -IfastBPE -o fast 48 | cd .. 49 | fi 50 | 51 | # Download Sennrich's tools 52 | if [ ! -d "$WMT16_SCRIPTS" ]; then 53 | echo "Cloning WMT16 preprocessing scripts..." 54 | git clone https://github.com/rsennrich/wmt16-scripts.git 55 | fi 56 | 57 | # Download WikiExtractor 58 | if [ ! -d $TOOLS_PATH/wikiextractor ]; then 59 | echo "Cloning WikiExtractor from GitHub repository..." 60 | git clone https://github.com/attardi/wikiextractor.git 61 | fi 62 | 63 | # # Chinese segmenter 64 | # if ! ls $TOOLS_PATH/stanford-segmenter-* 1> /dev/null 2>&1; then 65 | # echo "Stanford segmenter not found at $TOOLS_PATH/stanford-segmenter-*" 66 | # echo "Please install Stanford segmenter in $TOOLS_PATH" 67 | # exit 1 68 | # fi 69 | # 70 | # # Thai tokenizer 71 | # if ! python -c 'import pkgutil; exit(not pkgutil.find_loader("pythainlp"))'; then 72 | # echo "pythainlp package not found in python" 73 | # echo "Please install pythainlp (pip install pythainlp)" 74 | # exit 1 75 | # fi 76 | # 77 | -------------------------------------------------------------------------------- /prepare-glue.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | # 9 | # Usage: ./prepare-glue.sh 10 | # 11 | 12 | set -e 13 | 14 | # data paths 15 | MAIN_PATH=$PWD 16 | OUTPATH=$PWD/data/glue 17 | TOOLS_PATH=$PWD/tools 18 | PROCESSED_PATH=$PWD/data/processed/XLM_en 19 | CODES_PATH=$MAIN_PATH/codes_en 20 | VOCAB_PATH=$MAIN_PATH/vocab_en 21 | FASTBPE=$TOOLS_PATH/fastBPE/fast 22 | 23 | 24 | # Get BPE codes and vocab (MODIFY if needed) 25 | wget -c https://dl.fbaipublicfiles.com/XLM/codes_en -P $MAIN_PATH 26 | wget -c https://dl.fbaipublicfiles.com/XLM/vocab_en -P $MAIN_PATH 27 | 28 | # apply BPE codes and binarize the GLUE corpora 29 | glue_tasks="MNLI-m QNLI QQP SST-2 STS-B" # TODO: missing MRPC 30 | 31 | for task in $glue_tasks 32 | do 33 | if [ ! -d $PROCESSED_PATH/eval/$task ]; then 34 | mkdir -p $PROCESSED_PATH/eval/$task 35 | else 36 | rm -r $PROCESSED_PATH/eval/$task/* 37 | fi 38 | for splt in train dev test 39 | do 40 | FPATH=$OUTPATH/${task}/${splt}.xlm.tsv 41 | 42 | cut -f1 $FPATH > ${FPATH}.f1 43 | $FASTBPE applybpe $PROCESSED_PATH/eval/$task/${splt}.s1 ${FPATH}.f1 $CODES_PATH 44 | python preprocess.py $VOCAB_PATH $PROCESSED_PATH/eval/$task/${splt}.s1 45 | rm ${FPATH}.f1 46 | 47 | if [ "$task" != "CoLA" ] && [ "$task" != "SST-2" ] 48 | then 49 | cut -f2 $FPATH > ${FPATH}.f2 50 | $FASTBPE applybpe $PROCESSED_PATH/eval/$task/${splt}.s2 ${FPATH}.f2 $CODES_PATH 51 | python preprocess.py $VOCAB_PATH $PROCESSED_PATH/eval/$task/${splt}.s2 52 | rm ${FPATH}.f2 53 | if [ "$splt" != "test" ] || [ "$task" = "MRPC" ] 54 | then 55 | cut -f3 $FPATH > $PROCESSED_PATH/eval/$task/${splt}.label 56 | fi 57 | else 58 | if [ "$splt" != "test" ] 59 | then 60 | cut -f2 $FPATH > $PROCESSED_PATH/eval/$task/${splt}.label 61 | fi 62 | fi 63 | done 64 | done 65 | -------------------------------------------------------------------------------- /prepare-xnli.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | # 9 | # This script is meant to prepare data to reproduce XNLI experiments 10 | # Just modify the "code" and "vocab" path for your own model 11 | # 12 | 13 | set -e 14 | 15 | pair=$1 # input language pair 16 | 17 | # data paths 18 | MAIN_PATH=$PWD 19 | PARA_PATH=$PWD/data/para 20 | TOOLS_PATH=$PWD/tools 21 | WIKI_PATH=$PWD/data/wiki 22 | XNLI_PATH=$PWD/data/xnli/XNLI-1.0 23 | PROCESSED_PATH=$PWD/data/processed/XLM15 24 | CODES_PATH=$MAIN_PATH/codes_xnli_15 25 | VOCAB_PATH=$MAIN_PATH/vocab_xnli_15 26 | FASTBPE=$TOOLS_PATH/fastBPE/fast 27 | 28 | 29 | # Get BPE codes and vocab 30 | wget -c https://dl.fbaipublicfiles.com/XLM/codes_xnli_15 -P $MAIN_PATH 31 | wget -c https://dl.fbaipublicfiles.com/XLM/vocab_xnli_15 -P $MAIN_PATH 32 | 33 | 34 | ## Prepare monolingual data 35 | # apply BPE codes and binarize the monolingual corpora 36 | for lg in ar bg de el en es fr hi ru sw th tr ur vi zh; do 37 | for split in train valid test; do 38 | $FASTBPE applybpe $PROCESSED_PATH/$lg.$split $WIKI_PATH/txt/$lg.$split $CODES_PATH 39 | python preprocess.py $VOCAB_PATH $PROCESSED_PATH/$lg.$split 40 | done 41 | done 42 | 43 | ## Prepare parallel data 44 | # apply BPE codes and binarize the parallel corpora 45 | for pair in ar-en bg-en de-en el-en en-es en-fr en-hi en-ru en-sw en-th en-tr en-ur en-vi en-zh; do 46 | for lg in $(echo $pair | sed -e 's/\-/ /g'); do 47 | for split in train valid test; do 48 | $FASTBPE applybpe $PROCESSED_PATH/$pair.$lg.$split $PARA_PATH/$pair.$lg.$split $CODES_PATH 49 | python preprocess.py $VOCAB_PATH $PROCESSED_PATH/$pair.$lg.$split 50 | done 51 | done 52 | done 53 | 54 | ## Prepare XNLI data 55 | rm -rf $PROCESSED_PATH/eval/XNLI 56 | mkdir -p $PROCESSED_PATH/eval/XNLI 57 | # apply BPE codes and binarize the XNLI corpora 58 | for lg in ar bg de el en es fr hi ru sw th tr ur vi zh; do 59 | for splt in train valid test; do 60 | if [ "$splt" = "train" ] && [ "$lg" != "en" ]; then 61 | continue 62 | fi 63 | sed '1d' $XNLI_PATH/${lg}.${splt} | cut -f1 > $PROCESSED_PATH/eval/XNLI/f1.tok 64 | sed '1d' $XNLI_PATH/${lg}.${splt} | cut -f2 > $PROCESSED_PATH/eval/XNLI/f2.tok 65 | sed '1d' $XNLI_PATH/${lg}.${splt} | cut -f3 > $PROCESSED_PATH/eval/XNLI/${splt}.label.${lg} 66 | 67 | $FASTBPE applybpe $PROCESSED_PATH/eval/XNLI/${splt}.s1.${lg} $PROCESSED_PATH/eval/XNLI/f1.tok ${CODES_PATH} 68 | $FASTBPE applybpe $PROCESSED_PATH/eval/XNLI/${splt}.s2.${lg} $PROCESSED_PATH/eval/XNLI/f2.tok ${CODES_PATH} 69 | 70 | python preprocess.py ${VOCAB_PATH} $PROCESSED_PATH/eval/XNLI/${splt}.s1.${lg} 71 | python preprocess.py ${VOCAB_PATH} $PROCESSED_PATH/eval/XNLI/${splt}.s2.${lg} 72 | 73 | rm $PROCESSED_PATH/eval/XNLI/*.tok* 74 | done 75 | done -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright (c) 2019-present, Facebook, Inc. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | # 10 | 11 | 12 | """ 13 | Example: python data/vocab.txt data/train.txt 14 | vocab.txt: 1stline=word, 2ndline=count 15 | """ 16 | 17 | import os 18 | import sys 19 | 20 | from xlm.logger import create_logger 21 | from xlm.data.dictionary import Dictionary 22 | 23 | 24 | if __name__ == '__main__': 25 | 26 | logger = create_logger(None, 0) 27 | 28 | voc_path = sys.argv[1] 29 | txt_path = sys.argv[2] 30 | bin_path = sys.argv[2] + '.pth' 31 | assert os.path.isfile(voc_path) 32 | assert os.path.isfile(txt_path) 33 | 34 | dico = Dictionary.read_vocab(voc_path) 35 | logger.info("") 36 | 37 | data = Dictionary.index_data(txt_path, bin_path, dico) 38 | logger.info("%i words (%i unique) in %i sentences." % ( 39 | len(data['sentences']) - len(data['positions']), 40 | len(data['dico']), 41 | len(data['positions']) 42 | )) 43 | if len(data['unk_words']) > 0: 44 | logger.info("%i unknown words (%i unique), covering %.2f%% of the data." % ( 45 | sum(data['unk_words'].values()), 46 | len(data['unk_words']), 47 | sum(data['unk_words'].values()) * 100. / (len(data['sentences']) - len(data['positions'])) 48 | )) 49 | if len(data['unk_words']) < 30: 50 | for w, c in sorted(data['unk_words'].items(), key=lambda x: x[1])[::-1]: 51 | logger.info("%s: %i" % (w, c)) 52 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from setuptools import setup, find_packages 3 | 4 | 5 | setup( 6 | name='xlm', 7 | version='1.0', 8 | description='Text simplification', 9 | author='Guillaume Lample, Alexis Conneau', 10 | author_email='glample@fb.com, aconneau@fb.com', 11 | packages=find_packages(), 12 | ) 13 | -------------------------------------------------------------------------------- /src: -------------------------------------------------------------------------------- 1 | xlm -------------------------------------------------------------------------------- /tools/README.md: -------------------------------------------------------------------------------- 1 | # Tools 2 | 3 | In `XLM/tools/`, you will need to install the following tools: 4 | 5 | ## Tokenizers 6 | 7 | [Moses](https://github.com/moses-smt/mosesdecoder/tree/master/scripts/tokenizer) tokenizer: 8 | ``` 9 | git clone https://github.com/moses-smt/mosesdecoder 10 | ``` 11 | 12 | Thai [PythaiNLP](https://github.com/PyThaiNLP/pythainlp) tokenizer: 13 | ``` 14 | pip install pythainlp 15 | ``` 16 | 17 | Japanese [KyTea](http://www.phontron.com/kytea) tokenizer: 18 | ``` 19 | wget http://www.phontron.com/kytea/download/kytea-0.4.7.tar.gz 20 | tar -xzf kytea-0.4.7.tar.gz 21 | cd kytea-0.4.7 22 | ./configure 23 | make 24 | make install 25 | kytea --help 26 | ``` 27 | 28 | Chinese Stanford segmenter: 29 | ``` 30 | wget https://nlp.stanford.edu/software/stanford-segmenter-2018-10-16.zip 31 | unzip stanford-segmenter-2018-10-16.zip 32 | ``` 33 | 34 | ## fastBPE 35 | 36 | ``` 37 | git clone https://github.com/glample/fastBPE 38 | cd fastBPE 39 | g++ -std=c++11 -pthread -O3 fastBPE/main.cc -IfastBPE -o fast 40 | ``` 41 | -------------------------------------------------------------------------------- /tools/lowercase_and_remove_accent.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import sys 9 | import unicodedata 10 | import six 11 | 12 | 13 | def convert_to_unicode(text): 14 | """ 15 | Converts `text` to Unicode (if it's not already), assuming UTF-8 input. 16 | """ 17 | # six_ensure_text is copied from https://github.com/benjaminp/six 18 | def six_ensure_text(s, encoding='utf-8', errors='strict'): 19 | if isinstance(s, six.binary_type): 20 | return s.decode(encoding, errors) 21 | elif isinstance(s, six.text_type): 22 | return s 23 | else: 24 | raise TypeError("not expecting type '%s'" % type(s)) 25 | 26 | return six_ensure_text(text, encoding="utf-8", errors="ignore") 27 | 28 | 29 | def run_strip_accents(text): 30 | """ 31 | Strips accents from a piece of text. 32 | """ 33 | text = unicodedata.normalize("NFD", text) 34 | output = [] 35 | for char in text: 36 | cat = unicodedata.category(char) 37 | if cat == "Mn": 38 | continue 39 | output.append(char) 40 | return "".join(output) 41 | 42 | 43 | for line in sys.stdin: 44 | line = convert_to_unicode(line.rstrip().lower()) 45 | line = run_strip_accents(line) 46 | print(u'%s' % line.lower()) 47 | -------------------------------------------------------------------------------- /tools/segment_th.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import sys 9 | from pythainlp.tokenize import word_tokenize 10 | 11 | for line in sys.stdin.readlines(): 12 | line = line.rstrip('\n') 13 | print(' '.join(word_tokenize(line))) 14 | -------------------------------------------------------------------------------- /tools/tokenize.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | # Tokenize text data in various languages 9 | # Usage: e.g. cat wiki.ar | tokenize.sh ar 10 | 11 | set -e 12 | 13 | N_THREADS=8 14 | 15 | lg=$1 16 | TOOLS_PATH=$PWD/tools 17 | 18 | # moses 19 | MOSES=$TOOLS_PATH/mosesdecoder 20 | REPLACE_UNICODE_PUNCT=$MOSES/scripts/tokenizer/replace-unicode-punctuation.perl 21 | NORM_PUNC=$MOSES/scripts/tokenizer/normalize-punctuation.perl 22 | REM_NON_PRINT_CHAR=$MOSES/scripts/tokenizer/remove-non-printing-char.perl 23 | TOKENIZER=$MOSES/scripts/tokenizer/tokenizer.perl 24 | 25 | # Chinese 26 | if [ "$lg" = "zh" ]; then 27 | $TOOLS_PATH/stanford-segmenter-*/segment.sh pku /dev/stdin UTF-8 0 | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $lg | $REM_NON_PRINT_CHAR 28 | # Thai 29 | elif [ "$lg" = "th" ]; then 30 | cat - | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $lg | $REM_NON_PRINT_CHAR | python $TOOLS_PATH/segment_th.py 31 | # Japanese 32 | elif [ "$lg" = "ja" ]; then 33 | cat - | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $lg | $REM_NON_PRINT_CHAR | kytea -notags 34 | # other languages 35 | else 36 | cat - | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $lg | $REM_NON_PRINT_CHAR | $TOKENIZER -no-escape -threads $N_THREADS -l $lg 37 | fi 38 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import json 9 | import random 10 | import argparse 11 | 12 | from xlm.slurm import init_signal_handler, init_distributed_mode 13 | from xlm.data.loader import check_data_params, load_data 14 | from xlm.utils import bool_flag, initialize_exp, set_sampling_probs, shuf_order 15 | from xlm.model import check_model_params, build_model 16 | from xlm.model.memory import HashingMemory 17 | from xlm.trainer import SingleTrainer, EncDecTrainer 18 | from xlm.evaluation.evaluator import SingleEvaluator, EncDecEvaluator 19 | 20 | 21 | def get_parser(): 22 | """ 23 | Generate a parameters parser. 24 | """ 25 | # parse parameters 26 | parser = argparse.ArgumentParser(description="Language transfer") 27 | 28 | # main parameters 29 | parser.add_argument("--dump_path", type=str, default="./dumped/", 30 | help="Experiment dump path") 31 | parser.add_argument("--exp_name", type=str, default="", 32 | help="Experiment name") 33 | parser.add_argument("--save_periodic", type=int, default=0, 34 | help="Save the model periodically (0 to disable)") 35 | parser.add_argument("--exp_id", type=str, default="", 36 | help="Experiment ID") 37 | 38 | # float16 / AMP API 39 | parser.add_argument("--fp16", type=bool_flag, default=False, 40 | help="Run model with float16") 41 | parser.add_argument("--amp", type=int, default=-1, 42 | help="Use AMP wrapper for float16 / distributed / gradient accumulation. Level of optimization. -1 to disable.") 43 | 44 | # only use an encoder (use a specific decoder for machine translation) 45 | parser.add_argument("--encoder_only", type=bool_flag, default=True, 46 | help="Only use an encoder") 47 | 48 | # model parameters 49 | parser.add_argument("--emb_dim", type=int, default=512, 50 | help="Embedding layer size") 51 | parser.add_argument("--n_layers", type=int, default=4, 52 | help="Number of Transformer layers") 53 | parser.add_argument("--n_heads", type=int, default=8, 54 | help="Number of Transformer heads") 55 | parser.add_argument("--dropout", type=float, default=0, 56 | help="Dropout") 57 | parser.add_argument("--attention_dropout", type=float, default=0, 58 | help="Dropout in the attention layer") 59 | parser.add_argument("--gelu_activation", type=bool_flag, default=False, 60 | help="Use a GELU activation instead of ReLU") 61 | parser.add_argument("--share_inout_emb", type=bool_flag, default=True, 62 | help="Share input and output embeddings") 63 | parser.add_argument("--sinusoidal_embeddings", type=bool_flag, default=False, 64 | help="Use sinusoidal embeddings") 65 | parser.add_argument("--use_lang_emb", type=bool_flag, default=True, 66 | help="Use language embedding") 67 | 68 | # memory parameters 69 | parser.add_argument("--use_memory", type=bool_flag, default=False, 70 | help="Use an external memory") 71 | if parser.parse_known_args()[0].use_memory: 72 | HashingMemory.register_args(parser) 73 | parser.add_argument("--mem_enc_positions", type=str, default="", 74 | help="Memory positions in the encoder ('4' for inside layer 4, '7,10+' for inside layer 7 and after layer 10)") 75 | parser.add_argument("--mem_dec_positions", type=str, default="", 76 | help="Memory positions in the decoder. Same syntax as `mem_enc_positions`.") 77 | 78 | # adaptive softmax 79 | parser.add_argument("--asm", type=bool_flag, default=False, 80 | help="Use adaptive softmax") 81 | if parser.parse_known_args()[0].asm: 82 | parser.add_argument("--asm_cutoffs", type=str, default="8000,20000", 83 | help="Adaptive softmax cutoffs") 84 | parser.add_argument("--asm_div_value", type=float, default=4, 85 | help="Adaptive softmax cluster sizes ratio") 86 | 87 | # causal language modeling task parameters 88 | parser.add_argument("--context_size", type=int, default=0, 89 | help="Context size (0 means that the first elements in sequences won't have any context)") 90 | 91 | # masked language modeling task parameters 92 | parser.add_argument("--word_pred", type=float, default=0.15, 93 | help="Fraction of words for which we need to make a prediction") 94 | parser.add_argument("--sample_alpha", type=float, default=0, 95 | help="Exponent for transforming word counts to probabilities (~word2vec sampling)") 96 | parser.add_argument("--word_mask_keep_rand", type=str, default="0.8,0.1,0.1", 97 | help="Fraction of words to mask out / keep / randomize, among the words to predict") 98 | 99 | # input sentence noise 100 | parser.add_argument("--word_shuffle", type=float, default=0, 101 | help="Randomly shuffle input words (0 to disable)") 102 | parser.add_argument("--word_dropout", type=float, default=0, 103 | help="Randomly dropout input words (0 to disable)") 104 | parser.add_argument("--word_blank", type=float, default=0, 105 | help="Randomly blank input words (0 to disable)") 106 | 107 | # data 108 | parser.add_argument("--data_path", type=str, default="", 109 | help="Data path") 110 | parser.add_argument("--lgs", type=str, default="", 111 | help="Languages (lg1-lg2-lg3 .. ex: en-fr-es-de)") 112 | parser.add_argument("--max_vocab", type=int, default=-1, 113 | help="Maximum vocabulary size (-1 to disable)") 114 | parser.add_argument("--min_count", type=int, default=0, 115 | help="Minimum vocabulary count") 116 | parser.add_argument("--lg_sampling_factor", type=float, default=-1, 117 | help="Language sampling factor") 118 | 119 | # batch parameters 120 | parser.add_argument("--bptt", type=int, default=256, 121 | help="Sequence length") 122 | parser.add_argument("--max_len", type=int, default=100, 123 | help="Maximum length of sentences (after BPE)") 124 | parser.add_argument("--group_by_size", type=bool_flag, default=True, 125 | help="Sort sentences by size during the training") 126 | parser.add_argument("--batch_size", type=int, default=32, 127 | help="Number of sentences per batch") 128 | parser.add_argument("--max_batch_size", type=int, default=0, 129 | help="Maximum number of sentences per batch (used in combination with tokens_per_batch, 0 to disable)") 130 | parser.add_argument("--tokens_per_batch", type=int, default=-1, 131 | help="Number of tokens per batch") 132 | 133 | # training parameters 134 | parser.add_argument("--split_data", type=bool_flag, default=False, 135 | help="Split data across workers of a same node") 136 | parser.add_argument("--optimizer", type=str, default="adam,lr=0.0001", 137 | help="Optimizer (SGD / RMSprop / Adam, etc.)") 138 | parser.add_argument("--clip_grad_norm", type=float, default=5, 139 | help="Clip gradients norm (0 to disable)") 140 | parser.add_argument("--epoch_size", type=int, default=100000, 141 | help="Epoch size / evaluation frequency (-1 for parallel data size)") 142 | parser.add_argument("--max_epoch", type=int, default=100000, 143 | help="Maximum epoch size") 144 | parser.add_argument("--stopping_criterion", type=str, default="", 145 | help="Stopping criterion, and number of non-increase before stopping the experiment") 146 | parser.add_argument("--validation_metrics", type=str, default="", 147 | help="Validation metrics") 148 | parser.add_argument("--accumulate_gradients", type=int, default=1, 149 | help="Accumulate model gradients over N iterations (N times larger batch sizes)") 150 | 151 | # training coefficients 152 | parser.add_argument("--lambda_mlm", type=str, default="1", 153 | help="Prediction coefficient (MLM)") 154 | parser.add_argument("--lambda_clm", type=str, default="1", 155 | help="Causal coefficient (LM)") 156 | parser.add_argument("--lambda_pc", type=str, default="1", 157 | help="PC coefficient") 158 | parser.add_argument("--lambda_ae", type=str, default="1", 159 | help="AE coefficient") 160 | parser.add_argument("--lambda_mt", type=str, default="1", 161 | help="MT coefficient") 162 | parser.add_argument("--lambda_bt", type=str, default="1", 163 | help="BT coefficient") 164 | 165 | # training steps 166 | parser.add_argument("--clm_steps", type=str, default="", 167 | help="Causal prediction steps (CLM)") 168 | parser.add_argument("--mlm_steps", type=str, default="", 169 | help="Masked prediction steps (MLM / TLM)") 170 | parser.add_argument("--mt_steps", type=str, default="", 171 | help="Machine translation steps") 172 | parser.add_argument("--ae_steps", type=str, default="", 173 | help="Denoising auto-encoder steps") 174 | parser.add_argument("--bt_steps", type=str, default="", 175 | help="Back-translation steps") 176 | parser.add_argument("--pc_steps", type=str, default="", 177 | help="Parallel classification steps") 178 | 179 | # reload pretrained embeddings / pretrained model / checkpoint 180 | parser.add_argument("--reload_emb", type=str, default="", 181 | help="Reload pretrained word embeddings") 182 | parser.add_argument("--reload_model", type=str, default="", 183 | help="Reload a pretrained model") 184 | parser.add_argument("--reload_checkpoint", type=str, default="", 185 | help="Reload a checkpoint") 186 | 187 | # beam search (for MT only) 188 | parser.add_argument("--beam_size", type=int, default=1, 189 | help="Beam size, default = 1 (greedy decoding)") 190 | parser.add_argument("--length_penalty", type=float, default=1, 191 | help="Length penalty, values < 1.0 favor shorter sentences, while values > 1.0 favor longer ones.") 192 | parser.add_argument("--early_stopping", type=bool_flag, default=False, 193 | help="Early stopping, stop as soon as we have `beam_size` hypotheses, although longer ones may have better scores.") 194 | 195 | # evaluation 196 | parser.add_argument("--eval_bleu", type=bool_flag, default=False, 197 | help="Evaluate BLEU score during MT training") 198 | parser.add_argument("--eval_only", type=bool_flag, default=False, 199 | help="Only run evaluations") 200 | 201 | # debug 202 | parser.add_argument("--debug_train", type=bool_flag, default=False, 203 | help="Use valid sets for train sets (faster loading)") 204 | parser.add_argument("--debug_slurm", type=bool_flag, default=False, 205 | help="Debug multi-GPU / multi-node within a SLURM job") 206 | parser.add_argument("--debug", help="Enable all debug flags", 207 | action="store_true") 208 | 209 | # multi-gpu / multi-node 210 | parser.add_argument("--local_rank", type=int, default=-1, 211 | help="Multi-GPU - Local rank") 212 | parser.add_argument("--master_port", type=int, default=-1, 213 | help="Master port (for multi-node SLURM jobs)") 214 | 215 | return parser 216 | 217 | 218 | def main(params): 219 | 220 | # initialize the multi-GPU / multi-node training 221 | init_distributed_mode(params) 222 | 223 | # initialize the experiment 224 | logger = initialize_exp(params) 225 | 226 | # initialize SLURM signal handler for time limit / pre-emption 227 | init_signal_handler() 228 | 229 | # load data 230 | data = load_data(params) 231 | 232 | # build model 233 | if params.encoder_only: 234 | model = build_model(params, data['dico']) 235 | else: 236 | encoder, decoder = build_model(params, data['dico']) 237 | 238 | # build trainer, reload potential checkpoints / build evaluator 239 | if params.encoder_only: 240 | trainer = SingleTrainer(model, data, params) 241 | evaluator = SingleEvaluator(trainer, data, params) 242 | else: 243 | trainer = EncDecTrainer(encoder, decoder, data, params) 244 | evaluator = EncDecEvaluator(trainer, data, params) 245 | 246 | # evaluation 247 | if params.eval_only: 248 | scores = evaluator.run_all_evals(trainer) 249 | for k, v in scores.items(): 250 | logger.info("%s -> %.6f" % (k, v)) 251 | logger.info("__log__:%s" % json.dumps(scores)) 252 | exit() 253 | 254 | # set sampling probabilities for training 255 | set_sampling_probs(data, params) 256 | 257 | # language model training 258 | for _ in range(params.max_epoch): 259 | 260 | logger.info("============ Starting epoch %i ... ============" % trainer.epoch) 261 | 262 | trainer.n_sentences = 0 263 | 264 | while trainer.n_sentences < trainer.epoch_size: 265 | 266 | # CLM steps 267 | for lang1, lang2 in shuf_order(params.clm_steps, params): 268 | trainer.clm_step(lang1, lang2, params.lambda_clm) 269 | 270 | # MLM steps (also includes TLM if lang2 is not None) 271 | for lang1, lang2 in shuf_order(params.mlm_steps, params): 272 | trainer.mlm_step(lang1, lang2, params.lambda_mlm) 273 | 274 | # parallel classification steps 275 | for lang1, lang2 in shuf_order(params.pc_steps, params): 276 | trainer.pc_step(lang1, lang2, params.lambda_pc) 277 | 278 | # denoising auto-encoder steps 279 | for lang in shuf_order(params.ae_steps): 280 | trainer.mt_step(lang, lang, params.lambda_ae) 281 | 282 | # machine translation steps 283 | for lang1, lang2 in shuf_order(params.mt_steps, params): 284 | trainer.mt_step(lang1, lang2, params.lambda_mt) 285 | 286 | # back-translation steps 287 | for lang1, lang2, lang3 in shuf_order(params.bt_steps): 288 | trainer.bt_step(lang1, lang2, lang3, params.lambda_bt) 289 | 290 | trainer.iter() 291 | 292 | logger.info("============ End of epoch %i ============" % trainer.epoch) 293 | 294 | # evaluate perplexity 295 | scores = evaluator.run_all_evals(trainer) 296 | 297 | # print / JSON log 298 | for k, v in scores.items(): 299 | logger.info("%s -> %.6f" % (k, v)) 300 | if params.is_master: 301 | logger.info("__log__:%s" % json.dumps(scores)) 302 | 303 | # end of epoch 304 | trainer.save_best_model(scores) 305 | trainer.save_periodic() 306 | trainer.end_epoch(scores) 307 | 308 | 309 | if __name__ == '__main__': 310 | 311 | # generate parser / parse parameters 312 | parser = get_parser() 313 | params = parser.parse_args() 314 | 315 | # debug mode 316 | if params.debug: 317 | params.exp_name = 'debug' 318 | params.exp_id = 'debug_%08i' % random.randint(0, 100000000) 319 | params.debug_slurm = True 320 | params.debug_train = True 321 | 322 | # check parameters 323 | check_data_params(params) 324 | check_model_params(params) 325 | 326 | # run experiment 327 | main(params) 328 | -------------------------------------------------------------------------------- /translate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | # Translate sentences from the input stream. 8 | # The model will be faster is sentences are sorted by length. 9 | # Input sentences must have the same tokenization and BPE codes than the ones used in the model. 10 | # 11 | # Usage: 12 | # cat source_sentences.bpe | \ 13 | # python translate.py --exp_name translate \ 14 | # --src_lang en --tgt_lang fr \ 15 | # --model_path trained_model.pth --output_path output 16 | # 17 | 18 | import os 19 | import io 20 | import sys 21 | import argparse 22 | import torch 23 | 24 | from xlm.utils import AttrDict 25 | from xlm.utils import bool_flag, initialize_exp 26 | from xlm.data.dictionary import Dictionary 27 | from xlm.model.transformer import TransformerModel 28 | 29 | 30 | def get_parser(): 31 | """ 32 | Generate a parameters parser. 33 | """ 34 | # parse parameters 35 | parser = argparse.ArgumentParser(description="Translate sentences") 36 | 37 | # main parameters 38 | parser.add_argument("--dump_path", type=str, default="./dumped/", help="Experiment dump path") 39 | parser.add_argument("--exp_name", type=str, default="", help="Experiment name") 40 | parser.add_argument("--exp_id", type=str, default="", help="Experiment ID") 41 | parser.add_argument("--batch_size", type=int, default=32, help="Number of sentences per batch") 42 | 43 | # model / output paths 44 | parser.add_argument("--model_path", type=str, default="", help="Model path") 45 | parser.add_argument("--output_path", type=str, default="", help="Output path") 46 | 47 | # parser.add_argument("--max_vocab", type=int, default=-1, help="Maximum vocabulary size (-1 to disable)") 48 | # parser.add_argument("--min_count", type=int, default=0, help="Minimum vocabulary count") 49 | 50 | # source language / target language 51 | parser.add_argument("--src_lang", type=str, default="", help="Source language") 52 | parser.add_argument("--tgt_lang", type=str, default="", help="Target language") 53 | 54 | return parser 55 | 56 | 57 | def main(params): 58 | 59 | # initialize the experiment 60 | logger = initialize_exp(params) 61 | 62 | # generate parser / parse parameters 63 | parser = get_parser() 64 | params = parser.parse_args() 65 | reloaded = torch.load(params.model_path) 66 | model_params = AttrDict(reloaded['params']) 67 | logger.info("Supported languages: %s" % ", ".join(model_params.lang2id.keys())) 68 | 69 | # update dictionary parameters 70 | for name in ['n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index', 'mask_index']: 71 | setattr(params, name, getattr(model_params, name)) 72 | 73 | # build dictionary / build encoder / build decoder / reload weights 74 | dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts']) 75 | encoder = TransformerModel(model_params, dico, is_encoder=True, with_output=True).cuda().eval() 76 | decoder = TransformerModel(model_params, dico, is_encoder=False, with_output=True).cuda().eval() 77 | encoder.load_state_dict(reloaded['encoder']) 78 | decoder.load_state_dict(reloaded['decoder']) 79 | params.src_id = model_params.lang2id[params.src_lang] 80 | params.tgt_id = model_params.lang2id[params.tgt_lang] 81 | 82 | # read sentences from stdin 83 | src_sent = [] 84 | for line in sys.stdin.readlines(): 85 | assert len(line.strip().split()) > 0 86 | src_sent.append(line) 87 | logger.info("Read %i sentences from stdin. Translating ..." % len(src_sent)) 88 | 89 | f = io.open(params.output_path, 'w', encoding='utf-8') 90 | 91 | for i in range(0, len(src_sent), params.batch_size): 92 | 93 | # prepare batch 94 | word_ids = [torch.LongTensor([dico.index(w) for w in s.strip().split()]) 95 | for s in src_sent[i:i + params.batch_size]] 96 | lengths = torch.LongTensor([len(s) + 2 for s in word_ids]) 97 | batch = torch.LongTensor(lengths.max().item(), lengths.size(0)).fill_(params.pad_index) 98 | batch[0] = params.eos_index 99 | for j, s in enumerate(word_ids): 100 | if lengths[j] > 2: # if sentence not empty 101 | batch[1:lengths[j] - 1, j].copy_(s) 102 | batch[lengths[j] - 1, j] = params.eos_index 103 | langs = batch.clone().fill_(params.src_id) 104 | 105 | # encode source batch and translate it 106 | encoded = encoder('fwd', x=batch.cuda(), lengths=lengths.cuda(), langs=langs.cuda(), causal=False) 107 | encoded = encoded.transpose(0, 1) 108 | decoded, dec_lengths = decoder.generate(encoded, lengths.cuda(), params.tgt_id, max_len=int(1.5 * lengths.max().item() + 10)) 109 | 110 | # convert sentences to words 111 | for j in range(decoded.size(1)): 112 | 113 | # remove delimiters 114 | sent = decoded[:, j] 115 | delimiters = (sent == params.eos_index).nonzero().view(-1) 116 | assert len(delimiters) >= 1 and delimiters[0].item() == 0 117 | sent = sent[1:] if len(delimiters) == 1 else sent[1:delimiters[1]] 118 | 119 | # output translation 120 | source = src_sent[i + j].strip() 121 | target = " ".join([dico[sent[k].item()] for k in range(len(sent))]) 122 | sys.stderr.write("%i / %i: %s -> %s\n" % (i + j, len(src_sent), source, target)) 123 | f.write(target + "\n") 124 | 125 | f.close() 126 | 127 | 128 | if __name__ == '__main__': 129 | 130 | # generate parser / parse parameters 131 | parser = get_parser() 132 | params = parser.parse_args() 133 | 134 | # check parameters 135 | assert os.path.isfile(params.model_path) 136 | assert params.src_lang != '' and params.tgt_lang != '' and params.src_lang != params.tgt_lang 137 | assert params.output_path and not os.path.isfile(params.output_path) 138 | 139 | # translate 140 | with torch.no_grad(): 141 | main(params) 142 | -------------------------------------------------------------------------------- /xlm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/XLM/cd281d32612d145c6742b4d3f048f80df8669c30/xlm/__init__.py -------------------------------------------------------------------------------- /xlm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/XLM/cd281d32612d145c6742b4d3f048f80df8669c30/xlm/data/__init__.py -------------------------------------------------------------------------------- /xlm/data/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from logging import getLogger 9 | import math 10 | import numpy as np 11 | import torch 12 | 13 | 14 | logger = getLogger() 15 | 16 | 17 | class StreamDataset(object): 18 | 19 | def __init__(self, sent, pos, bs, params): 20 | """ 21 | Prepare batches for data iterator. 22 | """ 23 | bptt = params.bptt 24 | self.eos = params.eos_index 25 | 26 | # checks 27 | assert len(pos) == (sent == self.eos).sum() 28 | assert len(pos) == (sent[pos[:, 1]] == self.eos).sum() 29 | 30 | n_tokens = len(sent) 31 | n_batches = math.ceil(n_tokens / (bs * bptt)) 32 | t_size = n_batches * bptt * bs 33 | 34 | buffer = np.zeros(t_size, dtype=sent.dtype) + self.eos 35 | buffer[t_size - n_tokens:] = sent 36 | buffer = buffer.reshape((bs, n_batches * bptt)).T 37 | self.data = np.zeros((n_batches * bptt + 1, bs), dtype=sent.dtype) + self.eos 38 | self.data[1:] = buffer 39 | 40 | self.bptt = bptt 41 | self.n_tokens = n_tokens 42 | self.n_batches = n_batches 43 | self.n_sentences = len(pos) 44 | self.lengths = torch.LongTensor(bs).fill_(bptt) 45 | 46 | def __len__(self): 47 | """ 48 | Number of sentences in the dataset. 49 | """ 50 | return self.n_sentences 51 | 52 | def select_data(self, a, b): 53 | """ 54 | Only select a subset of the dataset. 55 | """ 56 | if not (0 <= a < b <= self.n_batches): 57 | logger.warning("Invalid split values: %i %i - %i" % (a, b, self.n_batches)) 58 | return 59 | assert 0 <= a < b <= self.n_batches 60 | logger.info("Selecting batches from %i to %i ..." % (a, b)) 61 | 62 | # sub-select 63 | self.data = self.data[a * self.bptt:b * self.bptt] 64 | self.n_batches = b - a 65 | self.n_sentences = (self.data == self.eos).sum().item() 66 | 67 | def get_iterator(self, shuffle, subsample=1): 68 | """ 69 | Return a sentences iterator. 70 | """ 71 | indexes = (np.random.permutation if shuffle else range)(self.n_batches // subsample) 72 | for i in indexes: 73 | a = self.bptt * i 74 | b = self.bptt * (i + 1) 75 | yield torch.from_numpy(self.data[a:b].astype(np.int64)), self.lengths 76 | 77 | 78 | class Dataset(object): 79 | 80 | def __init__(self, sent, pos, params): 81 | 82 | self.eos_index = params.eos_index 83 | self.pad_index = params.pad_index 84 | self.batch_size = params.batch_size 85 | self.tokens_per_batch = params.tokens_per_batch 86 | self.max_batch_size = params.max_batch_size 87 | 88 | self.sent = sent 89 | self.pos = pos 90 | self.lengths = self.pos[:, 1] - self.pos[:, 0] 91 | 92 | # check number of sentences 93 | assert len(self.pos) == (self.sent == self.eos_index).sum() 94 | 95 | # # remove empty sentences 96 | # self.remove_empty_sentences() 97 | 98 | # sanity checks 99 | self.check() 100 | 101 | def __len__(self): 102 | """ 103 | Number of sentences in the dataset. 104 | """ 105 | return len(self.pos) 106 | 107 | def check(self): 108 | """ 109 | Sanity checks. 110 | """ 111 | eos = self.eos_index 112 | assert len(self.pos) == (self.sent[self.pos[:, 1]] == eos).sum() # check sentences indices 113 | # assert self.lengths.min() > 0 # check empty sentences 114 | 115 | def batch_sentences(self, sentences): 116 | """ 117 | Take as input a list of n sentences (torch.LongTensor vectors) and return 118 | a tensor of size (slen, n) where slen is the length of the longest 119 | sentence, and a vector lengths containing the length of each sentence. 120 | """ 121 | # sentences = sorted(sentences, key=lambda x: len(x), reverse=True) 122 | lengths = torch.LongTensor([len(s) + 2 for s in sentences]) 123 | sent = torch.LongTensor(lengths.max().item(), lengths.size(0)).fill_(self.pad_index) 124 | 125 | sent[0] = self.eos_index 126 | for i, s in enumerate(sentences): 127 | if lengths[i] > 2: # if sentence not empty 128 | sent[1:lengths[i] - 1, i].copy_(torch.from_numpy(s.astype(np.int64))) 129 | sent[lengths[i] - 1, i] = self.eos_index 130 | 131 | return sent, lengths 132 | 133 | def remove_empty_sentences(self): 134 | """ 135 | Remove empty sentences. 136 | """ 137 | init_size = len(self.pos) 138 | indices = np.arange(len(self.pos)) 139 | indices = indices[self.lengths[indices] > 0] 140 | self.pos = self.pos[indices] 141 | self.lengths = self.pos[:, 1] - self.pos[:, 0] 142 | logger.info("Removed %i empty sentences." % (init_size - len(indices))) 143 | self.check() 144 | 145 | def remove_long_sentences(self, max_len): 146 | """ 147 | Remove sentences exceeding a certain length. 148 | """ 149 | assert max_len >= 0 150 | if max_len == 0: 151 | return 152 | init_size = len(self.pos) 153 | indices = np.arange(len(self.pos)) 154 | indices = indices[self.lengths[indices] <= max_len] 155 | self.pos = self.pos[indices] 156 | self.lengths = self.pos[:, 1] - self.pos[:, 0] 157 | logger.info("Removed %i too long sentences." % (init_size - len(indices))) 158 | self.check() 159 | 160 | def select_data(self, a, b): 161 | """ 162 | Only select a subset of the dataset. 163 | """ 164 | assert 0 <= a < b <= len(self.pos) 165 | logger.info("Selecting sentences from %i to %i ..." % (a, b)) 166 | 167 | # sub-select 168 | self.pos = self.pos[a:b] 169 | self.lengths = self.pos[:, 1] - self.pos[:, 0] 170 | 171 | # re-index 172 | min_pos = self.pos.min() 173 | max_pos = self.pos.max() 174 | self.pos -= min_pos 175 | self.sent = self.sent[min_pos:max_pos + 1] 176 | 177 | # sanity checks 178 | self.check() 179 | 180 | def get_batches_iterator(self, batches, return_indices): 181 | """ 182 | Return a sentences iterator, given the associated sentence batches. 183 | """ 184 | assert type(return_indices) is bool 185 | 186 | for sentence_ids in batches: 187 | if 0 < self.max_batch_size < len(sentence_ids): 188 | np.random.shuffle(sentence_ids) 189 | sentence_ids = sentence_ids[:self.max_batch_size] 190 | pos = self.pos[sentence_ids] 191 | sent = [self.sent[a:b] for a, b in pos] 192 | sent = self.batch_sentences(sent) 193 | yield (sent, sentence_ids) if return_indices else sent 194 | 195 | def get_iterator(self, shuffle, group_by_size=False, n_sentences=-1, seed=None, return_indices=False): 196 | """ 197 | Return a sentences iterator. 198 | """ 199 | assert seed is None or shuffle is True and type(seed) is int 200 | rng = np.random.RandomState(seed) 201 | n_sentences = len(self.pos) if n_sentences == -1 else n_sentences 202 | assert 0 < n_sentences <= len(self.pos) 203 | assert type(shuffle) is bool and type(group_by_size) is bool 204 | assert group_by_size is False or shuffle is True 205 | 206 | # sentence lengths 207 | lengths = self.lengths + 2 208 | 209 | # select sentences to iterate over 210 | if shuffle: 211 | indices = rng.permutation(len(self.pos))[:n_sentences] 212 | else: 213 | indices = np.arange(n_sentences) 214 | 215 | # group sentences by lengths 216 | if group_by_size: 217 | indices = indices[np.argsort(lengths[indices], kind='mergesort')] 218 | 219 | # create batches - either have a fixed number of sentences, or a similar number of tokens 220 | if self.tokens_per_batch == -1: 221 | batches = np.array_split(indices, math.ceil(len(indices) * 1. / self.batch_size)) 222 | else: 223 | batch_ids = np.cumsum(lengths[indices]) // self.tokens_per_batch 224 | _, bounds = np.unique(batch_ids, return_index=True) 225 | batches = [indices[bounds[i]:bounds[i + 1]] for i in range(len(bounds) - 1)] 226 | if bounds[-1] < len(indices): 227 | batches.append(indices[bounds[-1]:]) 228 | 229 | # optionally shuffle batches 230 | if shuffle: 231 | rng.shuffle(batches) 232 | 233 | # sanity checks 234 | assert n_sentences == sum([len(x) for x in batches]) 235 | assert lengths[indices].sum() == sum([lengths[x].sum() for x in batches]) 236 | # assert set.union(*[set(x.tolist()) for x in batches]) == set(range(n_sentences)) # slow 237 | 238 | # return the iterator 239 | return self.get_batches_iterator(batches, return_indices) 240 | 241 | 242 | class ParallelDataset(Dataset): 243 | 244 | def __init__(self, sent1, pos1, sent2, pos2, params): 245 | 246 | self.eos_index = params.eos_index 247 | self.pad_index = params.pad_index 248 | self.batch_size = params.batch_size 249 | self.tokens_per_batch = params.tokens_per_batch 250 | self.max_batch_size = params.max_batch_size 251 | 252 | self.sent1 = sent1 253 | self.sent2 = sent2 254 | self.pos1 = pos1 255 | self.pos2 = pos2 256 | self.lengths1 = self.pos1[:, 1] - self.pos1[:, 0] 257 | self.lengths2 = self.pos2[:, 1] - self.pos2[:, 0] 258 | 259 | # check number of sentences 260 | assert len(self.pos1) == (self.sent1 == self.eos_index).sum() 261 | assert len(self.pos2) == (self.sent2 == self.eos_index).sum() 262 | 263 | # remove empty sentences 264 | self.remove_empty_sentences() 265 | 266 | # sanity checks 267 | self.check() 268 | 269 | def __len__(self): 270 | """ 271 | Number of sentences in the dataset. 272 | """ 273 | return len(self.pos1) 274 | 275 | def check(self): 276 | """ 277 | Sanity checks. 278 | """ 279 | eos = self.eos_index 280 | assert len(self.pos1) == len(self.pos2) > 0 # check number of sentences 281 | assert len(self.pos1) == (self.sent1[self.pos1[:, 1]] == eos).sum() # check sentences indices 282 | assert len(self.pos2) == (self.sent2[self.pos2[:, 1]] == eos).sum() # check sentences indices 283 | assert eos <= self.sent1.min() < self.sent1.max() # check dictionary indices 284 | assert eos <= self.sent2.min() < self.sent2.max() # check dictionary indices 285 | assert self.lengths1.min() > 0 # check empty sentences 286 | assert self.lengths2.min() > 0 # check empty sentences 287 | 288 | def remove_empty_sentences(self): 289 | """ 290 | Remove empty sentences. 291 | """ 292 | init_size = len(self.pos1) 293 | indices = np.arange(len(self.pos1)) 294 | indices = indices[self.lengths1[indices] > 0] 295 | indices = indices[self.lengths2[indices] > 0] 296 | self.pos1 = self.pos1[indices] 297 | self.pos2 = self.pos2[indices] 298 | self.lengths1 = self.pos1[:, 1] - self.pos1[:, 0] 299 | self.lengths2 = self.pos2[:, 1] - self.pos2[:, 0] 300 | logger.info("Removed %i empty sentences." % (init_size - len(indices))) 301 | self.check() 302 | 303 | def remove_long_sentences(self, max_len): 304 | """ 305 | Remove sentences exceeding a certain length. 306 | """ 307 | assert max_len >= 0 308 | if max_len == 0: 309 | return 310 | init_size = len(self.pos1) 311 | indices = np.arange(len(self.pos1)) 312 | indices = indices[self.lengths1[indices] <= max_len] 313 | indices = indices[self.lengths2[indices] <= max_len] 314 | self.pos1 = self.pos1[indices] 315 | self.pos2 = self.pos2[indices] 316 | self.lengths1 = self.pos1[:, 1] - self.pos1[:, 0] 317 | self.lengths2 = self.pos2[:, 1] - self.pos2[:, 0] 318 | logger.info("Removed %i too long sentences." % (init_size - len(indices))) 319 | self.check() 320 | 321 | def select_data(self, a, b): 322 | """ 323 | Only select a subset of the dataset. 324 | """ 325 | assert 0 <= a < b <= len(self.pos1) 326 | logger.info("Selecting sentences from %i to %i ..." % (a, b)) 327 | 328 | # sub-select 329 | self.pos1 = self.pos1[a:b] 330 | self.pos2 = self.pos2[a:b] 331 | self.lengths1 = self.pos1[:, 1] - self.pos1[:, 0] 332 | self.lengths2 = self.pos2[:, 1] - self.pos2[:, 0] 333 | 334 | # re-index 335 | min_pos1 = self.pos1.min() 336 | max_pos1 = self.pos1.max() 337 | min_pos2 = self.pos2.min() 338 | max_pos2 = self.pos2.max() 339 | self.pos1 -= min_pos1 340 | self.pos2 -= min_pos2 341 | self.sent1 = self.sent1[min_pos1:max_pos1 + 1] 342 | self.sent2 = self.sent2[min_pos2:max_pos2 + 1] 343 | 344 | # sanity checks 345 | self.check() 346 | 347 | def get_batches_iterator(self, batches, return_indices): 348 | """ 349 | Return a sentences iterator, given the associated sentence batches. 350 | """ 351 | assert type(return_indices) is bool 352 | 353 | for sentence_ids in batches: 354 | if 0 < self.max_batch_size < len(sentence_ids): 355 | np.random.shuffle(sentence_ids) 356 | sentence_ids = sentence_ids[:self.max_batch_size] 357 | pos1 = self.pos1[sentence_ids] 358 | pos2 = self.pos2[sentence_ids] 359 | sent1 = self.batch_sentences([self.sent1[a:b] for a, b in pos1]) 360 | sent2 = self.batch_sentences([self.sent2[a:b] for a, b in pos2]) 361 | yield (sent1, sent2, sentence_ids) if return_indices else (sent1, sent2) 362 | 363 | def get_iterator(self, shuffle, group_by_size=False, n_sentences=-1, return_indices=False): 364 | """ 365 | Return a sentences iterator. 366 | """ 367 | n_sentences = len(self.pos1) if n_sentences == -1 else n_sentences 368 | assert 0 < n_sentences <= len(self.pos1) 369 | assert type(shuffle) is bool and type(group_by_size) is bool 370 | 371 | # sentence lengths 372 | lengths = self.lengths1 + self.lengths2 + 4 373 | 374 | # select sentences to iterate over 375 | if shuffle: 376 | indices = np.random.permutation(len(self.pos1))[:n_sentences] 377 | else: 378 | indices = np.arange(n_sentences) 379 | 380 | # group sentences by lengths 381 | if group_by_size: 382 | indices = indices[np.argsort(lengths[indices], kind='mergesort')] 383 | 384 | # create batches - either have a fixed number of sentences, or a similar number of tokens 385 | if self.tokens_per_batch == -1: 386 | batches = np.array_split(indices, math.ceil(len(indices) * 1. / self.batch_size)) 387 | else: 388 | batch_ids = np.cumsum(lengths[indices]) // self.tokens_per_batch 389 | _, bounds = np.unique(batch_ids, return_index=True) 390 | batches = [indices[bounds[i]:bounds[i + 1]] for i in range(len(bounds) - 1)] 391 | if bounds[-1] < len(indices): 392 | batches.append(indices[bounds[-1]:]) 393 | 394 | # optionally shuffle batches 395 | if shuffle: 396 | np.random.shuffle(batches) 397 | 398 | # sanity checks 399 | assert n_sentences == sum([len(x) for x in batches]) 400 | assert lengths[indices].sum() == sum([lengths[x].sum() for x in batches]) 401 | # assert set.union(*[set(x.tolist()) for x in batches]) == set(range(n_sentences)) # slow 402 | 403 | # return the iterator 404 | return self.get_batches_iterator(batches, return_indices) 405 | -------------------------------------------------------------------------------- /xlm/data/dictionary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import os 9 | import numpy as np 10 | import torch 11 | from logging import getLogger 12 | 13 | 14 | logger = getLogger() 15 | 16 | 17 | BOS_WORD = '' 18 | EOS_WORD = '' 19 | PAD_WORD = '' 20 | UNK_WORD = '' 21 | 22 | SPECIAL_WORD = '' 23 | SPECIAL_WORDS = 10 24 | 25 | SEP_WORD = SPECIAL_WORD % 0 26 | MASK_WORD = SPECIAL_WORD % 1 27 | 28 | 29 | class Dictionary(object): 30 | 31 | def __init__(self, id2word, word2id, counts): 32 | assert len(id2word) == len(word2id) == len(counts) 33 | self.id2word = id2word 34 | self.word2id = word2id 35 | self.counts = counts 36 | self.bos_index = word2id[BOS_WORD] 37 | self.eos_index = word2id[EOS_WORD] 38 | self.pad_index = word2id[PAD_WORD] 39 | self.unk_index = word2id[UNK_WORD] 40 | self.check_valid() 41 | 42 | def __len__(self): 43 | """ 44 | Returns the number of words in the dictionary. 45 | """ 46 | return len(self.id2word) 47 | 48 | def __getitem__(self, i): 49 | """ 50 | Returns the word of the specified index. 51 | """ 52 | return self.id2word[i] 53 | 54 | def __contains__(self, w): 55 | """ 56 | Returns whether a word is in the dictionary. 57 | """ 58 | return w in self.word2id 59 | 60 | def __eq__(self, y): 61 | """ 62 | Compare this dictionary with another one. 63 | """ 64 | self.check_valid() 65 | y.check_valid() 66 | if len(self.id2word) != len(y): 67 | return False 68 | return all(self.id2word[i] == y[i] for i in range(len(y))) 69 | 70 | def check_valid(self): 71 | """ 72 | Check that the dictionary is valid. 73 | """ 74 | assert self.bos_index == 0 75 | assert self.eos_index == 1 76 | assert self.pad_index == 2 77 | assert self.unk_index == 3 78 | assert all(self.id2word[4 + i] == SPECIAL_WORD % i for i in range(SPECIAL_WORDS)) 79 | assert len(self.id2word) == len(self.word2id) == len(self.counts) 80 | assert set(self.word2id.keys()) == set(self.counts.keys()) 81 | for i in range(len(self.id2word)): 82 | assert self.word2id[self.id2word[i]] == i 83 | last_count = 1e18 84 | for i in range(4 + SPECIAL_WORDS, len(self.id2word) - 1): 85 | count = self.counts[self.id2word[i]] 86 | assert count <= last_count 87 | last_count = count 88 | 89 | def index(self, word, no_unk=False): 90 | """ 91 | Returns the index of the specified word. 92 | """ 93 | if no_unk: 94 | return self.word2id[word] 95 | else: 96 | return self.word2id.get(word, self.unk_index) 97 | 98 | def max_vocab(self, max_vocab): 99 | """ 100 | Limit the vocabulary size. 101 | """ 102 | assert max_vocab >= 1 103 | init_size = len(self) 104 | self.id2word = {k: v for k, v in self.id2word.items() if k < max_vocab} 105 | self.word2id = {v: k for k, v in self.id2word.items()} 106 | self.counts = {k: v for k, v in self.counts.items() if k in self.word2id} 107 | self.check_valid() 108 | logger.info("Maximum vocabulary size: %i. Dictionary size: %i -> %i (removed %i words)." 109 | % (max_vocab, init_size, len(self), init_size - len(self))) 110 | 111 | def min_count(self, min_count): 112 | """ 113 | Threshold on the word frequency counts. 114 | """ 115 | assert min_count >= 0 116 | init_size = len(self) 117 | self.id2word = {k: v for k, v in self.id2word.items() if self.counts[self.id2word[k]] >= min_count or k < 4 + SPECIAL_WORDS} 118 | self.word2id = {v: k for k, v in self.id2word.items()} 119 | self.counts = {k: v for k, v in self.counts.items() if k in self.word2id} 120 | self.check_valid() 121 | logger.info("Minimum frequency count: %i. Dictionary size: %i -> %i (removed %i words)." 122 | % (min_count, init_size, len(self), init_size - len(self))) 123 | 124 | @staticmethod 125 | def read_vocab(vocab_path): 126 | """ 127 | Create a dictionary from a vocabulary file. 128 | """ 129 | skipped = 0 130 | assert os.path.isfile(vocab_path), vocab_path 131 | word2id = {BOS_WORD: 0, EOS_WORD: 1, PAD_WORD: 2, UNK_WORD: 3} 132 | for i in range(SPECIAL_WORDS): 133 | word2id[SPECIAL_WORD % i] = 4 + i 134 | counts = {k: 0 for k in word2id.keys()} 135 | f = open(vocab_path, 'r', encoding='utf-8') 136 | for i, line in enumerate(f): 137 | if '\u2028' in line: 138 | skipped += 1 139 | continue 140 | line = line.rstrip().split() 141 | if len(line) != 2: 142 | skipped += 1 143 | continue 144 | assert len(line) == 2, (i, line) 145 | # assert line[0] not in word2id and line[1].isdigit(), (i, line) 146 | assert line[1].isdigit(), (i, line) 147 | if line[0] in word2id: 148 | skipped += 1 149 | print('%s already in vocab' % line[0]) 150 | continue 151 | if not line[1].isdigit(): 152 | skipped += 1 153 | print('Empty word at line %s with count %s' % (i, line)) 154 | continue 155 | word2id[line[0]] = 4 + SPECIAL_WORDS + i - skipped # shift because of extra words 156 | counts[line[0]] = int(line[1]) 157 | f.close() 158 | id2word = {v: k for k, v in word2id.items()} 159 | dico = Dictionary(id2word, word2id, counts) 160 | logger.info("Read %i words from the vocabulary file." % len(dico)) 161 | if skipped > 0: 162 | logger.warning("Skipped %i empty lines!" % skipped) 163 | return dico 164 | 165 | @staticmethod 166 | def index_data(path, bin_path, dico): 167 | """ 168 | Index sentences with a dictionary. 169 | """ 170 | if bin_path is not None and os.path.isfile(bin_path): 171 | print("Loading data from %s ..." % bin_path) 172 | data = torch.load(bin_path) 173 | assert dico == data['dico'] 174 | return data 175 | 176 | positions = [] 177 | sentences = [] 178 | unk_words = {} 179 | 180 | # index sentences 181 | f = open(path, 'r', encoding='utf-8') 182 | for i, line in enumerate(f): 183 | if i % 1000000 == 0 and i > 0: 184 | print(i) 185 | s = line.rstrip().split() 186 | # skip empty sentences 187 | if len(s) == 0: 188 | print("Empty sentence in line %i." % i) 189 | # index sentence words 190 | count_unk = 0 191 | indexed = [] 192 | for w in s: 193 | word_id = dico.index(w, no_unk=False) 194 | # if we find a special word which is not an unknown word, skip the sentence 195 | if 0 <= word_id < 4 + SPECIAL_WORDS and word_id != 3: 196 | logger.warning('Found unexpected special word "%s" (%i)!!' % (w, word_id)) 197 | continue 198 | assert word_id >= 0 199 | indexed.append(word_id) 200 | if word_id == dico.unk_index: 201 | unk_words[w] = unk_words.get(w, 0) + 1 202 | count_unk += 1 203 | # add sentence 204 | positions.append([len(sentences), len(sentences) + len(indexed)]) 205 | sentences.extend(indexed) 206 | sentences.append(1) # EOS index 207 | f.close() 208 | 209 | # tensorize data 210 | positions = np.int64(positions) 211 | if len(dico) < 1 << 16: 212 | sentences = np.uint16(sentences) 213 | elif len(dico) < 1 << 31: 214 | sentences = np.int32(sentences) 215 | else: 216 | raise Exception("Dictionary is too big.") 217 | assert sentences.min() >= 0 218 | data = { 219 | 'dico': dico, 220 | 'positions': positions, 221 | 'sentences': sentences, 222 | 'unk_words': unk_words, 223 | } 224 | if bin_path is not None: 225 | print("Saving the data to %s ..." % bin_path) 226 | torch.save(data, bin_path, pickle_protocol=4) 227 | 228 | return data 229 | -------------------------------------------------------------------------------- /xlm/data/loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from logging import getLogger 9 | import os 10 | import numpy as np 11 | import torch 12 | 13 | from .dataset import StreamDataset, Dataset, ParallelDataset 14 | from .dictionary import BOS_WORD, EOS_WORD, PAD_WORD, UNK_WORD, MASK_WORD 15 | 16 | 17 | logger = getLogger() 18 | 19 | 20 | def process_binarized(data, params): 21 | """ 22 | Process a binarized dataset and log main statistics. 23 | """ 24 | dico = data['dico'] 25 | assert ((data['sentences'].dtype == np.uint16) and (len(dico) < 1 << 16) or 26 | (data['sentences'].dtype == np.int32) and (1 << 16 <= len(dico) < 1 << 31)) 27 | logger.info("%i words (%i unique) in %i sentences. %i unknown words (%i unique) covering %.2f%% of the data." % ( 28 | len(data['sentences']) - len(data['positions']), 29 | len(dico), len(data['positions']), 30 | sum(data['unk_words'].values()), len(data['unk_words']), 31 | 100. * sum(data['unk_words'].values()) / (len(data['sentences']) - len(data['positions'])) 32 | )) 33 | if params.max_vocab != -1: 34 | assert params.max_vocab > 0 35 | logger.info("Selecting %i most frequent words ..." % params.max_vocab) 36 | dico.max_vocab(params.max_vocab) 37 | data['sentences'][data['sentences'] >= params.max_vocab] = dico.index(UNK_WORD) 38 | unk_count = (data['sentences'] == dico.index(UNK_WORD)).sum() 39 | logger.info("Now %i unknown words covering %.2f%% of the data." 40 | % (unk_count, 100. * unk_count / (len(data['sentences']) - len(data['positions'])))) 41 | if params.min_count > 0: 42 | logger.info("Selecting words with >= %i occurrences ..." % params.min_count) 43 | dico.min_count(params.min_count) 44 | data['sentences'][data['sentences'] >= len(dico)] = dico.index(UNK_WORD) 45 | unk_count = (data['sentences'] == dico.index(UNK_WORD)).sum() 46 | logger.info("Now %i unknown words covering %.2f%% of the data." 47 | % (unk_count, 100. * unk_count / (len(data['sentences']) - len(data['positions'])))) 48 | if (data['sentences'].dtype == np.int32) and (len(dico) < 1 << 16): 49 | logger.info("Less than 65536 words. Moving data from int32 to uint16 ...") 50 | data['sentences'] = data['sentences'].astype(np.uint16) 51 | return data 52 | 53 | 54 | def load_binarized(path, params): 55 | """ 56 | Load a binarized dataset. 57 | """ 58 | assert path.endswith('.pth') 59 | if params.debug_train: 60 | path = path.replace('train', 'valid') 61 | if getattr(params, 'multi_gpu', False): 62 | split_path = '%s.%i.pth' % (path[:-4], params.local_rank) 63 | if os.path.isfile(split_path): 64 | assert params.split_data is False 65 | path = split_path 66 | assert os.path.isfile(path), path 67 | logger.info("Loading data from %s ..." % path) 68 | data = torch.load(path) 69 | data = process_binarized(data, params) 70 | return data 71 | 72 | 73 | def set_dico_parameters(params, data, dico): 74 | """ 75 | Update dictionary parameters. 76 | """ 77 | if 'dico' in data: 78 | assert data['dico'] == dico 79 | else: 80 | data['dico'] = dico 81 | 82 | n_words = len(dico) 83 | bos_index = dico.index(BOS_WORD) 84 | eos_index = dico.index(EOS_WORD) 85 | pad_index = dico.index(PAD_WORD) 86 | unk_index = dico.index(UNK_WORD) 87 | mask_index = dico.index(MASK_WORD) 88 | if hasattr(params, 'bos_index'): 89 | assert params.n_words == n_words 90 | assert params.bos_index == bos_index 91 | assert params.eos_index == eos_index 92 | assert params.pad_index == pad_index 93 | assert params.unk_index == unk_index 94 | assert params.mask_index == mask_index 95 | else: 96 | params.n_words = n_words 97 | params.bos_index = bos_index 98 | params.eos_index = eos_index 99 | params.pad_index = pad_index 100 | params.unk_index = unk_index 101 | params.mask_index = mask_index 102 | 103 | 104 | def load_mono_data(params, data): 105 | """ 106 | Load monolingual data. 107 | """ 108 | data['mono'] = {} 109 | data['mono_stream'] = {} 110 | 111 | for lang in params.mono_dataset.keys(): 112 | 113 | logger.info('============ Monolingual data (%s)' % lang) 114 | 115 | assert lang in params.langs and lang not in data['mono'] 116 | data['mono'][lang] = {} 117 | data['mono_stream'][lang] = {} 118 | 119 | for splt in ['train', 'valid', 'test']: 120 | 121 | # no need to load training data for evaluation 122 | if splt == 'train' and params.eval_only: 123 | continue 124 | 125 | # load data / update dictionary parameters / update data 126 | mono_data = load_binarized(params.mono_dataset[lang][splt], params) 127 | set_dico_parameters(params, data, mono_data['dico']) 128 | 129 | # create stream dataset 130 | bs = params.batch_size if splt == 'train' else 1 131 | data['mono_stream'][lang][splt] = StreamDataset(mono_data['sentences'], mono_data['positions'], bs, params) 132 | 133 | # if there are several processes on the same machine, we can split the dataset 134 | if splt == 'train' and params.split_data and 1 < params.n_gpu_per_node <= data['mono_stream'][lang][splt].n_batches: 135 | n_batches = data['mono_stream'][lang][splt].n_batches // params.n_gpu_per_node 136 | a = n_batches * params.local_rank 137 | b = n_batches * params.local_rank + n_batches 138 | data['mono_stream'][lang][splt].select_data(a, b) 139 | 140 | # for denoising auto-encoding and online back-translation, we need a non-stream (batched) dataset 141 | if lang in params.ae_steps or lang in params.bt_src_langs: 142 | 143 | # create batched dataset 144 | dataset = Dataset(mono_data['sentences'], mono_data['positions'], params) 145 | 146 | # remove empty and too long sentences 147 | if splt == 'train': 148 | dataset.remove_empty_sentences() 149 | dataset.remove_long_sentences(params.max_len) 150 | 151 | # if there are several processes on the same machine, we can split the dataset 152 | if splt == 'train' and params.n_gpu_per_node > 1 and params.split_data: 153 | n_sent = len(dataset) // params.n_gpu_per_node 154 | a = n_sent * params.local_rank 155 | b = n_sent * params.local_rank + n_sent 156 | dataset.select_data(a, b) 157 | 158 | data['mono'][lang][splt] = dataset 159 | 160 | logger.info("") 161 | 162 | logger.info("") 163 | 164 | 165 | def load_para_data(params, data): 166 | """ 167 | Load parallel data. 168 | """ 169 | data['para'] = {} 170 | 171 | required_para_train = set(params.clm_steps + params.mlm_steps + params.pc_steps + params.mt_steps) 172 | 173 | for src, tgt in params.para_dataset.keys(): 174 | 175 | logger.info('============ Parallel data (%s-%s)' % (src, tgt)) 176 | 177 | assert (src, tgt) not in data['para'] 178 | data['para'][(src, tgt)] = {} 179 | 180 | for splt in ['train', 'valid', 'test']: 181 | 182 | # no need to load training data for evaluation 183 | if splt == 'train' and params.eval_only: 184 | continue 185 | 186 | # for back-translation, we can't load training data 187 | if splt == 'train' and (src, tgt) not in required_para_train and (tgt, src) not in required_para_train: 188 | continue 189 | 190 | # load binarized datasets 191 | src_path, tgt_path = params.para_dataset[(src, tgt)][splt] 192 | src_data = load_binarized(src_path, params) 193 | tgt_data = load_binarized(tgt_path, params) 194 | 195 | # update dictionary parameters 196 | set_dico_parameters(params, data, src_data['dico']) 197 | set_dico_parameters(params, data, tgt_data['dico']) 198 | 199 | # create ParallelDataset 200 | dataset = ParallelDataset( 201 | src_data['sentences'], src_data['positions'], 202 | tgt_data['sentences'], tgt_data['positions'], 203 | params 204 | ) 205 | 206 | # remove empty and too long sentences 207 | if splt == 'train': 208 | dataset.remove_empty_sentences() 209 | dataset.remove_long_sentences(params.max_len) 210 | 211 | # for validation and test set, enumerate sentence per sentence 212 | if splt != 'train': 213 | dataset.tokens_per_batch = -1 214 | 215 | # if there are several processes on the same machine, we can split the dataset 216 | if splt == 'train' and params.n_gpu_per_node > 1 and params.split_data: 217 | n_sent = len(dataset) // params.n_gpu_per_node 218 | a = n_sent * params.local_rank 219 | b = n_sent * params.local_rank + n_sent 220 | dataset.select_data(a, b) 221 | 222 | data['para'][(src, tgt)][splt] = dataset 223 | logger.info("") 224 | 225 | logger.info("") 226 | 227 | 228 | def check_data_params(params): 229 | """ 230 | Check datasets parameters. 231 | """ 232 | # data path 233 | assert os.path.isdir(params.data_path), params.data_path 234 | 235 | # check languages 236 | params.langs = params.lgs.split('-') if params.lgs != 'debug' else ['en'] 237 | assert len(params.langs) == len(set(params.langs)) >= 1 238 | # assert sorted(params.langs) == params.langs 239 | params.id2lang = {k: v for k, v in enumerate(sorted(params.langs))} 240 | params.lang2id = {k: v for v, k in params.id2lang.items()} 241 | params.n_langs = len(params.langs) 242 | 243 | # CLM steps 244 | clm_steps = [s.split('-') for s in params.clm_steps.split(',') if len(s) > 0] 245 | params.clm_steps = [(s[0], None) if len(s) == 1 else tuple(s) for s in clm_steps] 246 | assert all([(l1 in params.langs) and (l2 in params.langs or l2 is None) for l1, l2 in params.clm_steps]) 247 | assert len(params.clm_steps) == len(set(params.clm_steps)) 248 | 249 | # MLM / TLM steps 250 | mlm_steps = [s.split('-') for s in params.mlm_steps.split(',') if len(s) > 0] 251 | params.mlm_steps = [(s[0], None) if len(s) == 1 else tuple(s) for s in mlm_steps] 252 | assert all([(l1 in params.langs) and (l2 in params.langs or l2 is None) for l1, l2 in params.mlm_steps]) 253 | assert len(params.mlm_steps) == len(set(params.mlm_steps)) 254 | 255 | # parallel classification steps 256 | params.pc_steps = [tuple(s.split('-')) for s in params.pc_steps.split(',') if len(s) > 0] 257 | assert all([len(x) == 2 for x in params.pc_steps]) 258 | assert all([l1 in params.langs and l2 in params.langs for l1, l2 in params.pc_steps]) 259 | assert all([l1 != l2 for l1, l2 in params.pc_steps]) 260 | assert len(params.pc_steps) == len(set(params.pc_steps)) 261 | 262 | # machine translation steps 263 | params.mt_steps = [tuple(s.split('-')) for s in params.mt_steps.split(',') if len(s) > 0] 264 | assert all([len(x) == 2 for x in params.mt_steps]) 265 | assert all([l1 in params.langs and l2 in params.langs for l1, l2 in params.mt_steps]) 266 | assert all([l1 != l2 for l1, l2 in params.mt_steps]) 267 | assert len(params.mt_steps) == len(set(params.mt_steps)) 268 | assert len(params.mt_steps) == 0 or not params.encoder_only 269 | 270 | # denoising auto-encoder steps 271 | params.ae_steps = [s for s in params.ae_steps.split(',') if len(s) > 0] 272 | assert all([lang in params.langs for lang in params.ae_steps]) 273 | assert len(params.ae_steps) == len(set(params.ae_steps)) 274 | assert len(params.ae_steps) == 0 or not params.encoder_only 275 | 276 | # back-translation steps 277 | params.bt_steps = [tuple(s.split('-')) for s in params.bt_steps.split(',') if len(s) > 0] 278 | assert all([len(x) == 3 for x in params.bt_steps]) 279 | assert all([l1 in params.langs and l2 in params.langs and l3 in params.langs for l1, l2, l3 in params.bt_steps]) 280 | assert all([l1 == l3 and l1 != l2 for l1, l2, l3 in params.bt_steps]) 281 | assert len(params.bt_steps) == len(set(params.bt_steps)) 282 | assert len(params.bt_steps) == 0 or not params.encoder_only 283 | params.bt_src_langs = [l1 for l1, _, _ in params.bt_steps] 284 | 285 | # check monolingual datasets 286 | required_mono = set([l1 for l1, l2 in (params.mlm_steps + params.clm_steps) if l2 is None] + params.ae_steps + params.bt_src_langs) 287 | params.mono_dataset = { 288 | lang: { 289 | splt: os.path.join(params.data_path, '%s.%s.pth' % (splt, lang)) 290 | for splt in ['train', 'valid', 'test'] 291 | } for lang in params.langs if lang in required_mono 292 | } 293 | for paths in params.mono_dataset.values(): 294 | for p in paths.values(): 295 | if not os.path.isfile(p): 296 | logger.error(f"{p} not found") 297 | assert all([all([os.path.isfile(p) for p in paths.values()]) for paths in params.mono_dataset.values()]) 298 | 299 | # check parallel datasets 300 | required_para_train = set(params.clm_steps + params.mlm_steps + params.pc_steps + params.mt_steps) 301 | required_para = required_para_train | set([(l2, l3) for _, l2, l3 in params.bt_steps]) 302 | params.para_dataset = { 303 | (src, tgt): { 304 | splt: (os.path.join(params.data_path, '%s.%s-%s.%s.pth' % (splt, src, tgt, src)), 305 | os.path.join(params.data_path, '%s.%s-%s.%s.pth' % (splt, src, tgt, tgt))) 306 | for splt in ['train', 'valid', 'test'] 307 | if splt != 'train' or (src, tgt) in required_para_train or (tgt, src) in required_para_train 308 | } for src in params.langs for tgt in params.langs 309 | if src < tgt and ((src, tgt) in required_para or (tgt, src) in required_para) 310 | } 311 | for paths in params.para_dataset.values(): 312 | for p1, p2 in paths.values(): 313 | if not os.path.isfile(p1): 314 | logger.error(f"{p1} not found") 315 | if not os.path.isfile(p2): 316 | logger.error(f"{p2} not found") 317 | assert all([all([os.path.isfile(p1) and os.path.isfile(p2) for p1, p2 in paths.values()]) for paths in params.para_dataset.values()]) 318 | 319 | # check that we can evaluate on BLEU 320 | assert params.eval_bleu is False or len(params.mt_steps + params.bt_steps) > 0 321 | 322 | 323 | def load_data(params): 324 | """ 325 | Load monolingual data. 326 | The returned dictionary contains: 327 | - dico (dictionary) 328 | - vocab (FloatTensor) 329 | - train / valid / test (monolingual datasets) 330 | """ 331 | data = {} 332 | 333 | # monolingual datasets 334 | load_mono_data(params, data) 335 | 336 | # parallel datasets 337 | load_para_data(params, data) 338 | 339 | # monolingual data summary 340 | logger.info('============ Data summary') 341 | for lang, v in data['mono_stream'].items(): 342 | for data_set in v.keys(): 343 | logger.info('{: <18} - {: >5} - {: >12}:{: >10}'.format('Monolingual data', data_set, lang, len(v[data_set]))) 344 | 345 | # parallel data summary 346 | for (src, tgt), v in data['para'].items(): 347 | for data_set in v.keys(): 348 | logger.info('{: <18} - {: >5} - {: >12}:{: >10}'.format('Parallel data', data_set, '%s-%s' % (src, tgt), len(v[data_set]))) 349 | 350 | logger.info("") 351 | return data 352 | -------------------------------------------------------------------------------- /xlm/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/XLM/cd281d32612d145c6742b4d3f048f80df8669c30/xlm/evaluation/__init__.py -------------------------------------------------------------------------------- /xlm/evaluation/glue.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from logging import getLogger 9 | import os 10 | import copy 11 | import time 12 | import json 13 | from collections import OrderedDict 14 | 15 | import numpy as np 16 | import torch 17 | from torch import nn 18 | import torch.nn.functional as F 19 | 20 | from scipy.stats import spearmanr, pearsonr 21 | from sklearn.metrics import f1_score, matthews_corrcoef 22 | 23 | from ..optim import get_optimizer 24 | from ..utils import concat_batches, truncate, to_cuda 25 | from ..data.dataset import Dataset, ParallelDataset 26 | from ..data.loader import load_binarized, set_dico_parameters 27 | 28 | 29 | N_CLASSES = { 30 | 'MNLI-m': 3, 31 | 'MNLI-mm': 3, 32 | 'QQP': 2, 33 | 'QNLI': 2, 34 | 'SST-2': 2, 35 | 'CoLA': 2, 36 | 'MRPC': 2, 37 | 'RTE': 2, 38 | 'STS-B': 1, 39 | 'WNLI': 2, 40 | 'AX_MNLI-m': 3, 41 | } 42 | 43 | 44 | logger = getLogger() 45 | 46 | 47 | class GLUE: 48 | 49 | def __init__(self, embedder, scores, params): 50 | """ 51 | Initialize GLUE trainer / evaluator. 52 | Initial `embedder` should be on CPU to save memory. 53 | """ 54 | self._embedder = embedder 55 | self.params = params 56 | self.scores = scores 57 | 58 | def get_iterator(self, splt): 59 | """ 60 | Build data iterator. 61 | """ 62 | return self.data[splt]['x'].get_iterator( 63 | shuffle=(splt == 'train'), 64 | return_indices=True, 65 | group_by_size=self.params.group_by_size 66 | ) 67 | 68 | def run(self, task): 69 | """ 70 | Run GLUE training / evaluation. 71 | """ 72 | params = self.params 73 | 74 | # task parameters 75 | self.task = task 76 | params.out_features = N_CLASSES[task] 77 | self.is_classif = task != 'STS-B' 78 | 79 | # load data 80 | self.data = self.load_data(task) 81 | if not self.data['dico'] == self._embedder.dico: 82 | raise Exception(("Dictionary in evaluation data (%i words) seems different than the one " + 83 | "in the pretrained model (%i words). Please verify you used the same dictionary, " + 84 | "and the same values for max_vocab and min_count.") % (len(self.data['dico']), len(self._embedder.dico))) 85 | 86 | # embedder 87 | self.embedder = copy.deepcopy(self._embedder) 88 | self.embedder.cuda() 89 | 90 | # projection layer 91 | self.proj = nn.Sequential(*[ 92 | nn.Dropout(params.dropout), 93 | nn.Linear(self.embedder.out_dim, params.out_features) 94 | ]).cuda() 95 | 96 | # optimizers 97 | self.optimizer_e = get_optimizer(list(self.embedder.get_parameters(params.finetune_layers)), params.optimizer_e) 98 | self.optimizer_p = get_optimizer(self.proj.parameters(), params.optimizer_p) 99 | 100 | # train and evaluate the model 101 | for epoch in range(params.n_epochs): 102 | 103 | # update epoch 104 | self.epoch = epoch 105 | 106 | # training 107 | logger.info("GLUE - %s - Training epoch %i ..." % (task, epoch)) 108 | self.train() 109 | 110 | # evaluation 111 | logger.info("GLUE - %s - Evaluating epoch %i ..." % (task, epoch)) 112 | with torch.no_grad(): 113 | scores = self.eval('valid') 114 | self.scores.update(scores) 115 | self.eval('test') 116 | 117 | def train(self): 118 | """ 119 | Finetune for one epoch on the training set. 120 | """ 121 | params = self.params 122 | self.embedder.train() 123 | self.proj.train() 124 | 125 | # training variables 126 | losses = [] 127 | ns = 0 # number of sentences 128 | nw = 0 # number of words 129 | t = time.time() 130 | 131 | iterator = self.get_iterator('train') 132 | lang_id = params.lang2id['en'] 133 | 134 | while True: 135 | 136 | # batch 137 | try: 138 | batch = next(iterator) 139 | except StopIteration: 140 | break 141 | if self.n_sent == 1: 142 | (x, lengths), idx = batch 143 | x, lengths = truncate(x, lengths, params.max_len, params.eos_index) 144 | else: 145 | (sent1, len1), (sent2, len2), idx = batch 146 | sent1, len1 = truncate(sent1, len1, params.max_len, params.eos_index) 147 | sent2, len2 = truncate(sent2, len2, params.max_len, params.eos_index) 148 | x, lengths, _, _ = concat_batches(sent1, len1, lang_id, sent2, len2, lang_id, params.pad_index, params.eos_index, reset_positions=False) 149 | y = self.data['train']['y'][idx] 150 | bs = len(lengths) 151 | 152 | # cuda 153 | x, y, lengths = to_cuda(x, y, lengths) 154 | 155 | # loss 156 | output = self.proj(self.embedder.get_embeddings(x, lengths, positions=None, langs=None)) 157 | if self.is_classif: 158 | loss = F.cross_entropy(output, y, weight=self.weights) 159 | else: 160 | loss = F.mse_loss(output.squeeze(1), y.float()) 161 | 162 | # backward / optimization 163 | self.optimizer_e.zero_grad() 164 | self.optimizer_p.zero_grad() 165 | loss.backward() 166 | self.optimizer_e.step() 167 | self.optimizer_p.step() 168 | 169 | # update statistics 170 | ns += bs 171 | nw += lengths.sum().item() 172 | losses.append(loss.item()) 173 | 174 | # log 175 | if ns != 0 and ns % (10 * bs) < bs: 176 | logger.info( 177 | "GLUE - %s - Epoch %s - Train iter %7i - %.1f words/s - %s Loss: %.4f" 178 | % (self.task, self.epoch, ns, nw / (time.time() - t), 'XE' if self.is_classif else 'MSE', sum(losses) / len(losses)) 179 | ) 180 | nw, t = 0, time.time() 181 | losses = [] 182 | 183 | # epoch size 184 | if params.epoch_size != -1 and ns >= params.epoch_size: 185 | break 186 | 187 | def eval(self, splt): 188 | """ 189 | Evaluate on XNLI validation and test sets, for all languages. 190 | """ 191 | params = self.params 192 | self.embedder.eval() 193 | self.proj.eval() 194 | 195 | assert splt in ['valid', 'test'] 196 | has_labels = 'y' in self.data[splt] 197 | 198 | scores = OrderedDict({'epoch': self.epoch}) 199 | task = self.task.lower() 200 | 201 | idxs = [] # sentence indices 202 | prob = [] # probabilities 203 | pred = [] # predicted values 204 | gold = [] # real values 205 | 206 | lang_id = params.lang2id['en'] 207 | 208 | for batch in self.get_iterator(splt): 209 | 210 | # batch 211 | if self.n_sent == 1: 212 | (x, lengths), idx = batch 213 | # x, lengths = truncate(x, lengths, params.max_len, params.eos_index) 214 | else: 215 | (sent1, len1), (sent2, len2), idx = batch 216 | # sent1, len1 = truncate(sent1, len1, params.max_len, params.eos_index) 217 | # sent2, len2 = truncate(sent2, len2, params.max_len, params.eos_index) 218 | x, lengths, _, _ = concat_batches(sent1, len1, lang_id, sent2, len2, lang_id, params.pad_index, params.eos_index, reset_positions=False) 219 | y = self.data[splt]['y'][idx] if has_labels else None 220 | 221 | # cuda 222 | x, y, lengths = to_cuda(x, y, lengths) 223 | 224 | # prediction 225 | output = self.proj(self.embedder.get_embeddings(x, lengths, positions=None, langs=None)) 226 | p = output.data.max(1)[1] if self.is_classif else output.squeeze(1) 227 | idxs.append(idx) 228 | prob.append(output.cpu().numpy()) 229 | pred.append(p.cpu().numpy()) 230 | if has_labels: 231 | gold.append(y.cpu().numpy()) 232 | 233 | # indices / predictions 234 | idxs = np.concatenate(idxs) 235 | prob = np.concatenate(prob) 236 | pred = np.concatenate(pred) 237 | assert len(idxs) == len(pred), (len(idxs), len(pred)) 238 | assert idxs[-1] == len(idxs) - 1, (idxs[-1], len(idxs) - 1) 239 | 240 | # score the predictions if we have labels 241 | if has_labels: 242 | gold = np.concatenate(gold) 243 | prefix = f'{splt}_{task}' 244 | if self.is_classif: 245 | scores['%s_acc' % prefix] = 100. * (pred == gold).sum() / len(pred) 246 | scores['%s_f1' % prefix] = 100. * f1_score(gold, pred, average='binary' if params.out_features == 2 else 'micro') 247 | scores['%s_mc' % prefix] = 100. * matthews_corrcoef(gold, pred) 248 | else: 249 | scores['%s_prs' % prefix] = 100. * pearsonr(pred, gold)[0] 250 | scores['%s_spr' % prefix] = 100. * spearmanr(pred, gold)[0] 251 | logger.info("__log__:%s" % json.dumps(scores)) 252 | 253 | # output predictions 254 | pred_path = os.path.join(params.dump_path, f'{splt}.pred.{self.epoch}') 255 | with open(pred_path, 'w') as f: 256 | for i, p in zip(idxs, prob): 257 | f.write('%i\t%s\n' % (i, ','.join([str(x) for x in p]))) 258 | logger.info(f"Wrote {len(idxs)} {splt} predictions to {pred_path}") 259 | 260 | return scores 261 | 262 | def load_data(self, task): 263 | """ 264 | Load pair regression/classification bi-sentence tasks 265 | """ 266 | params = self.params 267 | data = {splt: {} for splt in ['train', 'valid', 'test']} 268 | dpath = os.path.join(params.data_path, 'eval', task) 269 | 270 | self.n_sent = 1 if task in ['SST-2', 'CoLA'] else 2 271 | 272 | for splt in ['train', 'valid', 'test']: 273 | 274 | # load data and dictionary 275 | data1 = load_binarized(os.path.join(dpath, '%s.s1.pth' % splt), params) 276 | data2 = load_binarized(os.path.join(dpath, '%s.s2.pth' % splt), params) if self.n_sent == 2 else None 277 | data['dico'] = data.get('dico', data1['dico']) 278 | 279 | # set dictionary parameters 280 | set_dico_parameters(params, data, data1['dico']) 281 | if self.n_sent == 2: 282 | set_dico_parameters(params, data, data2['dico']) 283 | 284 | # create dataset 285 | if self.n_sent == 1: 286 | data[splt]['x'] = Dataset(data1['sentences'], data1['positions'], params) 287 | else: 288 | data[splt]['x'] = ParallelDataset( 289 | data1['sentences'], data1['positions'], 290 | data2['sentences'], data2['positions'], 291 | params 292 | ) 293 | 294 | # load labels 295 | if splt != 'test' or task in ['MRPC']: 296 | # read labels from file 297 | with open(os.path.join(dpath, '%s.label' % splt), 'r') as f: 298 | lines = [l.rstrip() for l in f] 299 | # STS-B task 300 | if task == 'STS-B': 301 | assert all(0 <= float(x) <= 5 for x in lines) 302 | y = [float(l) for l in lines] 303 | # QQP 304 | elif task == 'QQP': 305 | UNK_LABEL = 0 306 | lab2id = {x: i for i, x in enumerate(sorted(set(lines) - set([''])))} 307 | y = [lab2id.get(x, UNK_LABEL) for x in lines] 308 | # other tasks 309 | else: 310 | lab2id = {x: i for i, x in enumerate(sorted(set(lines)))} 311 | y = [lab2id[x] for x in lines] 312 | data[splt]['y'] = torch.LongTensor(y) 313 | assert len(data[splt]['x']) == len(data[splt]['y']) 314 | 315 | # compute weights for weighted training 316 | if task != 'STS-B' and params.weighted_training: 317 | weights = torch.FloatTensor([ 318 | 1.0 / (data['train']['y'] == i).sum().item() 319 | for i in range(len(lab2id)) 320 | ]).cuda() 321 | self.weights = weights / weights.sum() 322 | else: 323 | self.weights = None 324 | 325 | return data 326 | -------------------------------------------------------------------------------- /xlm/evaluation/multi-bleu.perl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # 3 | # This file is part of moses. Its use is licensed under the GNU Lesser General 4 | # Public License version 2.1 or, at your option, any later version. 5 | 6 | # $Id$ 7 | use warnings; 8 | use strict; 9 | 10 | my $lowercase = 0; 11 | if ($ARGV[0] eq "-lc") { 12 | $lowercase = 1; 13 | shift; 14 | } 15 | 16 | my $stem = $ARGV[0]; 17 | if (!defined $stem) { 18 | print STDERR "usage: multi-bleu.pl [-lc] reference < hypothesis\n"; 19 | print STDERR "Reads the references from reference or reference0, reference1, ...\n"; 20 | exit(1); 21 | } 22 | 23 | $stem .= ".ref" if !-e $stem && !-e $stem."0" && -e $stem.".ref0"; 24 | 25 | my @REF; 26 | my $ref=0; 27 | while(-e "$stem$ref") { 28 | &add_to_ref("$stem$ref",\@REF); 29 | $ref++; 30 | } 31 | &add_to_ref($stem,\@REF) if -e $stem; 32 | die("ERROR: could not find reference file $stem") unless scalar @REF; 33 | 34 | # add additional references explicitly specified on the command line 35 | shift; 36 | foreach my $stem (@ARGV) { 37 | &add_to_ref($stem,\@REF) if -e $stem; 38 | } 39 | 40 | 41 | 42 | sub add_to_ref { 43 | my ($file,$REF) = @_; 44 | my $s=0; 45 | if ($file =~ /.gz$/) { 46 | open(REF,"gzip -dc $file|") or die "Can't read $file"; 47 | } else { 48 | open(REF,$file) or die "Can't read $file"; 49 | } 50 | while() { 51 | chop; 52 | push @{$$REF[$s++]}, $_; 53 | } 54 | close(REF); 55 | } 56 | 57 | my(@CORRECT,@TOTAL,$length_translation,$length_reference); 58 | my $s=0; 59 | while() { 60 | chop; 61 | $_ = lc if $lowercase; 62 | my @WORD = split; 63 | my %REF_NGRAM = (); 64 | my $length_translation_this_sentence = scalar(@WORD); 65 | my ($closest_diff,$closest_length) = (9999,9999); 66 | foreach my $reference (@{$REF[$s]}) { 67 | # print "$s $_ <=> $reference\n"; 68 | $reference = lc($reference) if $lowercase; 69 | my @WORD = split(' ',$reference); 70 | my $length = scalar(@WORD); 71 | my $diff = abs($length_translation_this_sentence-$length); 72 | if ($diff < $closest_diff) { 73 | $closest_diff = $diff; 74 | $closest_length = $length; 75 | # print STDERR "$s: closest diff ".abs($length_translation_this_sentence-$length)." = abs($length_translation_this_sentence-$length), setting len: $closest_length\n"; 76 | } elsif ($diff == $closest_diff) { 77 | $closest_length = $length if $length < $closest_length; 78 | # from two references with the same closeness to me 79 | # take the *shorter* into account, not the "first" one. 80 | } 81 | for(my $n=1;$n<=4;$n++) { 82 | my %REF_NGRAM_N = (); 83 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 84 | my $ngram = "$n"; 85 | for(my $w=0;$w<$n;$w++) { 86 | $ngram .= " ".$WORD[$start+$w]; 87 | } 88 | $REF_NGRAM_N{$ngram}++; 89 | } 90 | foreach my $ngram (keys %REF_NGRAM_N) { 91 | if (!defined($REF_NGRAM{$ngram}) || 92 | $REF_NGRAM{$ngram} < $REF_NGRAM_N{$ngram}) { 93 | $REF_NGRAM{$ngram} = $REF_NGRAM_N{$ngram}; 94 | # print "$i: REF_NGRAM{$ngram} = $REF_NGRAM{$ngram}
\n"; 95 | } 96 | } 97 | } 98 | } 99 | $length_translation += $length_translation_this_sentence; 100 | $length_reference += $closest_length; 101 | for(my $n=1;$n<=4;$n++) { 102 | my %T_NGRAM = (); 103 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 104 | my $ngram = "$n"; 105 | for(my $w=0;$w<$n;$w++) { 106 | $ngram .= " ".$WORD[$start+$w]; 107 | } 108 | $T_NGRAM{$ngram}++; 109 | } 110 | foreach my $ngram (keys %T_NGRAM) { 111 | $ngram =~ /^(\d+) /; 112 | my $n = $1; 113 | # my $corr = 0; 114 | # print "$i e $ngram $T_NGRAM{$ngram}
\n"; 115 | $TOTAL[$n] += $T_NGRAM{$ngram}; 116 | if (defined($REF_NGRAM{$ngram})) { 117 | if ($REF_NGRAM{$ngram} >= $T_NGRAM{$ngram}) { 118 | $CORRECT[$n] += $T_NGRAM{$ngram}; 119 | # $corr = $T_NGRAM{$ngram}; 120 | # print "$i e correct1 $T_NGRAM{$ngram}
\n"; 121 | } 122 | else { 123 | $CORRECT[$n] += $REF_NGRAM{$ngram}; 124 | # $corr = $REF_NGRAM{$ngram}; 125 | # print "$i e correct2 $REF_NGRAM{$ngram}
\n"; 126 | } 127 | } 128 | # $REF_NGRAM{$ngram} = 0 if !defined $REF_NGRAM{$ngram}; 129 | # print STDERR "$ngram: {$s, $REF_NGRAM{$ngram}, $T_NGRAM{$ngram}, $corr}\n" 130 | } 131 | } 132 | $s++; 133 | } 134 | my $brevity_penalty = 1; 135 | my $bleu = 0; 136 | 137 | my @bleu=(); 138 | 139 | for(my $n=1;$n<=4;$n++) { 140 | if (defined ($TOTAL[$n])){ 141 | $bleu[$n]=($TOTAL[$n])?$CORRECT[$n]/$TOTAL[$n]:0; 142 | # print STDERR "CORRECT[$n]:$CORRECT[$n] TOTAL[$n]:$TOTAL[$n]\n"; 143 | }else{ 144 | $bleu[$n]=0; 145 | } 146 | } 147 | 148 | if ($length_reference==0){ 149 | printf "BLEU = 0, 0/0/0/0 (BP=0, ratio=0, hyp_len=0, ref_len=0)\n"; 150 | exit(1); 151 | } 152 | 153 | if ($length_translation<$length_reference) { 154 | $brevity_penalty = exp(1-$length_reference/$length_translation); 155 | } 156 | $bleu = $brevity_penalty * exp((my_log( $bleu[1] ) + 157 | my_log( $bleu[2] ) + 158 | my_log( $bleu[3] ) + 159 | my_log( $bleu[4] ) ) / 4) ; 160 | printf "BLEU = %.2f, %.1f/%.1f/%.1f/%.1f (BP=%.3f, ratio=%.3f, hyp_len=%d, ref_len=%d)\n", 161 | 100*$bleu, 162 | 100*$bleu[1], 163 | 100*$bleu[2], 164 | 100*$bleu[3], 165 | 100*$bleu[4], 166 | $brevity_penalty, 167 | $length_translation / $length_reference, 168 | $length_translation, 169 | $length_reference; 170 | 171 | 172 | # print STDERR "It is in-advisable to publish scores from multi-bleu.perl. The scores depend on your tokenizer, which is unlikely to be reproducible from your paper or consistent across research groups. Instead you should detokenize then use mteval-v14.pl, which has a standard tokenization. Scores from multi-bleu.perl can still be used for internal purposes when you have a consistent tokenizer.\n"; 173 | 174 | sub my_log { 175 | return -9999999999 unless $_[0]; 176 | return log($_[0]); 177 | } 178 | -------------------------------------------------------------------------------- /xlm/evaluation/xnli.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from logging import getLogger 9 | import os 10 | import copy 11 | import time 12 | import json 13 | from collections import OrderedDict 14 | 15 | import torch 16 | from torch import nn 17 | import torch.nn.functional as F 18 | 19 | from ..optim import get_optimizer 20 | from ..utils import concat_batches, truncate, to_cuda 21 | from ..data.dataset import ParallelDataset 22 | from ..data.loader import load_binarized, set_dico_parameters 23 | 24 | 25 | XNLI_LANGS = ['ar', 'bg', 'de', 'el', 'en', 'es', 'fr', 'hi', 'ru', 'sw', 'th', 'tr', 'ur', 'vi', 'zh'] 26 | 27 | 28 | logger = getLogger() 29 | 30 | 31 | class XNLI: 32 | 33 | def __init__(self, embedder, scores, params): 34 | """ 35 | Initialize XNLI trainer / evaluator. 36 | Initial `embedder` should be on CPU to save memory. 37 | """ 38 | self._embedder = embedder 39 | self.params = params 40 | self.scores = scores 41 | 42 | def get_iterator(self, splt, lang): 43 | """ 44 | Get a monolingual data iterator. 45 | """ 46 | assert splt in ['valid', 'test'] or splt == 'train' and lang == 'en' 47 | return self.data[lang][splt]['x'].get_iterator( 48 | shuffle=(splt == 'train'), 49 | group_by_size=self.params.group_by_size, 50 | return_indices=True 51 | ) 52 | 53 | def run(self): 54 | """ 55 | Run XNLI training / evaluation. 56 | """ 57 | params = self.params 58 | 59 | # load data 60 | self.data = self.load_data() 61 | if not self.data['dico'] == self._embedder.dico: 62 | raise Exception(("Dictionary in evaluation data (%i words) seems different than the one " + 63 | "in the pretrained model (%i words). Please verify you used the same dictionary, " + 64 | "and the same values for max_vocab and min_count.") % (len(self.data['dico']), len(self._embedder.dico))) 65 | 66 | # embedder 67 | self.embedder = copy.deepcopy(self._embedder) 68 | self.embedder.cuda() 69 | 70 | # projection layer 71 | self.proj = nn.Sequential(*[ 72 | nn.Dropout(params.dropout), 73 | nn.Linear(self.embedder.out_dim, 3) 74 | ]).cuda() 75 | 76 | # optimizers 77 | self.optimizer_e = get_optimizer(list(self.embedder.get_parameters(params.finetune_layers)), params.optimizer_e) 78 | self.optimizer_p = get_optimizer(self.proj.parameters(), params.optimizer_p) 79 | 80 | # train and evaluate the model 81 | for epoch in range(params.n_epochs): 82 | 83 | # update epoch 84 | self.epoch = epoch 85 | 86 | # training 87 | logger.info("XNLI - Training epoch %i ..." % epoch) 88 | self.train() 89 | 90 | # evaluation 91 | logger.info("XNLI - Evaluating epoch %i ..." % epoch) 92 | with torch.no_grad(): 93 | scores = self.eval() 94 | self.scores.update(scores) 95 | 96 | def train(self): 97 | """ 98 | Finetune for one epoch on the XNLI English training set. 99 | """ 100 | params = self.params 101 | self.embedder.train() 102 | self.proj.train() 103 | 104 | # training variables 105 | losses = [] 106 | ns = 0 # number of sentences 107 | nw = 0 # number of words 108 | t = time.time() 109 | 110 | iterator = self.get_iterator('train', 'en') 111 | lang_id = params.lang2id['en'] 112 | 113 | while True: 114 | 115 | # batch 116 | try: 117 | batch = next(iterator) 118 | except StopIteration: 119 | break 120 | (sent1, len1), (sent2, len2), idx = batch 121 | sent1, len1 = truncate(sent1, len1, params.max_len, params.eos_index) 122 | sent2, len2 = truncate(sent2, len2, params.max_len, params.eos_index) 123 | x, lengths, positions, langs = concat_batches( 124 | sent1, len1, lang_id, 125 | sent2, len2, lang_id, 126 | params.pad_index, 127 | params.eos_index, 128 | reset_positions=False 129 | ) 130 | y = self.data['en']['train']['y'][idx] 131 | bs = len(len1) 132 | 133 | # cuda 134 | x, y, lengths, positions, langs = to_cuda(x, y, lengths, positions, langs) 135 | 136 | # loss 137 | output = self.proj(self.embedder.get_embeddings(x, lengths, positions, langs)) 138 | loss = F.cross_entropy(output, y) 139 | 140 | # backward / optimization 141 | self.optimizer_e.zero_grad() 142 | self.optimizer_p.zero_grad() 143 | loss.backward() 144 | self.optimizer_e.step() 145 | self.optimizer_p.step() 146 | 147 | # update statistics 148 | ns += bs 149 | nw += lengths.sum().item() 150 | losses.append(loss.item()) 151 | 152 | # log 153 | if ns % (100 * bs) < bs: 154 | logger.info("XNLI - Epoch %i - Train iter %7i - %.1f words/s - Loss: %.4f" % (self.epoch, ns, nw / (time.time() - t), sum(losses) / len(losses))) 155 | nw, t = 0, time.time() 156 | losses = [] 157 | 158 | # epoch size 159 | if params.epoch_size != -1 and ns >= params.epoch_size: 160 | break 161 | 162 | def eval(self): 163 | """ 164 | Evaluate on XNLI validation and test sets, for all languages. 165 | """ 166 | params = self.params 167 | self.embedder.eval() 168 | self.proj.eval() 169 | 170 | scores = OrderedDict({'epoch': self.epoch}) 171 | 172 | for splt in ['valid', 'test']: 173 | 174 | for lang in XNLI_LANGS: 175 | if lang not in params.lang2id: 176 | continue 177 | 178 | lang_id = params.lang2id[lang] 179 | valid = 0 180 | total = 0 181 | 182 | for batch in self.get_iterator(splt, lang): 183 | 184 | # batch 185 | (sent1, len1), (sent2, len2), idx = batch 186 | x, lengths, positions, langs = concat_batches( 187 | sent1, len1, lang_id, 188 | sent2, len2, lang_id, 189 | params.pad_index, 190 | params.eos_index, 191 | reset_positions=False 192 | ) 193 | y = self.data[lang][splt]['y'][idx] 194 | 195 | # cuda 196 | x, y, lengths, positions, langs = to_cuda(x, y, lengths, positions, langs) 197 | 198 | # forward 199 | output = self.proj(self.embedder.get_embeddings(x, lengths, positions, langs)) 200 | predictions = output.data.max(1)[1] 201 | 202 | # update statistics 203 | valid += predictions.eq(y).sum().item() 204 | total += len(len1) 205 | 206 | # compute accuracy 207 | acc = 100.0 * valid / total 208 | scores['xnli_%s_%s_acc' % (splt, lang)] = acc 209 | logger.info("XNLI - %s - %s - Epoch %i - Acc: %.1f%%" % (splt, lang, self.epoch, acc)) 210 | 211 | logger.info("__log__:%s" % json.dumps(scores)) 212 | return scores 213 | 214 | def load_data(self): 215 | """ 216 | Load XNLI cross-lingual classification data. 217 | """ 218 | params = self.params 219 | data = {lang: {splt: {} for splt in ['train', 'valid', 'test']} for lang in XNLI_LANGS} 220 | label2id = {'contradiction': 0, 'neutral': 1, 'entailment': 2} 221 | dpath = os.path.join(params.data_path, 'eval', 'XNLI') 222 | 223 | for splt in ['train', 'valid', 'test']: 224 | 225 | for lang in XNLI_LANGS: 226 | 227 | # only English has a training set 228 | if splt == 'train' and lang != 'en': 229 | del data[lang]['train'] 230 | continue 231 | 232 | # load data and dictionary 233 | data1 = load_binarized(os.path.join(dpath, '%s.s1.%s.pth' % (splt, lang)), params) 234 | data2 = load_binarized(os.path.join(dpath, '%s.s2.%s.pth' % (splt, lang)), params) 235 | data['dico'] = data.get('dico', data1['dico']) 236 | 237 | # set dictionary parameters 238 | set_dico_parameters(params, data, data1['dico']) 239 | set_dico_parameters(params, data, data2['dico']) 240 | 241 | # create dataset 242 | data[lang][splt]['x'] = ParallelDataset( 243 | data1['sentences'], data1['positions'], 244 | data2['sentences'], data2['positions'], 245 | params 246 | ) 247 | 248 | # load labels 249 | with open(os.path.join(dpath, '%s.label.%s' % (splt, lang)), 'r') as f: 250 | labels = [label2id[l.rstrip()] for l in f] 251 | data[lang][splt]['y'] = torch.LongTensor(labels) 252 | assert len(data[lang][splt]['x']) == len(data[lang][splt]['y']) 253 | 254 | return data 255 | -------------------------------------------------------------------------------- /xlm/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import logging 9 | import time 10 | from datetime import timedelta 11 | 12 | 13 | class LogFormatter(): 14 | 15 | def __init__(self): 16 | self.start_time = time.time() 17 | 18 | def format(self, record): 19 | elapsed_seconds = round(record.created - self.start_time) 20 | 21 | prefix = "%s - %s - %s" % ( 22 | record.levelname, 23 | time.strftime('%x %X'), 24 | timedelta(seconds=elapsed_seconds) 25 | ) 26 | message = record.getMessage() 27 | message = message.replace('\n', '\n' + ' ' * (len(prefix) + 3)) 28 | return "%s - %s" % (prefix, message) if message else '' 29 | 30 | 31 | def create_logger(filepath, rank): 32 | """ 33 | Create a logger. 34 | Use a different log file for each process. 35 | """ 36 | # create log formatter 37 | log_formatter = LogFormatter() 38 | 39 | # create file handler and set level to debug 40 | if filepath is not None: 41 | if rank > 0: 42 | filepath = '%s-%i' % (filepath, rank) 43 | file_handler = logging.FileHandler(filepath, "a") 44 | file_handler.setLevel(logging.DEBUG) 45 | file_handler.setFormatter(log_formatter) 46 | 47 | # create console handler and set level to info 48 | console_handler = logging.StreamHandler() 49 | console_handler.setLevel(logging.INFO) 50 | console_handler.setFormatter(log_formatter) 51 | 52 | # create logger and set level to debug 53 | logger = logging.getLogger() 54 | logger.handlers = [] 55 | logger.setLevel(logging.DEBUG) 56 | logger.propagate = False 57 | if filepath is not None: 58 | logger.addHandler(file_handler) 59 | logger.addHandler(console_handler) 60 | 61 | # reset logger elapsed time 62 | def reset_time(): 63 | log_formatter.start_time = time.time() 64 | logger.reset_time = reset_time 65 | 66 | return logger 67 | -------------------------------------------------------------------------------- /xlm/model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from logging import getLogger 9 | import os 10 | import torch 11 | 12 | from .pretrain import load_embeddings 13 | from .transformer import DECODER_ONLY_PARAMS, TransformerModel # , TRANSFORMER_LAYER_PARAMS 14 | from .memory import HashingMemory 15 | 16 | 17 | logger = getLogger() 18 | 19 | 20 | def check_model_params(params): 21 | """ 22 | Check models parameters. 23 | """ 24 | # masked language modeling task parameters 25 | assert params.bptt >= 1 26 | assert 0 <= params.word_pred < 1 27 | assert 0 <= params.sample_alpha < 1 28 | s = params.word_mask_keep_rand.split(',') 29 | assert len(s) == 3 30 | s = [float(x) for x in s] 31 | assert all([0 <= x <= 1 for x in s]) and sum(s) == 1 32 | params.word_mask = s[0] 33 | params.word_keep = s[1] 34 | params.word_rand = s[2] 35 | 36 | # input sentence noise for DAE 37 | if len(params.ae_steps) == 0: 38 | assert params.word_shuffle == 0 39 | assert params.word_dropout == 0 40 | assert params.word_blank == 0 41 | else: 42 | assert params.word_shuffle == 0 or params.word_shuffle > 1 43 | assert 0 <= params.word_dropout < 1 44 | assert 0 <= params.word_blank < 1 45 | 46 | # model dimensions 47 | assert params.emb_dim % params.n_heads == 0 48 | 49 | # share input and output embeddings 50 | assert params.share_inout_emb is False or params.asm is False 51 | 52 | # adaptive softmax 53 | if params.asm: 54 | assert params.asm_div_value > 1 55 | s = params.asm_cutoffs.split(',') 56 | assert all([x.isdigit() for x in s]) 57 | params.asm_cutoffs = [int(x) for x in s] 58 | assert params.max_vocab == -1 or params.asm_cutoffs[-1] < params.max_vocab 59 | 60 | # memory 61 | if params.use_memory: 62 | HashingMemory.check_params(params) 63 | s_enc = [x for x in params.mem_enc_positions.split(',') if x != ''] 64 | s_dec = [x for x in params.mem_dec_positions.split(',') if x != ''] 65 | assert len(s_enc) == len(set(s_enc)) 66 | assert len(s_dec) == len(set(s_dec)) 67 | assert all(x.isdigit() or x[-1] == '+' and x[:-1].isdigit() for x in s_enc) 68 | assert all(x.isdigit() or x[-1] == '+' and x[:-1].isdigit() for x in s_dec) 69 | params.mem_enc_positions = [(int(x[:-1]), 'after') if x[-1] == '+' else (int(x), 'in') for x in s_enc] 70 | params.mem_dec_positions = [(int(x[:-1]), 'after') if x[-1] == '+' else (int(x), 'in') for x in s_dec] 71 | assert len(params.mem_enc_positions) + len(params.mem_dec_positions) > 0 72 | assert len(params.mem_enc_positions) == 0 or 0 <= min([x[0] for x in params.mem_enc_positions]) <= max([x[0] for x in params.mem_enc_positions]) <= params.n_layers - 1 73 | assert len(params.mem_dec_positions) == 0 or 0 <= min([x[0] for x in params.mem_dec_positions]) <= max([x[0] for x in params.mem_dec_positions]) <= params.n_layers - 1 74 | 75 | # reload pretrained word embeddings 76 | if params.reload_emb != '': 77 | assert os.path.isfile(params.reload_emb) 78 | 79 | # reload a pretrained model 80 | if params.reload_model != '': 81 | if params.encoder_only: 82 | assert os.path.isfile(params.reload_model) 83 | else: 84 | s = params.reload_model.split(',') 85 | assert len(s) == 2 86 | assert all([x == '' or os.path.isfile(x) for x in s]) 87 | 88 | 89 | def set_pretrain_emb(model, dico, word2id, embeddings): 90 | """ 91 | Pretrain word embeddings. 92 | """ 93 | n_found = 0 94 | with torch.no_grad(): 95 | for i in range(len(dico)): 96 | idx = word2id.get(dico[i], None) 97 | if idx is None: 98 | continue 99 | n_found += 1 100 | model.embeddings.weight[i] = embeddings[idx].cuda() 101 | model.pred_layer.proj.weight[i] = embeddings[idx].cuda() 102 | logger.info("Pretrained %i/%i words (%.3f%%)." 103 | % (n_found, len(dico), 100. * n_found / len(dico))) 104 | 105 | 106 | def build_model(params, dico): 107 | """ 108 | Build model. 109 | """ 110 | if params.encoder_only: 111 | # build 112 | model = TransformerModel(params, dico, is_encoder=True, with_output=True) 113 | 114 | # reload pretrained word embeddings 115 | if params.reload_emb != '': 116 | word2id, embeddings = load_embeddings(params.reload_emb, params) 117 | set_pretrain_emb(model, dico, word2id, embeddings) 118 | 119 | # reload a pretrained model 120 | if params.reload_model != '': 121 | logger.info("Reloading model from %s ..." % params.reload_model) 122 | reloaded = torch.load(params.reload_model, map_location=lambda storage, loc: storage.cuda(params.local_rank))['model'] 123 | if all([k.startswith('module.') for k in reloaded.keys()]): 124 | reloaded = {k[len('module.'):]: v for k, v in reloaded.items()} 125 | 126 | # # HACK to reload models with less layers 127 | # for i in range(12, 24): 128 | # for k in TRANSFORMER_LAYER_PARAMS: 129 | # k = k % i 130 | # if k in model.state_dict() and k not in reloaded: 131 | # logger.warning("Parameter %s not found. Ignoring ..." % k) 132 | # reloaded[k] = model.state_dict()[k] 133 | 134 | model.load_state_dict(reloaded) 135 | 136 | logger.info("Model: {}".format(model)) 137 | logger.info("Number of parameters (model): %i" % sum([p.numel() for p in model.parameters() if p.requires_grad])) 138 | 139 | return model.cuda() 140 | 141 | else: 142 | # build 143 | encoder = TransformerModel(params, dico, is_encoder=True, with_output=True) # TODO: only output when necessary - len(params.clm_steps + params.mlm_steps) > 0 144 | decoder = TransformerModel(params, dico, is_encoder=False, with_output=True) 145 | 146 | # reload pretrained word embeddings 147 | if params.reload_emb != '': 148 | word2id, embeddings = load_embeddings(params.reload_emb, params) 149 | set_pretrain_emb(encoder, dico, word2id, embeddings) 150 | set_pretrain_emb(decoder, dico, word2id, embeddings) 151 | 152 | # reload a pretrained model 153 | if params.reload_model != '': 154 | enc_path, dec_path = params.reload_model.split(',') 155 | assert not (enc_path == '' and dec_path == '') 156 | 157 | # reload encoder 158 | if enc_path != '': 159 | logger.info("Reloading encoder from %s ..." % enc_path) 160 | enc_reload = torch.load(enc_path, map_location=lambda storage, loc: storage.cuda(params.local_rank)) 161 | enc_reload = enc_reload['model' if 'model' in enc_reload else 'encoder'] 162 | if all([k.startswith('module.') for k in enc_reload.keys()]): 163 | enc_reload = {k[len('module.'):]: v for k, v in enc_reload.items()} 164 | encoder.load_state_dict(enc_reload) 165 | 166 | # reload decoder 167 | if dec_path != '': 168 | logger.info("Reloading decoder from %s ..." % dec_path) 169 | dec_reload = torch.load(dec_path, map_location=lambda storage, loc: storage.cuda(params.local_rank)) 170 | dec_reload = dec_reload['model' if 'model' in dec_reload else 'decoder'] 171 | if all([k.startswith('module.') for k in dec_reload.keys()]): 172 | dec_reload = {k[len('module.'):]: v for k, v in dec_reload.items()} 173 | for i in range(params.n_layers): 174 | for name in DECODER_ONLY_PARAMS: 175 | if name % i not in dec_reload: 176 | logger.warning("Parameter %s not found." % (name % i)) 177 | dec_reload[name % i] = decoder.state_dict()[name % i] 178 | decoder.load_state_dict(dec_reload) 179 | 180 | logger.debug("Encoder: {}".format(encoder)) 181 | logger.debug("Decoder: {}".format(decoder)) 182 | logger.info("Number of parameters (encoder): %i" % sum([p.numel() for p in encoder.parameters() if p.requires_grad])) 183 | logger.info("Number of parameters (decoder): %i" % sum([p.numel() for p in decoder.parameters() if p.requires_grad])) 184 | 185 | return encoder.cuda(), decoder.cuda() 186 | -------------------------------------------------------------------------------- /xlm/model/embedder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from logging import getLogger 9 | import torch 10 | 11 | from .transformer import TransformerModel 12 | from ..data.dictionary import Dictionary, BOS_WORD, EOS_WORD, PAD_WORD, UNK_WORD, MASK_WORD 13 | from ..utils import AttrDict 14 | 15 | 16 | logger = getLogger() 17 | 18 | 19 | class SentenceEmbedder(object): 20 | 21 | @staticmethod 22 | def reload(path, params): 23 | """ 24 | Create a sentence embedder from a pretrained model. 25 | """ 26 | # reload model 27 | reloaded = torch.load(path) 28 | state_dict = reloaded['model'] 29 | 30 | # handle models from multi-GPU checkpoints 31 | if 'checkpoint' in path: 32 | state_dict = {(k[7:] if k.startswith('module.') else k): v for k, v in state_dict.items()} 33 | 34 | # reload dictionary and model parameters 35 | dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts']) 36 | pretrain_params = AttrDict(reloaded['params']) 37 | pretrain_params.n_words = len(dico) 38 | pretrain_params.bos_index = dico.index(BOS_WORD) 39 | pretrain_params.eos_index = dico.index(EOS_WORD) 40 | pretrain_params.pad_index = dico.index(PAD_WORD) 41 | pretrain_params.unk_index = dico.index(UNK_WORD) 42 | pretrain_params.mask_index = dico.index(MASK_WORD) 43 | 44 | # build model and reload weights 45 | model = TransformerModel(pretrain_params, dico, True, True) 46 | model.load_state_dict(state_dict) 47 | model.eval() 48 | 49 | # adding missing parameters 50 | params.max_batch_size = 0 51 | 52 | return SentenceEmbedder(model, dico, pretrain_params) 53 | 54 | def __init__(self, model, dico, pretrain_params): 55 | """ 56 | Wrapper on top of the different sentence embedders. 57 | Returns sequence-wise or single-vector sentence representations. 58 | """ 59 | self.pretrain_params = {k: v for k, v in pretrain_params.__dict__.items()} 60 | self.model = model 61 | self.dico = dico 62 | self.n_layers = model.n_layers 63 | self.out_dim = model.dim 64 | self.n_words = model.n_words 65 | 66 | def train(self): 67 | self.model.train() 68 | 69 | def eval(self): 70 | self.model.eval() 71 | 72 | def cuda(self): 73 | self.model.cuda() 74 | 75 | def get_parameters(self, layer_range): 76 | 77 | s = layer_range.split(':') 78 | assert len(s) == 2 79 | i, j = int(s[0].replace('_', '-')), int(s[1].replace('_', '-')) 80 | 81 | # negative indexing 82 | i = self.n_layers + i + 1 if i < 0 else i 83 | j = self.n_layers + j + 1 if j < 0 else j 84 | 85 | # sanity check 86 | assert 0 <= i <= self.n_layers 87 | assert 0 <= j <= self.n_layers 88 | 89 | if i > j: 90 | return [] 91 | 92 | parameters = [] 93 | 94 | # embeddings 95 | if i == 0: 96 | # embeddings 97 | parameters += self.model.embeddings.parameters() 98 | logger.info("Adding embedding parameters to optimizer") 99 | # positional embeddings 100 | if self.pretrain_params['sinusoidal_embeddings'] is False: 101 | parameters += self.model.position_embeddings.parameters() 102 | logger.info("Adding positional embedding parameters to optimizer") 103 | # language embeddings 104 | if hasattr(self.model, 'lang_embeddings'): 105 | parameters += self.model.lang_embeddings.parameters() 106 | logger.info("Adding language embedding parameters to optimizer") 107 | parameters += self.model.layer_norm_emb.parameters() 108 | # layers 109 | for l in range(max(i - 1, 0), j): 110 | parameters += self.model.attentions[l].parameters() 111 | parameters += self.model.layer_norm1[l].parameters() 112 | parameters += self.model.ffns[l].parameters() 113 | parameters += self.model.layer_norm2[l].parameters() 114 | logger.info("Adding layer-%s parameters to optimizer" % (l + 1)) 115 | 116 | logger.info("Optimizing on %i Transformer elements." % sum([p.nelement() for p in parameters])) 117 | 118 | return parameters 119 | 120 | def get_embeddings(self, x, lengths, positions=None, langs=None): 121 | """ 122 | Inputs: 123 | `x` : LongTensor of shape (slen, bs) 124 | `lengths` : LongTensor of shape (bs,) 125 | Outputs: 126 | `sent_emb` : FloatTensor of shape (bs, out_dim) 127 | With out_dim == emb_dim 128 | """ 129 | slen, bs = x.size() 130 | assert lengths.size(0) == bs and lengths.max().item() == slen 131 | 132 | # get transformer last hidden layer 133 | tensor = self.model('fwd', x=x, lengths=lengths, positions=positions, langs=langs, causal=False) 134 | assert tensor.size() == (slen, bs, self.out_dim) 135 | 136 | # single-vector sentence representation (first column of last layer) 137 | return tensor[0] 138 | -------------------------------------------------------------------------------- /xlm/model/memory/__init__.py: -------------------------------------------------------------------------------- 1 | from .memory import HashingMemory 2 | -------------------------------------------------------------------------------- /xlm/model/memory/query.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from .utils import get_slices 5 | 6 | 7 | def mlp(sizes, bias=True, batchnorm=True, groups=1): 8 | """ 9 | Generate a feedforward neural network. 10 | """ 11 | assert len(sizes) >= 2 12 | pairs = [(sizes[i], sizes[i + 1]) for i in range(len(sizes) - 1)] 13 | layers = [] 14 | 15 | for i, (dim_in, dim_out) in enumerate(pairs): 16 | if groups == 1 or i == 0: 17 | layers.append(nn.Linear(dim_in, groups * dim_out, bias=bias)) 18 | else: 19 | layers.append(GroupedLinear(groups * dim_in, groups * dim_out, bias=bias, groups=groups)) 20 | if batchnorm: 21 | layers.append(nn.BatchNorm1d(groups * dim_out)) 22 | if i < len(pairs) - 1: 23 | layers.append(nn.ReLU()) 24 | 25 | return nn.Sequential(*layers) 26 | 27 | 28 | def convs(channel_sizes, kernel_sizes, bias=True, batchnorm=True, residual=False, groups=1): 29 | """ 30 | Generate a convolutional neural network. 31 | """ 32 | assert len(channel_sizes) >= 2 33 | assert len(channel_sizes) == len(kernel_sizes) + 1 34 | pairs = [(channel_sizes[i], channel_sizes[i + 1]) for i in range(len(channel_sizes) - 1)] 35 | layers = [] 36 | 37 | for i, (dim_in, dim_out) in enumerate(pairs): 38 | ks = (kernel_sizes[i], kernel_sizes[i]) 39 | in_group = 1 if i == 0 else groups 40 | _dim_in = dim_in * in_group 41 | _dim_out = dim_out * groups 42 | if not residual: 43 | layers.append(nn.Conv2d(_dim_in, _dim_out, ks, padding=[k // 2 for k in ks], bias=bias, groups=in_group)) 44 | if batchnorm: 45 | layers.append(nn.BatchNorm2d(_dim_out)) 46 | if i < len(pairs) - 1: 47 | layers.append(nn.ReLU()) 48 | else: 49 | layers.append(BottleneckResidualConv2d( 50 | _dim_in, _dim_out, ks, bias=bias, 51 | batchnorm=batchnorm, groups=in_group 52 | )) 53 | if i == len(pairs) - 1: 54 | layers.append(nn.Conv2d(_dim_out, _dim_out, (1, 1), bias=bias)) 55 | 56 | return nn.Sequential(*layers) 57 | 58 | 59 | class GroupedLinear(nn.Module): 60 | 61 | def __init__(self, in_features, out_features, bias=True, groups=1): 62 | 63 | super().__init__() 64 | self.in_features = in_features 65 | self.out_features = out_features 66 | self.groups = groups 67 | self.bias = bias 68 | assert groups > 1 69 | 70 | self.layer = nn.Conv1d(in_features, out_features, bias=bias, kernel_size=1, groups=groups) 71 | 72 | def forward(self, input): 73 | assert input.dim() == 2 and input.size(1) == self.in_features 74 | return self.layer(input.unsqueeze(2)).squeeze(2) 75 | 76 | def extra_repr(self): 77 | return 'in_features={}, out_features={}, groups={}, bias={}'.format( 78 | self.in_features, self.out_features, self.groups, self.bias is not None 79 | ) 80 | 81 | 82 | class BottleneckResidualConv2d(nn.Module): 83 | 84 | def __init__(self, input_channels, output_channels, kernel_size, bias=True, batchnorm=True, groups=1): 85 | 86 | super().__init__() 87 | hidden_channels = min(input_channels, output_channels) 88 | assert all(k % 2 == 1 for k in kernel_size) 89 | 90 | self.conv1 = nn.Conv2d(input_channels, hidden_channels, kernel_size, padding=[k // 2 for k in kernel_size], bias=bias, groups=groups) 91 | self.conv2 = nn.Conv2d(hidden_channels, output_channels, kernel_size, padding=[k // 2 for k in kernel_size], bias=bias, groups=groups) 92 | self.act = nn.ReLU() 93 | 94 | self.batchnorm = batchnorm 95 | if self.batchnorm: 96 | self.bn1 = nn.BatchNorm2d(hidden_channels) 97 | self.bn2 = nn.BatchNorm2d(output_channels) 98 | 99 | if input_channels == output_channels: 100 | self.residual = nn.Sequential() 101 | else: 102 | self.residual = nn.Conv2d(input_channels, output_channels, (1, 1), bias=False, groups=groups) 103 | 104 | def forward(self, input): 105 | x = self.conv1(input) 106 | x = self.bn1(x) if self.batchnorm else x 107 | x = self.act(x) 108 | x = self.conv2(x) 109 | x = self.bn2(x) if self.batchnorm else x 110 | x = self.act(x + self.residual(input)) 111 | return x 112 | 113 | 114 | class QueryIdentity(nn.Module): 115 | 116 | def __init__(self, input_dim, heads, shuffle_hidden): 117 | super().__init__() 118 | self.input_dim = input_dim 119 | self.heads = heads 120 | self.shuffle_query = shuffle_hidden 121 | assert shuffle_hidden is False or heads > 1 122 | assert shuffle_hidden is False or self.input_dim % (2 ** self.heads) == 0 123 | if shuffle_hidden: 124 | self.slices = {head_id: get_slices(input_dim, head_id) for head_id in range(heads)} 125 | 126 | def forward(self, input): 127 | """ 128 | Generate queries from hidden states by either 129 | repeating them or creating some shuffled version. 130 | """ 131 | assert input.shape[-1] == self.input_dim 132 | input = input.contiguous().view(-1, self.input_dim) if input.dim() > 2 else input 133 | bs = len(input) 134 | 135 | if self.heads == 1: 136 | query = input 137 | 138 | elif not self.shuffle_query: 139 | query = input.unsqueeze(1).repeat(1, self.heads, 1) 140 | query = query.view(bs * self.heads, self.input_dim) 141 | 142 | else: 143 | query = torch.cat([ 144 | input[:, a:b] 145 | for head_id in range(self.heads) 146 | for a, b in self.slices[head_id] 147 | ], 1).view(bs * self.heads, self.input_dim) 148 | 149 | assert query.shape == (bs * self.heads, self.input_dim) 150 | return query 151 | 152 | 153 | class QueryMLP(nn.Module): 154 | 155 | def __init__( 156 | self, input_dim, heads, k_dim, product_quantization, multi_query_net, 157 | sizes, bias=True, batchnorm=True, grouped_conv=False 158 | ): 159 | super().__init__() 160 | self.input_dim = input_dim 161 | self.heads = heads 162 | self.k_dim = k_dim 163 | self.sizes = sizes 164 | self.grouped_conv = grouped_conv 165 | assert not multi_query_net or product_quantization or heads >= 2 166 | assert sizes[0] == input_dim 167 | assert sizes[-1] == (k_dim // 2) if multi_query_net else (heads * k_dim) 168 | assert self.grouped_conv is False or len(sizes) > 2 169 | 170 | # number of required MLPs 171 | self.groups = (2 * heads) if multi_query_net else 1 172 | 173 | # MLPs 174 | if self.grouped_conv: 175 | self.query_mlps = mlp(sizes, bias=bias, batchnorm=batchnorm, groups=self.groups) 176 | elif len(self.sizes) == 2: 177 | sizes_ = list(sizes) 178 | sizes_[-1] = sizes_[-1] * self.groups 179 | self.query_mlps = mlp(sizes_, bias=bias, batchnorm=batchnorm, groups=1) 180 | else: 181 | self.query_mlps = nn.ModuleList([ 182 | mlp(sizes, bias=bias, batchnorm=batchnorm, groups=1) 183 | for _ in range(self.groups) 184 | ]) 185 | 186 | def forward(self, input): 187 | """ 188 | Compute queries using either grouped 1D convolutions or ModuleList + concat. 189 | """ 190 | assert input.shape[-1] == self.input_dim 191 | input = input.contiguous().view(-1, self.input_dim) if input.dim() > 2 else input 192 | bs = len(input) 193 | 194 | if self.grouped_conv or len(self.sizes) == 2: 195 | query = self.query_mlps(input) 196 | else: 197 | outputs = [m(input) for m in self.query_mlps] 198 | query = torch.cat(outputs, 1) if len(outputs) > 1 else outputs[0] 199 | 200 | assert query.shape == (bs, self.heads * self.k_dim) 201 | return query.view(bs * self.heads, self.k_dim) 202 | 203 | 204 | class QueryConv(nn.Module): 205 | 206 | def __init__( 207 | self, input_dim, heads, k_dim, product_quantization, multi_query_net, 208 | sizes, kernel_sizes, bias=True, batchnorm=True, 209 | residual=False, grouped_conv=False 210 | ): 211 | super().__init__() 212 | self.input_dim = input_dim 213 | self.heads = heads 214 | self.k_dim = k_dim 215 | self.sizes = sizes 216 | self.grouped_conv = grouped_conv 217 | assert not multi_query_net or product_quantization or heads >= 2 218 | assert sizes[0] == input_dim 219 | assert sizes[-1] == (k_dim // 2) if multi_query_net else (heads * k_dim) 220 | assert self.grouped_conv is False or len(sizes) > 2 221 | assert len(sizes) == len(kernel_sizes) + 1 >= 2 and all(ks % 2 == 1 for ks in kernel_sizes) 222 | 223 | # number of required CNNs 224 | self.groups = (2 * heads) if multi_query_net else 1 225 | 226 | # CNNs 227 | if self.grouped_conv: 228 | self.query_convs = convs(sizes, kernel_sizes, bias=bias, batchnorm=batchnorm, residual=residual, groups=self.groups) 229 | elif len(self.sizes) == 2: 230 | sizes_ = list(sizes) 231 | sizes_[-1] = sizes_[-1] * self.groups 232 | self.query_convs = convs(sizes_, kernel_sizes, bias=bias, batchnorm=batchnorm, residual=residual, groups=1) 233 | else: 234 | self.query_convs = nn.ModuleList([ 235 | convs(sizes, kernel_sizes, bias=bias, batchnorm=batchnorm, residual=residual, groups=1) 236 | for _ in range(self.groups) 237 | ]) 238 | 239 | def forward(self, input): 240 | 241 | bs, nf, h, w = input.shape 242 | assert nf == self.input_dim 243 | 244 | if self.grouped_conv or len(self.sizes) == 2: 245 | query = self.query_convs(input) 246 | else: 247 | outputs = [m(input) for m in self.query_convs] 248 | query = torch.cat(outputs, 1) if len(outputs) > 1 else outputs[0] 249 | 250 | assert query.shape == (bs, self.heads * self.k_dim, h, w) 251 | query = query.transpose(1, 3).contiguous().view(bs * w * h * self.heads, self.k_dim) 252 | return query 253 | -------------------------------------------------------------------------------- /xlm/model/memory/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import math 3 | import numpy as np 4 | import torch 5 | 6 | 7 | # load FAISS GPU library if available (dramatically accelerates the nearest neighbor search) 8 | try: 9 | import faiss 10 | FAISS_AVAILABLE = hasattr(faiss, 'StandardGpuResources') 11 | except ImportError: 12 | FAISS_AVAILABLE = False 13 | sys.stderr.write("FAISS library was not found.\n") 14 | 15 | 16 | def get_gaussian_keys(n_keys, dim, normalized, seed): 17 | """ 18 | Generate random Gaussian keys. 19 | """ 20 | rng = np.random.RandomState(seed) 21 | X = rng.randn(n_keys, dim) 22 | if normalized: 23 | X /= np.linalg.norm(X, axis=1, keepdims=True) 24 | return X.astype(np.float32) 25 | 26 | 27 | def get_uniform_keys(n_keys, dim, normalized, seed): 28 | """ 29 | Generate random uniform keys (same initialization as nn.Linear). 30 | """ 31 | rng = np.random.RandomState(seed) 32 | bound = 1 / math.sqrt(dim) 33 | X = rng.uniform(-bound, bound, (n_keys, dim)) 34 | if normalized: 35 | X /= np.linalg.norm(X, axis=1, keepdims=True) 36 | return X.astype(np.float32) 37 | 38 | 39 | def get_slices(dim, head_id): 40 | """ 41 | Generate slices of hidden dimensions. 42 | Used when there are multiple heads and/or different set of keys, 43 | and that there is no query network. 44 | """ 45 | if head_id == 0: 46 | return [(0, dim)] 47 | offset = dim // (2 ** (head_id + 1)) 48 | starts = np.arange(0, dim, offset) 49 | slices1 = [(x, x + offset) for i, x in enumerate(starts) if i % 2 == 0] 50 | slices2 = [(x, x + offset) for i, x in enumerate(starts) if i % 2 == 1] 51 | return slices1 + slices2 52 | 53 | 54 | def cartesian_product(a, b): 55 | """ 56 | Compute the batched cartesian product between two matrices. 57 | Input: 58 | a: Tensor(n, d1) 59 | b: Tensor(n, d2) 60 | Output: 61 | output: Tensor(n, d1 * d2, 2) 62 | """ 63 | n1, d1 = a.shape 64 | n2, d2 = b.shape 65 | assert n1 == n2 66 | return torch.cat([ 67 | a.unsqueeze(-1).repeat(1, 1, d2).unsqueeze(-1), 68 | b.repeat(1, d1).view(n2, d1, d2).unsqueeze(-1) 69 | ], 3).view(n1, d1 * d2, 2) 70 | 71 | 72 | def swig_ptr_from_FloatTensor(x): 73 | assert x.is_contiguous() 74 | assert x.dtype == torch.float32 75 | return faiss.cast_integer_to_float_ptr(x.storage().data_ptr() + x.storage_offset() * 4) 76 | 77 | 78 | def swig_ptr_from_LongTensor(x): 79 | assert x.is_contiguous() 80 | assert x.dtype == torch.int64, 'dtype=%s' % x.dtype 81 | return faiss.cast_integer_to_long_ptr(x.storage().data_ptr() + x.storage_offset() * 8) 82 | 83 | 84 | def get_knn_pytorch(a, b, k, distance='dot_product'): 85 | """ 86 | Input: 87 | - matrix of size (m, d) (keys) 88 | - matrix of size (n, d) (queries) 89 | - number of nearest neighbors 90 | - distance metric 91 | Output: 92 | - `scores` matrix of size (n, k) with nearest neighors scores 93 | - `indices` matrix of size (n, k) with nearest neighors indices 94 | """ 95 | m, d = a.size() 96 | n, _ = b.size() 97 | assert b.size(1) == d 98 | assert k > 0 99 | assert distance in ['dot_product', 'cosine', 'l2'] 100 | 101 | with torch.no_grad(): 102 | 103 | if distance == 'dot_product': 104 | scores = a.mm(b.t()) # (m, n) 105 | 106 | elif distance == 'cosine': 107 | scores = a.mm(b.t()) # (m, n) 108 | scores /= (a.norm(2, 1)[:, None] + 1e-9) # (m, n) 109 | scores /= (b.norm(2, 1)[None, :] + 1e-9) # (m, n) 110 | 111 | elif distance == 'l2': 112 | scores = a.mm(b.t()) # (m, n) 113 | scores *= 2 # (m, n) 114 | scores -= (a ** 2).sum(1)[:, None] # (m, n) 115 | scores -= (b ** 2).sum(1)[None, :] # (m, n) 116 | 117 | scores, indices = scores.topk(k=k, dim=0, largest=True) # (k, n) 118 | scores = scores.t() # (n, k) 119 | indices = indices.t() # (n, k) 120 | 121 | return scores, indices 122 | 123 | 124 | def get_knn_faiss(xb, xq, k, distance='dot_product'): 125 | """ 126 | `metric` can be faiss.METRIC_INNER_PRODUCT or faiss.METRIC_L2 127 | https://github.com/facebookresearch/faiss/blob/master/gpu/test/test_pytorch_faiss.py 128 | """ 129 | assert xb.device == xq.device 130 | assert distance in ['dot_product', 'l2'] 131 | metric = faiss.METRIC_INNER_PRODUCT if distance == 'dot_product' else faiss.METRIC_L2 132 | 133 | xq_ptr = swig_ptr_from_FloatTensor(xq) 134 | xb_ptr = swig_ptr_from_FloatTensor(xb) 135 | 136 | nq, d1 = xq.size() 137 | nb, d2 = xb.size() 138 | assert d1 == d2 139 | 140 | D = torch.empty(nq, k, device=xb.device, dtype=torch.float32) 141 | I = torch.empty(nq, k, device=xb.device, dtype=torch.int64) 142 | 143 | D_ptr = swig_ptr_from_FloatTensor(D) 144 | I_ptr = swig_ptr_from_LongTensor(I) 145 | 146 | faiss.bruteForceKnn( 147 | FAISS_RES, metric, 148 | xb_ptr, nb, 149 | xq_ptr, nq, 150 | d1, k, D_ptr, I_ptr 151 | ) 152 | 153 | return D, I 154 | 155 | 156 | if FAISS_AVAILABLE: 157 | FAISS_RES = faiss.StandardGpuResources() 158 | FAISS_RES.setDefaultNullStreamAllDevices() 159 | FAISS_RES.setTempMemory(1200 * 1024 * 1024) 160 | get_knn = get_knn_faiss 161 | else: 162 | sys.stderr.write("FAISS not available. Switching to standard nearest neighbors search implementation.\n") 163 | get_knn = get_knn_pytorch 164 | -------------------------------------------------------------------------------- /xlm/model/pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from logging import getLogger 9 | import io 10 | import numpy as np 11 | import torch 12 | 13 | 14 | logger = getLogger() 15 | 16 | 17 | def load_fasttext_model(path): 18 | """ 19 | Load a binarized fastText model. 20 | """ 21 | try: 22 | import fastText 23 | except ImportError: 24 | raise Exception("Unable to import fastText. Please install fastText for Python: " 25 | "https://github.com/facebookresearch/fastText") 26 | return fastText.load_model(path) 27 | 28 | 29 | def read_txt_embeddings(path, params): 30 | """ 31 | Reload pretrained embeddings from a text file. 32 | """ 33 | word2id = {} 34 | vectors = [] 35 | 36 | # load pretrained embeddings 37 | _emb_dim_file = params.emb_dim 38 | with io.open(path, 'r', encoding='utf-8', newline='\n', errors='ignore') as f: 39 | for i, line in enumerate(f): 40 | if i == 0: 41 | split = line.split() 42 | assert len(split) == 2 43 | assert _emb_dim_file == int(split[1]) 44 | continue 45 | word, vect = line.rstrip().split(' ', 1) 46 | vect = np.fromstring(vect, sep=' ') 47 | if word in word2id: 48 | logger.warning("Word \"%s\" found twice!" % word) 49 | continue 50 | if not vect.shape == (_emb_dim_file,): 51 | logger.warning("Invalid dimension (%i) for word \"%s\" in line %i." 52 | % (vect.shape[0], word, i)) 53 | continue 54 | assert vect.shape == (_emb_dim_file,) 55 | word2id[word] = len(word2id) 56 | vectors.append(vect[None]) 57 | 58 | assert len(word2id) == len(vectors) 59 | logger.info("Loaded %i pretrained word embeddings from %s" % (len(vectors), path)) 60 | 61 | # compute new vocabulary / embeddings 62 | embeddings = np.concatenate(vectors, 0) 63 | embeddings = torch.from_numpy(embeddings).float() 64 | 65 | assert embeddings.size() == (len(word2id), params.emb_dim) 66 | return word2id, embeddings 67 | 68 | 69 | def load_bin_embeddings(path, params): 70 | """ 71 | Reload pretrained embeddings from a fastText binary file. 72 | """ 73 | model = load_fasttext_model(path) 74 | assert model.get_dimension() == params.emb_dim 75 | words = model.get_labels() 76 | logger.info("Loaded binary model from %s" % path) 77 | 78 | # compute new vocabulary / embeddings 79 | embeddings = np.concatenate([model.get_word_vector(w)[None] for w in words], 0) 80 | embeddings = torch.from_numpy(embeddings).float() 81 | word2id = {w: i for i, w in enumerate(words)} 82 | logger.info("Generated embeddings for %i words." % len(words)) 83 | 84 | assert embeddings.size() == (len(word2id), params.emb_dim) 85 | return word2id, embeddings 86 | 87 | 88 | def load_embeddings(path, params): 89 | """ 90 | Reload pretrained embeddings. 91 | """ 92 | if path.endswith('.bin'): 93 | return load_bin_embeddings(path, params) 94 | else: 95 | return read_txt_embeddings(path, params) 96 | -------------------------------------------------------------------------------- /xlm/optim.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import re 9 | import math 10 | import inspect 11 | 12 | import torch 13 | from torch import optim 14 | 15 | 16 | class Adam(optim.Optimizer): 17 | """ 18 | Same as https://github.com/pytorch/pytorch/blob/master/torch/optim/adam.py, 19 | without amsgrad, with step in a tensor, and states initialization in __init__. 20 | It was important to add `.item()` in `state['step'].item()`. 21 | """ 22 | 23 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 24 | if not 0.0 <= lr: 25 | raise ValueError("Invalid learning rate: {}".format(lr)) 26 | if not 0.0 <= eps: 27 | raise ValueError("Invalid epsilon value: {}".format(eps)) 28 | if not 0.0 <= betas[0] < 1.0: 29 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 30 | if not 0.0 <= betas[1] < 1.0: 31 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 32 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 33 | super().__init__(params, defaults) 34 | 35 | for group in self.param_groups: 36 | for p in group['params']: 37 | state = self.state[p] 38 | state['step'] = 0 # torch.zeros(1) 39 | state['exp_avg'] = torch.zeros_like(p.data) 40 | state['exp_avg_sq'] = torch.zeros_like(p.data) 41 | 42 | def __setstate__(self, state): 43 | super().__setstate__(state) 44 | 45 | def step(self, closure=None): 46 | """ 47 | Step. 48 | """ 49 | loss = None 50 | if closure is not None: 51 | loss = closure() 52 | 53 | for group in self.param_groups: 54 | for p in group['params']: 55 | if p.grad is None: 56 | continue 57 | grad = p.grad.data 58 | if grad.is_sparse: 59 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 60 | 61 | state = self.state[p] 62 | 63 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 64 | beta1, beta2 = group['betas'] 65 | 66 | state['step'] += 1 67 | 68 | # if group['weight_decay'] != 0: 69 | # grad.add_(group['weight_decay'], p.data) 70 | 71 | # Decay the first and second moment running average coefficient 72 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 73 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 74 | denom = exp_avg_sq.sqrt().add_(group['eps']) 75 | # denom = exp_avg_sq.sqrt().clamp_(min=group['eps']) 76 | 77 | bias_correction1 = 1 - beta1 ** state['step'] # .item() 78 | bias_correction2 = 1 - beta2 ** state['step'] # .item() 79 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 80 | 81 | if group['weight_decay'] != 0: 82 | p.data.add_(-group['weight_decay'] * group['lr'], p.data) 83 | 84 | p.data.addcdiv_(-step_size, exp_avg, denom) 85 | 86 | return loss 87 | 88 | 89 | class AdamInverseSqrtWithWarmup(Adam): 90 | """ 91 | Decay the LR based on the inverse square root of the update number. 92 | We also support a warmup phase where we linearly increase the learning rate 93 | from some initial learning rate (`warmup-init-lr`) until the configured 94 | learning rate (`lr`). Thereafter we decay proportional to the number of 95 | updates, with a decay factor set to align with the configured learning rate. 96 | During warmup: 97 | lrs = torch.linspace(warmup_init_lr, lr, warmup_updates) 98 | lr = lrs[update_num] 99 | After warmup: 100 | lr = decay_factor / sqrt(update_num) 101 | where 102 | decay_factor = lr * sqrt(warmup_updates) 103 | """ 104 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 105 | weight_decay=0, warmup_updates=4000, warmup_init_lr=1e-7, 106 | exp_factor=0.5): 107 | super().__init__( 108 | params, 109 | lr=warmup_init_lr, 110 | betas=betas, 111 | eps=eps, 112 | weight_decay=weight_decay, 113 | ) 114 | 115 | # linearly warmup for the first warmup_updates 116 | self.warmup_updates = warmup_updates 117 | self.warmup_init_lr = warmup_init_lr 118 | warmup_end_lr = lr 119 | self.lr_step = (warmup_end_lr - warmup_init_lr) / warmup_updates 120 | 121 | # then, decay prop. to the inverse square root of the update number 122 | self.exp_factor = exp_factor 123 | self.decay_factor = warmup_end_lr * warmup_updates ** self.exp_factor 124 | 125 | # total number of updates 126 | for param_group in self.param_groups: 127 | param_group['num_updates'] = 0 128 | 129 | def get_lr_for_step(self, num_updates): 130 | if num_updates < self.warmup_updates: 131 | return self.warmup_init_lr + num_updates * self.lr_step 132 | else: 133 | return self.decay_factor * (num_updates ** -self.exp_factor) 134 | 135 | def step(self, closure=None): 136 | super().step(closure) 137 | for param_group in self.param_groups: 138 | param_group['num_updates'] += 1 139 | param_group['lr'] = self.get_lr_for_step(param_group['num_updates']) 140 | 141 | 142 | class AdamCosineWithWarmup(Adam): 143 | """ 144 | Assign LR based on a cyclical schedule that follows the cosine function. 145 | See https://arxiv.org/pdf/1608.03983.pdf for details. 146 | We also support a warmup phase where we linearly increase the learning rate 147 | from some initial learning rate (``--warmup-init-lr``) until the configured 148 | learning rate (``--lr``). 149 | During warmup:: 150 | lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates) 151 | lr = lrs[update_num] 152 | After warmup:: 153 | lr = lr_min + 0.5*(lr_max - lr_min)*(1 + cos(t_curr / t_i)) 154 | where ``t_curr`` is current percentage of updates within the current period 155 | range and ``t_i`` is the current period range, which is scaled by ``t_mul`` 156 | after every iteration. 157 | """ 158 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 159 | weight_decay=0, warmup_updates=4000, warmup_init_lr=1e-7, 160 | min_lr=1e-9, init_period=1000000, period_mult=1, lr_shrink=0.75): 161 | super().__init__( 162 | params, 163 | lr=warmup_init_lr, 164 | betas=betas, 165 | eps=eps, 166 | weight_decay=weight_decay, 167 | ) 168 | 169 | # linearly warmup for the first warmup_updates 170 | self.warmup_updates = warmup_updates 171 | self.warmup_init_lr = warmup_init_lr 172 | warmup_end_lr = lr 173 | self.lr_step = (warmup_end_lr - warmup_init_lr) / warmup_updates 174 | 175 | # then, apply cosine scheduler 176 | self.min_lr = min_lr 177 | self.max_lr = lr 178 | self.period = init_period 179 | self.period_mult = period_mult 180 | self.lr_shrink = lr_shrink 181 | 182 | # total number of updates 183 | for param_group in self.param_groups: 184 | param_group['num_updates'] = 0 185 | 186 | def get_lr_for_step(self, num_updates): 187 | if num_updates < self.warmup_updates: 188 | return self.warmup_init_lr + num_updates * self.lr_step 189 | else: 190 | t = num_updates - self.warmup_updates 191 | if self.period_mult == 1: 192 | pid = math.floor(t / self.period) 193 | t_i = self.period 194 | t_curr = t - (self.period * pid) 195 | else: 196 | pid = math.floor(math.log(1 - t / self.period * (1 - self.period_mult), self.period_mult)) 197 | t_i = self.period * (self.period_mult ** pid) 198 | t_curr = t - (1 - self.period_mult ** pid) / (1 - self.period_mult) * self.period 199 | lr_shrink = self.lr_shrink ** pid 200 | min_lr = self.min_lr * lr_shrink 201 | max_lr = self.max_lr * lr_shrink 202 | return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * t_curr / t_i)) 203 | 204 | def step(self, closure=None): 205 | super().step(closure) 206 | for param_group in self.param_groups: 207 | param_group['num_updates'] += 1 208 | param_group['lr'] = self.get_lr_for_step(param_group['num_updates']) 209 | 210 | 211 | def get_optimizer(parameters, s): 212 | """ 213 | Parse optimizer parameters. 214 | Input should be of the form: 215 | - "sgd,lr=0.01" 216 | - "adagrad,lr=0.1,lr_decay=0.05" 217 | """ 218 | if "," in s: 219 | method = s[:s.find(',')] 220 | optim_params = {} 221 | for x in s[s.find(',') + 1:].split(','): 222 | split = x.split('=') 223 | assert len(split) == 2 224 | assert re.match("^[+-]?(\d+(\.\d*)?|\.\d+)$", split[1]) is not None 225 | optim_params[split[0]] = float(split[1]) 226 | else: 227 | method = s 228 | optim_params = {} 229 | 230 | if method == 'adadelta': 231 | optim_fn = optim.Adadelta 232 | elif method == 'adagrad': 233 | optim_fn = optim.Adagrad 234 | elif method == 'adam': 235 | optim_fn = Adam 236 | optim_params['betas'] = (optim_params.get('beta1', 0.9), optim_params.get('beta2', 0.999)) 237 | optim_params.pop('beta1', None) 238 | optim_params.pop('beta2', None) 239 | elif method == 'adam_inverse_sqrt': 240 | optim_fn = AdamInverseSqrtWithWarmup 241 | optim_params['betas'] = (optim_params.get('beta1', 0.9), optim_params.get('beta2', 0.999)) 242 | optim_params.pop('beta1', None) 243 | optim_params.pop('beta2', None) 244 | elif method == 'adam_cosine': 245 | optim_fn = AdamCosineWithWarmup 246 | optim_params['betas'] = (optim_params.get('beta1', 0.9), optim_params.get('beta2', 0.999)) 247 | optim_params.pop('beta1', None) 248 | optim_params.pop('beta2', None) 249 | elif method == 'adamax': 250 | optim_fn = optim.Adamax 251 | elif method == 'asgd': 252 | optim_fn = optim.ASGD 253 | elif method == 'rmsprop': 254 | optim_fn = optim.RMSprop 255 | elif method == 'rprop': 256 | optim_fn = optim.Rprop 257 | elif method == 'sgd': 258 | optim_fn = optim.SGD 259 | assert 'lr' in optim_params 260 | else: 261 | raise Exception('Unknown optimization method: "%s"' % method) 262 | 263 | # check that we give good parameters to the optimizer 264 | expected_args = inspect.getargspec(optim_fn.__init__)[0] 265 | assert expected_args[:2] == ['self', 'params'] 266 | if not all(k in expected_args[2:] for k in optim_params.keys()): 267 | raise Exception('Unexpected parameters: expected "%s", got "%s"' % ( 268 | str(expected_args[2:]), str(optim_params.keys()))) 269 | 270 | return optim_fn(parameters, **optim_params) 271 | -------------------------------------------------------------------------------- /xlm/slurm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from logging import getLogger 9 | import os 10 | import sys 11 | import torch 12 | import socket 13 | import signal 14 | import subprocess 15 | 16 | 17 | logger = getLogger() 18 | 19 | 20 | def sig_handler(signum, frame): 21 | logger.warning("Signal handler called with signal " + str(signum)) 22 | prod_id = int(os.environ['SLURM_PROCID']) 23 | logger.warning("Host: %s - Global rank: %i" % (socket.gethostname(), prod_id)) 24 | if prod_id == 0: 25 | logger.warning("Requeuing job " + os.environ['SLURM_JOB_ID']) 26 | os.system('scontrol requeue ' + os.environ['SLURM_JOB_ID']) 27 | else: 28 | logger.warning("Not the master process, no need to requeue.") 29 | sys.exit(-1) 30 | 31 | 32 | def term_handler(signum, frame): 33 | logger.warning("Signal handler called with signal " + str(signum)) 34 | logger.warning("Bypassing SIGTERM.") 35 | 36 | 37 | def init_signal_handler(): 38 | """ 39 | Handle signals sent by SLURM for time limit / pre-emption. 40 | """ 41 | signal.signal(signal.SIGUSR1, sig_handler) 42 | signal.signal(signal.SIGTERM, term_handler) 43 | logger.warning("Signal handler installed.") 44 | 45 | 46 | def init_distributed_mode(params): 47 | """ 48 | Handle single and multi-GPU / multi-node / SLURM jobs. 49 | Initialize the following variables: 50 | - n_nodes 51 | - node_id 52 | - local_rank 53 | - global_rank 54 | - world_size 55 | """ 56 | params.is_slurm_job = 'SLURM_JOB_ID' in os.environ and not params.debug_slurm 57 | print("SLURM job: %s" % str(params.is_slurm_job)) 58 | 59 | # SLURM job 60 | if params.is_slurm_job: 61 | 62 | assert params.local_rank == -1 # on the cluster, this is handled by SLURM 63 | 64 | SLURM_VARIABLES = [ 65 | 'SLURM_JOB_ID', 66 | 'SLURM_JOB_NODELIST', 'SLURM_JOB_NUM_NODES', 'SLURM_NTASKS', 'SLURM_TASKS_PER_NODE', 67 | 'SLURM_MEM_PER_NODE', 'SLURM_MEM_PER_CPU', 68 | 'SLURM_NODEID', 'SLURM_PROCID', 'SLURM_LOCALID', 'SLURM_TASK_PID' 69 | ] 70 | 71 | PREFIX = "%i - " % int(os.environ['SLURM_PROCID']) 72 | for name in SLURM_VARIABLES: 73 | value = os.environ.get(name, None) 74 | print(PREFIX + "%s: %s" % (name, str(value))) 75 | 76 | # # job ID 77 | # params.job_id = os.environ['SLURM_JOB_ID'] 78 | 79 | # number of nodes / node ID 80 | params.n_nodes = int(os.environ['SLURM_JOB_NUM_NODES']) 81 | params.node_id = int(os.environ['SLURM_NODEID']) 82 | 83 | # local rank on the current node / global rank 84 | params.local_rank = int(os.environ['SLURM_LOCALID']) 85 | params.global_rank = int(os.environ['SLURM_PROCID']) 86 | 87 | # number of processes / GPUs per node 88 | params.world_size = int(os.environ['SLURM_NTASKS']) 89 | params.n_gpu_per_node = params.world_size // params.n_nodes 90 | 91 | # define master address and master port 92 | hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', os.environ['SLURM_JOB_NODELIST']]) 93 | params.master_addr = hostnames.split()[0].decode('utf-8') 94 | assert 10001 <= params.master_port <= 20000 or params.world_size == 1 95 | print(PREFIX + "Master address: %s" % params.master_addr) 96 | print(PREFIX + "Master port : %i" % params.master_port) 97 | 98 | # set environment variables for 'env://' 99 | os.environ['MASTER_ADDR'] = params.master_addr 100 | os.environ['MASTER_PORT'] = str(params.master_port) 101 | os.environ['WORLD_SIZE'] = str(params.world_size) 102 | os.environ['RANK'] = str(params.global_rank) 103 | 104 | # multi-GPU job (local or multi-node) - jobs started with torch.distributed.launch 105 | elif params.local_rank != -1: 106 | 107 | assert params.master_port == -1 108 | 109 | # read environment variables 110 | params.global_rank = int(os.environ['RANK']) 111 | params.world_size = int(os.environ['WORLD_SIZE']) 112 | params.n_gpu_per_node = int(os.environ['NGPU']) 113 | 114 | # number of nodes / node ID 115 | params.n_nodes = params.world_size // params.n_gpu_per_node 116 | params.node_id = params.global_rank // params.n_gpu_per_node 117 | 118 | # local job (single GPU) 119 | else: 120 | assert params.local_rank == -1 121 | assert params.master_port == -1 122 | params.n_nodes = 1 123 | params.node_id = 0 124 | params.local_rank = 0 125 | params.global_rank = 0 126 | params.world_size = 1 127 | params.n_gpu_per_node = 1 128 | 129 | # sanity checks 130 | assert params.n_nodes >= 1 131 | assert 0 <= params.node_id < params.n_nodes 132 | assert 0 <= params.local_rank <= params.global_rank < params.world_size 133 | assert params.world_size == params.n_nodes * params.n_gpu_per_node 134 | 135 | # define whether this is the master process / if we are in distributed mode 136 | params.is_master = params.node_id == 0 and params.local_rank == 0 137 | params.multi_node = params.n_nodes > 1 138 | params.multi_gpu = params.world_size > 1 139 | 140 | # summary 141 | PREFIX = "%i - " % params.global_rank 142 | print(PREFIX + "Number of nodes: %i" % params.n_nodes) 143 | print(PREFIX + "Node ID : %i" % params.node_id) 144 | print(PREFIX + "Local rank : %i" % params.local_rank) 145 | print(PREFIX + "Global rank : %i" % params.global_rank) 146 | print(PREFIX + "World size : %i" % params.world_size) 147 | print(PREFIX + "GPUs per node : %i" % params.n_gpu_per_node) 148 | print(PREFIX + "Master : %s" % str(params.is_master)) 149 | print(PREFIX + "Multi-node : %s" % str(params.multi_node)) 150 | print(PREFIX + "Multi-GPU : %s" % str(params.multi_gpu)) 151 | print(PREFIX + "Hostname : %s" % socket.gethostname()) 152 | 153 | # set GPU device 154 | torch.cuda.set_device(params.local_rank) 155 | 156 | # initialize multi-GPU 157 | if params.multi_gpu: 158 | 159 | # http://pytorch.apachecn.org/en/0.3.0/distributed.html#environment-variable-initialization 160 | # 'env://' will read these environment variables: 161 | # MASTER_PORT - required; has to be a free port on machine with rank 0 162 | # MASTER_ADDR - required (except for rank 0); address of rank 0 node 163 | # WORLD_SIZE - required; can be set either here, or in a call to init function 164 | # RANK - required; can be set either here, or in a call to init function 165 | 166 | print("Initializing PyTorch distributed ...") 167 | torch.distributed.init_process_group( 168 | init_method='env://', 169 | backend='nccl', 170 | ) 171 | -------------------------------------------------------------------------------- /xlm/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import os 9 | import re 10 | import sys 11 | import pickle 12 | import random 13 | import getpass 14 | import argparse 15 | import subprocess 16 | import numpy as np 17 | import torch 18 | 19 | from .logger import create_logger 20 | 21 | 22 | FALSY_STRINGS = {'off', 'false', '0'} 23 | TRUTHY_STRINGS = {'on', 'true', '1'} 24 | 25 | DUMP_PATH = '/checkpoint/%s/dumped' % getpass.getuser() 26 | DYNAMIC_COEFF = ['lambda_clm', 'lambda_mlm', 'lambda_pc', 'lambda_ae', 'lambda_mt', 'lambda_bt'] 27 | 28 | 29 | class AttrDict(dict): 30 | def __init__(self, *args, **kwargs): 31 | super(AttrDict, self).__init__(*args, **kwargs) 32 | self.__dict__ = self 33 | 34 | 35 | def bool_flag(s): 36 | """ 37 | Parse boolean arguments from the command line. 38 | """ 39 | if s.lower() in FALSY_STRINGS: 40 | return False 41 | elif s.lower() in TRUTHY_STRINGS: 42 | return True 43 | else: 44 | raise argparse.ArgumentTypeError("Invalid value for a boolean flag!") 45 | 46 | 47 | def initialize_exp(params): 48 | """ 49 | Initialize the experience: 50 | - dump parameters 51 | - create a logger 52 | """ 53 | # dump parameters 54 | get_dump_path(params) 55 | pickle.dump(params, open(os.path.join(params.dump_path, 'params.pkl'), 'wb')) 56 | 57 | # get running command 58 | command = ["python", sys.argv[0]] 59 | for x in sys.argv[1:]: 60 | if x.startswith('--'): 61 | assert '"' not in x and "'" not in x 62 | command.append(x) 63 | else: 64 | assert "'" not in x 65 | if re.match('^[a-zA-Z0-9_]+$', x): 66 | command.append("%s" % x) 67 | else: 68 | command.append("'%s'" % x) 69 | command = ' '.join(command) 70 | params.command = command + ' --exp_id "%s"' % params.exp_id 71 | 72 | # check experiment name 73 | assert len(params.exp_name.strip()) > 0 74 | 75 | # create a logger 76 | logger = create_logger(os.path.join(params.dump_path, 'train.log'), rank=getattr(params, 'global_rank', 0)) 77 | logger.info("============ Initialized logger ============") 78 | logger.info("\n".join("%s: %s" % (k, str(v)) 79 | for k, v in sorted(dict(vars(params)).items()))) 80 | logger.info("The experiment will be stored in %s\n" % params.dump_path) 81 | logger.info("Running command: %s" % command) 82 | logger.info("") 83 | return logger 84 | 85 | 86 | def get_dump_path(params): 87 | """ 88 | Create a directory to store the experiment. 89 | """ 90 | dump_path = DUMP_PATH if params.dump_path == '' else params.dump_path 91 | assert len(params.exp_name) > 0 92 | 93 | # create the sweep path if it does not exist 94 | sweep_path = os.path.join(dump_path, params.exp_name) 95 | if not os.path.exists(sweep_path): 96 | subprocess.Popen("mkdir -p %s" % sweep_path, shell=True).wait() 97 | 98 | # create an ID for the job if it is not given in the parameters. 99 | # if we run on the cluster, the job ID is the one of Chronos. 100 | # otherwise, it is randomly generated 101 | if params.exp_id == '': 102 | chronos_job_id = os.environ.get('CHRONOS_JOB_ID') 103 | slurm_job_id = os.environ.get('SLURM_JOB_ID') 104 | assert chronos_job_id is None or slurm_job_id is None 105 | exp_id = chronos_job_id if chronos_job_id is not None else slurm_job_id 106 | if exp_id is None: 107 | chars = 'abcdefghijklmnopqrstuvwxyz0123456789' 108 | while True: 109 | exp_id = ''.join(random.choice(chars) for _ in range(10)) 110 | if not os.path.isdir(os.path.join(sweep_path, exp_id)): 111 | break 112 | else: 113 | assert exp_id.isdigit() 114 | params.exp_id = exp_id 115 | 116 | # create the dump folder / update parameters 117 | params.dump_path = os.path.join(sweep_path, params.exp_id) 118 | if not os.path.isdir(params.dump_path): 119 | subprocess.Popen("mkdir -p %s" % params.dump_path, shell=True).wait() 120 | 121 | 122 | def to_cuda(*args): 123 | """ 124 | Move tensors to CUDA. 125 | """ 126 | return [None if x is None else x.cuda() for x in args] 127 | 128 | 129 | def restore_segmentation(path): 130 | """ 131 | Take a file segmented with BPE and restore it to its original segmentation. 132 | """ 133 | assert os.path.isfile(path) 134 | restore_cmd = "sed -i -r 's/(@@ )|(@@ ?$)//g' %s" 135 | subprocess.Popen(restore_cmd % path, shell=True).wait() 136 | 137 | 138 | def parse_lambda_config(params): 139 | """ 140 | Parse the configuration of lambda coefficient (for scheduling). 141 | x = "3" # lambda will be a constant equal to x 142 | x = "0:1,1000:0" # lambda will start from 1 and linearly decrease to 0 during the first 1000 iterations 143 | x = "0:0,1000:0,2000:1" # lambda will be equal to 0 for the first 1000 iterations, then will linearly increase to 1 until iteration 2000 144 | """ 145 | for name in DYNAMIC_COEFF: 146 | x = getattr(params, name) 147 | split = x.split(',') 148 | if len(split) == 1: 149 | setattr(params, name, float(x)) 150 | setattr(params, name + '_config', None) 151 | else: 152 | split = [s.split(':') for s in split] 153 | assert all(len(s) == 2 for s in split) 154 | assert all(k.isdigit() for k, _ in split) 155 | assert all(int(split[i][0]) < int(split[i + 1][0]) for i in range(len(split) - 1)) 156 | setattr(params, name, float(split[0][1])) 157 | setattr(params, name + '_config', [(int(k), float(v)) for k, v in split]) 158 | 159 | 160 | def get_lambda_value(config, n_iter): 161 | """ 162 | Compute a lambda value according to its schedule configuration. 163 | """ 164 | ranges = [i for i in range(len(config) - 1) if config[i][0] <= n_iter < config[i + 1][0]] 165 | if len(ranges) == 0: 166 | assert n_iter >= config[-1][0] 167 | return config[-1][1] 168 | assert len(ranges) == 1 169 | i = ranges[0] 170 | x_a, y_a = config[i] 171 | x_b, y_b = config[i + 1] 172 | return y_a + (n_iter - x_a) * float(y_b - y_a) / float(x_b - x_a) 173 | 174 | 175 | def update_lambdas(params, n_iter): 176 | """ 177 | Update all lambda coefficients. 178 | """ 179 | for name in DYNAMIC_COEFF: 180 | config = getattr(params, name + '_config') 181 | if config is not None: 182 | setattr(params, name, get_lambda_value(config, n_iter)) 183 | 184 | 185 | def set_sampling_probs(data, params): 186 | """ 187 | Set the probability of sampling specific languages / language pairs during training. 188 | """ 189 | coeff = params.lg_sampling_factor 190 | if coeff == -1: 191 | return 192 | assert coeff > 0 193 | 194 | # monolingual data 195 | params.mono_list = [k for k, v in data['mono_stream'].items() if 'train' in v] 196 | if len(params.mono_list) > 0: 197 | probs = np.array([1.0 * len(data['mono_stream'][lang]['train']) for lang in params.mono_list]) 198 | probs /= probs.sum() 199 | probs = np.array([p ** coeff for p in probs]) 200 | probs /= probs.sum() 201 | params.mono_probs = probs 202 | 203 | # parallel data 204 | params.para_list = [k for k, v in data['para'].items() if 'train' in v] 205 | if len(params.para_list) > 0: 206 | probs = np.array([1.0 * len(data['para'][(l1, l2)]['train']) for (l1, l2) in params.para_list]) 207 | probs /= probs.sum() 208 | probs = np.array([p ** coeff for p in probs]) 209 | probs /= probs.sum() 210 | params.para_probs = probs 211 | 212 | 213 | def concat_batches(x1, len1, lang1_id, x2, len2, lang2_id, pad_idx, eos_idx, reset_positions): 214 | """ 215 | Concat batches with different languages. 216 | """ 217 | assert reset_positions is False or lang1_id != lang2_id 218 | lengths = len1 + len2 219 | if not reset_positions: 220 | lengths -= 1 221 | slen, bs = lengths.max().item(), lengths.size(0) 222 | 223 | x = x1.new(slen, bs).fill_(pad_idx) 224 | x[:len1.max().item()].copy_(x1) 225 | positions = torch.arange(slen)[:, None].repeat(1, bs).to(x1.device) 226 | langs = x1.new(slen, bs).fill_(lang1_id) 227 | 228 | for i in range(bs): 229 | l1 = len1[i] if reset_positions else len1[i] - 1 230 | x[l1:l1 + len2[i], i].copy_(x2[:len2[i], i]) 231 | if reset_positions: 232 | positions[l1:, i] -= len1[i] 233 | langs[l1:, i] = lang2_id 234 | 235 | assert (x == eos_idx).long().sum().item() == (4 if reset_positions else 3) * bs 236 | 237 | return x, lengths, positions, langs 238 | 239 | 240 | def truncate(x, lengths, max_len, eos_index): 241 | """ 242 | Truncate long sentences. 243 | """ 244 | if lengths.max().item() > max_len: 245 | x = x[:max_len].clone() 246 | lengths = lengths.clone() 247 | for i in range(len(lengths)): 248 | if lengths[i] > max_len: 249 | lengths[i] = max_len 250 | x[max_len - 1, i] = eos_index 251 | return x, lengths 252 | 253 | 254 | def shuf_order(langs, params=None, n=5): 255 | """ 256 | Randomize training order. 257 | """ 258 | if len(langs) == 0: 259 | return [] 260 | 261 | if params is None: 262 | return [langs[i] for i in np.random.permutation(len(langs))] 263 | 264 | # sample monolingual and parallel languages separately 265 | mono = [l1 for l1, l2 in langs if l2 is None] 266 | para = [(l1, l2) for l1, l2 in langs if l2 is not None] 267 | 268 | # uniform / weighted sampling 269 | if params.lg_sampling_factor == -1: 270 | p_mono = None 271 | p_para = None 272 | else: 273 | p_mono = np.array([params.mono_probs[params.mono_list.index(k)] for k in mono]) 274 | p_para = np.array([params.para_probs[params.para_list.index(tuple(sorted(k)))] for k in para]) 275 | p_mono = p_mono / p_mono.sum() 276 | p_para = p_para / p_para.sum() 277 | 278 | s_mono = [mono[i] for i in np.random.choice(len(mono), size=min(n, len(mono)), p=p_mono, replace=True)] if len(mono) > 0 else [] 279 | s_para = [para[i] for i in np.random.choice(len(para), size=min(n, len(para)), p=p_para, replace=True)] if len(para) > 0 else [] 280 | 281 | assert len(s_mono) + len(s_para) > 0 282 | return [(lang, None) for lang in s_mono] + s_para 283 | 284 | 285 | def find_modules(module, module_name, module_instance, found): 286 | """ 287 | Recursively find all instances of a specific module inside a module. 288 | """ 289 | if isinstance(module, module_instance): 290 | found.append((module_name, module)) 291 | else: 292 | for name, child in module.named_children(): 293 | name = ('%s[%s]' if name.isdigit() else '%s.%s') % (module_name, name) 294 | find_modules(child, name, module_instance, found) 295 | --------------------------------------------------------------------------------