├── .gitignore ├── CONTRIBUTING.md ├── LICENSE.TXT ├── MASS-summarization ├── README.md ├── encode.py └── mass │ ├── __init__.py │ ├── bert_dictionary.py │ ├── learned_positional_embedding.py │ ├── masked_dataset.py │ ├── masked_s2s.py │ ├── s2s_model.py │ └── translation.py ├── MASS-supNMT ├── README.md ├── archi_mass_sup.png ├── archi_mass_sup_md.png ├── ft_mass_enzh.sh ├── generate_enzh_data.sh ├── mass │ ├── __init__.py │ ├── masked_language_pair_dataset.py │ ├── noisy_language_pair_dataset.py │ ├── xmasked_seq2seq.py │ └── xtransformer.py ├── run_mass_enzh.sh └── translate.sh ├── MASS-unsupNMT ├── filter_noisy_data.py ├── get-data-bilingual-enro-nmt.sh ├── get-data-gigaword.sh ├── get-data-nmt.sh ├── install-tools.sh ├── preprocess.py ├── src │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── dataset.py │ │ ├── dictionary.py │ │ └── loader.py │ ├── evaluation │ │ ├── __init__.py │ │ ├── evaluator.py │ │ └── multi-bleu.perl │ ├── fp16.py │ ├── logger.py │ ├── model │ │ ├── __init__.py │ │ └── transformer.py │ ├── slurm.py │ ├── trainer.py │ └── utils.py ├── train.py ├── translate.py └── translate_ensemble.py ├── NOTICE.md ├── README.md ├── SECURITY.md └── figs └── mass.png /.gitignore: -------------------------------------------------------------------------------- 1 | ## Ignore Visual Studio temporary files, build results, and 2 | ## files generated by popular Visual Studio add-ons. 3 | ## 4 | ## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore 5 | 6 | # User-specific files 7 | *.suo 8 | *.user 9 | *.userosscache 10 | *.sln.docstates 11 | 12 | # User-specific files (MonoDevelop/Xamarin Studio) 13 | *.userprefs 14 | 15 | # Build results 16 | [Dd]ebug/ 17 | [Dd]ebugPublic/ 18 | [Rr]elease/ 19 | [Rr]eleases/ 20 | x64/ 21 | x86/ 22 | bld/ 23 | [Bb]in/ 24 | [Oo]bj/ 25 | [Ll]og/ 26 | 27 | # Visual Studio 2015/2017 cache/options directory 28 | .vs/ 29 | # Uncomment if you have tasks that create the project's static files in wwwroot 30 | #wwwroot/ 31 | 32 | # Visual Studio 2017 auto generated files 33 | Generated\ Files/ 34 | 35 | # MSTest test Results 36 | [Tt]est[Rr]esult*/ 37 | [Bb]uild[Ll]og.* 38 | 39 | # NUNIT 40 | *.VisualState.xml 41 | TestResult.xml 42 | 43 | # Build Results of an ATL Project 44 | [Dd]ebugPS/ 45 | [Rr]eleasePS/ 46 | dlldata.c 47 | 48 | # Benchmark Results 49 | BenchmarkDotNet.Artifacts/ 50 | 51 | # .NET Core 52 | project.lock.json 53 | project.fragment.lock.json 54 | artifacts/ 55 | **/Properties/launchSettings.json 56 | 57 | # StyleCop 58 | StyleCopReport.xml 59 | 60 | # Files built by Visual Studio 61 | *_i.c 62 | *_p.c 63 | *_i.h 64 | *.ilk 65 | *.meta 66 | *.obj 67 | *.iobj 68 | *.pch 69 | *.pdb 70 | *.ipdb 71 | *.pgc 72 | *.pgd 73 | *.rsp 74 | *.sbr 75 | *.tlb 76 | *.tli 77 | *.tlh 78 | *.tmp 79 | *.tmp_proj 80 | *.log 81 | *.vspscc 82 | *.vssscc 83 | .builds 84 | *.pidb 85 | *.svclog 86 | *.scc 87 | 88 | # Chutzpah Test files 89 | _Chutzpah* 90 | 91 | # Visual C++ cache files 92 | ipch/ 93 | *.aps 94 | *.ncb 95 | *.opendb 96 | *.opensdf 97 | *.sdf 98 | *.cachefile 99 | *.VC.db 100 | *.VC.VC.opendb 101 | 102 | # Visual Studio profiler 103 | *.psess 104 | *.vsp 105 | *.vspx 106 | *.sap 107 | 108 | # Visual Studio Trace Files 109 | *.e2e 110 | 111 | # TFS 2012 Local Workspace 112 | $tf/ 113 | 114 | # Guidance Automation Toolkit 115 | *.gpState 116 | 117 | # ReSharper is a .NET coding add-in 118 | _ReSharper*/ 119 | *.[Rr]e[Ss]harper 120 | *.DotSettings.user 121 | 122 | # JustCode is a .NET coding add-in 123 | .JustCode 124 | 125 | # TeamCity is a build add-in 126 | _TeamCity* 127 | 128 | # DotCover is a Code Coverage Tool 129 | *.dotCover 130 | 131 | # AxoCover is a Code Coverage Tool 132 | .axoCover/* 133 | !.axoCover/settings.json 134 | 135 | # Visual Studio code coverage results 136 | *.coverage 137 | *.coveragexml 138 | 139 | # NCrunch 140 | _NCrunch_* 141 | .*crunch*.local.xml 142 | nCrunchTemp_* 143 | 144 | # MightyMoose 145 | *.mm.* 146 | AutoTest.Net/ 147 | 148 | # Web workbench (sass) 149 | .sass-cache/ 150 | 151 | # Installshield output folder 152 | [Ee]xpress/ 153 | 154 | # DocProject is a documentation generator add-in 155 | DocProject/buildhelp/ 156 | DocProject/Help/*.HxT 157 | DocProject/Help/*.HxC 158 | DocProject/Help/*.hhc 159 | DocProject/Help/*.hhk 160 | DocProject/Help/*.hhp 161 | DocProject/Help/Html2 162 | DocProject/Help/html 163 | 164 | # Click-Once directory 165 | publish/ 166 | 167 | # Publish Web Output 168 | *.[Pp]ublish.xml 169 | *.azurePubxml 170 | # Note: Comment the next line if you want to checkin your web deploy settings, 171 | # but database connection strings (with potential passwords) will be unencrypted 172 | *.pubxml 173 | *.publishproj 174 | 175 | # Microsoft Azure Web App publish settings. Comment the next line if you want to 176 | # checkin your Azure Web App publish settings, but sensitive information contained 177 | # in these scripts will be unencrypted 178 | PublishScripts/ 179 | 180 | # NuGet Packages 181 | *.nupkg 182 | # The packages folder can be ignored because of Package Restore 183 | **/[Pp]ackages/* 184 | # except build/, which is used as an MSBuild target. 185 | !**/[Pp]ackages/build/ 186 | # Uncomment if necessary however generally it will be regenerated when needed 187 | #!**/[Pp]ackages/repositories.config 188 | # NuGet v3's project.json files produces more ignorable files 189 | *.nuget.props 190 | *.nuget.targets 191 | 192 | # Microsoft Azure Build Output 193 | csx/ 194 | *.build.csdef 195 | 196 | # Microsoft Azure Emulator 197 | ecf/ 198 | rcf/ 199 | 200 | # Windows Store app package directories and files 201 | AppPackages/ 202 | BundleArtifacts/ 203 | Package.StoreAssociation.xml 204 | _pkginfo.txt 205 | *.appx 206 | 207 | # Visual Studio cache files 208 | # files ending in .cache can be ignored 209 | *.[Cc]ache 210 | # but keep track of directories ending in .cache 211 | !*.[Cc]ache/ 212 | 213 | # Others 214 | ClientBin/ 215 | ~$* 216 | *~ 217 | *.dbmdl 218 | *.dbproj.schemaview 219 | *.jfm 220 | *.pfx 221 | *.publishsettings 222 | orleans.codegen.cs 223 | 224 | # Including strong name files can present a security risk 225 | # (https://github.com/github/gitignore/pull/2483#issue-259490424) 226 | #*.snk 227 | 228 | # Since there are multiple workflows, uncomment next line to ignore bower_components 229 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) 230 | #bower_components/ 231 | 232 | # RIA/Silverlight projects 233 | Generated_Code/ 234 | 235 | # Backup & report files from converting an old project file 236 | # to a newer Visual Studio version. Backup files are not needed, 237 | # because we have git ;-) 238 | _UpgradeReport_Files/ 239 | Backup*/ 240 | UpgradeLog*.XML 241 | UpgradeLog*.htm 242 | ServiceFabricBackup/ 243 | *.rptproj.bak 244 | 245 | # SQL Server files 246 | *.mdf 247 | *.ldf 248 | *.ndf 249 | 250 | # Business Intelligence projects 251 | *.rdl.data 252 | *.bim.layout 253 | *.bim_*.settings 254 | *.rptproj.rsuser 255 | 256 | # Microsoft Fakes 257 | FakesAssemblies/ 258 | 259 | # GhostDoc plugin setting file 260 | *.GhostDoc.xml 261 | 262 | # Node.js Tools for Visual Studio 263 | .ntvs_analysis.dat 264 | node_modules/ 265 | 266 | # Visual Studio 6 build log 267 | *.plg 268 | 269 | # Visual Studio 6 workspace options file 270 | *.opt 271 | 272 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.) 273 | *.vbw 274 | 275 | # Visual Studio LightSwitch build output 276 | **/*.HTMLClient/GeneratedArtifacts 277 | **/*.DesktopClient/GeneratedArtifacts 278 | **/*.DesktopClient/ModelManifest.xml 279 | **/*.Server/GeneratedArtifacts 280 | **/*.Server/ModelManifest.xml 281 | _Pvt_Extensions 282 | 283 | # Paket dependency manager 284 | .paket/paket.exe 285 | paket-files/ 286 | 287 | # FAKE - F# Make 288 | .fake/ 289 | 290 | # JetBrains Rider 291 | .idea/ 292 | *.sln.iml 293 | 294 | # CodeRush 295 | .cr/ 296 | 297 | # Python Tools for Visual Studio (PTVS) 298 | __pycache__/ 299 | *.pyc 300 | 301 | # Cake - Uncomment if you are using it 302 | # tools/** 303 | # !tools/packages.config 304 | 305 | # Tabs Studio 306 | *.tss 307 | 308 | # Telerik's JustMock configuration file 309 | *.jmconfig 310 | 311 | # BizTalk build output 312 | *.btp.cs 313 | *.btm.cs 314 | *.odx.cs 315 | *.xsd.cs 316 | 317 | # OpenCover UI analysis results 318 | OpenCover/ 319 | 320 | # Azure Stream Analytics local run output 321 | ASALocalRun/ 322 | 323 | # MSBuild Binary and Structured Log 324 | *.binlog 325 | 326 | # NVidia Nsight GPU debugger configuration file 327 | *.nvuser 328 | 329 | # MFractors (Xamarin productivity tool) working folder 330 | .mfractor/ 331 | 332 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | This project welcomes contributions and suggestions. Most contributions require you to 4 | agree to a Contributor License Agreement (CLA) declaring that you have the right to, 5 | and actually do, grant us the rights to use your contribution. For details, visit 6 | https://cla.microsoft.com. 7 | 8 | When you submit a pull request, a CLA-bot will automatically determine whether you need 9 | to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the 10 | instructions provided by the bot. You will only need to do this once across all repositories using our CLA. 11 | 12 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 13 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 14 | or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 15 | 16 | -------------------------------------------------------------------------------- /LICENSE.TXT: -------------------------------------------------------------------------------- 1 | [MASS: Masked Sequence to Sequence Pre-training for Language Generation] 2 | 3 | Copyright (c) Microsoft Corporation 4 | 5 | All rights reserved. 6 | 7 | 8 | 9 | MIT License 10 | 11 | 12 | 13 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the Software), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 14 | 15 | 16 | 17 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 18 | 19 | 20 | 21 | THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 22 | 23 | 24 | -------------------------------------------------------------------------------- /MASS-summarization/README.md: -------------------------------------------------------------------------------- 1 | # MASS-SUM 2 | 3 | 6 | 7 | ## Dependency 8 | ``` 9 | pip install torch==1.0.0 10 | pip install fairseq==0.8.0 11 | ``` 12 | 13 | ## MODEL 14 | MASS uses default Transformer structure. We denote L, H, A as the number of layers, the hidden size and the number of attention heads. 15 | 16 | | Model | Encoder | Decoder | Download | 17 | | :------| :-----|:-----|:-----| 18 | | MASS-base-uncased | 6L-768H-12A | 6L-768H-12A | [MODEL](https://modelrelease.blob.core.windows.net/mass/mass-base-uncased.tar.gz) | 19 | | MASS-middle-uncased | 6L-1024H-16A | 6L-1024H-16A | [MODEL](https://modelrelease.blob.core.windows.net/mass/mass-middle-uncased.tar.gz) | 20 | 21 | ## Results on Abstractive Summarization (9/27/2019) 22 | 23 | | Dataset | Params | RG-1 | RG-2 | RG-L | FT model | 24 | | ------| ----- | ---- | ---- | ---- | :----: | 25 | | CNN/Daily Mail | 123M | 42.12 | 19.50 | 39.01 | [MODEL](https://modelrelease.blob.core.windows.net/mass/cnndm_evaluation.tar.gz) | 26 | | Gigaword | 123M | 38.73 | 19.71| 35.96 | [MODEL](https://modelrelease.blob.core.windows.net/mass/gigaword_evaluation.tar.gz) | 27 | | XSum | 123M | 39.75 | 17.24 | 31.95 | | 28 | | CNN/Daily Mail | 208M | 42.90 | 19.87 | 39.80 | | 29 | | Gigaword | 208M | 38.93 | 20.20 | 36.20 | | 30 | 31 | Evaluated by [files2rouge](https://github.com/pltrdy/files2rouge). `FT model` means `Fine-tuned model`. 32 | 33 | ## Pipeline for Pre-Training 34 | ### Download data 35 | Our model is trained on Wikipekia + BookCorpus. Here we use wikitext-103 to demonstrate how to process data. 36 | ``` 37 | wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip 38 | unzip wikitext-103-raw-v1.zip 39 | ``` 40 | 41 | ### Tokenize corpus 42 | We use wordpiece vocabuary (from bert) to tokenize the original text data directly. We provide a [script](encode.py) to deal with data. You need to `pip install pytorch_transformers` first to generate tokenized data. 43 | ``` 44 | mkdir -p mono 45 | for SPLIT in train valid test; do 46 | python encode.py \ 47 | --inputs wikitext-103-raw/wiki.${SPLIT}.raw \ 48 | --outputs mono/${SPLIT}.txt \ 49 | --workers 60; \ 50 | done 51 | ``` 52 | 53 | ### Binarized data 54 | ``` 55 | wget -c https://modelrelease.blob.core.windows.net/mass/mass-base-uncased.tar.gz 56 | tar -zxvf mass-base-uncased.tar.gz 57 | # Move dict.txt from tar file to the data directory 58 | 59 | fairseq-preprocess \ 60 | --user-dir mass --only-source --task masked_s2s \ 61 | --trainpref mono/train.txt --validpref mono/valid.txt --testpref mono/test.txt \ 62 | --destdir processed --srcdict dict.txt --workers 60 63 | ``` 64 | 65 | ### Pre-training 66 | ``` 67 | TOKENS_PER_SAMPLE=512 68 | WARMUP_UPDATES=10000 69 | PEAK_LR=0.0005 70 | TOTAL_UPDATES=125000 71 | MAX_SENTENCES=8 72 | UPDATE_FREQ=16 73 | 74 | fairseq-train processed \ 75 | --user-dir mass --task masked_s2s --arch transformer_mass_base \ 76 | --sample-break-mode none \ 77 | --tokens-per-sample $TOKENS_PER_SAMPLE \ 78 | --criterion masked_lm \ 79 | --optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-6 --clip-norm 0.0 \ 80 | --lr-scheduler polynomial_decay --lr $PEAK_LR --warmup-updates $WARMUP_UPDATES --total-num-update $TOTAL_UPDATES \ 81 | --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \ 82 | --max-sentences $MAX_SENTENCES --update-freq $UPDATE_FREQ \ 83 | --ddp-backend=no_c10d \ 84 | ``` 85 | ## Pipeline for Fine-tuning (CNN / Daily Mail) 86 | 87 | ### Data 88 | Download, tokenize and truncate data from this [link](https://github.com/abisee/cnn-dailymail), and use the above [tokenization](#tokenize-corpus) to generate wordpiece-level data. Rename the suffix `article` and `title` as `src` and `tgt`. Assume the tokenized data is under `cnndm/para` 89 | ``` 90 | fairseq-preprocess \ 91 | --user-dir mass --task masked_s2s \ 92 | --source-lang src --target-lang tgt \ 93 | --trainpref cnndm/para/train --validpref cnndm/para/valid --testpref cnndm/para/test \ 94 | --destdir cnndm/processed --srcdict dict.txt --tgtdict dict.txt \ 95 | --workers 20 96 | ``` 97 | `dict.txt` is included in `mass-base-uncased.tar.gz`. A copy of binarized data can be obtained from [here](https://modelrelease.blob.core.windows.net/mass/cnndm.tar.gz). 98 | 99 | 100 | ### Running 101 | ``` 102 | fairseq-train cnndm/processed/ \ 103 | --user-dir mass --task translation_mass --arch transformer_mass_base \ 104 | --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ 105 | --lr 0.0005 --min-lr 1e-09 \ 106 | --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \ 107 | --weight-decay 0.0 \ 108 | --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ 109 | --update-freq 8 --max-tokens 4096 \ 110 | --ddp-backend=no_c10d --max-epoch 25 \ 111 | --max-source-positions 512 --max-target-positions 512 \ 112 | --skip-invalid-size-inputs-valid-test \ 113 | --load-from-pretrained-model mass-base-uncased.pt \ 114 | ``` 115 | `lr=0.0005` is not the optimal choice for any task. It is tuned on the dev set (among 1e-4, 2e-4, 5e-4). 116 | ### Inference 117 | ``` 118 | MODEL=checkpoints/checkpoint_best.pt 119 | fairseq-generate $DATADIR --path $MODEL \ 120 | --user-dir mass --task translation_mass \ 121 | --batch-size 64 --beam 5 --min-len 50 --no-repeat-ngram-size 3 \ 122 | --lenpen 1.0 \ 123 | ``` 124 | `min-len` is sensitive for different tasks, `lenpen` needs to be tuned on the dev set. Restore the results to the word-level data by using `sed 's/ ##//g'`. 125 | 126 | 127 | ## Other questions 128 | 1. Q: I have met error like `ModuleNotFouldError: No module named 'mass'` in multi-GPUs or multi-nodes, how to solve it? 129 | A: It seems like a bug in python `multiprocessing/spawn.py`. A direct solution is to move these three files to its corresponding folder in the fairseq. For example: 130 | ``` 131 | mv bert_dictionary.py fairseq/fairseq/data/ 132 | mv masked_dataset.py fairseq/fairseq/data/ 133 | mv learned_positional_embedding.py fairseq/fairseq/modules/ 134 | modify fairseq/fairseq/data/__init__.py to import the above files. 135 | ``` 136 | 137 | 146 | -------------------------------------------------------------------------------- /MASS-summarization/encode.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import contextlib 3 | import sys 4 | 5 | from collections import Counter 6 | from multiprocessing import Pool 7 | 8 | from pytorch_transformers import BertTokenizer 9 | 10 | 11 | def main(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument( 14 | "--inputs", 15 | nargs="+", 16 | default=['-'], 17 | help="input files to filter/encode", 18 | ) 19 | parser.add_argument( 20 | "--outputs", 21 | nargs="+", 22 | default=['-'], 23 | help="path to save encoded outputs", 24 | ) 25 | parser.add_argument( 26 | "--keep-empty", 27 | action="store_true", 28 | help="keep empty lines", 29 | ) 30 | parser.add_argument("--workers", type=int, default=20) 31 | args = parser.parse_args() 32 | 33 | assert len(args.inputs) == len(args.outputs), \ 34 | "number of input and output paths should match" 35 | 36 | with contextlib.ExitStack() as stack: 37 | inputs = [ 38 | stack.enter_context(open(input, "r", encoding="utf-8")) 39 | if input != "-" else sys.stdin 40 | for input in args.inputs 41 | ] 42 | outputs = [ 43 | stack.enter_context(open(output, "w", encoding="utf-8")) 44 | if output != "-" else sys.stdout 45 | for output in args.outputs 46 | ] 47 | 48 | encoder = MultiprocessingEncoder(args) 49 | pool = Pool(args.workers, initializer=encoder.initializer) 50 | encoded_lines = pool.imap(encoder.encode_lines, zip(*inputs), 100) 51 | 52 | stats = Counter() 53 | for i, (filt, enc_lines) in enumerate(encoded_lines, start=1): 54 | if filt == "PASS": 55 | for enc_line, output_h in zip(enc_lines, outputs): 56 | print(enc_line, file=output_h) 57 | else: 58 | stats["num_filtered_" + filt] += 1 59 | if i % 10000 == 0: 60 | print("processed {} lines".format(i), file=sys.stderr) 61 | 62 | for k, v in stats.most_common(): 63 | print("[{}] filtered {} lines".format(k, v), file=sys.stderr) 64 | 65 | 66 | class MultiprocessingEncoder(object): 67 | 68 | def __init__(self, args): 69 | self.args = args 70 | 71 | def initializer(self): 72 | global bpe 73 | bpe = BertTokenizer.from_pretrained('bert-base-uncased') 74 | 75 | def encode(self, line): 76 | global bpe 77 | subword = bpe._tokenize(line) 78 | return subword 79 | 80 | def decode(self, tokens): 81 | global bpe 82 | return bpe.decode(tokens) 83 | 84 | def encode_lines(self, lines): 85 | """ 86 | Encode a set of lines. All lines will be encoded together. 87 | """ 88 | enc_lines = [] 89 | for line in lines: 90 | line = line.strip() 91 | if len(line) == 0 and not self.args.keep_empty: 92 | return ["EMPTY", None] 93 | tokens = self.encode(line) 94 | enc_lines.append(" ".join(tokens)) 95 | return ["PASS", enc_lines] 96 | 97 | def decode_lines(self, lines): 98 | dec_lines = [] 99 | for line in lines: 100 | tokens = map(int, line.strip().split()) 101 | dec_lines.append(self.decode(tokens)) 102 | return ["PASS", dec_lines] 103 | 104 | 105 | if __name__ == "__main__": 106 | main() 107 | -------------------------------------------------------------------------------- /MASS-summarization/mass/__init__.py: -------------------------------------------------------------------------------- 1 | from . import masked_s2s 2 | from . import s2s_model 3 | from . import translation 4 | -------------------------------------------------------------------------------- /MASS-summarization/mass/bert_dictionary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from collections import Counter 7 | from multiprocessing import Pool 8 | import os 9 | 10 | import torch 11 | 12 | from fairseq.tokenizer import tokenize_line 13 | from fairseq.binarizer import safe_readline 14 | from fairseq.data import data_utils, Dictionary 15 | 16 | 17 | class BertDictionary(Dictionary): 18 | """A mapping from symbols to consecutive integers""" 19 | 20 | def __init__( 21 | self, 22 | pad='', 23 | eos='', 24 | unk='', 25 | bos='', 26 | extra_special_symbols=None, 27 | ): 28 | super().__init__(pad, eos, unk, bos, extra_special_symbols) 29 | 30 | @classmethod 31 | def load_from_file(cls, filename): 32 | d = cls() 33 | d.symbols = [] 34 | d.count = [] 35 | d.indices = {} 36 | 37 | with open(filename, 'r', encoding='utf-8', errors='ignore') as input_file: 38 | for line in input_file: 39 | k, v = line.split() 40 | d.add_symbol(k) 41 | 42 | d.unk_word = '[UNK]' 43 | d.pad_word = '[PAD]' 44 | d.eos_word = '[SEP]' 45 | d.bos_word = '[CLS]' 46 | 47 | d.bos_index = d.add_symbol('[CLS]') 48 | d.pad_index = d.add_symbol('[PAD]') 49 | d.eos_index = d.add_symbol('[SEP]') 50 | d.unk_index = d.add_symbol('[UNK]') 51 | 52 | d.nspecial = 999 53 | return d 54 | 55 | def save(self, f): 56 | """Stores dictionary into a text file""" 57 | ex_keys, ex_vals = self._get_meta() 58 | self._save(f, zip(ex_keys + self.symbols, ex_vals + self.count)) 59 | -------------------------------------------------------------------------------- /MASS-summarization/mass/learned_positional_embedding.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from fairseq import utils 4 | 5 | 6 | class LearnedPositionalEmbedding(nn.Embedding): 7 | """ 8 | This module learns positional embeddings up to a fixed maximum size. 9 | Padding ids are ignored by either offsetting based on padding_idx 10 | or by setting padding_idx to None and ensuring that the appropriate 11 | position ids are passed to the forward function. 12 | """ 13 | 14 | def __init__( 15 | self, 16 | num_embeddings: int, 17 | embedding_dim: int, 18 | padding_idx: int, 19 | ): 20 | super().__init__(num_embeddings, embedding_dim, padding_idx) 21 | self.onnx_trace = False 22 | 23 | def forward(self, input, incremental_state=None, positions=None): 24 | """Input is expected to be of size [bsz x seqlen].""" 25 | assert ( 26 | (positions is None) or (self.padding_idx is None) 27 | ), "If positions is pre-computed then padding_idx should not be set." 28 | 29 | if positions is None: 30 | if incremental_state is not None: 31 | # positions is the same for every token when decoding a single step 32 | # Without the int() cast, it doesn't work in some cases when exporting to ONNX 33 | positions = input.data.new(1, 1).fill_(int(self.padding_idx + input.size(1))) 34 | else: 35 | positions = utils.make_positions( 36 | input.data, self.padding_idx, onnx_trace=self.onnx_trace, 37 | ) 38 | return super().forward(positions) 39 | 40 | def max_positions(self): 41 | """Maximum number of supported positions.""" 42 | if self.padding_idx is not None: 43 | return self.num_embeddings - self.padding_idx - 1 44 | else: 45 | return self.num_embeddings 46 | 47 | def _forward(self, positions): 48 | return super().forward(positions) 49 | -------------------------------------------------------------------------------- /MASS-summarization/mass/masked_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | import time 5 | import math 6 | 7 | from fairseq import utils 8 | from fairseq.data import data_utils, LanguagePairDataset 9 | 10 | 11 | class MaskedLanguagePairDataset(LanguagePairDataset): 12 | """ Wrapper for masked language datasets 13 | (support monolingual and bilingual) 14 | 15 | For monolingual dataset: 16 | [x1, x2, x3, x4, x5] 17 | || 18 | VV 19 | [x1, _, _, x4, x5] => [x2, x3] 20 | 21 | default, _ will be replaced by 8:1:1 (mask, self, rand), 22 | """ 23 | def __init__( 24 | self, 25 | src, src_sizes, src_dict, 26 | tgt=None, tgt_sizes=None, tgt_dict=None, 27 | left_pad_source=True, left_pad_target=False, 28 | max_source_positions=1024, max_target_positions=1024, 29 | shuffle=True, mask_prob=0.15, pred_probs=None, block_size=64, 30 | ): 31 | self.src = src 32 | self.tgt = tgt 33 | self.src_sizes = src_sizes 34 | self.tgt_sizes = tgt_sizes 35 | self.src_dict = src_dict 36 | self.tgt_dict = tgt_dict 37 | self.left_pad_source = left_pad_source 38 | self.left_pad_target = left_pad_target 39 | self.shuffle = shuffle 40 | 41 | self.mask_prob = mask_prob 42 | self.pred_probs = pred_probs 43 | self.block_size = block_size 44 | 45 | def __getitem__(self, index): 46 | pkgs = {'id': index} 47 | tgt_item = self.tgt[index] if self.tgt is not None else None 48 | src_item = self.src[index] 49 | 50 | positions = np.arange(0, len(self.src[index])) 51 | masked_pos = [] 52 | for i in range(1, len(src_item), self.block_size): 53 | block = positions[i: i + self.block_size] 54 | masked_len = int(len(block) * self.mask_prob) 55 | masked_block_start = np.random.choice(block[:len(block) - int(masked_len) + 1], 1)[0] 56 | masked_pos.extend(positions[masked_block_start : masked_block_start + masked_len]) 57 | masked_pos = np.array(masked_pos) 58 | 59 | pkgs['target'] = src_item[masked_pos].clone() 60 | pkgs['prev_output_tokens'] = src_item[masked_pos - 1].clone() 61 | pkgs['positions'] = torch.LongTensor(masked_pos) + self.src_dict.pad_index 62 | src_item[masked_pos] = self.replace(src_item[masked_pos]) 63 | pkgs['source'] = src_item 64 | return pkgs 65 | 66 | def collate(self, samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=False): 67 | if len(samples) == 0: 68 | return {} 69 | 70 | def merge(x, left_pad, move_eos_to_beginning=False): 71 | return data_utils.collate_tokens( 72 | x, pad_idx, eos_idx, left_pad, move_eos_to_beginning 73 | ) 74 | 75 | id = torch.LongTensor([s['id'] for s in samples]) 76 | source = merge([s['source'] for s in samples], left_pad=left_pad_source) 77 | src_lengths = torch.LongTensor([s['source'].numel() for s in samples]) 78 | 79 | prev_output_tokens = merge([s['prev_output_tokens'] for s in samples], left_pad=left_pad_target) 80 | positions = merge([s['positions'] for s in samples], left_pad=left_pad_target) 81 | target = merge([s['target'] for s in samples], left_pad=left_pad_target) 82 | ntokens = target.numel() 83 | 84 | batch = { 85 | 'id' : id, 86 | 'nsentences': len(samples), 87 | 'net_input' : { 88 | 'src_lengths': src_lengths, 89 | 'src_tokens' : source, 90 | 'prev_output_tokens': prev_output_tokens, 91 | 'positions' : positions, 92 | }, 93 | 'target' : target, 94 | 'ntokens': ntokens, 95 | } 96 | return batch 97 | 98 | def collater(self, samples): 99 | return self.collate(samples, self.src_dict.pad(), self.src_dict.eos()) 100 | 101 | def size(self, index): 102 | return self.src.sizes[index] 103 | 104 | def replace(self, x): 105 | _x_real = x 106 | _x_rand = _x_real.clone().random_(self.src_dict.nspecial, len(self.src_dict)) 107 | _x_mask = _x_real.clone().fill_(self.src_dict.index('[MASK]')) 108 | probs = torch.multinomial(self.pred_probs, len(x), replacement=True) 109 | _x = _x_mask * (probs == 0).long() + \ 110 | _x_real * (probs == 1).long() + \ 111 | _x_rand * (probs == 2).long() 112 | return _x 113 | -------------------------------------------------------------------------------- /MASS-summarization/mass/masked_s2s.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | import numpy as np 8 | 9 | import torch 10 | 11 | from collections import OrderedDict 12 | from fairseq import utils 13 | from fairseq.data import ( 14 | data_utils, 15 | Dictionary, 16 | TokenBlockDataset, 17 | ) 18 | from fairseq.tasks import FairseqTask, register_task 19 | from .masked_dataset import MaskedLanguagePairDataset 20 | from .bert_dictionary import BertDictionary 21 | 22 | 23 | @register_task('masked_s2s') 24 | class MaskedS2STask(FairseqTask): 25 | """ 26 | Train a sequence-to-sequence task 27 | 28 | Args: 29 | dictionary (~fairseq.data.Dictionary): the dictionary for the input of 30 | the language model 31 | """ 32 | 33 | @staticmethod 34 | def add_args(parser): 35 | """Add task-specific arguments to the parser.""" 36 | # fmt: off 37 | parser.add_argument('data', help='path to data directory') 38 | parser.add_argument('--sample-break-mode', default='none', 39 | choices=['none', 'complete', 'complete_doc', 'eos'], 40 | help='If omitted or "none", fills each sample with tokens-per-sample ' 41 | 'tokens. If set to "complete", splits samples only at the end ' 42 | 'of sentence, but may include multiple sentences per sample. ' 43 | '"complete_doc" is similar but respects doc boundaries. ' 44 | 'If set to "eos", includes only one sentence per sample.') 45 | parser.add_argument('--tokens-per-sample', default=512, type=int, 46 | help='max number of tokens per sample for text dataset') 47 | parser.add_argument('--lazy-load', action='store_true', 48 | help='load the dataset lazily') 49 | parser.add_argument('--raw-text', default=False, action='store_true', 50 | help='load raw text dataset') 51 | 52 | parser.add_argument('--mask-s2s-prob', default=0.15, type=float, 53 | help='probability of replacing a token with mask') 54 | parser.add_argument('--mask-s2s-mask-keep-rand', default="0.8,0.1,0.1", type=str, 55 | help='Word prediction probability for decoder mask') 56 | 57 | # fmt: on 58 | 59 | def __init__(self, args, dictionary): 60 | super().__init__(args) 61 | self.dictionary = dictionary 62 | 63 | @classmethod 64 | def setup_task(cls, args, **kwargs): 65 | """Setup the task (e.g., load dictionaries). 66 | 67 | Args: 68 | args (argparse.Namespace): parsed command-line arguments 69 | """ 70 | if getattr(args, 'raw_text', False): 71 | utils.deprecation_warning('--raw-text is deprecated, please use --dataset-impl=raw') 72 | args.dataset_impl = 'raw' 73 | elif getattr(args, 'lazy_load', False): 74 | utils.deprecation_warning('--lazy-load is deprecated, please use --dataset-impl=lazy') 75 | args.dataset_impl = 'lazy' 76 | 77 | paths = args.data.split(':') 78 | 79 | dictionary = cls.load_dictionary(os.path.join(paths[0], 'dict.txt')) 80 | print('| dictionary: {} types'.format(len(dictionary))) 81 | return cls(args, dictionary) 82 | 83 | @classmethod 84 | def load_dictionary(cls, filename): 85 | return BertDictionary.load_from_file(filename) 86 | 87 | def train_step(self, sample, model, criterion, optimizer, ignore_grad=False): 88 | model.train() 89 | loss, sample_size, logging_output = criterion(model, sample) 90 | if ignore_grad: 91 | loss *= 0 92 | optimizer.backward(loss) 93 | return loss, sample_size, logging_output 94 | 95 | def valid_step(self, sample, model, criterion): 96 | model.eval() 97 | with torch.no_grad(): 98 | loss, sample_size, logging_output = criterion(model, sample) 99 | return loss, sample_size, logging_output 100 | 101 | def build_model(self, args): 102 | from fairseq import models 103 | model = models.build_model(args, self) 104 | return model 105 | 106 | def load_dataset(self, split, epoch=0, combine=False, **kwargs): 107 | """Load a given dataset split. 108 | 109 | Args: 110 | split (str): name of the split (e.g., train, valid, test) 111 | """ 112 | paths = self.args.data.split(':') 113 | assert len(paths) > 0 114 | data_path = paths[epoch % len(paths)] 115 | split_path = os.path.join(data_path, split) 116 | 117 | dataset = data_utils.load_indexed_dataset( 118 | split_path, 119 | self.dictionary, 120 | self.args.dataset_impl, 121 | combine=combine, 122 | ) 123 | if dataset is None: 124 | raise FileNotFoundError('Dataset not found: {} ({})'.format(split, split_path)) 125 | 126 | self.datasets[split] = self.build_s2s_dataset(dataset) 127 | 128 | def build_s2s_dataset(self, dataset): 129 | dataset = TokenBlockDataset( 130 | dataset, 131 | dataset.sizes, 132 | self.args.tokens_per_sample, 133 | pad=self.source_dictionary.pad(), 134 | eos=self.source_dictionary.eos(), 135 | break_mode=self.args.sample_break_mode, 136 | ) 137 | 138 | pred_probs = torch.FloatTensor([float(x) for x in self.args.mask_s2s_mask_keep_rand.split(',')]) 139 | 140 | s2s_dataset = MaskedLanguagePairDataset( 141 | dataset, dataset.sizes, self.source_dictionary, 142 | shuffle=True, mask_prob=self.args.mask_s2s_prob, 143 | pred_probs=pred_probs, 144 | ) 145 | return s2s_dataset 146 | 147 | def build_dataset_for_inference(self, src_tokens, src_lengths): 148 | raise NotImplementedError 149 | 150 | def inference_step(self, generator, models, sample, prefix_tokens=None): 151 | raise NotImplementedError 152 | 153 | @property 154 | def source_dictionary(self): 155 | """Return the :class:`~fairseq.data.Dictionary` for the language 156 | model.""" 157 | return self.dictionary 158 | 159 | @property 160 | def target_dictionary(self): 161 | """Return the :class:`~fairseq.data.Dictionary` for the language 162 | model.""" 163 | return self.dictionary 164 | 165 | def max_positions(self): 166 | max_positions = 1024 167 | if hasattr(self.args, 'max_positions'): 168 | max_positions = min(max_positions, self.args.max_positions) 169 | if hasattr(self.args, 'max_source_positions'): 170 | max_positions = min(max_positions, self.args.max_source_positions) 171 | if hasattr(self.args, 'max_target_positions'): 172 | max_positions = min(max_positions, self.args.max_target_positions) 173 | return (max_positions, max_positions) 174 | -------------------------------------------------------------------------------- /MASS-summarization/mass/translation.py: -------------------------------------------------------------------------------- 1 | #from fairseq.data import BertDictionary 2 | 3 | from fairseq.tasks import register_task 4 | from fairseq.tasks.translation import TranslationTask 5 | 6 | from .bert_dictionary import BertDictionary 7 | 8 | 9 | @register_task('translation_mass') 10 | class TranslationMASSTask(TranslationTask): 11 | def __init__(self, args, src_dict, tgt_dict): 12 | super().__init__(args, src_dict, tgt_dict) 13 | 14 | @classmethod 15 | def load_dictionary(cls, filename): 16 | return BertDictionary.load_from_file(filename) 17 | 18 | def max_positions(self): 19 | """Return the max sentence length allowed by the task.""" 20 | return (self.args.max_source_positions, self.args.max_target_positions) 21 | -------------------------------------------------------------------------------- /MASS-supNMT/README.md: -------------------------------------------------------------------------------- 1 | # MASS with Supervised Pre-training 2 | 3 | We implement MASS on [fairseq](https://github.com/pytorch/fairseq), in order to support the pre-training and fine-tuning for large scale supervised tasks, such as neural machine translation, text summarization, grammatical error correction. Unsupervised pre-training usually works better in zero-resource or low-resource downstream tasks. However, there are plenty of supervised data in these tasks, which brings challenges for conventional unsupervised pre-training. Therefore, we design new pre-training loss to support large scale supervised tasks. 4 | 5 | We extend the MASS to supervised setting where the supervised sentence pair (X, Y) is leveraged for pre-training. The sentence X is masked and feed into the encoder, and the decoder predicts the whole sentence Y. Some discret tokens in the decoder input are also masked, to encourage the decoder to extract more informaiton from the encoder side. 6 | ![img](archi_mass_sup_md.png) 7 | 8 | During pre-training, we combine the orignal MASS pre-training loss and the new supervised pre-training loss together. During fine-tuning, we directly use supervised sentence pairs to fine-tune the pre-trained model. 9 | 10 | MASS on fairseq contains the following codes: 11 | * [Neural Machine Translation](#neural-machine-translation) 12 | * [Text Summarization](#text-summarization) 13 | * [Grammatical Error Correction](#grammatical-error-correction) 14 | 15 | 16 | 17 | ## Prerequisites 18 | After download the repository, you need to install `fairseq` by `pip`: 19 | ``` 20 | pip install fairseq==0.7.1 21 | ``` 22 | 23 | 24 | ## Neural Machine Translation 25 | 26 | | Languages | Pre-trained Model | BPE codes | English-Dict | Chinese-Dict | 27 | |:-----------:|:-----------------:| :---------:| :------------:| :------------:| 28 | |En - Zh | [MODEL](https://modelrelease.blob.core.windows.net/mass/zhen_mass_pre-training.pt) | [CODE](https://modelrelease.blob.core.windows.net/mass/bpecode.zip) | [VOCAB](https://modelrelease.blob.core.windows.net/mass/dict.en.txt) | [VOCAB](https://modelrelease.blob.core.windows.net/mass/dict.zh.txt) 29 | 30 | We provide an example of how to pre-train and fine-tune on WMT English<->Chinese (En<->Zh) translation. 31 | 32 | 33 | ### Data Ready 34 | We first prepare the monolingual and bilingual sentences for Chinese and English respectively. The data directory looks like: 35 | 36 | ``` 37 | - data/ 38 | ├─ mono/ 39 | | ├─ train.en 40 | | ├─ train.zh 41 | | ├─ valid.en 42 | | ├─ valid.zh 43 | | ├─ dict.en.txt 44 | | └─ dict.zh.txt 45 | └─ para/ 46 | ├─ train.en 47 | ├─ train.zh 48 | ├─ valid.en 49 | ├─ valid.zh 50 | ├─ dict.en.txt 51 | └─ dict.zh.txt 52 | ``` 53 | The files under `mono` are monolingual data, while under `para` are bilingual data. `dict.en(zh).txt` in different directory should be identical. The dictionary for different language can be different. Running the following command can generate the binarized data: 54 | 55 | ``` 56 | # Ensure the output directory exists 57 | data_dir=data/ 58 | mono_data_dir=$data_dir/mono/ 59 | para_data_dir=$data_dir/para/ 60 | save_dir=$data_dir/processed/ 61 | 62 | # set this relative path of MASS in your server 63 | user_dir=mass 64 | 65 | mkdir -p $data_dir $save_dir $mono_data_dir $para_data_dir 66 | 67 | 68 | # Generate Monolingual Data 69 | for lg in en zh 70 | do 71 | 72 | fairseq-preprocess \ 73 | --task cross_lingual_lm \ 74 | --srcdict $mono_data_dir/dict.$lg.txt \ 75 | --only-source \ 76 | --trainpref $mono_data_dir/train --validpref $mono_data_dir/valid \ 77 | --destdir $save_dir \ 78 | --workers 20 \ 79 | --source-lang $lg 80 | 81 | # Since we only have a source language, the output file has a None for the 82 | # target language. Remove this 83 | 84 | for stage in train valid 85 | do 86 | mv $save_dir/$stage.$lg-None.$lg.bin $save_dir/$stage.$lg.bin 87 | mv $save_dir/$stage.$lg-None.$lg.idx $save_dir/$stage.$lg.idx 88 | done 89 | done 90 | 91 | # Generate Bilingual Data 92 | fairseq-preprocess \ 93 | --user-dir $mass_dir \ 94 | --task xmasked_seq2seq \ 95 | --source-lang en --target-lang zh \ 96 | --trainpref $para_data_dir/train --validpref $para_data_dir/valid \ 97 | --destdir $save_dir \ 98 | --srcdict $para_data_dir/dict.en.txt \ 99 | --tgtdict $para_data_dir/dict.zh.txt 100 | ``` 101 | 102 | 103 | ### Pre-training 104 | We provide a simple demo code to demonstrate how to deploy mass pre-training. 105 | ``` 106 | save_dir=checkpoints/mass/pre-training/ 107 | user_dir=mass 108 | data_dir=data/processed/ 109 | 110 | mkdir -p $save_dir 111 | 112 | fairseq-train $data_dir \ 113 | --user-dir $user_dir \ 114 | --save-dir $save_dir \ 115 | --task xmasked_seq2seq \ 116 | --source-langs en,zh \ 117 | --target-langs en,zh \ 118 | --langs en,zh \ 119 | --arch xtransformer \ 120 | --mass_steps en-en,zh-zh \ 121 | --memt_steps en-zh,zh-en \ 122 | --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ 123 | --lr-scheduler inverse_sqrt --lr 0.00005 --min-lr 1e-09 \ 124 | --criterion label_smoothed_cross_entropy \ 125 | --max-tokens 4096 \ 126 | --dropout 0.1 --relu-dropout 0.1 --attention-dropout 0.1 \ 127 | --max-update 100000 \ 128 | --share-decoder-input-output-embed \ 129 | --valid-lang-pairs en-zh \ 130 | ``` 131 | We also provide a pre-training [script](run_mass_enzh.sh) which is used for our released model. 132 | 133 | ### Fine-tuning 134 | After pre-training stage, we fine-tune the model on bilingual sentence pairs: 135 | ``` 136 | data_dir=data/processed 137 | save_dir=checkpoints/mass/fine_tune/ 138 | user_dir=mass 139 | model=checkpoint/mass/pre-training/checkpoint_last.pt # The path of pre-trained model 140 | 141 | mkdir -p $save_dir 142 | 143 | fairseq-train $data_dir \ 144 | --user-dir $user_dir \ 145 | --task xmasked_seq2seq \ 146 | --source-langs zh --target-langs en \ 147 | --langs en,zh \ 148 | --arch xtransformer \ 149 | --mt_steps zh-en \ 150 | --save-dir $save_dir \ 151 | --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ 152 | --lr-scheduler inverse_sqrt --lr-shrink 0.5 --lr 0.00005 --min-lr 1e-09 \ 153 | --criterion label_smoothed_cross_entropy \ 154 | --max-tokens 4096 \ 155 | --max-update 100000 --max-epoch 50 \ 156 | --dropout 0.1 --relu-dropout 0.1 --attention-dropout 0.1 \ 157 | --share-decoder-input-output-embed \ 158 | --valid-lang-pairs zh-en \ 159 | --reload_checkpoint $model 160 | ``` 161 | We also provide a fine-tuning [script](ft_mass_enzh.sh) which is used for our pre-trained model. 162 | 163 | ### Inference 164 | After the fine-tuning stage, you can generate translation results by using the below [script](translate.sh): 165 | ``` 166 | model=checkpoints/mass/fine_tune/checkpoint_best.pt 167 | data_dir=data/processed 168 | user_dir=mass 169 | 170 | fairseq-generate $data_dir \ 171 | --user-dir $user_dir \ 172 | -s zh -t en \ 173 | --langs en,zh \ 174 | --source-langs zh --target-langs en \ 175 | --mt_steps zh-en \ 176 | --gen-subset valid \ 177 | --task xmasked_seq2seq \ 178 | --path $model \ 179 | --beam 5 --remove-bpe 180 | ``` 181 | 182 | ## Text Summarization 183 | 184 | ### Data Ready 185 | Download [CNN/Daily dataset](https://github.com/abisee/cnn-dailymail) and use Stanford CoreNLP to tokenize the data. We truncate the length of article as 400 tokens and use BPE to process data. 186 | We use the similar data process pipeline as NMT to generate the binarized data. 187 | 188 | ### Pre-training 189 | Here is a demo code about how to run pre-training in text summarization: 190 | ``` 191 | save_dir=checkpoints/mass/pre-training/ 192 | user_dir=mass 193 | data_dir=data/processed/ 194 | 195 | mkdir -p $save_dir 196 | 197 | fairseq-train $data_dir \ 198 | --user-dir $user_dir \ 199 | --save-dir $save_dir \ 200 | --task xmasked_seq2seq \ 201 | --source-langs ar,ti \ 202 | --target-langs ar,ti \ 203 | --langs ar,ti \ 204 | --arch xtransformer \ 205 | --mass_steps ar-ar,ti-ti \ 206 | --memt_steps ar-ti \ 207 | --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ 208 | --lr-scheduler inverse_sqrt --lr 0.0001 --min-lr 1e-09 \ 209 | --criterion label_smoothed_cross_entropy \ 210 | --max-tokens 4096 \ 211 | --dropout 0.1 --relu-dropout 0.1 --attention-dropout 0.1 \ 212 | --max-update 300000 \ 213 | --share-decoder-input-output-embed \ 214 | --valid-lang-pairs ar-ti \ 215 | --word_mask 0.15 216 | ``` 217 | Our experiments are still ongoing, we will summarize a better experiment setting in the future. 218 | 219 | ## Grammatical Error Correction 220 | To be updated soon 221 | 222 | ## Paper 223 | Paper for supervised pre-training will be available soon 224 | -------------------------------------------------------------------------------- /MASS-supNMT/archi_mass_sup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MASS/779f22fc47c8a256d8bc04826ebe1c8307063cbe/MASS-supNMT/archi_mass_sup.png -------------------------------------------------------------------------------- /MASS-supNMT/archi_mass_sup_md.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MASS/779f22fc47c8a256d8bc04826ebe1c8307063cbe/MASS-supNMT/archi_mass_sup_md.png -------------------------------------------------------------------------------- /MASS-supNMT/ft_mass_enzh.sh: -------------------------------------------------------------------------------- 1 | data_dir=data/processed 2 | save_dir=checkpoints/mass/fine-tune 3 | user_dir=mass 4 | 5 | model=zhen_mass_pre-training.pt 6 | 7 | seed=1234 8 | max_tokens=2048 9 | update_freq=1 10 | dropout=0.1 11 | attention_heads=16 12 | embed_dim=1024 13 | ffn_embed_dim=4096 14 | encoder_layers=10 15 | decoder_layers=6 16 | 17 | mkdir -p $save_dir 18 | 19 | fairseq-train $data_dir \ 20 | --user-dir $user_dir \ 21 | --save-dir $save_dir \ 22 | --task xmasked_seq2seq \ 23 | --source-langs zh --target-langs en \ 24 | --langs en,zh \ 25 | --arch xtransformer \ 26 | --mt_steps zh-en \ 27 | --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ 28 | --lr-scheduler inverse_sqrt --lr 0.00005 --min-lr 1e-09 \ 29 | --criterion label_smoothed_cross_entropy \ 30 | --lm-bias --lazy-load --seed ${seed} \ 31 | --log-format json \ 32 | --max-tokens ${max_tokens} --update-freq ${update_freq} \ 33 | --encoder-normalize-before --decoder-normalize-before \ 34 | --dropout ${dropout} --relu-dropout 0.1 --attention-dropout 0.1 \ 35 | --decoder-attention-heads ${attention_heads} --encoder-attention-heads ${attention_heads} \ 36 | --decoder-embed-dim ${embed_dim} --encoder-embed-dim ${embed_dim} \ 37 | --decoder-ffn-embed-dim ${ffn_embed_dim} --encoder-ffn-embed-dim ${ffn_embed_dim} \ 38 | --encoder-layers ${encoder_layers} --decoder-layers ${decoder_layers} \ 39 | --max-update 1000000 --max-epoch 50 \ 40 | --keep-interval-updates 100 --save-interval-updates 3000 --log-interval 50 \ 41 | --share-decoder-input-output-embed \ 42 | --valid-lang-pairs zh-en \ 43 | --reload-checkpoint $model \ 44 | --ddp-backend=no_c10d 45 | -------------------------------------------------------------------------------- /MASS-supNMT/generate_enzh_data.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | 4 | # Ensure the output directory exists 5 | data_dir=data/ 6 | mono_data_dir=$data_dir/mono/ 7 | para_data_dir=$data_dir/para/ 8 | save_dir=$data_dir/processed/ 9 | 10 | # set this relative path of MASS in your server 11 | user_dir=mass 12 | 13 | mkdir -p $data_dir $save_dir $mono_data_dir $para_data_dir 14 | 15 | 16 | # Generate Monolingual Data 17 | for lg in en zh 18 | do 19 | 20 | fairseq-preprocess \ 21 | --task cross_lingual_lm \ 22 | --srcdict $mono_data_dir/dict.$lg.txt \ 23 | --only-source \ 24 | --trainpref $mono_data_dir/train \ 25 | --validpref $mono_data_dir/valid \ 26 | --destdir $save_dir \ 27 | --workers 20 \ 28 | --source-lang $lg 29 | 30 | for stage in train valid 31 | do 32 | mv $save_dir/$stage.$lg-None.$lg.bin $save_dir/$stage.$lg.bin 33 | mv $save_dir/$stage.$lg-None.$lg.idx $save_dir/$stage.$lg.idx 34 | done 35 | done 36 | 37 | # Generate Bilingual Data 38 | fairseq-preprocess \ 39 | --user-dir $user_dir \ 40 | --task xmasked_seq2seq \ 41 | --source-lang en --target-lang zh \ 42 | --trainpref $para_data_dir/train --validpref $para_data_dir/valid \ 43 | --destdir $save_dir \ 44 | --srcdict $para_data_dir/dict.en.txt \ 45 | --tgtdict $para_data_dir/dict.zh.txt 46 | -------------------------------------------------------------------------------- /MASS-supNMT/mass/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | # 4 | 5 | from . import xmasked_seq2seq 6 | from . import xtransformer 7 | -------------------------------------------------------------------------------- /MASS-supNMT/mass/masked_language_pair_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | # 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from fairseq import utils 9 | 10 | from fairseq.data import data_utils, FairseqDataset 11 | 12 | 13 | class MaskedLanguagePairDataset(FairseqDataset): 14 | """Masked Language Pair dataset (only support for single language) 15 | [x1, x2, x3, x4, x5] 16 | | 17 | V 18 | src: [x1, _, _, x4, x5] 19 | tgt: [x1, x2] => [x2, x3] 20 | """ 21 | 22 | def __init__( 23 | self, src, sizes, vocab, 24 | left_pad_source=True, left_pad_target=False, 25 | max_source_positions=1024, max_target_positions=1024, 26 | shuffle=True, lang_id=None, ratio=None, training=True, 27 | pred_probs=None, 28 | ): 29 | self.src = src 30 | self.sizes = np.array(sizes) 31 | self.vocab = vocab 32 | self.left_pad_source = left_pad_source 33 | self.left_pad_target = left_pad_target 34 | self.max_source_positions = max_source_positions 35 | self.max_target_positions = max_target_positions 36 | self.shuffle = shuffle 37 | self.lang_id = lang_id 38 | self.ratio = ratio 39 | self.training = training 40 | self.pred_probs = pred_probs 41 | 42 | def __getitem__(self, index): 43 | if self.training is False: 44 | source = [self.vocab.eos_index] + self.src[index].tolist() 45 | target, output = source[:-1], source[1:] 46 | else: 47 | src_item = self.src[index] 48 | src_list = [self.vocab.eos_index] + src_item.tolist() 49 | 50 | start, length = self.mask_interval(len(src_list)) 51 | output = src_list[start : start + length].copy() 52 | _target = src_list[start - 1 : start + length - 1].copy() 53 | 54 | target = [] 55 | for w in _target: 56 | target.append(self.random_word(w, self.pred_probs)) 57 | 58 | source = [] 59 | for i, w in enumerate(src_list[1:]): # to keep consistent with finetune 60 | if i >= start and i <= start + length: 61 | w = self.mask_word(w) 62 | if w is not None: 63 | source.append(w) 64 | 65 | assert len(target) == len(output) 66 | return { 67 | 'id': index, 68 | 'source': torch.LongTensor(source), 69 | 'target': torch.LongTensor(target), 70 | 'output': torch.LongTensor(output), 71 | } 72 | 73 | def __len__(self): 74 | return len(self.src) 75 | 76 | def _collate(self, samples, pad_idx, eos_idx, segment_label): 77 | 78 | def merge(key, left_pad): 79 | return data_utils.collate_tokens( 80 | [s[key] for s in samples], 81 | pad_idx, eos_idx, left_pad, 82 | ) 83 | 84 | id = torch.LongTensor([s['id'] for s in samples]) 85 | src_tokens = merge('source', left_pad=self.left_pad_source) 86 | src_lengths = torch.LongTensor([s['source'].numel() for s in samples]) 87 | src_lengths, sort_order = src_lengths.sort(descending=True) 88 | id = id.index_select(0, sort_order) 89 | src_tokens = src_tokens.index_select(0, sort_order) 90 | 91 | #ntokens = sum(len(s['source']) for s in samples) 92 | ntokens = sum(len(s['target']) for s in samples) 93 | 94 | prev_output_tokens = merge('target', left_pad=self.left_pad_target) 95 | prev_output_tokens = prev_output_tokens.index_select(0, sort_order) 96 | 97 | target = merge('output', left_pad=self.left_pad_target) 98 | target = target.index_select(0, sort_order) 99 | 100 | batch = { 101 | 'id': id, 102 | 'nsentences': len(samples), 103 | 'ntokens': ntokens, 104 | 'net_input': { 105 | 'src_tokens': src_tokens, 106 | 'src_lengths': src_lengths, 107 | }, 108 | 'target': target, 109 | } 110 | batch['net_input']['prev_output_tokens'] = prev_output_tokens 111 | return batch 112 | 113 | 114 | def collater(self, samples): 115 | return self._collate( 116 | samples, 117 | pad_idx=self.vocab.pad(), 118 | eos_idx=self.vocab.eos(), 119 | segment_label=self.lang_id, 120 | ) 121 | 122 | def get_dummy_batch( 123 | self, 124 | num_tokens, 125 | max_positions, 126 | tgt_len=128 127 | ): 128 | if isinstance(max_positions, float) or isinstance(max_positions, int): 129 | tgt_len = min(tgt_len, max_positions) 130 | source = self.vocab.dummy_sentence(tgt_len) 131 | target = self.vocab.dummy_sentence(tgt_len) 132 | bsz = max(num_tokens // tgt_len, 1) 133 | return self.collater([ 134 | { 135 | 'id': i, 136 | 'source': source, 137 | 'target': target, 138 | 'output': target, 139 | } 140 | for i in range(bsz) 141 | ]) 142 | 143 | def num_tokens(self, index): 144 | return self.sizes[index] 145 | 146 | def ordered_indices(self): 147 | """Return an ordered list of indices. Batches will be constructed based 148 | on this order.""" 149 | if self.shuffle: 150 | indices = np.random.permutation(len(self)) 151 | else: 152 | indices = np.arange(len(self)) 153 | return indices[np.argsort(self.sizes[indices], kind='mergesort')] 154 | 155 | @property 156 | def supports_prefetch(self): 157 | return ( 158 | getattr(self.src, 'supports_prefetch', False) and getattr(self.src, 'supports_prefetch', False) 159 | ) 160 | 161 | def prefetch(self, indices): 162 | self.src.prefetch(indices) 163 | 164 | def mask_start(self, end): 165 | p = np.random.random() 166 | if p >= 0.8: 167 | return 1 168 | elif p >= 0.6: 169 | return end 170 | else: 171 | return np.random.randint(1, end) 172 | 173 | def mask_word(self, w): 174 | p = np.random.random() 175 | if p >= 0.2: 176 | return self.vocab.mask_index 177 | elif p >= 0.1: 178 | return np.random.randint(self.vocab.nspecial, len(self.vocab)) 179 | else: 180 | return w 181 | 182 | def random_word(self, w, pred_probs): 183 | cands = [self.vocab.mask_index, np.random.randint(self.vocab.nspecial, len(self.vocab)), w] 184 | prob = torch.multinomial(self.pred_probs, 1, replacement=True) 185 | return cands[prob] 186 | 187 | def mask_interval(self, l): 188 | mask_length = round(l * self.ratio) 189 | mask_length = max(1, mask_length) 190 | mask_start = self.mask_start(l - mask_length) 191 | return mask_start, mask_length 192 | 193 | def size(self, index): 194 | return (self.sizes[index], int(round(self.sizes[index] * self.ratio))) 195 | -------------------------------------------------------------------------------- /MASS-supNMT/mass/noisy_language_pair_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | # 4 | 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from fairseq import utils 10 | 11 | from fairseq.data import data_utils, FairseqDataset 12 | 13 | 14 | def collate( 15 | samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=False, 16 | input_feeding=True 17 | ): 18 | if len(samples) == 0: 19 | return {} 20 | 21 | def merge(key, left_pad, move_eos_to_beginning=False): 22 | return data_utils.collate_tokens( 23 | [s[key] for s in samples], 24 | pad_idx, eos_idx, left_pad, move_eos_to_beginning, 25 | ) 26 | 27 | id = torch.LongTensor([s['id'] for s in samples]) 28 | src_tokens = merge('source', left_pad=left_pad_source) 29 | # sort by descending source length 30 | src_lengths = torch.LongTensor([s['source'].numel() for s in samples]) 31 | src_lengths, sort_order = src_lengths.sort(descending=True) 32 | id = id.index_select(0, sort_order) 33 | src_tokens = src_tokens.index_select(0, sort_order) 34 | 35 | prev_output_tokens = None 36 | target = None 37 | if samples[0].get('target', None) is not None: 38 | target = merge('target', left_pad=left_pad_target) 39 | target = target.index_select(0, sort_order) 40 | ntokens = sum(len(s['target']) for s in samples) 41 | 42 | if input_feeding: 43 | # we create a shifted version of targets for feeding the 44 | # previous output token(s) into the next decoder step 45 | prev_output_tokens = merge( 46 | 'target', 47 | left_pad=left_pad_target, 48 | move_eos_to_beginning=True, 49 | ) 50 | prev_output_tokens = prev_output_tokens.index_select(0, sort_order) 51 | else: 52 | ntokens = sum(len(s['source']) for s in samples) 53 | 54 | batch = { 55 | 'id': id, 56 | 'nsentences': len(samples), 57 | 'ntokens': ntokens, 58 | 'net_input': { 59 | 'src_tokens': src_tokens, 60 | 'src_lengths': src_lengths, 61 | }, 62 | 'target': target, 63 | } 64 | if prev_output_tokens is not None: 65 | batch['net_input']['prev_output_tokens'] = prev_output_tokens 66 | return batch 67 | 68 | 69 | def generate_dummy_batch(num_tokens, collate_fn, src_vocab, tgt_vocab, src_len=128, tgt_len=128): 70 | """Return a dummy batch with a given number of tokens.""" 71 | bsz = num_tokens // max(src_len, tgt_len) 72 | return collate_fn([ 73 | { 74 | 'id': i, 75 | 'source': src_vocab.dummy_sentence(src_len), 76 | 'target': tgt_vocab.dummy_sentence(tgt_len), 77 | 'output': tgt_vocab.dummy_sentence(tgt_len), 78 | } 79 | for i in range(bsz) 80 | ]) 81 | 82 | 83 | class NoisyLanguagePairDataset(FairseqDataset): 84 | """ 85 | [x1, x2, x3, x4, x5] [y1, y2, y3, y4, y5] 86 | | 87 | V 88 | [x1, _, x3, _, x5] [y1, y2, y3, y4, y5] 89 | """ 90 | def __init__( 91 | self, src, src_sizes, tgt, tgt_sizes, src_vocab, tgt_vocab, 92 | src_lang_id, tgt_lang_id, 93 | left_pad_source=True, left_pad_target=False, 94 | max_source_positions=1024, max_target_positions=1024, 95 | shuffle=True, input_feeding=True, ratio=0.50, 96 | pred_probs=None, 97 | ): 98 | self.src = src 99 | self.tgt = tgt 100 | self.src_sizes = np.array(src_sizes) 101 | self.tgt_sizes = np.array(tgt_sizes) 102 | self.src_lang_id = src_lang_id 103 | self.tgt_lang_id = tgt_lang_id 104 | self.src_vocab = src_vocab 105 | self.tgt_vocab = tgt_vocab 106 | self.left_pad_source = left_pad_source 107 | self.left_pad_target = left_pad_target 108 | self.max_source_positions = max_source_positions 109 | self.max_target_positions = max_target_positions 110 | self.shuffle = shuffle 111 | self.input_feeding = input_feeding 112 | self.ratio = ratio 113 | self.pred_probs = pred_probs 114 | 115 | def __getitem__(self, index): 116 | tgt_item = self.tgt[index] 117 | src_item = self.src[index] 118 | 119 | src_list = src_item.tolist() 120 | source = [] 121 | for i, w in enumerate(src_list): 122 | p = np.random.random() 123 | if i > 0 and i < len(src_list) - 1 and p <= self.ratio: 124 | source.append(self.src_vocab.mask_index) 125 | else: 126 | source.append(w) 127 | 128 | return { 129 | 'id': index, 130 | 'source': torch.LongTensor(source), 131 | 'target': tgt_item, 132 | } 133 | 134 | def __len__(self): 135 | return len(self.src) 136 | 137 | def collater(self, samples): 138 | """Merge a list of samples to form a mini-batch. 139 | 140 | Args: 141 | samples (List[dict]): samples to collate 142 | 143 | Returns: 144 | dict: a mini-batch with the following keys: 145 | 146 | - `id` (LongTensor): example IDs in the original input order 147 | - `ntokens` (int): total number of tokens in the batch 148 | - `net_input` (dict): the input to the Model, containing keys: 149 | 150 | - `src_tokens` (LongTensor): a padded 2D Tensor of tokens in 151 | the source sentence of shape `(bsz, src_len)`. Padding will 152 | appear on the left if *left_pad_source* is ``True``. 153 | - `src_lengths` (LongTensor): 1D Tensor of the unpadded 154 | lengths of each source sentence of shape `(bsz)` 155 | - `prev_output_tokens` (LongTensor): a padded 2D Tensor of 156 | tokens in the target sentence, shifted right by one position 157 | for input feeding/teacher forcing, of shape `(bsz, 158 | tgt_len)`. This key will not be present if *input_feeding* 159 | is ``False``. Padding will appear on the left if 160 | *left_pad_target* is ``True``. 161 | 162 | - `target` (LongTensor): a padded 2D Tensor of tokens in the 163 | target sentence of shape `(bsz, tgt_len)`. Padding will appear 164 | on the left if *left_pad_target* is ``True``. 165 | """ 166 | return collate( 167 | samples, pad_idx=self.src_vocab.pad(), eos_idx=self.src_vocab.eos(), 168 | left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target, 169 | input_feeding=self.input_feeding 170 | ) 171 | 172 | def get_dummy_batch(self, num_tokens, max_positions, src_len=128, tgt_len=128): 173 | """Return a dummy batch with a given number of tokens.""" 174 | src_len, tgt_len = utils.resolve_max_positions( 175 | (src_len, tgt_len), 176 | max_positions, 177 | (self.max_source_positions, self.max_target_positions), 178 | ) 179 | return generate_dummy_batch(num_tokens, self.collater, self.src_vocab, self.tgt_vocab, src_len, tgt_len) 180 | 181 | def num_tokens(self, index): 182 | """Return the number of tokens in a sample. This value is used to 183 | enforce ``--max-tokens`` during batching.""" 184 | return max(self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0) 185 | 186 | def size(self, index): 187 | """Return an example's size as a float or tuple. This value is used when 188 | filtering a dataset with ``--max-positions``.""" 189 | return (self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0) 190 | 191 | def ordered_indices(self): 192 | """Return an ordered list of indices. Batches will be constructed based 193 | on this order.""" 194 | if self.shuffle: 195 | indices = np.random.permutation(len(self)) 196 | else: 197 | indices = np.arange(len(self)) 198 | if self.tgt_sizes is not None: 199 | indices = indices[np.argsort(self.tgt_sizes[indices], kind='mergesort')] 200 | return indices[np.argsort(self.src_sizes[indices], kind='mergesort')] 201 | 202 | @property 203 | def supports_prefetch(self): 204 | return ( 205 | getattr(self.src, 'supports_prefetch', False) 206 | and (getattr(self.tgt, 'supports_prefetch', False) or self.tgt is None) 207 | ) 208 | 209 | def prefetch(self, indices): 210 | self.src.prefetch(indices) 211 | if self.tgt is not None: 212 | self.tgt.prefetch(indices) 213 | -------------------------------------------------------------------------------- /MASS-supNMT/mass/xtransformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | # 4 | 5 | from collections import OrderedDict 6 | 7 | from fairseq import utils 8 | from fairseq.models import FairseqMultiModel, register_model, register_model_architecture, BaseFairseqModel 9 | 10 | from fairseq.models.transformer import ( 11 | base_architecture, 12 | Embedding, 13 | TransformerEncoder, 14 | TransformerDecoder, 15 | TransformerModel, 16 | ) 17 | 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | 22 | class XTransformerEncoder(TransformerEncoder): 23 | 24 | def __init__(self, args, dictionary, embed_tokens): 25 | super().__init__(args, dictionary, embed_tokens) 26 | self.mask_idx = dictionary.mask_index 27 | 28 | def forward(self, src_tokens, src_lengths): 29 | x = self.embed_scale * self.embed_tokens(src_tokens) 30 | if self.embed_positions is not None: 31 | x += self.embed_positions(src_tokens) 32 | x = F.dropout(x, p=self.dropout, training=self.training) 33 | 34 | # B x T x C -> T x B x C 35 | x = x.transpose(0, 1) 36 | 37 | # compute padding mask 38 | encoder_padding_mask = src_tokens.eq(self.padding_idx) | src_tokens.eq(self.mask_idx) 39 | if not encoder_padding_mask.any(): 40 | encoder_padding_mask = None 41 | 42 | # encoder layers 43 | for layer in self.layers: 44 | x = layer(x, encoder_padding_mask) 45 | 46 | if self.layer_norm: 47 | x = self.layer_norm(x) 48 | 49 | return { 50 | 'encoder_out': x, # T x B x C 51 | 'encoder_padding_mask': encoder_padding_mask, # B x T 52 | } 53 | 54 | 55 | class XTransformerDecoder(TransformerDecoder): 56 | 57 | def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): 58 | super().__init__(args, dictionary, embed_tokens, no_encoder_attn) 59 | 60 | def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, positions=None): 61 | # embed positions 62 | positions = self.embed_positions( 63 | prev_output_tokens, 64 | incremental_state=incremental_state, 65 | positions=positions, 66 | ) if self.embed_positions is not None else None 67 | 68 | if incremental_state is not None: 69 | prev_output_tokens = prev_output_tokens[:, -1:] 70 | if positions is not None: 71 | positions = positions[:, -1:] 72 | 73 | # embed tokens and positions 74 | x = self.embed_scale * self.embed_tokens(prev_output_tokens) 75 | 76 | if self.project_in_dim is not None: 77 | x = self.project_in_dim(x) 78 | 79 | if positions is not None: 80 | x += positions 81 | x = F.dropout(x, p=self.dropout, training=self.training) 82 | 83 | # B x T x C -> T x B x C 84 | x = x.transpose(0, 1) 85 | attn = None 86 | 87 | inner_states = [x] 88 | 89 | # decoder layers 90 | for layer in self.layers: 91 | x, attn = layer( 92 | x, 93 | encoder_out['encoder_out'] if encoder_out is not None else None, 94 | encoder_out['encoder_padding_mask'] if encoder_out is not None else None, 95 | incremental_state, 96 | self_attn_mask=self.buffered_future_mask(x) if incremental_state is None else None, 97 | ) 98 | inner_states.append(x) 99 | 100 | if self.layer_norm: 101 | x = self.layer_norm(x) 102 | 103 | # T x B x C -> B x T x C 104 | x = x.transpose(0, 1) 105 | 106 | if self.project_out_dim is not None: 107 | x = self.project_out_dim(x) 108 | 109 | if self.adaptive_softmax is None: 110 | # project back to size of vocabulary 111 | if self.share_input_output_embed: 112 | x = F.linear(x, self.embed_tokens.weight) 113 | else: 114 | x = F.linear(x, self.embed_out) 115 | 116 | return x, {'attn': attn, 'inner_states': inner_states} 117 | 118 | 119 | @register_model('xtransformer') 120 | class XTransformerModel(BaseFairseqModel): 121 | def __init__(self, encoders, decoders, eval_lang_pair=None): 122 | super().__init__() 123 | self.encoders = nn.ModuleDict(encoders) 124 | self.decoders = nn.ModuleDict(decoders) 125 | self.tgt_key = None 126 | if eval_lang_pair is not None: 127 | self.source_lang = eval_lang_pair.split('-')[0] 128 | self.target_lang = eval_lang_pair.split('-')[1] 129 | 130 | def get_normalized_probs(self, net_output, log_probs, sample=None): 131 | """Get normalized probabilities (or log probs) from a net's output.""" 132 | if hasattr(self, 'decoder'): 133 | return self.decoder.get_normalized_probs(net_output, log_probs, sample) 134 | elif hasattr(self, 'decoders'): 135 | return self.decoders[self.tgt_key].get_normalized_probs(net_output, log_probs, sample) 136 | elif torch.is_tensor(net_output): 137 | logits = net_output.float() 138 | if log_probs: 139 | return F.log_softmax(logits, dim=-1) 140 | else: 141 | return F.softmax(logits, dim=-1) 142 | raise NotImplementedError 143 | 144 | def max_positions(self): 145 | return None 146 | 147 | def max_decoder_positions(self): 148 | return min(decoder.max_positions() for decoder in self.decoders.values()) 149 | 150 | def forward(self, src_tokens, src_lengths, prev_output_tokens, src_key, tgt_key, positions=None): 151 | encoder_out = self.encoders[src_key](src_tokens, src_lengths) 152 | decoder_out = self.decoders[tgt_key](prev_output_tokens, encoder_out, positions=positions) 153 | self.tgt_key = tgt_key 154 | return decoder_out 155 | 156 | def add_args(parser): 157 | TransformerModel.add_args(parser) 158 | parser.add_argument('--share-encoders', action='store_true', 159 | help='share encoders across languages') 160 | parser.add_argument('--share-decoders', action='store_true', 161 | help='share decoders across languages') 162 | 163 | @classmethod 164 | def build_model(cls, args, task): 165 | langs = [lang for lang in args.langs] 166 | 167 | embed_tokens = {} 168 | for lang in langs: 169 | if len(embed_tokens) == 0 or args.share_all_embeddings is False: 170 | embed_token = build_embedding( 171 | task.dicts[lang], args.encoder_embed_dim, args.encoder_embed_path 172 | ) 173 | embed_tokens[lang] = embed_token 174 | else: 175 | embed_tokens[lang] = embed_tokens[langs[0]] 176 | 177 | args.share_decoder_input_output_embed = True 178 | encoders, decoders = {}, {} 179 | 180 | for lang in langs: 181 | encoder_embed_tokens = embed_tokens[lang] 182 | decoder_embed_tokens = encoder_embed_tokens 183 | if lang in args.source_langs: 184 | encoder = XTransformerEncoder(args, task.dicts[lang], encoder_embed_tokens) 185 | encoders[lang] = encoder 186 | if lang in args.target_langs: 187 | decoder = XTransformerDecoder(args, task.dicts[lang], decoder_embed_tokens) 188 | decoders[lang] = decoder 189 | return XTransformerModel(encoders, decoders, args.eval_lang_pair) 190 | 191 | @property 192 | def decoder(self): 193 | return self.decoders[self.target_lang] 194 | 195 | @property 196 | def encoder(self): 197 | return self.encoders[self.source_lang] 198 | 199 | 200 | @register_model_architecture('xtransformer', 'xtransformer') 201 | def base_x_transformer(args): 202 | base_architecture(args) 203 | 204 | 205 | def build_embedding(dictionary, embed_dim, path=None): 206 | num_embeddings = len(dictionary) 207 | padding_idx = dictionary.pad() 208 | emb = Embedding(num_embeddings, embed_dim, padding_idx) 209 | # if provided, load from preloaded dictionaries 210 | if path: 211 | embed_dict = utils.parse_embedding(path) 212 | utils.load_embedding(embed_dict, dictionary, emb) 213 | return emb 214 | -------------------------------------------------------------------------------- /MASS-supNMT/run_mass_enzh.sh: -------------------------------------------------------------------------------- 1 | data_dir=data/processed 2 | save_dir=checkpoints/mass/pretraining 3 | user_dir=mass 4 | 5 | seed=1234 6 | max_tokens=2048 # for 16GB GPUs 7 | update_freq=1 8 | dropout=0.1 9 | attention_heads=16 10 | embed_dim=1024 11 | ffn_embed_dim=4096 12 | encoder_layers=10 13 | decoder_layers=6 14 | word_mask=0.3 15 | 16 | mkdir -p $save_dir 17 | 18 | fairseq-train $data_dir \ 19 | --user-dir $user_dir \ 20 | --task xmasked_seq2seq \ 21 | --source-langs en,zh \ 22 | --target-langs en,zh \ 23 | --langs en,zh \ 24 | --arch xtransformer \ 25 | --mass_steps en-en,zh-zh \ 26 | --memt_steps en-zh,zh-en \ 27 | --save-dir $save_dir \ 28 | --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ 29 | --lr-scheduler inverse_sqrt --lr 0.00005 --min-lr 1e-09 \ 30 | --criterion label_smoothed_cross_entropy \ 31 | --lm-bias --lazy-load --seed ${seed} \ 32 | --log-format json \ 33 | --max-tokens ${max_tokens} --update-freq ${update_freq} \ 34 | --encoder-normalize-before --decoder-normalize-before \ 35 | --dropout ${dropout} --relu-dropout 0.1 --attention-dropout 0.1 \ 36 | --decoder-attention-heads ${attention_heads} --encoder-attention-heads ${attention_heads} \ 37 | --decoder-embed-dim ${embed_dim} --encoder-embed-dim ${embed_dim} \ 38 | --decoder-ffn-embed-dim ${ffn_embed_dim} --encoder-ffn-embed-dim ${ffn_embed_dim} \ 39 | --encoder-layers ${encoder_layers} --decoder-layers ${decoder_layers} \ 40 | --max-update 100000000 --max-epoch 50 \ 41 | --keep-interval-updates 100 --save-interval-updates 3000 --log-interval 50 \ 42 | --share-decoder-input-output-embed \ 43 | --valid-lang-pairs en-zh \ 44 | --word_mask ${word_mask} \ 45 | --ddp-backend=no_c10d 46 | -------------------------------------------------------------------------------- /MASS-supNMT/translate.sh: -------------------------------------------------------------------------------- 1 | MODEL=zhen_mass_pre-training.pt 2 | 3 | fairseq-generate ./data/processed \ 4 | -s zh -t en \ 5 | --user-dir mass \ 6 | --langs en,zh \ 7 | --source-langs zh --target-langs en \ 8 | --mt_steps zh-en \ 9 | --gen-subset valid \ 10 | --task xmasked_seq2seq \ 11 | --path $MODEL \ 12 | --beam 5 \ 13 | --remove-bpe 14 | -------------------------------------------------------------------------------- /MASS-unsupNMT/filter_noisy_data.py: -------------------------------------------------------------------------------- 1 | import re 2 | import argparse 3 | 4 | from langdetect import detect 5 | from polyglot.detect import Detector 6 | 7 | def get_parser(): 8 | parser = argparse.ArgumentParser(description="Remove noisy data") 9 | 10 | parser.add_argument("--input", type=str, 11 | help="The path of input file") 12 | parser.add_argument("--lang", type=str, 13 | help="The language of input file") 14 | parser.add_argument("--output", type=str, default=None, 15 | help="The path of output file") 16 | 17 | return parser 18 | 19 | def detect_exist_url(text): 20 | urls = re.findall('http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\), ]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', text) 21 | url1 = re.findall('http[s]?//(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\), ]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', text) 22 | return len(urls) > 0 or len(url1) > 0 23 | 24 | def detect_lang(text, lang): 25 | try: 26 | for i, l in enumerate(Detector(text, quiet=True).languages): 27 | if l.code == lang and i == 0: 28 | return True 29 | if detect(text) == lang: 30 | return True 31 | return False 32 | except: 33 | return False 34 | 35 | def main(): 36 | parser = get_parser() 37 | args = parser.parse_args() 38 | 39 | count = 0 40 | allcount = 0 41 | f = None 42 | if args.output is not None: 43 | f = open(args.output, 'w') 44 | with open(args.input, encoding='utf-8') as input_file: 45 | for line in input_file: 46 | allcount += 1 47 | line = line.strip() 48 | if detect_exist_url(line) is False: 49 | if detect_lang(line, args.lang) is True: 50 | count += 1 51 | if args.output is not None: 52 | f.write(line + '\n') 53 | #print(line) 54 | if allcount % 1000000 == 0: 55 | print("{} sentences processed".format(allcount), count) 56 | print(count, allcount) 57 | 58 | if __name__ == "__main__": 59 | main() 60 | 61 | -------------------------------------------------------------------------------- /MASS-unsupNMT/get-data-bilingual-enro-nmt.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | 4 | 5 | MAIN_PATH=$(pwd) 6 | 7 | N_MONO=10000000 # number of monolingual sentences for each language 8 | CODES=60000 # number of BPE codes 9 | N_THREADS=16 10 | 11 | SRC=en 12 | TGT=ro 13 | 14 | POSITIONAL=() 15 | while [[ $# -gt 0 ]] 16 | do 17 | key="$1" 18 | case $key in 19 | --reload_codes) 20 | RELOAD_CODES="$2"; shift 2;; 21 | --reload_vocab) 22 | RELOAD_VOCAB="$2"; shift 2;; 23 | *) 24 | POSITIONAL+=("$1") 25 | shift 26 | ;; 27 | esac 28 | done 29 | set -- "${POSITIONAL[@]}" 30 | 31 | # main paths 32 | MAIN_PATH=$PWD 33 | TOOLS_PATH=$PWD/tools 34 | DATA_PATH=$PWD/data 35 | MONO_PATH=$DATA_PATH/mono 36 | PARA_PATH=$DATA_PATH/para 37 | PROC_PATH=$DATA_PATH/processed/$SRC-$TGT 38 | 39 | # create paths 40 | mkdir -p $TOOLS_PATH 41 | mkdir -p $DATA_PATH 42 | mkdir -p $MONO_PATH 43 | mkdir -p $PARA_PATH 44 | mkdir -p $PROC_PATH 45 | 46 | # moses 47 | MOSES=$TOOLS_PATH/mosesdecoder 48 | REPLACE_UNICODE_PUNCT=$MOSES/scripts/tokenizer/replace-unicode-punctuation.perl 49 | NORM_PUNC=$MOSES/scripts/tokenizer/normalize-punctuation.perl 50 | REM_NON_PRINT_CHAR=$MOSES/scripts/tokenizer/remove-non-printing-char.perl 51 | TOKENIZER=$MOSES/scripts/tokenizer/tokenizer.perl 52 | INPUT_FROM_SGM=$MOSES/scripts/ems/support/input-from-sgm.perl 53 | 54 | # fastBPE 55 | FASTBPE_DIR=$TOOLS_PATH/fastBPE 56 | FASTBPE=$TOOLS_PATH/fastBPE/fast 57 | 58 | # Sennrich's WMT16 scripts for Romanian preprocessing 59 | WMT16_SCRIPTS=$TOOLS_PATH/wmt16-scripts 60 | NORMALIZE_ROMANIAN=$WMT16_SCRIPTS/preprocess/normalise-romanian.py 61 | REMOVE_DIACRITICS=$WMT16_SCRIPTS/preprocess/remove-diacritics.py 62 | 63 | # raw and tokenized files 64 | SRC_RAW=$MONO_PATH/$SRC/all.$SRC 65 | TGT_RAW=$MONO_PATH/$TGT/all.$TGT 66 | SRC_TOK=$SRC_RAW.tok 67 | TGT_TOK=$TGT_RAW.tok 68 | 69 | # BPE / vocab files 70 | BPE_CODES=$PROC_PATH/codes 71 | SRC_VOCAB=$PROC_PATH/vocab.$SRC 72 | TGT_VOCAB=$PROC_PATH/vocab.$TGT 73 | FULL_VOCAB=$PROC_PATH/vocab.$SRC-$TGT 74 | 75 | # train / valid / test monolingual BPE data 76 | SRC_TRAIN_BPE=$PROC_PATH/train.$SRC 77 | TGT_TRAIN_BPE=$PROC_PATH/train.$TGT 78 | SRC_VALID_BPE=$PROC_PATH/valid.$SRC 79 | TGT_VALID_BPE=$PROC_PATH/valid.$TGT 80 | SRC_TEST_BPE=$PROC_PATH/test.$SRC 81 | TGT_TEST_BPE=$PROC_PATH/test.$TGT 82 | 83 | # valid / test parallel BPE data 84 | PARA_SRC_TRAIN_BPE=$PROC_PATH/train.$SRC-$TGT.$SRC 85 | PARA_TGT_TRAIN_BPE=$PROC_PATH/train.$SRC-$TGT.$TGT 86 | PARA_SRC_VALID_BPE=$PROC_PATH/valid.$SRC-$TGT.$SRC 87 | PARA_TGT_VALID_BPE=$PROC_PATH/valid.$SRC-$TGT.$TGT 88 | PARA_SRC_TEST_BPE=$PROC_PATH/test.$SRC-$TGT.$SRC 89 | PARA_TGT_TEST_BPE=$PROC_PATH/test.$SRC-$TGT.$TGT 90 | 91 | # en-ro valid / test data 92 | PARA_SRC_TRAIN=$PARA_PATH/train.$SRC-$TGT.$SRC 93 | PARA_TGT_TRAIN=$PARA_PATH/train.$SRC-$TGT.$TGT 94 | PARA_SRC_VALID=$PARA_PATH/dev/newsdev2016-roen-ref.en 95 | PARA_TGT_VALID=$PARA_PATH/dev/newsdev2016-enro-ref.ro 96 | PARA_SRC_TEST=$PARA_PATH/dev/newstest2016-roen-ref.en 97 | PARA_TGT_TEST=$PARA_PATH/dev/newstest2016-enro-ref.ro 98 | 99 | # install tools 100 | ./install-tools.sh 101 | 102 | # 103 | # Download monolingual data 104 | # 105 | 106 | cd $MONO_PATH 107 | 108 | if [ "$SRC" == "en" -o "$TGT" == "en" ]; then 109 | echo "Downloading English monolingual data ..." 110 | mkdir -p $MONO_PATH/en 111 | cd $MONO_PATH/en 112 | wget -c http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2007.en.shuffled.gz 113 | wget -c http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2008.en.shuffled.gz 114 | wget -c http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2009.en.shuffled.gz 115 | wget -c http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2010.en.shuffled.gz 116 | # wget -c http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2011.en.shuffled.gz 117 | # wget -c http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2012.en.shuffled.gz 118 | # wget -c http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2013.en.shuffled.gz 119 | # wget -c http://www.statmt.org/wmt15/training-monolingual-news-crawl-v2/news.2014.en.shuffled.v2.gz 120 | # wget -c http://data.statmt.org/wmt16/translation-task/news.2015.en.shuffled.gz 121 | # wget -c http://data.statmt.org/wmt17/translation-task/news.2016.en.shuffled.gz 122 | # wget -c http://data.statmt.org/wmt18/translation-task/news.2017.en.shuffled.deduped.gz 123 | fi 124 | 125 | if [ "$SRC" == "ro" -o "$TGT" == "ro" ]; then 126 | echo "Downloading Romanian monolingual data ..." 127 | mkdir -p $MONO_PATH/ro 128 | cd $MONO_PATH/ro 129 | wget -c http://data.statmt.org/wmt16/translation-task/news.2015.ro.shuffled.gz 130 | fi 131 | 132 | cd $MONO_PATH 133 | 134 | # decompress monolingual data 135 | for FILENAME in $SRC/news*gz $TGT/news*gz; do 136 | OUTPUT="${FILENAME::-3}" 137 | if [ ! -f "$OUTPUT" ]; then 138 | echo "Decompressing $FILENAME..." 139 | gunzip -k $FILENAME 140 | else 141 | echo "$OUTPUT already decompressed." 142 | fi 143 | done 144 | 145 | cd $PARA_PATH 146 | # 147 | # Download bilingual data 148 | # 149 | echo "Downloading parallel data..." 150 | if [[ ! -f "$PARA_SRC_TRAIN.raw" ]] && [[ ! -f "$PARA_TGT_TRAIN.raw" ]]; then 151 | wget -c http://data.statmt.org/wmt18/translation-task/dev.tgz 152 | wget http://www.statmt.org/europarl/v7/ro-en.tgz 153 | wget http://opus.lingfil.uu.se/download.php?f=SETIMES2/en-ro.txt.zip -O SETIMES2.ro-en.txt.zip 154 | wget -nc http://data.statmt.org/rsennrich/wmt16_backtranslations/ro-en/corpus.bt.ro-en.en.gz 155 | wget -nc http://data.statmt.org/rsennrich/wmt16_backtranslations/ro-en/corpus.bt.ro-en.ro.gz 156 | 157 | echo "Extracting parallel data..." 158 | tar -xzf dev.tgz 159 | tar -xf ro-en.tgz 160 | unzip SETIMES2.ro-en.txt.zip 161 | gzip -d corpus.bt.ro-en.en.gz corpus.bt.ro-en.ro.gz 162 | fi 163 | 164 | 165 | # concatenate bilingual data 166 | if ! [[ -f "$PARA_SRC_TRAIN.raw" ]]; then 167 | echo "Concatenating $SRC bilingual data ..." 168 | cat corpus.bt.ro-en.$SRC europarl-v7.ro-en.$SRC SETIMES.en-ro.$SRC > $PARA_SRC_TRAIN.raw 169 | fi 170 | 171 | if ! [[ -f "$PARA_TGT_TRAIN.raw" ]]; then 172 | echo "Concatenating $TGT bilingual data ..." 173 | cat corpus.bt.ro-en.$TGT europarl-v7.ro-en.$TGT SETIMES.en-ro.$TGT > $PARA_TGT_TRAIN.raw 174 | fi 175 | 176 | SRC_PREPROCESSING="$REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $SRC | $REM_NON_PRINT_CHAR | $TOKENIZER -l $SRC -no-escape -threads $N_THREADS" 177 | TGT_PREPROCESSING="$REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $TGT | $REM_NON_PRINT_CHAR | $NORMALIZE_ROMANIAN | $REMOVE_DIACRITICS | $TOKENIZER -l $TGT -no-escape -threads $N_THREADS" 178 | 179 | # check valid and test files are here 180 | if ! [[ -f "$PARA_SRC_VALID.sgm" ]]; then echo "$PARA_SRC_VALID.sgm is not found!"; exit; fi 181 | if ! [[ -f "$PARA_TGT_VALID.sgm" ]]; then echo "$PARA_TGT_VALID.sgm is not found!"; exit; fi 182 | if ! [[ -f "$PARA_SRC_TEST.sgm" ]]; then echo "$PARA_SRC_TEST.sgm is not found!"; exit; fi 183 | if ! [[ -f "$PARA_TGT_TEST.sgm" ]]; then echo "$PARA_TGT_TEST.sgm is not found!"; exit; fi 184 | 185 | echo "Tokenizing valid and test data..." 186 | eval "$INPUT_FROM_SGM < $PARA_SRC_VALID.sgm | $SRC_PREPROCESSING > $PARA_SRC_VALID" 187 | eval "$INPUT_FROM_SGM < $PARA_TGT_VALID.sgm | $TGT_PREPROCESSING > $PARA_TGT_VALID" 188 | eval "$INPUT_FROM_SGM < $PARA_SRC_TEST.sgm | $SRC_PREPROCESSING > $PARA_SRC_TEST" 189 | eval "$INPUT_FROM_SGM < $PARA_TGT_TEST.sgm | $TGT_PREPROCESSING > $PARA_TGT_TEST" 190 | 191 | cat $PARA_SRC_TRAIN.raw | $SRC_PREPROCESSING > $PARA_SRC_TRAIN 192 | cat $PARA_TGT_TRAIN.raw | $TGT_PREPROCESSING > $PARA_TGT_TRAIN 193 | 194 | cd $MONO_PATH 195 | 196 | if ! [[ -f "$SRC_RAW" ]]; then 197 | echo "Concatenating $SRC monolingual data..." 198 | cat $(ls $SRC/news*$SRC* | grep -v gz) | head -n $N_MONO > $SRC_RAW 199 | fi 200 | if ! [[ -f "$TGT_RAW" ]]; then 201 | echo "Concatenating $TGT monolingual data..." 202 | cat $(ls $TGT/news*$TGT* | grep -v gz) | head -n $N_MONO > $TGT_RAW 203 | cat $PARA_PATH/SETIMES.en-ro.$TGT $PARA_PATH/europarl-v7.ro-en.$TGT >> $TGT_RAW 204 | fi 205 | 206 | # tokenize data 207 | if ! [[ -f "$SRC_TOK" ]]; then 208 | echo "Tokenize $SRC monolingual data..." 209 | eval "cat $SRC_RAW | $SRC_PREPROCESSING > $SRC_TOK" 210 | fi 211 | 212 | if ! [[ -f "$TGT_TOK" ]]; then 213 | echo "Tokenize $TGT monolingual data..." 214 | eval "cat $TGT_RAW | $TGT_PREPROCESSING > $TGT_TOK" 215 | fi 216 | echo "$SRC monolingual data tokenized in: $SRC_TOK" 217 | echo "$TGT monolingual data tokenized in: $TGT_TOK" 218 | 219 | # reload BPE codes 220 | cd $MAIN_PATH 221 | if [ ! -f "$BPE_CODES" ] && [ -f "$RELOAD_CODES" ]; then 222 | echo "Reloading BPE codes from $RELOAD_CODES ..." 223 | cp $RELOAD_CODES $BPE_CODES 224 | fi 225 | 226 | # learn BPE codes 227 | if [ ! -f "$BPE_CODES" ]; then 228 | echo "Learning BPE codes..." 229 | $FASTBPE learnbpe $CODES $SRC_TOK $TGT_TOK > $BPE_CODES 230 | fi 231 | echo "BPE learned in $BPE_CODES" 232 | 233 | # apply BPE codes 234 | if ! [[ -f "$SRC_TRAIN_BPE" ]]; then 235 | echo "Applying $SRC BPE codes..." 236 | $FASTBPE applybpe $SRC_TRAIN_BPE $SRC_TOK $BPE_CODES 237 | fi 238 | if ! [[ -f "$TGT_TRAIN_BPE" ]]; then 239 | echo "Applying $TGT BPE codes..." 240 | $FASTBPE applybpe $TGT_TRAIN_BPE $TGT_TOK $BPE_CODES 241 | fi 242 | echo "BPE codes applied to $SRC in: $SRC_TRAIN_BPE" 243 | echo "BPE codes applied to $TGT in: $TGT_TRAIN_BPE" 244 | 245 | echo "Applying BPE to train, valid and test bilingual files..." 246 | $FASTBPE applybpe $PARA_SRC_TRAIN_BPE $PARA_SRC_TRAIN $BPE_CODES 247 | $FASTBPE applybpe $PARA_TGT_TRAIN_BPE $PARA_TGT_TRAIN $BPE_CODES 248 | $FASTBPE applybpe $PARA_SRC_VALID_BPE $PARA_SRC_VALID $BPE_CODES 249 | $FASTBPE applybpe $PARA_TGT_VALID_BPE $PARA_TGT_VALID $BPE_CODES 250 | $FASTBPE applybpe $PARA_SRC_TEST_BPE $PARA_SRC_TEST $BPE_CODES 251 | $FASTBPE applybpe $PARA_TGT_TEST_BPE $PARA_TGT_TEST $BPE_CODES 252 | 253 | # extract source and target vocabulary 254 | if ! [[ -f "$SRC_VOCAB" && -f "$TGT_VOCAB" ]]; then 255 | echo "Extracting vocabulary..." 256 | $FASTBPE getvocab $SRC_TRAIN_BPE > $SRC_VOCAB 257 | $FASTBPE getvocab $TGT_TRAIN_BPE > $TGT_VOCAB 258 | fi 259 | echo "$SRC vocab in: $SRC_VOCAB" 260 | echo "$TGT vocab in: $TGT_VOCAB" 261 | 262 | # reload full vocabulary 263 | cd $MAIN_PATH 264 | if [ ! -f "$FULL_VOCAB" ] && [ -f "$RELOAD_VOCAB" ]; then 265 | echo "Reloading vocabulary from $RELOAD_VOCAB ..." 266 | cp $RELOAD_VOCAB $FULL_VOCAB 267 | fi 268 | 269 | # extract full vocabulary 270 | if ! [[ -f "$FULL_VOCAB" ]]; then 271 | echo "Extracting vocabulary..." 272 | $FASTBPE getvocab $SRC_TRAIN_BPE $TGT_TRAIN_BPE > $FULL_VOCAB 273 | fi 274 | echo "Full vocab in: $FULL_VOCAB" 275 | 276 | # binarize data 277 | if ! [[ -f "$SRC_TRAIN_BPE.pth" ]]; then 278 | echo "Binarizing $SRC data..." 279 | $MAIN_PATH/preprocess.py $FULL_VOCAB $SRC_TRAIN_BPE 280 | fi 281 | if ! [[ -f "$TGT_TRAIN_BPE.pth" ]]; then 282 | echo "Binarizing $TGT data..." 283 | $MAIN_PATH/preprocess.py $FULL_VOCAB $TGT_TRAIN_BPE 284 | fi 285 | echo "$SRC binarized data in: $SRC_TRAIN_BPE.pth" 286 | echo "$TGT binarized data in: $TGT_TRAIN_BPE.pth" 287 | 288 | echo "Binarizing data..." 289 | rm -f $PARA_SRC_VALID_BPE.pth $PARA_TGT_VALID_BPE.pth $PARA_SRC_TEST_BPE.pth $PARA_TGT_TEST_BPE.pth 290 | rm -f $PARA_SRC_TRAIN_BPE.pth $PARA_TGT_TRAIN_BPE.pth 291 | $MAIN_PATH/preprocess.py $FULL_VOCAB $PARA_SRC_TRAIN_BPE 292 | $MAIN_PATH/preprocess.py $FULL_VOCAB $PARA_TGT_TRAIN_BPE 293 | $MAIN_PATH/preprocess.py $FULL_VOCAB $PARA_SRC_VALID_BPE 294 | $MAIN_PATH/preprocess.py $FULL_VOCAB $PARA_TGT_VALID_BPE 295 | $MAIN_PATH/preprocess.py $FULL_VOCAB $PARA_SRC_TEST_BPE 296 | $MAIN_PATH/preprocess.py $FULL_VOCAB $PARA_TGT_TEST_BPE 297 | # 298 | # Link monolingual validation and test data to parallel data 299 | # 300 | ln -sf $PARA_SRC_VALID_BPE.pth $SRC_VALID_BPE.pth 301 | ln -sf $PARA_TGT_VALID_BPE.pth $TGT_VALID_BPE.pth 302 | ln -sf $PARA_SRC_TEST_BPE.pth $SRC_TEST_BPE.pth 303 | ln -sf $PARA_TGT_TEST_BPE.pth $TGT_TEST_BPE.pth 304 | 305 | cd $PARA_PATH 306 | if [[ -f "dev.tgz" ]]; then 307 | rm dev.tgz ro-en.tgz SETIMES2.ro-en.txt.zip 308 | rm LICENSE README SETIMES.en-ro.* corpus.bt.ro-en.* europarl-v7.ro-en.* 309 | fi 310 | -------------------------------------------------------------------------------- /MASS-unsupNMT/get-data-gigaword.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | 4 | GIGAPATH=sumdata 5 | 6 | CODES=40000 7 | 8 | POSITIONAL=() 9 | while [[ $# -gt 0 ]] 10 | do 11 | key="$1" 12 | case $key in 13 | --gigapath) 14 | GIGAPATH="$2"; shift 2;; 15 | --reload_codes) 16 | RELOAD_CODES="$2"; shift 2;; 17 | --reload_vocab) 18 | RELOAD_VOCAB="$2"; shift 2;; 19 | --replace_ner) 20 | REPLACE_NER="$2"; shift 2;; 21 | --replace_unk) 22 | REPLACE_UNK="$2"; shift 2;; 23 | *) 24 | POSITIONAL+=("$1") 25 | shift 26 | ;; 27 | esac 28 | done 29 | set -- "${POSITIONAL[@]}" 30 | 31 | # Check parameters 32 | 33 | if [ "$RELOAD_CODES" != "" ] && [ ! -f "$RELOAD_CODES" ]; then echo "cannot locate BPE codes"; exit; fi 34 | if [ "$RELOAD_VOCAB" != "" ] && [ ! -f "$RELOAD_VOCAB" ]; then echo "cannot locate vocabulary"; exit; fi 35 | if [ "$RELOAD_CODES" == "" -a "$RELOAD_VOCAB" != "" -o "$RELOAD_CODES" != "" -a "$RELOAD_VOCAB" == "" ]; then echo "BPE codes should be provided if and only if vocabulary is also provided"; exit; fi 36 | 37 | # fastBPE 38 | FASTBPE_DIR=$TOOLS_PATH/fastBPE 39 | FASTBPE=$TOOLS_PATH/fastBPE/fast 40 | 41 | # main paths 42 | MAIN_PATH=$PWD 43 | TOOLS_PATH=$PWD/tools 44 | DATA_PATH=$PWD/data 45 | PARA_PATH=$DATA_PATH/para/ 46 | PROC_PATH=$DATA_PATH/processed/giga/ 47 | 48 | # create paths 49 | mkdir -p $TOOLS_PATH 50 | mkdir -p $DATA_PATH 51 | mkdir -p $PROC_PATH 52 | mkdir -p $PARA_PATH 53 | 54 | MOSES=$TOOLS_PATH/mosesdecoder 55 | 56 | # fastBPE 57 | FASTBPE_DIR=$TOOLS_PATH/fastBPE 58 | FASTBPE=$TOOLS_PATH/fastBPE/fast 59 | 60 | TRAIN_SRC_RAW=$GIGAPATH/train/train.article.txt 61 | TRAIN_TGT_RAW=$GIGAPATH/train/train.title.txt 62 | VALID_SRC_RAW=$GIGAPATH/train/valid.article.filter.txt 63 | VALID_TGT_RAW=$GIGAPATH/train/valid.title.filter.txt 64 | TEST_SRC_RAW=$GIGAPATH/Giga/input.txt 65 | TEST_TGT_RAW=$GIGAPATH/Giga/task1_ref0.txt 66 | 67 | TRAIN_SRC=$PARA_PATH/train.ar-ti.ar 68 | TRAIN_TGT=$PARA_PATH/train.ar-ti.ti 69 | VALID_SRC=$PARA_PATH/valid.ar-ti.ar 70 | VALID_TGT=$PARA_PATH/valid.ar-ti.ti 71 | TEST_SRC=$PARA_PATH/test.ar-ti.ar 72 | TEST_TGT=$PARA_PATH/test.ar-ti.ar 73 | 74 | TRAIN_SRC_BPE=$PROC_PATH/train.ar-ti.ar 75 | TRAIN_TGT_BPE=$PROC_PATH/train.ar-ti.ti 76 | VALID_SRC_BPE=$PROC_PATH/valid.ar-ti.ar 77 | VALID_TGT_BPE=$PROC_PATH/valid.ar-ti.ti 78 | TEST_SRC_BPE=$PROC_PATH/test.ar-ti.ar 79 | TEST_TGT_BPE=$PROC_PATH/test.ar-ti.ti 80 | 81 | BPE_CODES=$PROC_PATH/codes 82 | FULL_VOCAB=$PROC_PATH/vocab.ar-ti 83 | 84 | if [ ! -f $TRAIN_SRC_RAW ]; then 85 | gzip -d $TRAIN_SRC_RAW.gz 86 | fi 87 | 88 | if [ ! -f $TRAIN_TGT_RAW ]; then 89 | gzip -d $TRAIN_TGT_RAW.gz 90 | fi 91 | 92 | preprocess="" 93 | 94 | if [ "$REPLACE_NER" == "true" ] && [ "$REPLACE_UNK" == "true" ]; then 95 | preprocess="sed 's/#/1/g' | sed 's//unk/g' | sed 's/UNK/unk/g'" 96 | else 97 | if [ "$REPLACE_UNK" == "true" ]; then 98 | preprocess="sed 's//unk/g' | sed 's/UNK/unk/g'" 99 | fi 100 | if [ "$REPLACE_NER" == "true" ]; then 101 | preprocess="sed 's/#/1/g'" 102 | fi 103 | fi 104 | 105 | if ! [[ -f $TRAIN_SRC ]]; then 106 | eval "cat $TRAIN_SRC_RAW | $preprocess > $TRAIN_SRC" 107 | fi 108 | 109 | if ! [[ -f $TRAIN_TGT ]]; then 110 | eval "cat $TRAIN_TGT_RAW | $preprocess > $TRAIN_TGT" 111 | fi 112 | 113 | if [ ! -f "$BPE_CODES" ] && [ -f "$RELOAD_CODES" ]; then 114 | echo "Reloading BPE codes from $RELOAD_CODES ..." 115 | cp $RELOAD_CODES $BPE_CODES 116 | fi 117 | 118 | # learn BPE codes 119 | if [ ! -f "$BPE_CODES" ]; then 120 | echo "Learning BPE codes..." 121 | $FASTBPE learnbpe $CODES $TRAIN_SRC $TRAIN_TGT > $BPE_CODES 122 | fi 123 | echo "BPE learned in $BPE_CODES" 124 | 125 | if [ ! -f "$FULL_VOCAB" ] && [ -f "$RELOAD_VOCAB" ]; then 126 | echo "Reloading vocabulary from $RELOAD_VOCAB ..." 127 | cp $RELOAD_VOCAB $FULL_VOCAB 128 | fi 129 | 130 | if [ ! -f "$TRAIN_SRC_BPE" ]; then 131 | echo "Applying article BPE codes..." 132 | $FASTBPE applybpe $TRAIN_SRC_BPE $TRAIN_SRC $BPE_CODES 133 | fi 134 | 135 | if [ ! -f "$TRAIN_TGT_BPE" ]; then 136 | echo "Applying title BPE codes..." 137 | $FASTBPE applybpe $TRAIN_TGT_BPE $TRAIN_TGT $BPE_CODES 138 | fi 139 | 140 | # extract full vocabulary 141 | if ! [[ -f "$FULL_VOCAB" ]]; then 142 | echo "Extracting vocabulary..." 143 | $FASTBPE getvocab $TRAIN_SRC_BPE $TRAIN_TGT_BPE > $FULL_VOCAB 144 | fi 145 | echo "Full vocab in: $FULL_VOCAB" 146 | 147 | eval "cat $VALID_SRC_RAW | $preprocess > $VALID_SRC" 148 | eval "cat $VALID_TGT_RAW | $preprocess > $VALID_TGT" 149 | eval "cat $TEST_SRC_RAW | $preprocess > $TEST_SRC" 150 | eval "cat $TEST_TGT_RAW | $preprocess > $TEST_TGT" 151 | 152 | $FASTBPE applybpe $VALID_SRC_BPE $VALID_SRC $BPE_CODES 153 | $FASTBPE applybpe $VALID_TGT_BPE $VALID_TGT $BPE_CODES 154 | $FASTBPE applybpe $TEST_SRC_BPE $TEST_SRC $BPE_CODES 155 | $FASTBPE applybpe $TEST_TGT_BPE $TEST_TGT $BPE_CODES 156 | 157 | python $MAIN_PATH/preprocess.py $FULL_VOCAB $TRAIN_SRC_BPE 158 | python $MAIN_PATH/preprocess.py $FULL_VOCAB $TRAIN_TGT_BPE 159 | python $MAIN_PATH/preprocess.py $FULL_VOCAB $VALID_SRC_BPE 160 | python $MAIN_PATH/preprocess.py $FULL_VOCAB $VALID_TGT_BPE 161 | python $MAIN_PATH/preprocess.py $FULL_VOCAB $TEST_SRC_BPE 162 | python $MAIN_PATH/preprocess.py $FULL_VOCAB $TEST_TGT_BPE 163 | -------------------------------------------------------------------------------- /MASS-unsupNMT/get-data-nmt.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # NOTICE FILE in the root directory of this source tree. 6 | # 7 | 8 | set -e 9 | 10 | 11 | # 12 | # Data preprocessing configuration 13 | # 14 | N_MONO=5000000 # number of monolingual sentences for each language 15 | CODES=60000 # number of BPE codes 16 | N_THREADS=16 # number of threads in data preprocessing 17 | 18 | 19 | # 20 | # Read arguments 21 | # 22 | POSITIONAL=() 23 | while [[ $# -gt 0 ]] 24 | do 25 | key="$1" 26 | case $key in 27 | --src) 28 | SRC="$2"; shift 2;; 29 | --tgt) 30 | TGT="$2"; shift 2;; 31 | --reload_codes) 32 | RELOAD_CODES="$2"; shift 2;; 33 | --reload_vocab) 34 | RELOAD_VOCAB="$2"; shift 2;; 35 | *) 36 | POSITIONAL+=("$1") 37 | shift 38 | ;; 39 | esac 40 | done 41 | set -- "${POSITIONAL[@]}" 42 | 43 | 44 | # 45 | # Check parameters 46 | # 47 | if [ "$SRC" == "" ]; then echo "--src not provided"; exit; fi 48 | if [ "$TGT" == "" ]; then echo "--tgt not provided"; exit; fi 49 | if [ "$SRC" != "de" -a "$SRC" != "en" -a "$SRC" != "fr" -a "$SRC" != "ro" ]; then echo "unknown source language"; exit; fi 50 | if [ "$TGT" != "de" -a "$TGT" != "en" -a "$TGT" != "fr" -a "$TGT" != "ro" ]; then echo "unknown target language"; exit; fi 51 | if [ "$SRC" == "$TGT" ]; then echo "source and target cannot be identical"; exit; fi 52 | if [ "$SRC" \> "$TGT" ]; then echo "please ensure SRC < TGT"; exit; fi 53 | if [ "$RELOAD_CODES" != "" ] && [ ! -f "$RELOAD_CODES" ]; then echo "cannot locate BPE codes"; exit; fi 54 | if [ "$RELOAD_VOCAB" != "" ] && [ ! -f "$RELOAD_VOCAB" ]; then echo "cannot locate vocabulary"; exit; fi 55 | if [ "$RELOAD_CODES" == "" -a "$RELOAD_VOCAB" != "" -o "$RELOAD_CODES" != "" -a "$RELOAD_VOCAB" == "" ]; then echo "BPE codes should be provided if and only if vocabulary is also provided"; exit; fi 56 | 57 | 58 | # 59 | # Initialize tools and data paths 60 | # 61 | 62 | # main paths 63 | MAIN_PATH=$PWD 64 | TOOLS_PATH=$PWD/tools 65 | DATA_PATH=$PWD/data 66 | MONO_PATH=$DATA_PATH/mono 67 | PARA_PATH=$DATA_PATH/para 68 | PROC_PATH=$DATA_PATH/processed/$SRC-$TGT 69 | 70 | # create paths 71 | mkdir -p $TOOLS_PATH 72 | mkdir -p $DATA_PATH 73 | mkdir -p $MONO_PATH 74 | mkdir -p $PARA_PATH 75 | mkdir -p $PROC_PATH 76 | 77 | # moses 78 | MOSES=$TOOLS_PATH/mosesdecoder 79 | REPLACE_UNICODE_PUNCT=$MOSES/scripts/tokenizer/replace-unicode-punctuation.perl 80 | NORM_PUNC=$MOSES/scripts/tokenizer/normalize-punctuation.perl 81 | REM_NON_PRINT_CHAR=$MOSES/scripts/tokenizer/remove-non-printing-char.perl 82 | TOKENIZER=$MOSES/scripts/tokenizer/tokenizer.perl 83 | INPUT_FROM_SGM=$MOSES/scripts/ems/support/input-from-sgm.perl 84 | 85 | # fastBPE 86 | FASTBPE_DIR=$TOOLS_PATH/fastBPE 87 | FASTBPE=$TOOLS_PATH/fastBPE/fast 88 | 89 | # Sennrich's WMT16 scripts for Romanian preprocessing 90 | WMT16_SCRIPTS=$TOOLS_PATH/wmt16-scripts 91 | NORMALIZE_ROMANIAN=$WMT16_SCRIPTS/preprocess/normalise-romanian.py 92 | REMOVE_DIACRITICS=$WMT16_SCRIPTS/preprocess/remove-diacritics.py 93 | 94 | # raw and tokenized files 95 | SRC_RAW=$MONO_PATH/$SRC/all.$SRC 96 | TGT_RAW=$MONO_PATH/$TGT/all.$TGT 97 | SRC_TOK=$SRC_RAW.tok 98 | TGT_TOK=$TGT_RAW.tok 99 | 100 | # BPE / vocab files 101 | BPE_CODES=$PROC_PATH/codes 102 | SRC_VOCAB=$PROC_PATH/vocab.$SRC 103 | TGT_VOCAB=$PROC_PATH/vocab.$TGT 104 | FULL_VOCAB=$PROC_PATH/vocab.$SRC-$TGT 105 | 106 | # train / valid / test monolingual BPE data 107 | SRC_TRAIN_BPE=$PROC_PATH/train.$SRC 108 | TGT_TRAIN_BPE=$PROC_PATH/train.$TGT 109 | SRC_VALID_BPE=$PROC_PATH/valid.$SRC 110 | TGT_VALID_BPE=$PROC_PATH/valid.$TGT 111 | SRC_TEST_BPE=$PROC_PATH/test.$SRC 112 | TGT_TEST_BPE=$PROC_PATH/test.$TGT 113 | 114 | # valid / test parallel BPE data 115 | PARA_SRC_VALID_BPE=$PROC_PATH/valid.$SRC-$TGT.$SRC 116 | PARA_TGT_VALID_BPE=$PROC_PATH/valid.$SRC-$TGT.$TGT 117 | PARA_SRC_TEST_BPE=$PROC_PATH/test.$SRC-$TGT.$SRC 118 | PARA_TGT_TEST_BPE=$PROC_PATH/test.$SRC-$TGT.$TGT 119 | 120 | # valid / test file raw data 121 | unset PARA_SRC_VALID PARA_TGT_VALID PARA_SRC_TEST PARA_TGT_TEST 122 | if [ "$SRC" == "en" -a "$TGT" == "fr" ]; then 123 | PARA_SRC_VALID=$PARA_PATH/dev/newstest2013-ref.en 124 | PARA_TGT_VALID=$PARA_PATH/dev/newstest2013-ref.fr 125 | PARA_SRC_TEST=$PARA_PATH/dev/newstest2014-fren-ref.en 126 | PARA_TGT_TEST=$PARA_PATH/dev/newstest2014-fren-ref.fr 127 | fi 128 | if [ "$SRC" == "de" -a "$TGT" == "en" ]; then 129 | PARA_SRC_VALID=$PARA_PATH/dev/newstest2013-ref.de 130 | PARA_TGT_VALID=$PARA_PATH/dev/newstest2013-ref.en 131 | PARA_SRC_TEST=$PARA_PATH/dev/newstest2016-ende-ref.de 132 | PARA_TGT_TEST=$PARA_PATH/dev/newstest2016-deen-ref.en 133 | # PARA_SRC_TEST=$PARA_PATH/dev/newstest2014-deen-ref.de 134 | # PARA_TGT_TEST=$PARA_PATH/dev/newstest2014-deen-ref.en 135 | fi 136 | if [ "$SRC" == "en" -a "$TGT" == "ro" ]; then 137 | PARA_SRC_VALID=$PARA_PATH/dev/newsdev2016-roen-ref.en 138 | PARA_TGT_VALID=$PARA_PATH/dev/newsdev2016-enro-ref.ro 139 | PARA_SRC_TEST=$PARA_PATH/dev/newstest2016-roen-ref.en 140 | PARA_TGT_TEST=$PARA_PATH/dev/newstest2016-enro-ref.ro 141 | fi 142 | 143 | # install tools 144 | ./install-tools.sh 145 | 146 | 147 | # 148 | # Download monolingual data 149 | # 150 | 151 | cd $MONO_PATH 152 | 153 | if [ "$SRC" == "de" -o "$TGT" == "de" ]; then 154 | echo "Downloading German monolingual data ..." 155 | mkdir -p $MONO_PATH/de 156 | cd $MONO_PATH/de 157 | wget -c http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2007.de.shuffled.gz 158 | wget -c http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2008.de.shuffled.gz 159 | # wget -c http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2009.de.shuffled.gz 160 | # wget -c http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2010.de.shuffled.gz 161 | # wget -c http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2011.de.shuffled.gz 162 | # wget -c http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2012.de.shuffled.gz 163 | # wget -c http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2013.de.shuffled.gz 164 | # wget -c http://www.statmt.org/wmt15/training-monolingual-news-crawl-v2/news.2014.de.shuffled.v2.gz 165 | # wget -c http://data.statmt.org/wmt16/translation-task/news.2015.de.shuffled.gz 166 | # wget -c http://data.statmt.org/wmt17/translation-task/news.2016.de.shuffled.gz 167 | # wget -c http://data.statmt.org/wmt18/translation-task/news.2017.de.shuffled.deduped.gz 168 | fi 169 | 170 | if [ "$SRC" == "en" -o "$TGT" == "en" ]; then 171 | echo "Downloading English monolingual data ..." 172 | mkdir -p $MONO_PATH/en 173 | cd $MONO_PATH/en 174 | wget -c http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2007.en.shuffled.gz 175 | wget -c http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2008.en.shuffled.gz 176 | # wget -c http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2009.en.shuffled.gz 177 | # wget -c http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2010.en.shuffled.gz 178 | # wget -c http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2011.en.shuffled.gz 179 | # wget -c http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2012.en.shuffled.gz 180 | # wget -c http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2013.en.shuffled.gz 181 | # wget -c http://www.statmt.org/wmt15/training-monolingual-news-crawl-v2/news.2014.en.shuffled.v2.gz 182 | # wget -c http://data.statmt.org/wmt16/translation-task/news.2015.en.shuffled.gz 183 | # wget -c http://data.statmt.org/wmt17/translation-task/news.2016.en.shuffled.gz 184 | # wget -c http://data.statmt.org/wmt18/translation-task/news.2017.en.shuffled.deduped.gz 185 | fi 186 | 187 | if [ "$SRC" == "fr" -o "$TGT" == "fr" ]; then 188 | echo "Downloading French monolingual data ..." 189 | mkdir -p $MONO_PATH/fr 190 | cd $MONO_PATH/fr 191 | wget -c http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2007.fr.shuffled.gz 192 | wget -c http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2008.fr.shuffled.gz 193 | wget -c http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2009.fr.shuffled.gz 194 | # wget -c http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2010.fr.shuffled.gz 195 | # wget -c http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2011.fr.shuffled.gz 196 | # wget -c http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2012.fr.shuffled.gz 197 | # wget -c http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2013.fr.shuffled.gz 198 | # wget -c http://www.statmt.org/wmt15/training-monolingual-news-crawl-v2/news.2014.fr.shuffled.v2.gz 199 | # wget -c http://data.statmt.org/wmt17/translation-task/news.2015.fr.shuffled.gz 200 | # wget -c http://data.statmt.org/wmt17/translation-task/news.2016.fr.shuffled.gz 201 | # wget -c http://data.statmt.org/wmt17/translation-task/news.2017.fr.shuffled.gz 202 | fi 203 | 204 | if [ "$SRC" == "ro" -o "$TGT" == "ro" ]; then 205 | echo "Downloading Romanian monolingual data ..." 206 | mkdir -p $MONO_PATH/ro 207 | cd $MONO_PATH/ro 208 | wget -c http://data.statmt.org/wmt16/translation-task/news.2015.ro.shuffled.gz 209 | fi 210 | 211 | cd $MONO_PATH 212 | 213 | # decompress monolingual data 214 | for FILENAME in $SRC/news*gz $TGT/news*gz; do 215 | OUTPUT="${FILENAME::-3}" 216 | if [ ! -f "$OUTPUT" ]; then 217 | echo "Decompressing $FILENAME..." 218 | gunzip -k $FILENAME 219 | else 220 | echo "$OUTPUT already decompressed." 221 | fi 222 | done 223 | 224 | # concatenate monolingual data files 225 | if ! [[ -f "$SRC_RAW" ]]; then 226 | echo "Concatenating $SRC monolingual data..." 227 | cat $(ls $SRC/news*$SRC* | grep -v gz) | head -n $N_MONO > $SRC_RAW 228 | fi 229 | if ! [[ -f "$TGT_RAW" ]]; then 230 | echo "Concatenating $TGT monolingual data..." 231 | cat $(ls $TGT/news*$TGT* | grep -v gz) | head -n $N_MONO > $TGT_RAW 232 | fi 233 | echo "$SRC monolingual data concatenated in: $SRC_RAW" 234 | echo "$TGT monolingual data concatenated in: $TGT_RAW" 235 | 236 | # # check number of lines 237 | # if ! [[ "$(wc -l < $SRC_RAW)" -eq "$N_MONO" ]]; then echo "ERROR: Number of lines does not match! Be sure you have $N_MONO sentences in your $SRC monolingual data."; exit; fi 238 | # if ! [[ "$(wc -l < $TGT_RAW)" -eq "$N_MONO" ]]; then echo "ERROR: Number of lines does not match! Be sure you have $N_MONO sentences in your $TGT monolingual data."; exit; fi 239 | 240 | # preprocessing commands - special case for Romanian 241 | if [ "$SRC" == "ro" ]; then 242 | SRC_PREPROCESSING="$REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $SRC | $REM_NON_PRINT_CHAR | $NORMALIZE_ROMANIAN | $REMOVE_DIACRITICS | $TOKENIZER -l $SRC -no-escape -threads $N_THREADS" 243 | else 244 | SRC_PREPROCESSING="$REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $SRC | $REM_NON_PRINT_CHAR | $TOKENIZER -l $SRC -no-escape -threads $N_THREADS" 245 | fi 246 | if [ "$TGT" == "ro" ]; then 247 | TGT_PREPROCESSING="$REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $TGT | $REM_NON_PRINT_CHAR | $NORMALIZE_ROMANIAN | $REMOVE_DIACRITICS | $TOKENIZER -l $TGT -no-escape -threads $N_THREADS" 248 | else 249 | TGT_PREPROCESSING="$REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $TGT | $REM_NON_PRINT_CHAR | $TOKENIZER -l $TGT -no-escape -threads $N_THREADS" 250 | fi 251 | 252 | # tokenize data 253 | if ! [[ -f "$SRC_TOK" ]]; then 254 | echo "Tokenize $SRC monolingual data..." 255 | eval "cat $SRC_RAW | $SRC_PREPROCESSING > $SRC_TOK" 256 | fi 257 | 258 | if ! [[ -f "$TGT_TOK" ]]; then 259 | echo "Tokenize $TGT monolingual data..." 260 | eval "cat $TGT_RAW | $TGT_PREPROCESSING > $TGT_TOK" 261 | fi 262 | echo "$SRC monolingual data tokenized in: $SRC_TOK" 263 | echo "$TGT monolingual data tokenized in: $TGT_TOK" 264 | 265 | # reload BPE codes 266 | cd $MAIN_PATH 267 | if [ ! -f "$BPE_CODES" ] && [ -f "$RELOAD_CODES" ]; then 268 | echo "Reloading BPE codes from $RELOAD_CODES ..." 269 | cp $RELOAD_CODES $BPE_CODES 270 | fi 271 | 272 | # learn BPE codes 273 | if [ ! -f "$BPE_CODES" ]; then 274 | echo "Learning BPE codes..." 275 | $FASTBPE learnbpe $CODES $SRC_TOK $TGT_TOK > $BPE_CODES 276 | fi 277 | echo "BPE learned in $BPE_CODES" 278 | 279 | # apply BPE codes 280 | if ! [[ -f "$SRC_TRAIN_BPE" ]]; then 281 | echo "Applying $SRC BPE codes..." 282 | $FASTBPE applybpe $SRC_TRAIN_BPE $SRC_TOK $BPE_CODES 283 | fi 284 | if ! [[ -f "$TGT_TRAIN_BPE" ]]; then 285 | echo "Applying $TGT BPE codes..." 286 | $FASTBPE applybpe $TGT_TRAIN_BPE $TGT_TOK $BPE_CODES 287 | fi 288 | echo "BPE codes applied to $SRC in: $SRC_TRAIN_BPE" 289 | echo "BPE codes applied to $TGT in: $TGT_TRAIN_BPE" 290 | 291 | # extract source and target vocabulary 292 | if ! [[ -f "$SRC_VOCAB" && -f "$TGT_VOCAB" ]]; then 293 | echo "Extracting vocabulary..." 294 | $FASTBPE getvocab $SRC_TRAIN_BPE > $SRC_VOCAB 295 | $FASTBPE getvocab $TGT_TRAIN_BPE > $TGT_VOCAB 296 | fi 297 | echo "$SRC vocab in: $SRC_VOCAB" 298 | echo "$TGT vocab in: $TGT_VOCAB" 299 | 300 | # reload full vocabulary 301 | cd $MAIN_PATH 302 | if [ ! -f "$FULL_VOCAB" ] && [ -f "$RELOAD_VOCAB" ]; then 303 | echo "Reloading vocabulary from $RELOAD_VOCAB ..." 304 | cp $RELOAD_VOCAB $FULL_VOCAB 305 | fi 306 | 307 | # extract full vocabulary 308 | if ! [[ -f "$FULL_VOCAB" ]]; then 309 | echo "Extracting vocabulary..." 310 | $FASTBPE getvocab $SRC_TRAIN_BPE $TGT_TRAIN_BPE > $FULL_VOCAB 311 | fi 312 | echo "Full vocab in: $FULL_VOCAB" 313 | 314 | # binarize data 315 | if ! [[ -f "$SRC_TRAIN_BPE.pth" ]]; then 316 | echo "Binarizing $SRC data..." 317 | $MAIN_PATH/preprocess.py $FULL_VOCAB $SRC_TRAIN_BPE 318 | fi 319 | if ! [[ -f "$TGT_TRAIN_BPE.pth" ]]; then 320 | echo "Binarizing $TGT data..." 321 | $MAIN_PATH/preprocess.py $FULL_VOCAB $TGT_TRAIN_BPE 322 | fi 323 | echo "$SRC binarized data in: $SRC_TRAIN_BPE.pth" 324 | echo "$TGT binarized data in: $TGT_TRAIN_BPE.pth" 325 | 326 | 327 | # 328 | # Download parallel data (for evaluation only) 329 | # 330 | 331 | cd $PARA_PATH 332 | 333 | echo "Downloading parallel data..." 334 | wget -c http://data.statmt.org/wmt18/translation-task/dev.tgz 335 | 336 | echo "Extracting parallel data..." 337 | tar -xzf dev.tgz 338 | 339 | # check valid and test files are here 340 | if ! [[ -f "$PARA_SRC_VALID.sgm" ]]; then echo "$PARA_SRC_VALID.sgm is not found!"; exit; fi 341 | if ! [[ -f "$PARA_TGT_VALID.sgm" ]]; then echo "$PARA_TGT_VALID.sgm is not found!"; exit; fi 342 | if ! [[ -f "$PARA_SRC_TEST.sgm" ]]; then echo "$PARA_SRC_TEST.sgm is not found!"; exit; fi 343 | if ! [[ -f "$PARA_TGT_TEST.sgm" ]]; then echo "$PARA_TGT_TEST.sgm is not found!"; exit; fi 344 | 345 | echo "Tokenizing valid and test data..." 346 | eval "$INPUT_FROM_SGM < $PARA_SRC_VALID.sgm | $SRC_PREPROCESSING > $PARA_SRC_VALID" 347 | eval "$INPUT_FROM_SGM < $PARA_TGT_VALID.sgm | $TGT_PREPROCESSING > $PARA_TGT_VALID" 348 | eval "$INPUT_FROM_SGM < $PARA_SRC_TEST.sgm | $SRC_PREPROCESSING > $PARA_SRC_TEST" 349 | eval "$INPUT_FROM_SGM < $PARA_TGT_TEST.sgm | $TGT_PREPROCESSING > $PARA_TGT_TEST" 350 | 351 | echo "Applying BPE to valid and test files..." 352 | $FASTBPE applybpe $PARA_SRC_VALID_BPE $PARA_SRC_VALID $BPE_CODES $SRC_VOCAB 353 | $FASTBPE applybpe $PARA_TGT_VALID_BPE $PARA_TGT_VALID $BPE_CODES $TGT_VOCAB 354 | $FASTBPE applybpe $PARA_SRC_TEST_BPE $PARA_SRC_TEST $BPE_CODES $SRC_VOCAB 355 | $FASTBPE applybpe $PARA_TGT_TEST_BPE $PARA_TGT_TEST $BPE_CODES $TGT_VOCAB 356 | 357 | echo "Binarizing data..." 358 | rm -f $PARA_SRC_VALID_BPE.pth $PARA_TGT_VALID_BPE.pth $PARA_SRC_TEST_BPE.pth $PARA_TGT_TEST_BPE.pth 359 | $MAIN_PATH/preprocess.py $FULL_VOCAB $PARA_SRC_VALID_BPE 360 | $MAIN_PATH/preprocess.py $FULL_VOCAB $PARA_TGT_VALID_BPE 361 | $MAIN_PATH/preprocess.py $FULL_VOCAB $PARA_SRC_TEST_BPE 362 | $MAIN_PATH/preprocess.py $FULL_VOCAB $PARA_TGT_TEST_BPE 363 | 364 | 365 | # 366 | # Link monolingual validation and test data to parallel data 367 | # 368 | ln -sf $PARA_SRC_VALID_BPE.pth $SRC_VALID_BPE.pth 369 | ln -sf $PARA_TGT_VALID_BPE.pth $TGT_VALID_BPE.pth 370 | ln -sf $PARA_SRC_TEST_BPE.pth $SRC_TEST_BPE.pth 371 | ln -sf $PARA_TGT_TEST_BPE.pth $TGT_TEST_BPE.pth 372 | 373 | 374 | # 375 | # Summary 376 | # 377 | echo "" 378 | echo "===== Data summary" 379 | echo "Monolingual training data:" 380 | echo " $SRC: $SRC_TRAIN_BPE.pth" 381 | echo " $TGT: $TGT_TRAIN_BPE.pth" 382 | echo "Monolingual validation data:" 383 | echo " $SRC: $SRC_VALID_BPE.pth" 384 | echo " $TGT: $TGT_VALID_BPE.pth" 385 | echo "Monolingual test data:" 386 | echo " $SRC: $SRC_TEST_BPE.pth" 387 | echo " $TGT: $TGT_TEST_BPE.pth" 388 | echo "Parallel validation data:" 389 | echo " $SRC: $PARA_SRC_VALID_BPE.pth" 390 | echo " $TGT: $PARA_TGT_VALID_BPE.pth" 391 | echo "Parallel test data:" 392 | echo " $SRC: $PARA_SRC_TEST_BPE.pth" 393 | echo " $TGT: $PARA_TGT_TEST_BPE.pth" 394 | echo "" 395 | -------------------------------------------------------------------------------- /MASS-unsupNMT/install-tools.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # NOTICE FILE in the root directory of this source tree. 6 | # 7 | 8 | set -e 9 | 10 | lg=$1 # input language 11 | 12 | # data path 13 | MAIN_PATH=$PWD 14 | TOOLS_PATH=$PWD/tools 15 | 16 | # tools 17 | MOSES_DIR=$TOOLS_PATH/mosesdecoder 18 | FASTBPE_DIR=$TOOLS_PATH/fastBPE 19 | FASTBPE=$FASTBPE_DIR/fast 20 | WMT16_SCRIPTS=$TOOLS_PATH/wmt16-scripts 21 | 22 | # tools path 23 | mkdir -p $TOOLS_PATH 24 | 25 | # 26 | # Download and install tools 27 | # 28 | 29 | cd $TOOLS_PATH 30 | 31 | # Download Moses 32 | if [ ! -d "$MOSES_DIR" ]; then 33 | echo "Cloning Moses from GitHub repository..." 34 | git clone https://github.com/moses-smt/mosesdecoder.git 35 | fi 36 | 37 | # Download fastBPE 38 | if [ ! -d "$FASTBPE_DIR" ]; then 39 | echo "Cloning fastBPE from GitHub repository..." 40 | git clone https://github.com/glample/fastBPE 41 | fi 42 | 43 | # Compile fastBPE 44 | if [ ! -f "$FASTBPE" ]; then 45 | echo "Compiling fastBPE..." 46 | cd fastBPE 47 | g++ -std=c++11 -pthread -O3 fastBPE/main.cc -IfastBPE -o fast 48 | cd .. 49 | fi 50 | 51 | # Download Sennrich's tools 52 | if [ ! -d "$WMT16_SCRIPTS" ]; then 53 | echo "Cloning WMT16 preprocessing scripts..." 54 | git clone https://github.com/rsennrich/wmt16-scripts.git 55 | fi 56 | 57 | # Download WikiExtractor 58 | if [ ! -d $TOOLS_PATH/wikiextractor ]; then 59 | echo "Cloning WikiExtractor from GitHub repository..." 60 | git clone https://github.com/attardi/wikiextractor.git 61 | fi 62 | 63 | # # Chinese segmenter 64 | # if ! ls $TOOLS_PATH/stanford-segmenter-* 1> /dev/null 2>&1; then 65 | # echo "Stanford segmenter not found at $TOOLS_PATH/stanford-segmenter-*" 66 | # echo "Please install Stanford segmenter in $TOOLS_PATH" 67 | # exit 1 68 | # fi 69 | # 70 | # # Thai tokenizer 71 | # if ! python -c 'import pkgutil; exit(not pkgutil.find_loader("pythainlp"))'; then 72 | # echo "pythainlp package not found in python" 73 | # echo "Please install pythainlp (pip install pythainlp)" 74 | # exit 1 75 | # fi 76 | # 77 | -------------------------------------------------------------------------------- /MASS-unsupNMT/preprocess.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright (c) 2019-present, Facebook, Inc. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the license found in the 8 | # NOTICE FILE in the root directory of this source tree. 9 | # 10 | 11 | 12 | """ 13 | Example: python data/vocab.txt data/train.txt 14 | vocab.txt: 1stline=word, 2ndline=count 15 | """ 16 | 17 | import os 18 | import sys 19 | 20 | from src.logger import create_logger 21 | from src.data.dictionary import Dictionary 22 | 23 | 24 | if __name__ == '__main__': 25 | 26 | logger = create_logger(None, 0) 27 | 28 | voc_path = sys.argv[1] 29 | txt_path = sys.argv[2] 30 | bin_path = sys.argv[2] + '.pth' 31 | assert os.path.isfile(voc_path) 32 | assert os.path.isfile(txt_path) 33 | 34 | dico = Dictionary.read_vocab(voc_path) 35 | logger.info("") 36 | 37 | data = Dictionary.index_data(txt_path, bin_path, dico) 38 | logger.info("%i words (%i unique) in %i sentences." % ( 39 | len(data['sentences']) - len(data['positions']), 40 | len(data['dico']), 41 | len(data['positions']) 42 | )) 43 | if len(data['unk_words']) > 0: 44 | logger.info("%i unknown words (%i unique), covering %.2f%% of the data." % ( 45 | sum(data['unk_words'].values()), 46 | len(data['unk_words']), 47 | sum(data['unk_words'].values()) * 100. / (len(data['sentences']) - len(data['positions'])) 48 | )) 49 | if len(data['unk_words']) < 30: 50 | for w, c in sorted(data['unk_words'].items(), key=lambda x: x[1])[::-1]: 51 | logger.info("%s: %i" % (w, c)) 52 | -------------------------------------------------------------------------------- /MASS-unsupNMT/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MASS/779f22fc47c8a256d8bc04826ebe1c8307063cbe/MASS-unsupNMT/src/__init__.py -------------------------------------------------------------------------------- /MASS-unsupNMT/src/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MASS/779f22fc47c8a256d8bc04826ebe1c8307063cbe/MASS-unsupNMT/src/data/__init__.py -------------------------------------------------------------------------------- /MASS-unsupNMT/src/data/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # NOTICE FILE in the root directory of this source tree. 6 | # 7 | 8 | from logging import getLogger 9 | import math 10 | import numpy as np 11 | import torch 12 | 13 | 14 | logger = getLogger() 15 | 16 | 17 | class StreamDataset(object): 18 | 19 | def __init__(self, sent, pos, params): 20 | """ 21 | Prepare batches for data iterator. 22 | """ 23 | bptt = params.bptt 24 | bs = params.batch_size 25 | self.eos = params.eos_index 26 | 27 | # checks 28 | assert len(pos) == (sent == self.eos).sum() 29 | assert len(pos) == (sent[pos[:, 1]] == self.eos).sum() 30 | 31 | n_tokens = len(sent) 32 | n_batches = math.ceil(n_tokens / (bs * bptt)) 33 | t_size = n_batches * bptt * bs 34 | 35 | buffer = np.zeros(t_size, dtype=sent.dtype) + self.eos 36 | buffer[t_size - n_tokens:] = sent 37 | buffer = buffer.reshape((bs, n_batches * bptt)).T 38 | self.data = np.zeros((n_batches * bptt + 1, bs), dtype=sent.dtype) + self.eos 39 | self.data[1:] = buffer 40 | 41 | self.bptt = bptt 42 | self.n_tokens = n_tokens 43 | self.n_batches = n_batches 44 | self.n_sentences = len(pos) 45 | self.lengths = torch.LongTensor(bs).fill_(bptt) 46 | 47 | def __len__(self): 48 | """ 49 | Number of sentences in the dataset. 50 | """ 51 | return self.n_sentences 52 | 53 | def select_data(self, a, b): 54 | """ 55 | Only select a subset of the dataset. 56 | """ 57 | if not (0 <= a < b <= self.n_batches): 58 | logger.warning("Invalid split values: %i %i - %i" % (a, b, self.n_batches)) 59 | return 60 | assert 0 <= a < b <= self.n_batches 61 | logger.info("Selecting batches from %i to %i ..." % (a, b)) 62 | 63 | # sub-select 64 | self.data = self.data[a * self.bptt:b * self.bptt] 65 | self.n_batches = b - a 66 | self.n_sentences = (self.data == self.eos).sum().item() 67 | 68 | def get_iterator(self, shuffle, subsample=1): 69 | """ 70 | Return a sentences iterator. 71 | """ 72 | indexes = (np.random.permutation if shuffle else range)(self.n_batches // subsample) 73 | for i in indexes: 74 | a = self.bptt * i 75 | b = self.bptt * (i + 1) 76 | yield torch.from_numpy(self.data[a:b].astype(np.int64)), self.lengths 77 | 78 | 79 | class Dataset(object): 80 | 81 | def __init__(self, sent, pos, params): 82 | 83 | self.eos_index = params.eos_index 84 | self.pad_index = params.pad_index 85 | self.batch_size = params.batch_size 86 | self.tokens_per_batch = params.tokens_per_batch 87 | self.max_batch_size = params.max_batch_size 88 | 89 | self.sent = sent 90 | self.pos = pos 91 | self.lengths = self.pos[:, 1] - self.pos[:, 0] 92 | 93 | # check number of sentences 94 | assert len(self.pos) == (self.sent == self.eos_index).sum() 95 | 96 | # # remove empty sentences 97 | # self.remove_empty_sentences() 98 | 99 | # sanity checks 100 | self.check() 101 | 102 | def __len__(self): 103 | """ 104 | Number of sentences in the dataset. 105 | """ 106 | return len(self.pos) 107 | 108 | def check(self): 109 | """ 110 | Sanity checks. 111 | """ 112 | eos = self.eos_index 113 | assert len(self.pos) == (self.sent[self.pos[:, 1]] == eos).sum() # check sentences indices 114 | # assert self.lengths.min() > 0 # check empty sentences 115 | 116 | def batch_sentences(self, sentences): 117 | """ 118 | Take as input a list of n sentences (torch.LongTensor vectors) and return 119 | a tensor of size (slen, n) where slen is the length of the longest 120 | sentence, and a vector lengths containing the length of each sentence. 121 | """ 122 | # sentences = sorted(sentences, key=lambda x: len(x), reverse=True) 123 | lengths = torch.LongTensor([len(s) + 2 for s in sentences]) 124 | sent = torch.LongTensor(lengths.max().item(), lengths.size(0)).fill_(self.pad_index) 125 | 126 | sent[0] = self.eos_index 127 | for i, s in enumerate(sentences): 128 | if lengths[i] > 2: # if sentence not empty 129 | sent[1:lengths[i] - 1, i].copy_(torch.from_numpy(s.astype(np.int64))) 130 | sent[lengths[i] - 1, i] = self.eos_index 131 | 132 | return sent, lengths 133 | 134 | def remove_empty_sentences(self): 135 | """ 136 | Remove empty sentences. 137 | """ 138 | init_size = len(self.pos) 139 | indices = np.arange(len(self.pos)) 140 | indices = indices[self.lengths[indices] > 0] 141 | self.pos = self.pos[indices] 142 | self.lengths = self.pos[:, 1] - self.pos[:, 0] 143 | logger.info("Removed %i empty sentences." % (init_size - len(indices))) 144 | self.check() 145 | 146 | def remove_long_sentences(self, max_len): 147 | """ 148 | Remove sentences exceeding a certain length. 149 | """ 150 | assert max_len >= 0 151 | if max_len == 0: 152 | return 153 | init_size = len(self.pos) 154 | indices = np.arange(len(self.pos)) 155 | indices = indices[self.lengths[indices] <= max_len] 156 | self.pos = self.pos[indices] 157 | self.lengths = self.pos[:, 1] - self.pos[:, 0] 158 | logger.info("Removed %i too long sentences." % (init_size - len(indices))) 159 | self.check() 160 | 161 | def remove_short_sentences(self, min_len): 162 | assert min_len >= 0 163 | if min_len == 0: 164 | return 165 | init_size = len(self.pos) 166 | indices = np.arange(len(self.pos)) 167 | indices = indices[self.lengths[indices] >= min_len] 168 | self.pos = self.pos[indices] 169 | self.lengths = self.pos[:, 1] - self.pos[:, 0] 170 | logger.info("Removed %i too short sentences" % (init_size - len(indices))) 171 | self.check() 172 | 173 | def select_data(self, a, b): 174 | """ 175 | Only select a subset of the dataset. 176 | """ 177 | assert 0 <= a < b <= len(self.pos) 178 | logger.info("Selecting sentences from %i to %i ..." % (a, b)) 179 | 180 | # sub-select 181 | self.pos = self.pos[a:b] 182 | self.lengths = self.pos[:, 1] - self.pos[:, 0] 183 | 184 | # re-index 185 | min_pos = self.pos.min() 186 | max_pos = self.pos.max() 187 | self.pos -= min_pos 188 | self.sent = self.sent[min_pos:max_pos + 1] 189 | 190 | # sanity checks 191 | self.check() 192 | 193 | def get_batches_iterator(self, batches, return_indices): 194 | """ 195 | Return a sentences iterator, given the associated sentence batches. 196 | """ 197 | assert type(return_indices) is bool 198 | 199 | for sentence_ids in batches: 200 | if 0 < self.max_batch_size < len(sentence_ids): 201 | np.random.shuffle(sentence_ids) 202 | sentence_ids = sentence_ids[:self.max_batch_size] 203 | pos = self.pos[sentence_ids] 204 | sent = [self.sent[a:b] for a, b in pos] 205 | sent = self.batch_sentences(sent) 206 | yield (sent, sentence_ids) if return_indices else sent 207 | 208 | def get_iterator(self, shuffle, group_by_size=False, n_sentences=-1, seed=None, return_indices=False): 209 | """ 210 | Return a sentences iterator. 211 | """ 212 | assert seed is None or shuffle is True and type(seed) is int 213 | rng = np.random.RandomState(seed) 214 | n_sentences = len(self.pos) if n_sentences == -1 else n_sentences 215 | assert 0 < n_sentences <= len(self.pos) 216 | assert type(shuffle) is bool and type(group_by_size) is bool 217 | #assert group_by_size is False or shuffle is True 218 | 219 | # sentence lengths 220 | lengths = self.lengths + 2 221 | 222 | # select sentences to iterate over 223 | if shuffle: 224 | indices = rng.permutation(len(self.pos))[:n_sentences] 225 | else: 226 | indices = np.arange(n_sentences) 227 | 228 | # group sentences by lengths 229 | if group_by_size: 230 | indices = indices[np.argsort(lengths[indices], kind='mergesort')] 231 | 232 | # create batches - either have a fixed number of sentences, or a similar number of tokens 233 | if self.tokens_per_batch == -1: 234 | batches = np.array_split(indices, math.ceil(len(indices) * 1. / self.batch_size)) 235 | else: 236 | batch_ids = np.cumsum(lengths[indices]) // self.tokens_per_batch 237 | _, bounds = np.unique(batch_ids, return_index=True) 238 | batches = [indices[bounds[i]:bounds[i + 1]] for i in range(len(bounds) - 1)] 239 | if bounds[-1] < len(indices): 240 | batches.append(indices[bounds[-1]:]) 241 | 242 | # optionally shuffle batches 243 | if shuffle: 244 | rng.shuffle(batches) 245 | 246 | # sanity checks 247 | assert n_sentences == sum([len(x) for x in batches]) 248 | assert lengths[indices].sum() == sum([lengths[x].sum() for x in batches]) 249 | # assert set.union(*[set(x.tolist()) for x in batches]) == set(range(n_sentences)) # slow 250 | 251 | # return the iterator 252 | return self.get_batches_iterator(batches, return_indices) 253 | 254 | 255 | class ParallelDataset(Dataset): 256 | 257 | def __init__(self, sent1, pos1, sent2, pos2, params): 258 | 259 | self.eos_index = params.eos_index 260 | self.pad_index = params.pad_index 261 | self.batch_size = params.batch_size 262 | self.tokens_per_batch = params.tokens_per_batch 263 | self.max_batch_size = params.max_batch_size 264 | 265 | self.sent1 = sent1 266 | self.sent2 = sent2 267 | self.pos1 = pos1 268 | self.pos2 = pos2 269 | self.lengths1 = self.pos1[:, 1] - self.pos1[:, 0] 270 | self.lengths2 = self.pos2[:, 1] - self.pos2[:, 0] 271 | 272 | # check number of sentences 273 | assert len(self.pos1) == (self.sent1 == self.eos_index).sum() 274 | assert len(self.pos2) == (self.sent2 == self.eos_index).sum() 275 | 276 | # remove empty sentences 277 | self.remove_empty_sentences() 278 | 279 | # sanity checks 280 | self.check() 281 | 282 | def __len__(self): 283 | """ 284 | Number of sentences in the dataset. 285 | """ 286 | return len(self.pos1) 287 | 288 | def check(self): 289 | """ 290 | Sanity checks. 291 | """ 292 | eos = self.eos_index 293 | assert len(self.pos1) == len(self.pos2) > 0 # check number of sentences 294 | assert len(self.pos1) == (self.sent1[self.pos1[:, 1]] == eos).sum() # check sentences indices 295 | assert len(self.pos2) == (self.sent2[self.pos2[:, 1]] == eos).sum() # check sentences indices 296 | assert eos <= self.sent1.min() < self.sent1.max() # check dictionary indices 297 | assert eos <= self.sent2.min() < self.sent2.max() # check dictionary indices 298 | assert self.lengths1.min() > 0 # check empty sentences 299 | assert self.lengths2.min() > 0 # check empty sentences 300 | 301 | def remove_empty_sentences(self): 302 | """ 303 | Remove empty sentences. 304 | """ 305 | init_size = len(self.pos1) 306 | indices = np.arange(len(self.pos1)) 307 | indices = indices[self.lengths1[indices] > 0] 308 | indices = indices[self.lengths2[indices] > 0] 309 | self.pos1 = self.pos1[indices] 310 | self.pos2 = self.pos2[indices] 311 | self.lengths1 = self.pos1[:, 1] - self.pos1[:, 0] 312 | self.lengths2 = self.pos2[:, 1] - self.pos2[:, 0] 313 | logger.info("Removed %i empty sentences." % (init_size - len(indices))) 314 | self.check() 315 | 316 | def remove_long_sentences(self, max_len): 317 | """ 318 | Remove sentences exceeding a certain length. 319 | """ 320 | assert max_len >= 0 321 | if max_len == 0: 322 | return 323 | init_size = len(self.pos1) 324 | indices = np.arange(len(self.pos1)) 325 | indices = indices[self.lengths1[indices] <= max_len] 326 | indices = indices[self.lengths2[indices] <= max_len] 327 | self.pos1 = self.pos1[indices] 328 | self.pos2 = self.pos2[indices] 329 | self.lengths1 = self.pos1[:, 1] - self.pos1[:, 0] 330 | self.lengths2 = self.pos2[:, 1] - self.pos2[:, 0] 331 | logger.info("Removed %i too long sentences." % (init_size - len(indices))) 332 | self.check() 333 | 334 | def select_data(self, a, b): 335 | """ 336 | Only select a subset of the dataset. 337 | """ 338 | assert 0 <= a < b <= len(self.pos1) 339 | logger.info("Selecting sentences from %i to %i ..." % (a, b)) 340 | 341 | # sub-select 342 | self.pos1 = self.pos1[a:b] 343 | self.pos2 = self.pos2[a:b] 344 | self.lengths1 = self.pos1[:, 1] - self.pos1[:, 0] 345 | self.lengths2 = self.pos2[:, 1] - self.pos2[:, 0] 346 | 347 | # re-index 348 | min_pos1 = self.pos1.min() 349 | max_pos1 = self.pos1.max() 350 | min_pos2 = self.pos2.min() 351 | max_pos2 = self.pos2.max() 352 | self.pos1 -= min_pos1 353 | self.pos2 -= min_pos2 354 | self.sent1 = self.sent1[min_pos1:max_pos1 + 1] 355 | self.sent2 = self.sent2[min_pos2:max_pos2 + 1] 356 | 357 | # sanity checks 358 | self.check() 359 | 360 | def get_batches_iterator(self, batches, return_indices): 361 | """ 362 | Return a sentences iterator, given the associated sentence batches. 363 | """ 364 | assert type(return_indices) is bool 365 | 366 | for sentence_ids in batches: 367 | if 0 < self.max_batch_size < len(sentence_ids): 368 | np.random.shuffle(sentence_ids) 369 | sentence_ids = sentence_ids[:self.max_batch_size] 370 | pos1 = self.pos1[sentence_ids] 371 | pos2 = self.pos2[sentence_ids] 372 | sent1 = self.batch_sentences([self.sent1[a:b] for a, b in pos1]) 373 | sent2 = self.batch_sentences([self.sent2[a:b] for a, b in pos2]) 374 | yield (sent1, sent2, sentence_ids) if return_indices else (sent1, sent2) 375 | 376 | def get_iterator(self, shuffle, group_by_size=False, n_sentences=-1, return_indices=False): 377 | """ 378 | Return a sentences iterator. 379 | """ 380 | n_sentences = len(self.pos1) if n_sentences == -1 else n_sentences 381 | assert 0 < n_sentences <= len(self.pos1) 382 | assert type(shuffle) is bool and type(group_by_size) is bool 383 | 384 | # sentence lengths 385 | lengths = self.lengths1 + self.lengths2 + 4 386 | 387 | # select sentences to iterate over 388 | if shuffle: 389 | indices = np.random.permutation(len(self.pos1))[:n_sentences] 390 | else: 391 | indices = np.arange(n_sentences) 392 | 393 | # group sentences by lengths 394 | if group_by_size: 395 | indices = indices[np.argsort(lengths[indices], kind='mergesort')] 396 | 397 | # create batches - either have a fixed number of sentences, or a similar number of tokens 398 | if self.tokens_per_batch == -1: 399 | batches = np.array_split(indices, math.ceil(len(indices) * 1. / self.batch_size)) 400 | else: 401 | batch_ids = np.cumsum(lengths[indices]) // self.tokens_per_batch 402 | _, bounds = np.unique(batch_ids, return_index=True) 403 | batches = [indices[bounds[i]:bounds[i + 1]] for i in range(len(bounds) - 1)] 404 | if bounds[-1] < len(indices): 405 | batches.append(indices[bounds[-1]:]) 406 | 407 | # optionally shuffle batches 408 | if shuffle: 409 | np.random.shuffle(batches) 410 | 411 | # sanity checks 412 | assert n_sentences == sum([len(x) for x in batches]) 413 | assert lengths[indices].sum() == sum([lengths[x].sum() for x in batches]) 414 | # assert set.union(*[set(x.tolist()) for x in batches]) == set(range(n_sentences)) # slow 415 | 416 | # return the iterator 417 | return self.get_batches_iterator(batches, return_indices) 418 | -------------------------------------------------------------------------------- /MASS-unsupNMT/src/data/dictionary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # NOTICE FILE in the root directory of this source tree. 6 | # 7 | 8 | import os 9 | import numpy as np 10 | import torch 11 | from logging import getLogger 12 | 13 | 14 | logger = getLogger() 15 | 16 | 17 | BOS_WORD = '' 18 | EOS_WORD = '' 19 | PAD_WORD = '' 20 | UNK_WORD = '' 21 | 22 | SPECIAL_WORD = '' 23 | SPECIAL_WORDS = 10 24 | 25 | SEP_WORD = SPECIAL_WORD % 0 26 | MASK_WORD = SPECIAL_WORD % 1 27 | 28 | 29 | class Dictionary(object): 30 | 31 | def __init__(self, id2word, word2id, counts): 32 | assert len(id2word) == len(word2id) == len(counts) 33 | self.id2word = id2word 34 | self.word2id = word2id 35 | self.counts = counts 36 | self.bos_index = word2id[BOS_WORD] 37 | self.eos_index = word2id[EOS_WORD] 38 | self.pad_index = word2id[PAD_WORD] 39 | self.unk_index = word2id[UNK_WORD] 40 | self.check_valid() 41 | 42 | def __len__(self): 43 | """ 44 | Returns the number of words in the dictionary. 45 | """ 46 | return len(self.id2word) 47 | 48 | def __getitem__(self, i): 49 | """ 50 | Returns the word of the specified index. 51 | """ 52 | return self.id2word[i] 53 | 54 | def __contains__(self, w): 55 | """ 56 | Returns whether a word is in the dictionary. 57 | """ 58 | return w in self.word2id 59 | 60 | def __eq__(self, y): 61 | """ 62 | Compare this dictionary with another one. 63 | """ 64 | self.check_valid() 65 | y.check_valid() 66 | if len(self.id2word) != len(y): 67 | return False 68 | return all(self.id2word[i] == y[i] for i in range(len(y))) 69 | 70 | def check_valid(self): 71 | """ 72 | Check that the dictionary is valid. 73 | """ 74 | assert self.bos_index == 0 75 | assert self.eos_index == 1 76 | assert self.pad_index == 2 77 | assert self.unk_index == 3 78 | assert all(self.id2word[4 + i] == SPECIAL_WORD % i for i in range(SPECIAL_WORDS)) 79 | assert len(self.id2word) == len(self.word2id) == len(self.counts) 80 | assert set(self.word2id.keys()) == set(self.counts.keys()) 81 | for i in range(len(self.id2word)): 82 | assert self.word2id[self.id2word[i]] == i 83 | last_count = 1e18 84 | for i in range(4 + SPECIAL_WORDS, len(self.id2word) - 1): 85 | count = self.counts[self.id2word[i]] 86 | assert count <= last_count 87 | last_count = count 88 | 89 | def index(self, word, no_unk=False): 90 | """ 91 | Returns the index of the specified word. 92 | """ 93 | if no_unk: 94 | return self.word2id[word] 95 | else: 96 | return self.word2id.get(word, self.unk_index) 97 | 98 | def max_vocab(self, max_vocab): 99 | """ 100 | Limit the vocabulary size. 101 | """ 102 | assert max_vocab >= 1 103 | init_size = len(self) 104 | self.id2word = {k: v for k, v in self.id2word.items() if k < max_vocab} 105 | self.word2id = {v: k for k, v in self.id2word.items()} 106 | self.counts = {k: v for k, v in self.counts.items() if k in self.word2id} 107 | self.check_valid() 108 | logger.info("Maximum vocabulary size: %i. Dictionary size: %i -> %i (removed %i words)." 109 | % (max_vocab, init_size, len(self), init_size - len(self))) 110 | 111 | def min_count(self, min_count): 112 | """ 113 | Threshold on the word frequency counts. 114 | """ 115 | assert min_count >= 0 116 | init_size = len(self) 117 | self.id2word = {k: v for k, v in self.id2word.items() if self.counts[self.id2word[k]] >= min_count or k < 4 + SPECIAL_WORDS} 118 | self.word2id = {v: k for k, v in self.id2word.items()} 119 | self.counts = {k: v for k, v in self.counts.items() if k in self.word2id} 120 | self.check_valid() 121 | logger.info("Minimum frequency count: %i. Dictionary size: %i -> %i (removed %i words)." 122 | % (min_count, init_size, len(self), init_size - len(self))) 123 | 124 | @staticmethod 125 | def read_vocab(vocab_path): 126 | """ 127 | Create a dictionary from a vocabulary file. 128 | """ 129 | skipped = 0 130 | assert os.path.isfile(vocab_path), vocab_path 131 | word2id = {BOS_WORD: 0, EOS_WORD: 1, PAD_WORD: 2, UNK_WORD: 3} 132 | for i in range(SPECIAL_WORDS): 133 | word2id[SPECIAL_WORD % i] = 4 + i 134 | counts = {k: 0 for k in word2id.keys()} 135 | f = open(vocab_path, 'r', encoding='utf-8') 136 | for i, line in enumerate(f): 137 | if '\u2028' in line: 138 | skipped += 1 139 | continue 140 | line = line.rstrip().split() 141 | if len(line) != 2: 142 | skipped += 1 143 | continue 144 | assert len(line) == 2, (i, line) 145 | # assert line[0] not in word2id and line[1].isdigit(), (i, line) 146 | assert line[1].isdigit(), (i, line) 147 | if line[0] in word2id: 148 | skipped += 1 149 | print('%s already in vocab' % line[0]) 150 | continue 151 | if not line[1].isdigit(): 152 | skipped += 1 153 | print('Empty word at line %s with count %s' % (i, line)) 154 | continue 155 | word2id[line[0]] = 4 + SPECIAL_WORDS + i - skipped # shift because of extra words 156 | counts[line[0]] = int(line[1]) 157 | f.close() 158 | id2word = {v: k for k, v in word2id.items()} 159 | dico = Dictionary(id2word, word2id, counts) 160 | logger.info("Read %i words from the vocabulary file." % len(dico)) 161 | if skipped > 0: 162 | logger.warning("Skipped %i empty lines!" % skipped) 163 | return dico 164 | 165 | @staticmethod 166 | def index_data(path, bin_path, dico): 167 | """ 168 | Index sentences with a dictionary. 169 | """ 170 | if bin_path is not None and os.path.isfile(bin_path): 171 | print("Loading data from %s ..." % bin_path) 172 | data = torch.load(bin_path) 173 | assert dico == data['dico'] 174 | return data 175 | 176 | positions = [] 177 | sentences = [] 178 | unk_words = {} 179 | 180 | # index sentences 181 | f = open(path, 'r', encoding='utf-8') 182 | for i, line in enumerate(f): 183 | if i % 1000000 == 0 and i > 0: 184 | print(i) 185 | s = line.rstrip().split() 186 | # skip empty sentences 187 | if len(s) == 0: 188 | print("Empty sentence in line %i." % i) 189 | # index sentence words 190 | count_unk = 0 191 | indexed = [] 192 | for w in s: 193 | word_id = dico.index(w, no_unk=False) 194 | # if we find a special word which is not an unknown word, skip the sentence 195 | if 0 <= word_id < 4 + SPECIAL_WORDS and word_id != 3: 196 | logger.warning('Found unexpected special word "%s" (%i)!!' % (w, word_id)) 197 | continue 198 | assert word_id >= 0 199 | indexed.append(word_id) 200 | if word_id == dico.unk_index: 201 | unk_words[w] = unk_words.get(w, 0) + 1 202 | count_unk += 1 203 | # add sentence 204 | positions.append([len(sentences), len(sentences) + len(indexed)]) 205 | sentences.extend(indexed) 206 | sentences.append(1) # EOS index 207 | f.close() 208 | 209 | # tensorize data 210 | positions = np.int64(positions) 211 | if len(dico) < 1 << 16: 212 | sentences = np.uint16(sentences) 213 | elif len(dico) < 1 << 31: 214 | sentences = np.int32(sentences) 215 | else: 216 | raise Exception("Dictionary is too big.") 217 | assert sentences.min() >= 0 218 | data = { 219 | 'dico': dico, 220 | 'positions': positions, 221 | 'sentences': sentences, 222 | 'unk_words': unk_words, 223 | } 224 | if bin_path is not None: 225 | print("Saving the data to %s ..." % bin_path) 226 | torch.save(data, bin_path, pickle_protocol=4) 227 | 228 | return data 229 | -------------------------------------------------------------------------------- /MASS-unsupNMT/src/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MASS/779f22fc47c8a256d8bc04826ebe1c8307063cbe/MASS-unsupNMT/src/evaluation/__init__.py -------------------------------------------------------------------------------- /MASS-unsupNMT/src/evaluation/multi-bleu.perl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # 3 | # This file is part of moses. Its use is licensed under the GNU Lesser General 4 | # Public License version 2.1 or, at your option, any later version. 5 | 6 | # $Id$ 7 | use warnings; 8 | use strict; 9 | 10 | my $lowercase = 0; 11 | if ($ARGV[0] eq "-lc") { 12 | $lowercase = 1; 13 | shift; 14 | } 15 | 16 | my $stem = $ARGV[0]; 17 | if (!defined $stem) { 18 | print STDERR "usage: multi-bleu.pl [-lc] reference < hypothesis\n"; 19 | print STDERR "Reads the references from reference or reference0, reference1, ...\n"; 20 | exit(1); 21 | } 22 | 23 | $stem .= ".ref" if !-e $stem && !-e $stem."0" && -e $stem.".ref0"; 24 | 25 | my @REF; 26 | my $ref=0; 27 | while(-e "$stem$ref") { 28 | &add_to_ref("$stem$ref",\@REF); 29 | $ref++; 30 | } 31 | &add_to_ref($stem,\@REF) if -e $stem; 32 | die("ERROR: could not find reference file $stem") unless scalar @REF; 33 | 34 | # add additional references explicitly specified on the command line 35 | shift; 36 | foreach my $stem (@ARGV) { 37 | &add_to_ref($stem,\@REF) if -e $stem; 38 | } 39 | 40 | 41 | 42 | sub add_to_ref { 43 | my ($file,$REF) = @_; 44 | my $s=0; 45 | if ($file =~ /.gz$/) { 46 | open(REF,"gzip -dc $file|") or die "Can't read $file"; 47 | } else { 48 | open(REF,$file) or die "Can't read $file"; 49 | } 50 | while() { 51 | chop; 52 | push @{$$REF[$s++]}, $_; 53 | } 54 | close(REF); 55 | } 56 | 57 | my(@CORRECT,@TOTAL,$length_translation,$length_reference); 58 | my $s=0; 59 | while() { 60 | chop; 61 | $_ = lc if $lowercase; 62 | my @WORD = split; 63 | my %REF_NGRAM = (); 64 | my $length_translation_this_sentence = scalar(@WORD); 65 | my ($closest_diff,$closest_length) = (9999,9999); 66 | foreach my $reference (@{$REF[$s]}) { 67 | # print "$s $_ <=> $reference\n"; 68 | $reference = lc($reference) if $lowercase; 69 | my @WORD = split(' ',$reference); 70 | my $length = scalar(@WORD); 71 | my $diff = abs($length_translation_this_sentence-$length); 72 | if ($diff < $closest_diff) { 73 | $closest_diff = $diff; 74 | $closest_length = $length; 75 | # print STDERR "$s: closest diff ".abs($length_translation_this_sentence-$length)." = abs($length_translation_this_sentence-$length), setting len: $closest_length\n"; 76 | } elsif ($diff == $closest_diff) { 77 | $closest_length = $length if $length < $closest_length; 78 | # from two references with the same closeness to me 79 | # take the *shorter* into account, not the "first" one. 80 | } 81 | for(my $n=1;$n<=4;$n++) { 82 | my %REF_NGRAM_N = (); 83 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 84 | my $ngram = "$n"; 85 | for(my $w=0;$w<$n;$w++) { 86 | $ngram .= " ".$WORD[$start+$w]; 87 | } 88 | $REF_NGRAM_N{$ngram}++; 89 | } 90 | foreach my $ngram (keys %REF_NGRAM_N) { 91 | if (!defined($REF_NGRAM{$ngram}) || 92 | $REF_NGRAM{$ngram} < $REF_NGRAM_N{$ngram}) { 93 | $REF_NGRAM{$ngram} = $REF_NGRAM_N{$ngram}; 94 | # print "$i: REF_NGRAM{$ngram} = $REF_NGRAM{$ngram}
\n"; 95 | } 96 | } 97 | } 98 | } 99 | $length_translation += $length_translation_this_sentence; 100 | $length_reference += $closest_length; 101 | for(my $n=1;$n<=4;$n++) { 102 | my %T_NGRAM = (); 103 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 104 | my $ngram = "$n"; 105 | for(my $w=0;$w<$n;$w++) { 106 | $ngram .= " ".$WORD[$start+$w]; 107 | } 108 | $T_NGRAM{$ngram}++; 109 | } 110 | foreach my $ngram (keys %T_NGRAM) { 111 | $ngram =~ /^(\d+) /; 112 | my $n = $1; 113 | # my $corr = 0; 114 | # print "$i e $ngram $T_NGRAM{$ngram}
\n"; 115 | $TOTAL[$n] += $T_NGRAM{$ngram}; 116 | if (defined($REF_NGRAM{$ngram})) { 117 | if ($REF_NGRAM{$ngram} >= $T_NGRAM{$ngram}) { 118 | $CORRECT[$n] += $T_NGRAM{$ngram}; 119 | # $corr = $T_NGRAM{$ngram}; 120 | # print "$i e correct1 $T_NGRAM{$ngram}
\n"; 121 | } 122 | else { 123 | $CORRECT[$n] += $REF_NGRAM{$ngram}; 124 | # $corr = $REF_NGRAM{$ngram}; 125 | # print "$i e correct2 $REF_NGRAM{$ngram}
\n"; 126 | } 127 | } 128 | # $REF_NGRAM{$ngram} = 0 if !defined $REF_NGRAM{$ngram}; 129 | # print STDERR "$ngram: {$s, $REF_NGRAM{$ngram}, $T_NGRAM{$ngram}, $corr}\n" 130 | } 131 | } 132 | $s++; 133 | } 134 | my $brevity_penalty = 1; 135 | my $bleu = 0; 136 | 137 | my @bleu=(); 138 | 139 | for(my $n=1;$n<=4;$n++) { 140 | if (defined ($TOTAL[$n])){ 141 | $bleu[$n]=($TOTAL[$n])?$CORRECT[$n]/$TOTAL[$n]:0; 142 | # print STDERR "CORRECT[$n]:$CORRECT[$n] TOTAL[$n]:$TOTAL[$n]\n"; 143 | }else{ 144 | $bleu[$n]=0; 145 | } 146 | } 147 | 148 | if ($length_reference==0){ 149 | printf "BLEU = 0, 0/0/0/0 (BP=0, ratio=0, hyp_len=0, ref_len=0)\n"; 150 | exit(1); 151 | } 152 | 153 | if ($length_translation<$length_reference) { 154 | $brevity_penalty = exp(1-$length_reference/$length_translation); 155 | } 156 | $bleu = $brevity_penalty * exp((my_log( $bleu[1] ) + 157 | my_log( $bleu[2] ) + 158 | my_log( $bleu[3] ) + 159 | my_log( $bleu[4] ) ) / 4) ; 160 | printf "BLEU = %.2f, %.1f/%.1f/%.1f/%.1f (BP=%.3f, ratio=%.3f, hyp_len=%d, ref_len=%d)\n", 161 | 100*$bleu, 162 | 100*$bleu[1], 163 | 100*$bleu[2], 164 | 100*$bleu[3], 165 | 100*$bleu[4], 166 | $brevity_penalty, 167 | $length_translation / $length_reference, 168 | $length_translation, 169 | $length_reference; 170 | 171 | 172 | # print STDERR "It is in-advisable to publish scores from multi-bleu.perl. The scores depend on your tokenizer, which is unlikely to be reproducible from your paper or consistent across research groups. Instead you should detokenize then use mteval-v14.pl, which has a standard tokenization. Scores from multi-bleu.perl can still be used for internal purposes when you have a consistent tokenizer.\n"; 173 | 174 | sub my_log { 175 | return -9999999999 unless $_[0]; 176 | return log($_[0]); 177 | } 178 | -------------------------------------------------------------------------------- /MASS-unsupNMT/src/fp16.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # NOTICE FILE in the root directory of this source tree. 6 | # 7 | 8 | # 9 | # float16 related code 10 | # https://github.com/NVIDIA/apex/blob/master/apex/fp16_utils/fp16util.py 11 | # 12 | 13 | import torch 14 | 15 | 16 | def BN_convert_float(module): 17 | ''' 18 | Designed to work with network_to_half. 19 | BatchNorm layers need parameters in single precision. 20 | Find all layers and convert them back to float. This can't 21 | be done with built in .apply as that function will apply 22 | fn to all modules, parameters, and buffers. Thus we wouldn't 23 | be able to guard the float conversion based on the module type. 24 | ''' 25 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): 26 | module.float() 27 | for child in module.children(): 28 | BN_convert_float(child) 29 | return module 30 | 31 | 32 | def network_to_half(network): 33 | """ 34 | Convert model to half precision in a batchnorm-safe way. 35 | """ 36 | return BN_convert_float(network.half()) 37 | -------------------------------------------------------------------------------- /MASS-unsupNMT/src/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # NOTICE FILE in the root directory of this source tree. 6 | # 7 | 8 | import logging 9 | import time 10 | from datetime import timedelta 11 | 12 | 13 | class LogFormatter(): 14 | 15 | def __init__(self): 16 | self.start_time = time.time() 17 | 18 | def format(self, record): 19 | elapsed_seconds = round(record.created - self.start_time) 20 | 21 | prefix = "%s - %s - %s" % ( 22 | record.levelname, 23 | time.strftime('%x %X'), 24 | timedelta(seconds=elapsed_seconds) 25 | ) 26 | message = record.getMessage() 27 | message = message.replace('\n', '\n' + ' ' * (len(prefix) + 3)) 28 | return "%s - %s" % (prefix, message) if message else '' 29 | 30 | 31 | def create_logger(filepath, rank): 32 | """ 33 | Create a logger. 34 | Use a different log file for each process. 35 | """ 36 | # create log formatter 37 | log_formatter = LogFormatter() 38 | 39 | # create file handler and set level to debug 40 | if filepath is not None: 41 | if rank > 0: 42 | filepath = '%s-%i' % (filepath, rank) 43 | file_handler = logging.FileHandler(filepath, "a") 44 | file_handler.setLevel(logging.DEBUG) 45 | file_handler.setFormatter(log_formatter) 46 | 47 | # create console handler and set level to info 48 | console_handler = logging.StreamHandler() 49 | console_handler.setLevel(logging.INFO) 50 | console_handler.setFormatter(log_formatter) 51 | 52 | # create logger and set level to debug 53 | logger = logging.getLogger() 54 | logger.handlers = [] 55 | logger.setLevel(logging.DEBUG) 56 | logger.propagate = False 57 | if filepath is not None: 58 | logger.addHandler(file_handler) 59 | logger.addHandler(console_handler) 60 | 61 | # reset logger elapsed time 62 | def reset_time(): 63 | log_formatter.start_time = time.time() 64 | logger.reset_time = reset_time 65 | 66 | return logger 67 | -------------------------------------------------------------------------------- /MASS-unsupNMT/src/model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # NOTICE FILE in the root directory of this source tree. 6 | # 7 | 8 | from logging import getLogger 9 | import os 10 | import torch 11 | 12 | from .transformer import TransformerModel 13 | 14 | 15 | logger = getLogger() 16 | 17 | 18 | def check_model_params(params): 19 | """ 20 | Check models parameters. 21 | """ 22 | # masked language modeling task parameters 23 | assert params.bptt >= 1 24 | assert 0 <= params.word_pred < 1 25 | assert 0 <= params.sample_alpha < 1 26 | s = params.word_mask_keep_rand.split(',') 27 | assert len(s) == 3 28 | s = [float(x) for x in s] 29 | assert all([0 <= x <= 1 for x in s]) and sum(s) == 1 30 | params.word_mask = s[0] 31 | params.word_keep = s[1] 32 | params.word_rand = s[2] 33 | 34 | # input sentence noise for DAE 35 | if len(params.ae_steps) == 0: 36 | assert params.word_shuffle == 0 37 | assert params.word_dropout == 0 38 | assert params.word_blank == 0 39 | else: 40 | assert params.word_shuffle == 0 or params.word_shuffle > 1 41 | assert 0 <= params.word_dropout < 1 42 | assert 0 <= params.word_blank < 1 43 | 44 | # model dimensions 45 | assert params.emb_dim % params.n_heads == 0 46 | 47 | # share input and output embeddings 48 | assert params.share_inout_emb is False or params.asm is False 49 | 50 | # adaptive softmax 51 | if params.asm: 52 | assert params.asm_div_value > 1 53 | s = params.asm_cutoffs.split(',') 54 | assert all([x.isdigit() for x in s]) 55 | params.asm_cutoffs = [int(x) for x in s] 56 | assert params.max_vocab == -1 or params.asm_cutoffs[-1] < params.max_vocab 57 | 58 | # reload a pretrained model 59 | if params.reload_model != '': 60 | if params.encoder_only: 61 | assert os.path.isfile(params.reload_model) 62 | else: 63 | s = params.reload_model.split(',') 64 | assert len(s) == 2 65 | assert all([x == '' or os.path.isfile(x) for x in s]) 66 | 67 | 68 | def set_pretrain_emb(model, dico, word2id, embeddings): 69 | """ 70 | Pretrain word embeddings. 71 | """ 72 | n_found = 0 73 | with torch.no_grad(): 74 | for i in range(len(dico)): 75 | idx = word2id.get(dico[i], None) 76 | if idx is None: 77 | continue 78 | n_found += 1 79 | model.embeddings.weight[i] = embeddings[idx].cuda() 80 | model.pred_layer.proj.weight[i] = embeddings[idx].cuda() 81 | logger.info("Pretrained %i/%i words (%.3f%%)." 82 | % (n_found, len(dico), 100. * n_found / len(dico))) 83 | 84 | 85 | def build_model(params, dico): 86 | """ 87 | Build model. 88 | """ 89 | if params.encoder_only: 90 | # build 91 | model = TransformerModel(params, dico, is_encoder=True, with_output=True) 92 | 93 | # reload a pretrained model 94 | if params.reload_model != '': 95 | logger.info("Reloading model from %s ..." % params.reload_model) 96 | reloaded = torch.load(params.reload_model, map_location=lambda storage, loc: storage.cuda(params.local_rank))['model'] 97 | if all([k.startswith('module.') for k in reloaded.keys()]): 98 | reloaded = {k[len('module.'):]: v for k, v in reloaded.items()} 99 | 100 | # # HACK to reload models with less layers 101 | # for i in range(12, 24): 102 | # for k in TRANSFORMER_LAYER_PARAMS: 103 | # k = k % i 104 | # if k in model.state_dict() and k not in reloaded: 105 | # logger.warning("Parameter %s not found. Ignoring ..." % k) 106 | # reloaded[k] = model.state_dict()[k] 107 | 108 | model.load_state_dict(reloaded) 109 | 110 | logger.debug("Model: {}".format(model)) 111 | logger.info("Number of parameters (model): %i" % sum([p.numel() for p in model.parameters() if p.requires_grad])) 112 | 113 | return model.cuda() 114 | 115 | else: 116 | # build 117 | encoder = TransformerModel(params, dico, is_encoder=True, with_output=True) # TODO: only output when necessary - len(params.clm_steps + params.mlm_steps) > 0 118 | decoder = TransformerModel(params, dico, is_encoder=False, with_output=True) 119 | 120 | # reload a pretrained model 121 | if params.reload_model != '': 122 | enc_path, dec_path = params.reload_model.split(',') 123 | assert not (enc_path == '' and dec_path == '') 124 | 125 | # reload encoder 126 | if enc_path != '': 127 | logger.info("Reloading encoder from %s ..." % enc_path) 128 | enc_reload = torch.load(enc_path, map_location=lambda storage, loc: storage.cuda(params.local_rank)) 129 | enc_reload = enc_reload['model' if 'model' in enc_reload else 'encoder'] 130 | if all([k.startswith('module.') for k in enc_reload.keys()]): 131 | enc_reload = {k[len('module.'):]: v for k, v in enc_reload.items()} 132 | encoder.load_state_dict(enc_reload) 133 | 134 | # reload decoder 135 | if dec_path != '': 136 | logger.info("Reloading decoder from %s ..." % dec_path) 137 | dec_reload = torch.load(dec_path, map_location=lambda storage, loc: storage.cuda(params.local_rank)) 138 | dec_reload = dec_reload['model' if 'model' in dec_reload else 'decoder'] 139 | if all([k.startswith('module.') for k in dec_reload.keys()]): 140 | dec_reload = {k[len('module.'):]: v for k, v in dec_reload.items()} 141 | decoder.load_state_dict(dec_reload, strict=False) 142 | 143 | logger.debug("Encoder: {}".format(encoder)) 144 | logger.debug("Decoder: {}".format(decoder)) 145 | logger.info("Number of parameters (encoder): %i" % sum([p.numel() for p in encoder.parameters() if p.requires_grad])) 146 | logger.info("Number of parameters (decoder): %i" % sum([p.numel() for p in decoder.parameters() if p.requires_grad])) 147 | 148 | return encoder.cuda(), decoder.cuda() 149 | -------------------------------------------------------------------------------- /MASS-unsupNMT/src/slurm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # NOTICE FILE in the root directory of this source tree. 6 | # 7 | 8 | from logging import getLogger 9 | import os 10 | import sys 11 | import torch 12 | import socket 13 | import signal 14 | import subprocess 15 | 16 | 17 | logger = getLogger() 18 | 19 | 20 | def sig_handler(signum, frame): 21 | logger.warning("Signal handler called with signal " + str(signum)) 22 | prod_id = int(os.environ['SLURM_PROCID']) 23 | logger.warning("Host: %s - Global rank: %i" % (socket.gethostname(), prod_id)) 24 | if prod_id == 0: 25 | logger.warning("Requeuing job " + os.environ['SLURM_JOB_ID']) 26 | os.system('scontrol requeue ' + os.environ['SLURM_JOB_ID']) 27 | else: 28 | logger.warning("Not the master process, no need to requeue.") 29 | sys.exit(-1) 30 | 31 | 32 | def term_handler(signum, frame): 33 | logger.warning("Signal handler called with signal " + str(signum)) 34 | logger.warning("Bypassing SIGTERM.") 35 | 36 | 37 | def init_signal_handler(): 38 | """ 39 | Handle signals sent by SLURM for time limit / pre-emption. 40 | """ 41 | signal.signal(signal.SIGUSR1, sig_handler) 42 | signal.signal(signal.SIGTERM, term_handler) 43 | logger.warning("Signal handler installed.") 44 | 45 | 46 | def init_distributed_mode(params): 47 | """ 48 | Handle single and multi-GPU / multi-node / SLURM jobs. 49 | Initialize the following variables: 50 | - n_nodes 51 | - node_id 52 | - local_rank 53 | - global_rank 54 | - world_size 55 | """ 56 | params.is_slurm_job = 'SLURM_JOB_ID' in os.environ and not params.debug_slurm 57 | print("SLURM job: %s" % str(params.is_slurm_job)) 58 | 59 | # SLURM job 60 | if params.is_slurm_job: 61 | 62 | assert params.local_rank == -1 # on the cluster, this is handled by SLURM 63 | 64 | SLURM_VARIABLES = [ 65 | 'SLURM_JOB_ID', 66 | 'SLURM_JOB_NODELIST', 'SLURM_JOB_NUM_NODES', 'SLURM_NTASKS', 'SLURM_TASKS_PER_NODE', 67 | 'SLURM_MEM_PER_NODE', 'SLURM_MEM_PER_CPU', 68 | 'SLURM_NODEID', 'SLURM_PROCID', 'SLURM_LOCALID', 'SLURM_TASK_PID' 69 | ] 70 | 71 | PREFIX = "%i - " % int(os.environ['SLURM_PROCID']) 72 | for name in SLURM_VARIABLES: 73 | value = os.environ.get(name, None) 74 | print(PREFIX + "%s: %s" % (name, str(value))) 75 | 76 | # # job ID 77 | # params.job_id = os.environ['SLURM_JOB_ID'] 78 | 79 | # number of nodes / node ID 80 | params.n_nodes = int(os.environ['SLURM_JOB_NUM_NODES']) 81 | params.node_id = int(os.environ['SLURM_NODEID']) 82 | 83 | # local rank on the current node / global rank 84 | params.local_rank = int(os.environ['SLURM_LOCALID']) 85 | params.global_rank = int(os.environ['SLURM_PROCID']) 86 | 87 | # number of processes / GPUs per node 88 | params.world_size = int(os.environ['SLURM_NTASKS']) 89 | params.n_gpu_per_node = params.world_size // params.n_nodes 90 | 91 | # define master address and master port 92 | hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', os.environ['SLURM_JOB_NODELIST']]) 93 | params.master_addr = hostnames.split()[0].decode('utf-8') 94 | assert 10001 <= params.master_port <= 20000 or params.world_size == 1 95 | print(PREFIX + "Master address: %s" % params.master_addr) 96 | print(PREFIX + "Master port : %i" % params.master_port) 97 | 98 | # set environment variables for 'env://' 99 | os.environ['MASTER_ADDR'] = params.master_addr 100 | os.environ['MASTER_PORT'] = str(params.master_port) 101 | os.environ['WORLD_SIZE'] = str(params.world_size) 102 | os.environ['RANK'] = str(params.global_rank) 103 | 104 | # multi-GPU job (local or multi-node) - jobs started with torch.distributed.launch 105 | elif params.local_rank != -1: 106 | 107 | assert params.master_port == -1 108 | 109 | # read environment variables 110 | params.global_rank = int(os.environ['RANK']) 111 | params.world_size = int(os.environ['WORLD_SIZE']) 112 | params.n_gpu_per_node = int(os.environ['NGPU']) 113 | 114 | # number of nodes / node ID 115 | params.n_nodes = params.world_size // params.n_gpu_per_node 116 | params.node_id = params.global_rank // params.n_gpu_per_node 117 | 118 | # local job (single GPU) 119 | else: 120 | assert params.local_rank == -1 121 | assert params.master_port == -1 122 | params.n_nodes = 1 123 | params.node_id = 0 124 | params.local_rank = 0 125 | params.global_rank = 0 126 | params.world_size = 1 127 | params.n_gpu_per_node = 1 128 | 129 | # sanity checks 130 | assert params.n_nodes >= 1 131 | assert 0 <= params.node_id < params.n_nodes 132 | assert 0 <= params.local_rank <= params.global_rank < params.world_size 133 | assert params.world_size == params.n_nodes * params.n_gpu_per_node 134 | 135 | # define whether this is the master process / if we are in distributed mode 136 | params.is_master = params.node_id == 0 and params.local_rank == 0 137 | params.multi_node = params.n_nodes > 1 138 | params.multi_gpu = params.world_size > 1 139 | 140 | # summary 141 | PREFIX = "%i - " % params.global_rank 142 | print(PREFIX + "Number of nodes: %i" % params.n_nodes) 143 | print(PREFIX + "Node ID : %i" % params.node_id) 144 | print(PREFIX + "Local rank : %i" % params.local_rank) 145 | print(PREFIX + "Global rank : %i" % params.global_rank) 146 | print(PREFIX + "World size : %i" % params.world_size) 147 | print(PREFIX + "GPUs per node : %i" % params.n_gpu_per_node) 148 | print(PREFIX + "Master : %s" % str(params.is_master)) 149 | print(PREFIX + "Multi-node : %s" % str(params.multi_node)) 150 | print(PREFIX + "Multi-GPU : %s" % str(params.multi_gpu)) 151 | print(PREFIX + "Hostname : %s" % socket.gethostname()) 152 | 153 | # set GPU device 154 | torch.cuda.set_device(params.local_rank) 155 | 156 | # initialize multi-GPU 157 | if params.multi_gpu: 158 | 159 | # http://pytorch.apachecn.org/en/0.3.0/distributed.html#environment-variable-initialization 160 | # 'env://' will read these environment variables: 161 | # MASTER_PORT - required; has to be a free port on machine with rank 0 162 | # MASTER_ADDR - required (except for rank 0); address of rank 0 node 163 | # WORLD_SIZE - required; can be set either here, or in a call to init function 164 | # RANK - required; can be set either here, or in a call to init function 165 | 166 | print("Initializing PyTorch distributed ...") 167 | torch.distributed.init_process_group( 168 | init_method='env://', 169 | backend='nccl', 170 | ) 171 | -------------------------------------------------------------------------------- /MASS-unsupNMT/src/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # NOTICE FILE in the root directory of this source tree. 6 | # 7 | 8 | import os 9 | import re 10 | import sys 11 | import pickle 12 | import random 13 | import inspect 14 | import getpass 15 | import argparse 16 | import subprocess 17 | import numpy as np 18 | import torch 19 | from torch import optim 20 | 21 | from .logger import create_logger 22 | 23 | 24 | FALSY_STRINGS = {'off', 'false', '0'} 25 | TRUTHY_STRINGS = {'on', 'true', '1'} 26 | 27 | DUMP_PATH = '/checkpoint/%s/dumped' % getpass.getuser() 28 | DYNAMIC_COEFF = ['lambda_clm', 'lambda_mlm', 'lambda_pc', 'lambda_ae', 'lambda_mt', 'lambda_bt', 'lambda_mass', 'lambda_bmt', 'lambda_span'] 29 | 30 | 31 | class AttrDict(dict): 32 | def __init__(self, *args, **kwargs): 33 | super(AttrDict, self).__init__(*args, **kwargs) 34 | self.__dict__ = self 35 | 36 | 37 | def bool_flag(s): 38 | """ 39 | Parse boolean arguments from the command line. 40 | """ 41 | if s.lower() in FALSY_STRINGS: 42 | return False 43 | elif s.lower() in TRUTHY_STRINGS: 44 | return True 45 | else: 46 | raise argparse.ArgumentTypeError("Invalid value for a boolean flag!") 47 | 48 | 49 | def initialize_exp(params): 50 | """ 51 | Initialize the experience: 52 | - dump parameters 53 | - create a logger 54 | """ 55 | # dump parameters 56 | get_dump_path(params) 57 | pickle.dump(params, open(os.path.join(params.dump_path, 'params.pkl'), 'wb')) 58 | 59 | # get running command 60 | command = ["python", sys.argv[0]] 61 | for x in sys.argv[1:]: 62 | if x.startswith('--'): 63 | assert '"' not in x and "'" not in x 64 | command.append(x) 65 | else: 66 | assert "'" not in x 67 | if re.match('^[a-zA-Z0-9_]+$', x): 68 | command.append("%s" % x) 69 | else: 70 | command.append("'%s'" % x) 71 | command = ' '.join(command) 72 | params.command = command + ' --exp_id "%s"' % params.exp_id 73 | 74 | # check experiment name 75 | assert len(params.exp_name.strip()) > 0 76 | 77 | # create a logger 78 | logger = create_logger(os.path.join(params.dump_path, 'train.log'), rank=getattr(params, 'global_rank', 0)) 79 | logger.info("============ Initialized logger ============") 80 | logger.info("\n".join("%s: %s" % (k, str(v)) 81 | for k, v in sorted(dict(vars(params)).items()))) 82 | logger.info("The experiment will be stored in %s\n" % params.dump_path) 83 | logger.info("Running command: %s" % command) 84 | logger.info("") 85 | return logger 86 | 87 | 88 | def get_dump_path(params): 89 | """ 90 | Create a directory to store the experiment. 91 | """ 92 | dump_path = DUMP_PATH if params.dump_path == '' else params.dump_path 93 | assert len(params.exp_name) > 0 94 | 95 | # create the sweep path if it does not exist 96 | sweep_path = os.path.join(dump_path, params.exp_name) 97 | if not os.path.exists(sweep_path): 98 | subprocess.Popen("mkdir -p %s" % sweep_path, shell=True).wait() 99 | 100 | # create an ID for the job if it is not given in the parameters. 101 | # if we run on the cluster, the job ID is the one of Chronos. 102 | # otherwise, it is randomly generated 103 | if params.exp_id == '': 104 | chronos_job_id = os.environ.get('CHRONOS_JOB_ID') 105 | slurm_job_id = os.environ.get('SLURM_JOB_ID') 106 | assert chronos_job_id is None or slurm_job_id is None 107 | exp_id = chronos_job_id if chronos_job_id is not None else slurm_job_id 108 | if exp_id is None: 109 | chars = 'abcdefghijklmnopqrstuvwxyz0123456789' 110 | while True: 111 | exp_id = ''.join(random.choice(chars) for _ in range(10)) 112 | if not os.path.isdir(os.path.join(sweep_path, exp_id)): 113 | break 114 | else: 115 | assert exp_id.isdigit() 116 | params.exp_id = exp_id 117 | 118 | # create the dump folder / update parameters 119 | params.dump_path = os.path.join(sweep_path, params.exp_id) 120 | if not os.path.isdir(params.dump_path): 121 | subprocess.Popen("mkdir -p %s" % params.dump_path, shell=True).wait() 122 | 123 | 124 | class AdamInverseSqrtWithWarmup(optim.Adam): 125 | """ 126 | Decay the LR based on the inverse square root of the update number. 127 | We also support a warmup phase where we linearly increase the learning rate 128 | from some initial learning rate (`warmup-init-lr`) until the configured 129 | learning rate (`lr`). Thereafter we decay proportional to the number of 130 | updates, with a decay factor set to align with the configured learning rate. 131 | During warmup: 132 | lrs = torch.linspace(warmup_init_lr, lr, warmup_updates) 133 | lr = lrs[update_num] 134 | After warmup: 135 | lr = decay_factor / sqrt(update_num) 136 | where 137 | decay_factor = lr * sqrt(warmup_updates) 138 | """ 139 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 140 | weight_decay=0, warmup_updates=4000, warmup_init_lr=1e-7): 141 | super().__init__( 142 | params, 143 | lr=warmup_init_lr, 144 | betas=betas, 145 | eps=eps, 146 | weight_decay=weight_decay, 147 | ) 148 | self.warmup_updates = warmup_updates 149 | self.warmup_init_lr = warmup_init_lr 150 | # linearly warmup for the first warmup_updates 151 | warmup_end_lr = lr 152 | self.lr_step = (warmup_end_lr - warmup_init_lr) / warmup_updates 153 | # then, decay prop. to the inverse square root of the update number 154 | self.decay_factor = warmup_end_lr * warmup_updates ** 0.5 155 | for param_group in self.param_groups: 156 | param_group['num_updates'] = 0 157 | 158 | def get_lr_for_step(self, num_updates): 159 | # update learning rate 160 | if num_updates < self.warmup_updates: 161 | return self.warmup_init_lr + num_updates * self.lr_step 162 | else: 163 | return self.decay_factor * (num_updates ** -0.5) 164 | 165 | def step(self, closure=None): 166 | super().step(closure) 167 | for param_group in self.param_groups: 168 | param_group['num_updates'] += 1 169 | param_group['lr'] = self.get_lr_for_step(param_group['num_updates']) 170 | 171 | 172 | def get_optimizer(parameters, s): 173 | """ 174 | Parse optimizer parameters. 175 | Input should be of the form: 176 | - "sgd,lr=0.01" 177 | - "adagrad,lr=0.1,lr_decay=0.05" 178 | """ 179 | if "," in s: 180 | method = s[:s.find(',')] 181 | optim_params = {} 182 | for x in s[s.find(',') + 1:].split(','): 183 | split = x.split('=') 184 | assert len(split) == 2 185 | assert re.match("^[+-]?(\d+(\.\d*)?|\.\d+)$", split[1]) is not None 186 | optim_params[split[0]] = float(split[1]) 187 | else: 188 | method = s 189 | optim_params = {} 190 | 191 | if method == 'adadelta': 192 | optim_fn = optim.Adadelta 193 | elif method == 'adagrad': 194 | optim_fn = optim.Adagrad 195 | elif method == 'adam': 196 | optim_fn = optim.Adam 197 | optim_params['betas'] = (optim_params.get('beta1', 0.9), optim_params.get('beta2', 0.999)) 198 | optim_params.pop('beta1', None) 199 | optim_params.pop('beta2', None) 200 | elif method == 'adam_inverse_sqrt': 201 | optim_fn = AdamInverseSqrtWithWarmup 202 | optim_params['betas'] = (optim_params.get('beta1', 0.9), optim_params.get('beta2', 0.999)) 203 | optim_params.pop('beta1', None) 204 | optim_params.pop('beta2', None) 205 | elif method == 'adamax': 206 | optim_fn = optim.Adamax 207 | elif method == 'asgd': 208 | optim_fn = optim.ASGD 209 | elif method == 'rmsprop': 210 | optim_fn = optim.RMSprop 211 | elif method == 'rprop': 212 | optim_fn = optim.Rprop 213 | elif method == 'sgd': 214 | optim_fn = optim.SGD 215 | assert 'lr' in optim_params 216 | else: 217 | raise Exception('Unknown optimization method: "%s"' % method) 218 | 219 | # check that we give good parameters to the optimizer 220 | expected_args = inspect.getargspec(optim_fn.__init__)[0] 221 | assert expected_args[:2] == ['self', 'params'] 222 | if not all(k in expected_args[2:] for k in optim_params.keys()): 223 | raise Exception('Unexpected parameters: expected "%s", got "%s"' % ( 224 | str(expected_args[2:]), str(optim_params.keys()))) 225 | 226 | return optim_fn(parameters, **optim_params) 227 | 228 | 229 | def to_cuda(*args): 230 | """ 231 | Move tensors to CUDA. 232 | """ 233 | return [None if x is None else x.cuda() for x in args] 234 | 235 | 236 | def restore_segmentation(path): 237 | """ 238 | Take a file segmented with BPE and restore it to its original segmentation. 239 | """ 240 | assert os.path.isfile(path) 241 | restore_cmd = "sed -i -r 's/(@@ )|(@@ ?$)//g' %s" 242 | subprocess.Popen(restore_cmd % path, shell=True).wait() 243 | 244 | 245 | def parse_lambda_config(params): 246 | """ 247 | Parse the configuration of lambda coefficient (for scheduling). 248 | x = "3" # lambda will be a constant equal to x 249 | x = "0:1,1000:0" # lambda will start from 1 and linearly decrease to 0 during the first 1000 iterations 250 | x = "0:0,1000:0,2000:1" # lambda will be equal to 0 for the first 1000 iterations, then will linearly increase to 1 until iteration 2000 251 | """ 252 | for name in DYNAMIC_COEFF: 253 | x = getattr(params, name) 254 | split = x.split(',') 255 | if len(split) == 1: 256 | setattr(params, name, float(x)) 257 | setattr(params, name + '_config', None) 258 | else: 259 | split = [s.split(':') for s in split] 260 | assert all(len(s) == 2 for s in split) 261 | assert all(k.isdigit() for k, _ in split) 262 | assert all(int(split[i][0]) < int(split[i + 1][0]) for i in range(len(split) - 1)) 263 | setattr(params, name, float(split[0][1])) 264 | setattr(params, name + '_config', [(int(k), float(v)) for k, v in split]) 265 | 266 | 267 | def get_lambda_value(config, n_iter): 268 | """ 269 | Compute a lambda value according to its schedule configuration. 270 | """ 271 | ranges = [i for i in range(len(config) - 1) if config[i][0] <= n_iter < config[i + 1][0]] 272 | if len(ranges) == 0: 273 | assert n_iter >= config[-1][0] 274 | return config[-1][1] 275 | assert len(ranges) == 1 276 | i = ranges[0] 277 | x_a, y_a = config[i] 278 | x_b, y_b = config[i + 1] 279 | return y_a + (n_iter - x_a) * float(y_b - y_a) / float(x_b - x_a) 280 | 281 | 282 | def update_lambdas(params, n_iter): 283 | """ 284 | Update all lambda coefficients. 285 | """ 286 | for name in DYNAMIC_COEFF: 287 | config = getattr(params, name + '_config') 288 | if config is not None: 289 | setattr(params, name, get_lambda_value(config, n_iter)) 290 | 291 | 292 | def set_sampling_probs(data, params): 293 | """ 294 | Set the probability of sampling specific languages / language pairs during training. 295 | """ 296 | coeff = params.lg_sampling_factor 297 | if coeff == -1: 298 | return 299 | assert coeff > 0 300 | 301 | # monolingual data 302 | params.mono_list = [k for k, v in data['mono_stream'].items() if 'train' in v] 303 | if len(params.mono_list) > 0: 304 | probs = np.array([1.0 * len(data['mono_stream'][lang]['train']) for lang in params.mono_list]) 305 | probs /= probs.sum() 306 | probs = np.array([p ** coeff for p in probs]) 307 | probs /= probs.sum() 308 | params.mono_probs = probs 309 | 310 | # parallel data 311 | params.para_list = [k for k, v in data['para'].items() if 'train' in v] 312 | if len(params.para_list) > 0: 313 | probs = np.array([1.0 * len(data['para'][(l1, l2)]['train']) for (l1, l2) in params.para_list]) 314 | probs /= probs.sum() 315 | probs = np.array([p ** coeff for p in probs]) 316 | probs /= probs.sum() 317 | params.para_probs = probs 318 | 319 | 320 | def concat_batches(x1, len1, lang1_id, x2, len2, lang2_id, pad_idx, eos_idx, reset_positions): 321 | """ 322 | Concat batches with different languages. 323 | """ 324 | assert reset_positions is False or lang1_id != lang2_id 325 | lengths = len1 + len2 326 | if not reset_positions: 327 | lengths -= 1 328 | slen, bs = lengths.max().item(), lengths.size(0) 329 | 330 | x = x1.new(slen, bs).fill_(pad_idx) 331 | x[:len1.max().item()].copy_(x1) 332 | positions = torch.arange(slen)[:, None].repeat(1, bs).to(x1.device) 333 | langs = x1.new(slen, bs).fill_(lang1_id) 334 | 335 | for i in range(bs): 336 | l1 = len1[i] if reset_positions else len1[i] - 1 337 | x[l1:l1 + len2[i], i].copy_(x2[:len2[i], i]) 338 | if reset_positions: 339 | positions[l1:, i] -= len1[i] 340 | langs[l1:, i] = lang2_id 341 | 342 | assert (x == eos_idx).long().sum().item() == (4 if reset_positions else 3) * bs 343 | 344 | return x, lengths, positions, langs 345 | 346 | 347 | def truncate(x, lengths, max_len, eos_index): 348 | """ 349 | Truncate long sentences. 350 | """ 351 | if lengths.max().item() > max_len: 352 | x = x[:max_len].clone() 353 | lengths = lengths.clone() 354 | for i in range(len(lengths)): 355 | if lengths[i] > max_len: 356 | lengths[i] = max_len 357 | x[max_len - 1, i] = eos_index 358 | return x, lengths 359 | 360 | 361 | def shuf_order(langs, params=None, n=5): 362 | """ 363 | Randomize training order. 364 | """ 365 | if len(langs) == 0: 366 | return [] 367 | 368 | if params is None: 369 | return [langs[i] for i in np.random.permutation(len(langs))] 370 | 371 | # sample monolingual and parallel languages separately 372 | mono = [l1 for l1, l2 in langs if l2 is None] 373 | para = [(l1, l2) for l1, l2 in langs if l2 is not None] 374 | 375 | # uniform / weighted sampling 376 | if params.lg_sampling_factor == -1: 377 | p_mono = None 378 | p_para = None 379 | else: 380 | p_mono = np.array([params.mono_probs[params.mono_list.index(k)] for k in mono]) 381 | p_para = np.array([params.para_probs[params.para_list.index(tuple(sorted(k)))] for k in para]) 382 | p_mono = p_mono / p_mono.sum() 383 | p_para = p_para / p_para.sum() 384 | 385 | s_mono = [mono[i] for i in np.random.choice(len(mono), size=min(n, len(mono)), p=p_mono, replace=True)] if len(mono) > 0 else [] 386 | s_para = [para[i] for i in np.random.choice(len(para), size=min(n, len(para)), p=p_para, replace=True)] if len(para) > 0 else [] 387 | 388 | assert len(s_mono) + len(s_para) > 0 389 | return [(lang, None) for lang in s_mono] + s_para 390 | -------------------------------------------------------------------------------- /MASS-unsupNMT/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | # 4 | # Copyright (c) 2019-present, Facebook, Inc. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the license found in the 8 | # NOTICE FILE in the root directory of this source tree. 9 | # 10 | 11 | import json 12 | import argparse 13 | import torch 14 | import numpy as np 15 | from torch import nn 16 | 17 | from src.slurm import init_signal_handler, init_distributed_mode 18 | from src.data.loader import check_data_params, load_data 19 | from src.utils import bool_flag, initialize_exp, set_sampling_probs, shuf_order 20 | from src.model import check_model_params, build_model 21 | from src.trainer import SingleTrainer, EncDecTrainer 22 | from src.evaluation.evaluator import SingleEvaluator, EncDecEvaluator 23 | 24 | import apex 25 | from src.fp16 import network_to_half 26 | 27 | 28 | def get_parser(): 29 | """ 30 | Generate a parameters parser. 31 | """ 32 | # parse parameters 33 | parser = argparse.ArgumentParser(description="Language transfer") 34 | 35 | # main parameters 36 | parser.add_argument("--dump_path", type=str, default="./dumped/", 37 | help="Experiment dump path") 38 | parser.add_argument("--exp_name", type=str, default="", 39 | help="Experiment name") 40 | parser.add_argument("--save_periodic", type=int, default=0, 41 | help="Save the model periodically (0 to disable)") 42 | parser.add_argument("--exp_id", type=str, default="", 43 | help="Experiment ID") 44 | 45 | # float16 46 | parser.add_argument("--fp16", type=bool_flag, default=False, 47 | help="Run model with float16") 48 | 49 | # only use an encoder (use a specific decoder for machine translation) 50 | parser.add_argument("--encoder_only", type=bool_flag, default=True, 51 | help="Only use an encoder") 52 | parser.add_argument("--english_only", type=bool_flag, default=False, 53 | help="Only use english domain (equal to only use one language)") 54 | 55 | # model parameters 56 | parser.add_argument("--emb_dim", type=int, default=512, 57 | help="Embedding layer size") 58 | parser.add_argument("--n_layers", type=int, default=4, 59 | help="Number of Transformer layers") 60 | parser.add_argument("--n_dec_layers", type=int, default=6, 61 | help="Number of Decoder Transformer layers") 62 | parser.add_argument("--n_heads", type=int, default=8, 63 | help="Number of Transformer heads") 64 | parser.add_argument("--dropout", type=float, default=0, 65 | help="Dropout") 66 | parser.add_argument("--attention_dropout", type=float, default=0, 67 | help="Dropout in the attention layer") 68 | parser.add_argument("--gelu_activation", type=bool_flag, default=False, 69 | help="Use a GELU activation instead of ReLU") 70 | parser.add_argument("--share_inout_emb", type=bool_flag, default=True, 71 | help="Share input and output embeddings") 72 | parser.add_argument("--sinusoidal_embeddings", type=bool_flag, default=False, 73 | help="Use sinusoidal embeddings") 74 | parser.add_argument("--attention_setting", type=str, default="v1", choices=["v1", "v2"], 75 | help="Setting for attention module, benefits for distinguish language") 76 | 77 | # adaptive softmax 78 | parser.add_argument("--asm", type=bool_flag, default=False, 79 | help="Use adaptive softmax") 80 | if parser.parse_known_args()[0].asm: 81 | parser.add_argument("--asm_cutoffs", type=str, default="8000,20000", 82 | help="Adaptive softmax cutoffs") 83 | parser.add_argument("--asm_div_value", type=float, default=4, 84 | help="Adaptive softmax cluster sizes ratio") 85 | 86 | # causal language modeling task parameters 87 | parser.add_argument("--context_size", type=int, default=0, 88 | help="Context size (0 means that the first elements in sequences won't have any context)") 89 | 90 | # masked language modeling task parameters 91 | parser.add_argument("--word_pred", type=float, default=0.15, 92 | help="Fraction of words for which we need to make a prediction") 93 | parser.add_argument("--sample_alpha", type=float, default=0, 94 | help="Exponent for transforming word counts to probabilities (~word2vec sampling)") 95 | parser.add_argument("--word_mask_keep_rand", type=str, default="0.8,0.1,0.1", 96 | help="Fraction of words to mask out / keep / randomize, among the words to predict") 97 | 98 | # input sentence noise 99 | parser.add_argument("--word_shuffle", type=float, default=0, 100 | help="Randomly shuffle input words (0 to disable)") 101 | parser.add_argument("--word_dropout", type=float, default=0, 102 | help="Randomly dropout input words (0 to disable)") 103 | parser.add_argument("--word_blank", type=float, default=0, 104 | help="Randomly blank input words (0 to disable)") 105 | parser.add_argument("--word_mass", type=float, default=0, 106 | help="Randomly mask input words (0 to disable)") 107 | 108 | # data 109 | parser.add_argument("--data_path", type=str, default="", 110 | help="Data path") 111 | parser.add_argument("--lgs", type=str, default="", 112 | help="Languages (lg1-lg2-lg3 .. ex: en-fr-es-de)") 113 | parser.add_argument("--max_vocab", type=int, default=-1, 114 | help="Maximum vocabulary size (-1 to disable)") 115 | parser.add_argument("--min_count", type=int, default=0, 116 | help="Minimum vocabulary count") 117 | parser.add_argument("--lg_sampling_factor", type=float, default=-1, 118 | help="Language sampling factor") 119 | 120 | # batch parameters 121 | parser.add_argument("--bptt", type=int, default=256, 122 | help="Sequence length") 123 | parser.add_argument("--min_len", type=int, default=0, 124 | help="Minimum length of sentences (after BPE)") 125 | parser.add_argument("--max_len", type=int, default=100, 126 | help="Maximum length of sentences (after BPE)") 127 | parser.add_argument("--group_by_size", type=bool_flag, default=True, 128 | help="Sort sentences by size during the training") 129 | parser.add_argument("--batch_size", type=int, default=32, 130 | help="Number of sentences per batch") 131 | parser.add_argument("--max_batch_size", type=int, default=0, 132 | help="Maximum number of sentences per batch (used in combination with tokens_per_batch, 0 to disable)") 133 | parser.add_argument("--tokens_per_batch", type=int, default=-1, 134 | help="Number of tokens per batch") 135 | 136 | # training parameters 137 | parser.add_argument("--split_data", type=bool_flag, default=False, 138 | help="Split data across workers of a same node") 139 | parser.add_argument("--optimizer", type=str, default="adam,lr=0.0001", 140 | help="Optimizer (SGD / RMSprop / Adam, etc.)") 141 | parser.add_argument("--clip_grad_norm", type=float, default=5, 142 | help="Clip gradients norm (0 to disable)") 143 | parser.add_argument("--epoch_size", type=int, default=100000, 144 | help="Epoch size / evaluation frequency (-1 for parallel data size)") 145 | parser.add_argument("--max_epoch", type=int, default=100000, 146 | help="Maximum epoch size") 147 | parser.add_argument("--stopping_criterion", type=str, default="", 148 | help="Stopping criterion, and number of non-increase before stopping the experiment") 149 | parser.add_argument("--validation_metrics", type=str, default="", 150 | help="Validation metrics") 151 | 152 | # training coefficients 153 | parser.add_argument("--lambda_mlm", type=str, default="1", 154 | help="Prediction coefficient (MLM)") 155 | parser.add_argument("--lambda_clm", type=str, default="1", 156 | help="Causal coefficient (LM)") 157 | parser.add_argument("--lambda_bmt", type=str, default="1", 158 | help="Back Parallel coefficient") 159 | parser.add_argument("--lambda_pc", type=str, default="1", 160 | help="PC coefficient") 161 | parser.add_argument("--lambda_ae", type=str, default="1", 162 | help="AE coefficient") 163 | parser.add_argument("--lambda_mt", type=str, default="1", 164 | help="MT coefficient") 165 | parser.add_argument("--lambda_bt", type=str, default="1", 166 | help="BT coefficient") 167 | parser.add_argument("--lambda_mass", type=str, default="1", 168 | help="MASS coefficient") 169 | parser.add_argument("--lambda_span", type=str, default="10000", 170 | help="Span coefficient") 171 | 172 | # training steps 173 | parser.add_argument("--clm_steps", type=str, default="", 174 | help="Causal prediction steps (CLM)") 175 | parser.add_argument("--mlm_steps", type=str, default="", 176 | help="Masked prediction steps (MLM / TLM)") 177 | parser.add_argument("--bmt_steps", type=str, default="", 178 | help="Back Machine Translation step") 179 | parser.add_argument("--mass_steps", type=str, default="", 180 | help="MASS prediction steps") 181 | parser.add_argument("--mt_steps", type=str, default="", 182 | help="Machine translation steps") 183 | parser.add_argument("--ae_steps", type=str, default="", 184 | help="Denoising auto-encoder steps") 185 | parser.add_argument("--bt_steps", type=str, default="", 186 | help="Back-translation steps") 187 | parser.add_argument("--pc_steps", type=str, default="", 188 | help="Parallel classification steps") 189 | 190 | # reload a pretrained model 191 | parser.add_argument("--reload_model", type=str, default="", 192 | help="Reload a pretrained model") 193 | 194 | # beam search (for MT only) 195 | parser.add_argument("--beam_size", type=int, default=1, 196 | help="Beam size, default = 1 (greedy decoding)") 197 | parser.add_argument("--length_penalty", type=float, default=1, 198 | help="Length penalty, values < 1.0 favor shorter sentences, while values > 1.0 favor longer ones.") 199 | parser.add_argument("--early_stopping", type=bool_flag, default=False, 200 | help="Early stopping, stop as soon as we have `beam_size` hypotheses, although longer ones may have better scores.") 201 | 202 | # evaluation 203 | parser.add_argument("--eval_bleu", type=bool_flag, default=False, 204 | help="Evaluate BLEU score during MT training") 205 | parser.add_argument("--eval_only", type=bool_flag, default=False, 206 | help="Only run evaluations") 207 | 208 | # debug 209 | parser.add_argument("--debug_train", type=bool_flag, default=False, 210 | help="Use valid sets for train sets (faster loading)") 211 | parser.add_argument("--debug_slurm", type=bool_flag, default=False, 212 | help="Debug multi-GPU / multi-node within a SLURM job") 213 | 214 | # multi-gpu / multi-node 215 | parser.add_argument("--local_rank", type=int, default=-1, 216 | help="Multi-GPU - Local rank") 217 | parser.add_argument("--master_port", type=int, default=-1, 218 | help="Master port (for multi-node SLURM jobs)") 219 | 220 | return parser 221 | 222 | 223 | def main(params): 224 | 225 | # initialize the multi-GPU / multi-node training 226 | init_distributed_mode(params) 227 | 228 | # initialize the experiment 229 | logger = initialize_exp(params) 230 | 231 | # initialize SLURM signal handler for time limit / pre-emption 232 | init_signal_handler() 233 | 234 | # load data 235 | data = load_data(params) 236 | 237 | # build model 238 | if params.encoder_only: 239 | model = build_model(params, data['dico']) 240 | else: 241 | encoder, decoder = build_model(params, data['dico']) 242 | 243 | # float16 244 | if params.fp16: 245 | assert torch.backends.cudnn.enabled 246 | if params.encoder_only: 247 | model = network_to_half(model) 248 | else: 249 | encoder = network_to_half(encoder) 250 | decoder = network_to_half(decoder) 251 | 252 | # distributed 253 | if params.multi_gpu: 254 | logger.info("Using nn.parallel.DistributedDataParallel ...") 255 | if params.encoder_only: 256 | model = apex.parallel.DistributedDataParallel(model, delay_allreduce=True) 257 | else: 258 | encoder = apex.parallel.DistributedDataParallel(encoder, delay_allreduce=True) 259 | decoder = apex.parallel.DistributedDataParallel(decoder, delay_allreduce=True) 260 | 261 | # build trainer, reload potential checkpoints / build evaluator 262 | if params.encoder_only: 263 | trainer = SingleTrainer(model, data, params) 264 | evaluator = SingleEvaluator(trainer, data, params) 265 | else: 266 | trainer = EncDecTrainer(encoder, decoder, data, params) 267 | evaluator = EncDecEvaluator(trainer, data, params) 268 | 269 | # evaluation 270 | if params.eval_only: 271 | scores = evaluator.run_all_evals(trainer) 272 | for k, v in scores.items(): 273 | logger.info("%s -> %.6f" % (k, v)) 274 | logger.info("__log__:%s" % json.dumps(scores)) 275 | exit() 276 | 277 | # set sampling probabilities for training 278 | set_sampling_probs(data, params) 279 | 280 | # language model training 281 | for _ in range(params.max_epoch): 282 | 283 | logger.info("============ Starting epoch %i ... ============" % trainer.epoch) 284 | 285 | trainer.n_sentences = 0 286 | 287 | while trainer.n_sentences < trainer.epoch_size: 288 | 289 | # CLM steps 290 | for lang1, lang2 in shuf_order(params.clm_steps, params): 291 | trainer.clm_step(lang1, lang2, params.lambda_clm) 292 | 293 | # MLM steps (also includes TLM if lang2 is not None) 294 | for lang1, lang2 in shuf_order(params.mlm_steps, params): 295 | trainer.mlm_step(lang1, lang2, params.lambda_mlm) 296 | 297 | # parallel classification steps 298 | for lang1, lang2 in shuf_order(params.pc_steps, params): 299 | trainer.pc_step(lang1, lang2, params.lambda_pc) 300 | 301 | # denoising auto-encoder steps 302 | for lang in shuf_order(params.ae_steps): 303 | trainer.mt_step(lang, lang, params.lambda_ae) 304 | 305 | # mass prediction steps 306 | for lang in shuf_order(params.mass_steps): 307 | trainer.mass_step(lang, params.lambda_mass) 308 | 309 | # machine translation steps 310 | for lang1, lang2 in shuf_order(params.mt_steps, params): 311 | trainer.mt_step(lang1, lang2, params.lambda_mt) 312 | 313 | # back-translation steps 314 | for lang1, lang2, lang3 in shuf_order(params.bt_steps): 315 | trainer.bt_step(lang1, lang2, lang3, params.lambda_bt) 316 | 317 | # back-parallel steps 318 | for lang1, lang2 in shuf_order(params.bmt_steps, params): 319 | trainer.bmt_step(lang1, lang2, params.lambda_bmt) 320 | 321 | trainer.iter() 322 | 323 | logger.info("============ End of epoch %i ============" % trainer.epoch) 324 | 325 | # evaluate perplexity 326 | scores = evaluator.run_all_evals(trainer) 327 | 328 | # print / JSON log 329 | for k, v in scores.items(): 330 | logger.info("%s -> %.6f" % (k, v)) 331 | if params.is_master: 332 | logger.info("__log__:%s" % json.dumps(scores)) 333 | 334 | # end of epoch 335 | trainer.save_best_model(scores) 336 | trainer.save_periodic() 337 | trainer.end_epoch(scores) 338 | 339 | 340 | if __name__ == '__main__': 341 | 342 | # generate parser / parse parameters 343 | parser = get_parser() 344 | params = parser.parse_args() 345 | 346 | # check parameters 347 | check_data_params(params) 348 | check_model_params(params) 349 | 350 | # run experiment 351 | main(params) 352 | -------------------------------------------------------------------------------- /MASS-unsupNMT/translate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # NOTICE FILE in the root directory of this source tree. 6 | # 7 | # Translate sentences from the input stream. 8 | # The model will be faster is sentences are sorted by length. 9 | # Input sentences must have the same tokenization and BPE codes than the ones used in the model. 10 | # 11 | # Usage: 12 | # cat source_sentences.bpe | \ 13 | # python translate.py --exp_name translate \ 14 | # --src_lang en --tgt_lang fr \ 15 | # --model_path trained_model.pth --output_path output 16 | # 17 | 18 | import os 19 | import io 20 | import sys 21 | import argparse 22 | import torch 23 | 24 | from src.utils import AttrDict 25 | from src.utils import bool_flag, initialize_exp 26 | from src.data.dictionary import Dictionary 27 | from src.model.transformer import TransformerModel 28 | 29 | from src.fp16 import network_to_half 30 | 31 | 32 | def get_parser(): 33 | """ 34 | Generate a parameters parser. 35 | """ 36 | # parse parameters 37 | parser = argparse.ArgumentParser(description="Translate sentences") 38 | 39 | # main parameters 40 | parser.add_argument("--dump_path", type=str, default="./dumped/", help="Experiment dump path") 41 | parser.add_argument("--exp_name", type=str, default="", help="Experiment name") 42 | parser.add_argument("--exp_id", type=str, default="", help="Experiment ID") 43 | parser.add_argument("--fp16", type=bool_flag, default=False, help="Run model with float16") 44 | parser.add_argument("--batch_size", type=int, default=32, help="Number of sentences per batch") 45 | 46 | # model / output paths 47 | parser.add_argument("--model_path", type=str, default="", help="Model path") 48 | parser.add_argument("--output_path", type=str, default="", help="Output path") 49 | 50 | parser.add_argument("--beam", type=int, default=1, help="Beam size") 51 | parser.add_argument("--length_penalty", type=float, default=1, help="length penalty") 52 | 53 | # parser.add_argument("--max_vocab", type=int, default=-1, help="Maximum vocabulary size (-1 to disable)") 54 | # parser.add_argument("--min_count", type=int, default=0, help="Minimum vocabulary count") 55 | 56 | # source language / target language 57 | parser.add_argument("--src_lang", type=str, default="", help="Source language") 58 | parser.add_argument("--tgt_lang", type=str, default="", help="Target language") 59 | 60 | return parser 61 | 62 | 63 | def main(params): 64 | 65 | # initialize the experiment 66 | logger = initialize_exp(params) 67 | 68 | # generate parser / parse parameters 69 | parser = get_parser() 70 | params = parser.parse_args() 71 | reloaded = torch.load(params.model_path) 72 | model_params = AttrDict(reloaded['params']) 73 | logger.info("Supported languages: %s" % ", ".join(model_params.lang2id.keys())) 74 | 75 | # update dictionary parameters 76 | for name in ['n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index', 'mask_index']: 77 | setattr(params, name, getattr(model_params, name)) 78 | 79 | # build dictionary / build encoder / build decoder / reload weights 80 | dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts']) 81 | encoder = TransformerModel(model_params, dico, is_encoder=True, with_output=True).cuda().eval() 82 | decoder = TransformerModel(model_params, dico, is_encoder=False, with_output=True).cuda().eval() 83 | encoder.load_state_dict(reloaded['encoder']) 84 | decoder.load_state_dict(reloaded['decoder']) 85 | params.src_id = model_params.lang2id[params.src_lang] 86 | params.tgt_id = model_params.lang2id[params.tgt_lang] 87 | 88 | # float16 89 | if params.fp16: 90 | assert torch.backends.cudnn.enabled 91 | encoder = network_to_half(encoder) 92 | decoder = network_to_half(decoder) 93 | 94 | # read sentences from stdin 95 | src_sent = [] 96 | for line in sys.stdin.readlines(): 97 | assert len(line.strip().split()) > 0 98 | src_sent.append(line) 99 | logger.info("Read %i sentences from stdin. Translating ..." % len(src_sent)) 100 | 101 | f = io.open(params.output_path, 'w', encoding='utf-8') 102 | 103 | for i in range(0, len(src_sent), params.batch_size): 104 | 105 | # prepare batch 106 | word_ids = [torch.LongTensor([dico.index(w) for w in s.strip().split()]) 107 | for s in src_sent[i:i + params.batch_size]] 108 | lengths = torch.LongTensor([len(s) + 2 for s in word_ids]) 109 | batch = torch.LongTensor(lengths.max().item(), lengths.size(0)).fill_(params.pad_index) 110 | batch[0] = params.eos_index 111 | for j, s in enumerate(word_ids): 112 | if lengths[j] > 2: # if sentence not empty 113 | batch[1:lengths[j] - 1, j].copy_(s) 114 | batch[lengths[j] - 1, j] = params.eos_index 115 | langs = batch.clone().fill_(params.src_id) 116 | 117 | # encode source batch and translate it 118 | encoded = encoder('fwd', x=batch.cuda(), lengths=lengths.cuda(), langs=langs.cuda(), causal=False) 119 | encoded = encoded.transpose(0, 1) 120 | if params.beam == 1: 121 | decoded, dec_lengths = decoder.generate(encoded, lengths.cuda(), params.tgt_id, max_len=int(1.5 * lengths.max().item() + 10)) 122 | else: 123 | decoded, dec_lengths = decoder.generate_beam( 124 | encoded, lengths.cuda(), params.tgt_id, beam_size=params.beam, 125 | length_penalty=params.length_penalty, 126 | early_stopping=False, 127 | max_len=int(1.5 * lengths.max().item() + 10)) 128 | 129 | # convert sentences to words 130 | for j in range(decoded.size(1)): 131 | 132 | # remove delimiters 133 | sent = decoded[:, j] 134 | delimiters = (sent == params.eos_index).nonzero().view(-1) 135 | assert len(delimiters) >= 1 and delimiters[0].item() == 0 136 | sent = sent[1:] if len(delimiters) == 1 else sent[1:delimiters[1]] 137 | 138 | # output translation 139 | source = src_sent[i + j].strip() 140 | target = " ".join([dico[sent[k].item()] for k in range(len(sent))]) 141 | sys.stderr.write("%i / %i: %s -> %s\n" % (i + j, len(src_sent), source, target)) 142 | f.write(target + "\n") 143 | 144 | f.close() 145 | 146 | 147 | if __name__ == '__main__': 148 | 149 | # generate parser / parse parameters 150 | parser = get_parser() 151 | params = parser.parse_args() 152 | 153 | # check parameters 154 | assert os.path.isfile(params.model_path) 155 | assert params.src_lang != '' and params.tgt_lang != '' and params.src_lang != params.tgt_lang 156 | assert params.output_path and not os.path.isfile(params.output_path) 157 | 158 | # translate 159 | with torch.no_grad(): 160 | main(params) 161 | -------------------------------------------------------------------------------- /MASS-unsupNMT/translate_ensemble.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | # 4 | # Translate sentences from the input stream. 5 | # The model will be faster is sentences are sorted by length. 6 | # Input sentences must have the same tokenization and BPE codes than the ones used in the model. 7 | # It also supports ensemble multiple models, beam search and length penlty. 8 | # 9 | # Usage: 10 | # cat source_sentences.bpe | \ 11 | # python translate.py --exp_name translate \ 12 | # --exp_id en-fr \ 13 | # --src_lang en --tgt_lang fr \ 14 | # --model_path model1.pth,model2.pth --output_path output \ 15 | # --beam 10 --length_penalty 1.1 16 | # 17 | 18 | import os 19 | import io 20 | import sys 21 | import argparse 22 | import torch 23 | import math 24 | import torch.nn as nn 25 | import torch.nn.functional as F 26 | 27 | from collections import OrderedDict 28 | 29 | from src.utils import AttrDict 30 | from src.utils import bool_flag, initialize_exp 31 | from src.data.dictionary import Dictionary 32 | from src.model.transformer import TransformerModel 33 | from src.model.transformer import BeamHypotheses 34 | 35 | from src.fp16 import network_to_half 36 | 37 | 38 | def get_parser(): 39 | """ 40 | Generate a parameters parser. 41 | """ 42 | # parse parameters 43 | parser = argparse.ArgumentParser(description="Translate sentences") 44 | 45 | # main parameters 46 | parser.add_argument("--dump_path", type=str, default="./dumped/", help="Experiment dump path") 47 | parser.add_argument("--exp_name", type=str, default="", help="Experiment name") 48 | parser.add_argument("--exp_id", type=str, default="", help="Experiment ID") 49 | parser.add_argument("--fp16", type=bool_flag, default=False, help="Run model with float16") 50 | parser.add_argument("--batch_size", type=int, default=32, help="Number of sentences per batch") 51 | 52 | # model / output paths 53 | parser.add_argument("--model_path", type=str, default="", help="Model path") 54 | parser.add_argument("--output_path", type=str, default="", help="Output path") 55 | 56 | parser.add_argument("--beam", type=int, default=1, help="Beam size") 57 | parser.add_argument("--length_penalty", type=float, default=1, help="length penalty") 58 | 59 | # source language / target language 60 | parser.add_argument("--src_lang", type=str, default="", help="Source language") 61 | parser.add_argument("--tgt_lang", type=str, default="", help="Target language") 62 | 63 | return parser 64 | 65 | 66 | def generate_beam(decoders, src_encodeds, src_len, tgt_lang_id, beam_size, length_penalty, early_stopping, max_len=200, params=None): 67 | assert params is not None 68 | 69 | src_encs = [] 70 | 71 | bs = len(src_len) 72 | n_words = params.n_words 73 | 74 | src_len = src_len.unsqueeze(1).expand(bs, beam_size).contiguous().view(-1) 75 | for i in range(len(src_encodeds)): 76 | src_encodeds[i] = src_encodeds[i].unsqueeze(1).expand((bs, beam_size) + src_encodeds[i].shape[1:]).contiguous().view((bs * beam_size,) + src_encodeds[i].shape[1:]) 77 | 78 | generated = src_len.new(max_len, bs * beam_size) 79 | generated.fill_(params.pad_index) 80 | generated[0].fill_(params.eos_index) 81 | 82 | generated_hyps = [BeamHypotheses(beam_size, max_len, length_penalty, early_stopping) for _ in range(bs)] 83 | 84 | positions = src_len.new(max_len).long() 85 | positions = torch.arange(max_len, out=positions).unsqueeze(1).expand_as(generated) 86 | 87 | langs = positions.clone().fill_(tgt_lang_id) 88 | beam_scores = src_encodeds[0].new(bs, beam_size).fill_(0) 89 | beam_scores[:, 1:] = -1e9 90 | beam_scores = beam_scores.view(-1) 91 | 92 | cur_len = 1 93 | caches = [{'slen': 0} for i in range(len(decoders))] 94 | done = [False for _ in range(bs)] 95 | 96 | while cur_len < max_len: 97 | avg_scores = [] 98 | #avg_scores = None 99 | for i, (src_enc, decoder) in enumerate(zip(src_encodeds, decoders)): 100 | tensor = decoder.forward( 101 | 'fwd', 102 | x=generated[:cur_len], 103 | lengths=src_len.new(bs * beam_size).fill_(cur_len), 104 | positions=positions[:cur_len], 105 | langs=langs[:cur_len], 106 | causal=True, 107 | src_enc=src_enc, 108 | src_len=src_len, 109 | cache=caches[i] 110 | ) 111 | assert tensor.size() == (1, bs * beam_size, decoder.dim) 112 | tensor = tensor.data[-1, :, :] # (bs * beam_size, dim) 113 | scores = decoder.pred_layer.get_scores(tensor) # (bs * beam_size, n_words) 114 | scores = F.log_softmax(scores, dim=-1) # (bs * beam_size, n_words) 115 | 116 | avg_scores.append(scores) 117 | 118 | avg_scores = torch.logsumexp(torch.stack(avg_scores, dim=0), dim=0) - math.log(len(decoders)) 119 | #avg_scores.div_(len(decoders)) 120 | _scores = avg_scores + beam_scores[:, None].expand_as(avg_scores) 121 | _scores = _scores.view(bs, beam_size * n_words) 122 | next_scores, next_words = torch.topk(_scores, 2 * beam_size, dim=1, largest=True, sorted=True) 123 | assert next_scores.size() == next_words.size() == (bs, 2 * beam_size) 124 | 125 | next_batch_beam = [] 126 | 127 | for sent_id in range(bs): 128 | 129 | # if we are done with this sentence 130 | done[sent_id] = done[sent_id] or generated_hyps[sent_id].is_done(next_scores[sent_id].max().item()) 131 | if done[sent_id]: 132 | next_batch_beam.extend([(0, params.pad_index, 0)] * beam_size) # pad the batch 133 | continue 134 | 135 | # next sentence beam content 136 | next_sent_beam = [] 137 | 138 | # next words for this sentence 139 | for idx, value in zip(next_words[sent_id], next_scores[sent_id]): 140 | 141 | # get beam and word IDs 142 | beam_id = idx // n_words 143 | word_id = idx % n_words 144 | 145 | # end of sentence, or next word 146 | if word_id == params.eos_index or cur_len + 1 == max_len: 147 | generated_hyps[sent_id].add(generated[:cur_len, sent_id * beam_size + beam_id].clone(), value.item()) 148 | else: 149 | next_sent_beam.append((value, word_id, sent_id * beam_size + beam_id)) 150 | 151 | # the beam for next step is full 152 | if len(next_sent_beam) == beam_size: 153 | break 154 | 155 | # update next beam content 156 | assert len(next_sent_beam) == 0 if cur_len + 1 == max_len else beam_size 157 | if len(next_sent_beam) == 0: 158 | next_sent_beam = [(0, params.pad_index, 0)] * beam_size # pad the batch 159 | next_batch_beam.extend(next_sent_beam) 160 | assert len(next_batch_beam) == beam_size * (sent_id + 1) 161 | 162 | # sanity check / prepare next batch 163 | assert len(next_batch_beam) == bs * beam_size 164 | beam_scores = beam_scores.new([x[0] for x in next_batch_beam]) 165 | beam_words = generated.new([x[1] for x in next_batch_beam]) 166 | beam_idx = src_len.new([x[2] for x in next_batch_beam]) 167 | 168 | # re-order batch and internal states 169 | generated = generated[:, beam_idx] 170 | generated[cur_len] = beam_words 171 | for cache in caches: 172 | for k in cache.keys(): 173 | if k != 'slen': 174 | cache[k] = (cache[k][0][beam_idx], cache[k][1][beam_idx]) 175 | 176 | # update current length 177 | cur_len = cur_len + 1 178 | 179 | # stop when we are done with each sentence 180 | if all(done): 181 | break 182 | 183 | tgt_len = src_len.new(bs) 184 | best = [] 185 | 186 | for i, hypotheses in enumerate(generated_hyps): 187 | best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1] 188 | tgt_len[i] = len(best_hyp) + 1 # +1 for the symbol 189 | best.append(best_hyp) 190 | 191 | # generate target batch 192 | decoded = src_len.new(tgt_len.max().item(), bs).fill_(params.pad_index) 193 | for i, hypo in enumerate(best): 194 | decoded[:tgt_len[i] - 1, i] = hypo 195 | decoded[tgt_len[i] - 1, i] = params.eos_index 196 | 197 | # sanity check 198 | assert (decoded == params.eos_index).sum() == 2 * bs 199 | 200 | return decoded, tgt_len 201 | 202 | 203 | def main(params): 204 | 205 | # initialize the experiment 206 | logger = initialize_exp(params) 207 | parser = get_parser() 208 | params = parser.parse_args() 209 | models_path = params.model_path.split(',') 210 | 211 | # generate parser / parse parameters 212 | models_reloaded = [] 213 | for model_path in models_path: 214 | models_reloaded.append(torch.load(model_path)) 215 | model_params = AttrDict(models_reloaded[0]['params']) 216 | logger.info("Supported languages: %s" % ", ".join(model_params.lang2id.keys())) 217 | 218 | # update dictionary parameters 219 | for name in ['n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index', 'mask_index']: 220 | setattr(params, name, getattr(model_params, name)) 221 | 222 | # build dictionary / build encoder / build decoder / reload weights 223 | dico = Dictionary(models_reloaded[0]['dico_id2word'], models_reloaded[0]['dico_word2id'], models_reloaded[0]['dico_counts']) 224 | params.src_id = model_params.lang2id[params.src_lang] 225 | params.tgt_id = model_params.lang2id[params.tgt_lang] 226 | 227 | encoders = [] 228 | decoders = [] 229 | 230 | def package_module(modules): 231 | state_dict = OrderedDict() 232 | for k, v in modules.items(): 233 | if k.startswith('module.'): 234 | state_dict[k[7:]] = v 235 | else: 236 | state_dict[k] = v 237 | return state_dict 238 | 239 | for reloaded in models_reloaded: 240 | encoder = TransformerModel(model_params, dico, is_encoder=True, with_output=True).cuda().eval() 241 | decoder = TransformerModel(model_params, dico, is_encoder=False, with_output=True).cuda().eval() 242 | encoder.load_state_dict(package_module(reloaded['encoder'])) 243 | decoder.load_state_dict(package_module(reloaded['decoder'])) 244 | 245 | # float16 246 | if params.fp16: 247 | assert torch.backends.cudnn.enabled 248 | encoder = network_to_half(encoder) 249 | decoder = network_to_half(decoder) 250 | 251 | encoders.append(encoder) 252 | decoders.append(decoder) 253 | 254 | #src_sent = ['Poly@@ gam@@ ie statt Demokratie .'] 255 | src_sent = [] 256 | for line in sys.stdin.readlines(): 257 | assert len(line.strip().split()) > 0 258 | src_sent.append(line) 259 | 260 | f = io.open(params.output_path, 'w', encoding='utf-8') 261 | 262 | for i in range(0, len(src_sent), params.batch_size): 263 | 264 | # prepare batch 265 | word_ids = [torch.LongTensor([dico.index(w) for w in s.strip().split()]) 266 | for s in src_sent[i:i + params.batch_size]] 267 | lengths = torch.LongTensor([len(s) + 2 for s in word_ids]) 268 | batch = torch.LongTensor(lengths.max().item(), lengths.size(0)).fill_(params.pad_index) 269 | batch[0] = params.eos_index 270 | for j, s in enumerate(word_ids): 271 | if lengths[j] > 2: # if sentence not empty 272 | batch[1:lengths[j] - 1, j].copy_(s) 273 | batch[lengths[j] - 1, j] = params.eos_index 274 | langs = batch.clone().fill_(params.src_id) 275 | 276 | # encode source batch and translate it 277 | encodeds = [] 278 | for encoder in encoders: 279 | encoded = encoder('fwd', x=batch.cuda(), lengths=lengths.cuda(), langs=langs.cuda(), causal=False) 280 | encoded = encoded.transpose(0, 1) 281 | encodeds.append(encoded) 282 | 283 | assert encoded.size(0) == lengths.size(0) 284 | 285 | decoded, dec_lengths = generate_beam(decoders, encodeds, lengths.cuda(), params.tgt_id, 286 | beam_size=params.beam, 287 | length_penalty=params.length_penalty, 288 | early_stopping=False, 289 | max_len=int(1.5 * lengths.max().item() + 10), params=params) 290 | 291 | # convert sentences to words 292 | for j in range(decoded.size(1)): 293 | 294 | # remove delimiters 295 | sent = decoded[:, j] 296 | delimiters = (sent == params.eos_index).nonzero().view(-1) 297 | assert len(delimiters) >= 1 and delimiters[0].item() == 0 298 | sent = sent[1:] if len(delimiters) == 1 else sent[1:delimiters[1]] 299 | 300 | # output translation 301 | source = src_sent[i + j].strip() 302 | target = " ".join([dico[sent[k].item()] for k in range(len(sent))]) 303 | sys.stderr.write("%i / %i: %s -> %s\n" % (i + j, len(src_sent), source, target)) 304 | f.write(target + "\n") 305 | 306 | f.close() 307 | 308 | 309 | if __name__ == '__main__': 310 | 311 | # generate parser / parse parameters 312 | parser = get_parser() 313 | params = parser.parse_args() 314 | 315 | # check parameters 316 | #assert os.path.isfile(params.model_path) 317 | assert params.src_lang != '' and params.tgt_lang != '' and params.src_lang != params.tgt_lang 318 | assert params.output_path and not os.path.isfile(params.output_path) 319 | 320 | # translate 321 | with torch.no_grad(): 322 | main(params) 323 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /figs/mass.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MASS/779f22fc47c8a256d8bc04826ebe1c8307063cbe/figs/mass.png --------------------------------------------------------------------------------