├── .gitignore
├── LICENSE
├── README.md
├── SECURITY.md
├── conlleval.pl
├── data_prep
├── bio_dataset.py
└── multi_lingual_amazon.py
├── data_processing_scripts
└── amazon
│ ├── pickle_dataset.py
│ └── process_dataset.py
├── layers.py
├── models.py
├── options.py
├── scripts
├── get_overall_perf_amazon.py
├── get_overall_perf_ner.py
├── train_amazon_3to1.sh
└── train_conll_ner_3to1.sh
├── train_cls_man_moe.py
├── train_tagging_man_moe.py
├── utils.py
└── vocab.py
/.gitignore:
--------------------------------------------------------------------------------
1 | save/
2 | data/
3 | logs/
4 |
5 | # Byte-compiled / optimized / DLL files
6 | __pycache__/
7 | *.py[cod]
8 | *$py.class
9 |
10 | # C extensions
11 | *.so
12 |
13 | # Distribution / packaging
14 | .Python
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 |
63 | # Flask stuff:
64 | instance/
65 | .webassets-cache
66 |
67 | # Scrapy stuff:
68 | .scrapy
69 |
70 | # Sphinx documentation
71 | docs/_build/
72 |
73 | # PyBuilder
74 | target/
75 |
76 | # Jupyter Notebook
77 | .ipynb_checkpoints
78 |
79 | # pyenv
80 | .python-version
81 |
82 | # celery beat schedule file
83 | celerybeat-schedule
84 |
85 | # SageMath parsed files
86 | *.sage.py
87 |
88 | # Environments
89 | .env
90 | .venv
91 | env/
92 | venv/
93 | ENV/
94 | env.bak/
95 | venv.bak/
96 |
97 | # Spyder project settings
98 | .spyderproject
99 | .spyproject
100 |
101 | # Rope project settings
102 | .ropeproject
103 |
104 | # mkdocs documentation
105 | /site
106 |
107 | # mypy
108 | .mypy_cache/
109 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation. All rights reserved.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Zero-Resource Multilingual Model Transfer
2 |
3 | This repo contains the source code for our ACL 2019 paper:
4 |
5 | [**Multi-Source Cross-Lingual Model Transfer: Learning What to Share**](https://www.aclweb.org/anthology/P19-1299)
6 |
7 | [Xilun Chen](http://www.cs.cornell.edu/~xlchen/),
8 | [Ahmed Hassan Awadallah](https://www.microsoft.com/en-us/research/people/hassanam/),
9 | [Hany Hassan](https://www.microsoft.com/en-us/research/people/hanyh/),
10 | [Wei Wang](https://www.microsoft.com/en-us/research/people/wawe/),
11 | [Claire Cardie](http://www.cs.cornell.edu/home/cardie/)
12 |
13 | The 57th Annual Meeting of the Association for Computational Linguistics (ACL 2019)
14 |
15 | [paper](https://www.aclweb.org/anthology/P19-1299),
16 | [arXiv](https://arxiv.org/abs/1810.03552),
17 | [bibtex](https://www.aclweb.org/anthology/P19-1299.bib)
18 |
19 | ## Introduction
20 | Modern NLP applications have enjoyed a great boost utilizing neural networks models. Such deep neural models, however, are not applicable to most human languages due to the lack of annotated training data for various NLP tasks. Cross-lingual transfer learning (CLTL) is a viable method for building NLP models for a low-resource target language by leveraging labeled data from other (source) languages. In this work, we focus on the multilingual transfer setting where training data in multiple source languages is leveraged to further boost target language performance.
21 |
22 | Unlike most existing methods that rely only on language-invariant features for CLTL, our approach coherently utilizes both **language-invariant** and **language-specific** features at instance level. Our model leverages adversarial networks to learn language-invariant features, and mixture-of-experts models to dynamically exploit the similarity between the target language and each individual source language. This enables our model to learn effectively what to share between various languages in the multilingual setup. Moreover, when coupled with unsupervised multilingual embeddings, our model can operate in a **zero-resource** setting where neither **target language training data** nor **cross-lingual resources** (e.g. parallel corpora or Machine Translation systems) are available. Our model achieves significant performance gains over prior art, as shown in an extensive set of experiments over multiple text classification and sequence tagging tasks including a large-scale industry dataset.
23 |
24 | ## Requirements
25 | - Python 3.6
26 | - PyTorch 0.4
27 | - PyTorchNet (for confusion matrix)
28 | - tqdm (for progress bar)
29 |
30 |
31 | ## File Structure
32 | ```
33 | .
34 | ├── LICENSE
35 | ├── README.md
36 | ├── conlleval.pl (official CoNLL evaluation script)
37 | ├── data_prep (data processing scripts)
38 | │ ├── bio_dataset.py (processing the CoNLL dataset)
39 | │ └── multi_lingual_amazon.py (processing the Amazon Review dataset)
40 | ├── data_processing_scripts (auxiliary scripts for dataset pre-processing)
41 | │ └── amazon
42 | │ ├── pickle_dataset.py
43 | │ └── process_dataset.py
44 | ├── layers.py (lower-level helper modules)
45 | ├── models.py (higher-level modules)
46 | ├── options.py (hyper-parameters aka. all the knobs you may want to turn)
47 | ├── scripts (scripts for training and evaluating the models)
48 | │ ├── get_overall_perf_amazon.py (evaluation script for Amazon Reviews)
49 | │ ├── get_overall_perf_ner.py (evaluation script for CoNLL NER)
50 | │ ├── train_amazon_3to1.sh (training script for Amazon Reviews)
51 | │ └── train_conll_ner_3to1.sh (training script for CoNLL NER)
52 | ├── train_cls_man_moe.py (training code for text classification)
53 | ├── train_tagging_man_moe.py (training code for sequence tagging)
54 | ├── utils.py (helper functions)
55 | └── vocab.py (building the vocabulary)
56 | ```
57 |
58 |
59 | ## Dataset
60 | The CoNLL [2002](https://www.clips.uantwerpen.be/conll2002/ner/), [2003](https://www.clips.uantwerpen.be/conll2003/ner/) and [Amazon](https://webis.de/data/webis-cls-10.html) datasets, as well as the multilingual word embeddings ([MUSE](https://github.com/facebookresearch/MUSE), [VecMap](https://github.com/artetxem/vecmap), [UMWE](https://github.com/ccsasuke/umwe)) are all publicly available online.
61 |
62 | ## Run Experiments
63 |
64 | ### CoNLL Named Entity Recogintion
65 | ```bash
66 | ./scripts/train_conll_ner_3to1.sh {exp_name}
67 | ```
68 |
69 | The following script can print out a compiled dev/test F1 scores for all languages:
70 |
71 | ```bash
72 | python scripts/get_overall_perf_ner.py save {exp_name}
73 | ```
74 |
75 | ### Multilingual Amazon Reviews
76 | ```bash
77 | ./scripts/train_amazon_3to1.sh {exp_name}
78 | ```
79 |
80 | The following script can print out a compiled dev/test F1 scores for all languages:
81 |
82 | ```bash
83 | python scripts/get_overall_perf_amazon.py save {exp_name}
84 | ```
85 |
86 | ## Citation
87 |
88 | If you find this project useful for your research, please kindly cite our ACL 2019 paper:
89 |
90 | ```bibtex
91 | @InProceedings{chen-etal-acl2019-multi-source,
92 | author = {Chen, Xilun and Hassan Awadallah, Ahmed and Hassan, Hany and Wang, Wei and Cardie, Claire},
93 | title = {Multi-Source Cross-Lingual Model Transfer: Learning What to Share},
94 | booktitle = "Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
95 | month = jul,
96 | year = "2019",
97 | address = "Florence, Italy",
98 | publisher = "Association for Computational Linguistics",
99 | }
100 | ```
101 |
102 | ## Contributing
103 |
104 | This project welcomes contributions and suggestions. Most contributions require you to agree to a
105 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
106 | the rights to use your contribution. For details, visit https://cla.microsoft.com.
107 |
108 | When you submit a pull request, a CLA-bot will automatically determine whether you need to provide
109 | a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the instructions
110 | provided by the bot. You will only need to do this once across all repos using our CLA.
111 |
112 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
113 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
114 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
115 |
116 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd).
40 |
41 |
42 |
--------------------------------------------------------------------------------
/conlleval.pl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/perl -w
2 | # conlleval: evaluate result of processing CoNLL-2000 shared task
3 | # usage: conlleval [-l] [-r] [-d delimiterTag] [-o oTag] < file
4 | # README: http://cnts.uia.ac.be/conll2000/chunking/output.html
5 | # options: l: generate LaTeX output for tables like in
6 | # http://cnts.uia.ac.be/conll2003/ner/example.tex
7 | # r: accept raw result tags (without B- and I- prefix;
8 | # assumes one word per chunk)
9 | # d: alternative delimiter tag (default is single space)
10 | # o: alternative outside tag (default is O)
11 | # note: the file should contain lines with items separated
12 | # by $delimiter characters (default space). The final
13 | # two items should contain the correct tag and the
14 | # guessed tag in that order. Sentences should be
15 | # separated from each other by empty lines or lines
16 | # with $boundary fields (default -X-).
17 | # url: http://lcg-www.uia.ac.be/conll2000/chunking/
18 | # started: 1998-09-25
19 | # version: 2004-01-26
20 | # author: Erik Tjong Kim Sang
21 |
22 | use strict;
23 |
24 | my $false = 0;
25 | my $true = 42;
26 |
27 | my $boundary = "-X-"; # sentence boundary
28 | my $correct; # current corpus chunk tag (I,O,B)
29 | my $correctChunk = 0; # number of correctly identified chunks
30 | my $correctTags = 0; # number of correct chunk tags
31 | my $correctType; # type of current corpus chunk tag (NP,VP,etc.)
32 | my $delimiter = " "; # field delimiter
33 | my $FB1 = 0.0; # FB1 score (Van Rijsbergen 1979)
34 | my $firstItem; # first feature (for sentence boundary checks)
35 | my $foundCorrect = 0; # number of chunks in corpus
36 | my $foundGuessed = 0; # number of identified chunks
37 | my $guessed; # current guessed chunk tag
38 | my $guessedType; # type of current guessed chunk tag
39 | my $i; # miscellaneous counter
40 | my $inCorrect = $false; # currently processed chunk is correct until now
41 | my $lastCorrect = "O"; # previous chunk tag in corpus
42 | my $latex = 0; # generate LaTeX formatted output
43 | my $lastCorrectType = ""; # type of previously identified chunk tag
44 | my $lastGuessed = "O"; # previously identified chunk tag
45 | my $lastGuessedType = ""; # type of previous chunk tag in corpus
46 | my $lastType; # temporary storage for detecting duplicates
47 | my $line; # line
48 | my $nbrOfFeatures = -1; # number of features per line
49 | my $precision = 0.0; # precision score
50 | my $oTag = "O"; # outside tag, default O
51 | my $raw = 0; # raw input: add B to every token
52 | my $recall = 0.0; # recall score
53 | my $tokenCounter = 0; # token counter (ignores sentence breaks)
54 |
55 | my %correctChunk = (); # number of correctly identified chunks per type
56 | my %foundCorrect = (); # number of chunks in corpus per type
57 | my %foundGuessed = (); # number of identified chunks per type
58 |
59 | my @features; # features on line
60 | my @sortedTypes; # sorted list of chunk type names
61 |
62 | # sanity check
63 | while (@ARGV and $ARGV[0] =~ /^-/) {
64 | if ($ARGV[0] eq "-l") { $latex = 1; shift(@ARGV); }
65 | elsif ($ARGV[0] eq "-r") { $raw = 1; shift(@ARGV); }
66 | elsif ($ARGV[0] eq "-d") {
67 | shift(@ARGV);
68 | if (not defined $ARGV[0]) {
69 | die "conlleval: -d requires delimiter character";
70 | }
71 | $delimiter = shift(@ARGV);
72 | } elsif ($ARGV[0] eq "-o") {
73 | shift(@ARGV);
74 | if (not defined $ARGV[0]) {
75 | die "conlleval: -o requires delimiter character";
76 | }
77 | $oTag = shift(@ARGV);
78 | } else { die "conlleval: unknown argument $ARGV[0]\n"; }
79 | }
80 | if (@ARGV) { die "conlleval: unexpected command line argument\n"; }
81 | # process input
82 | while () {
83 | chomp($line = $_);
84 | @features = split(/$delimiter/,$line);
85 | if ($nbrOfFeatures < 0) { $nbrOfFeatures = $#features; }
86 | elsif ($nbrOfFeatures != $#features and @features != 0) {
87 | printf STDERR "unexpected number of features: %d (%d)\n",
88 | $#features+1,$nbrOfFeatures+1;
89 | exit(1);
90 | }
91 | if (@features == 0 or
92 | $features[0] eq $boundary) { @features = ($boundary,"O","O"); }
93 | if (@features < 2) {
94 | die "conlleval: unexpected number of features in line $line\n";
95 | }
96 | if ($raw) {
97 | if ($features[$#features] eq $oTag) { $features[$#features] = "O"; }
98 | if ($features[$#features-1] eq $oTag) { $features[$#features-1] = "O"; }
99 | if ($features[$#features] ne "O") {
100 | $features[$#features] = "B-$features[$#features]";
101 | }
102 | if ($features[$#features-1] ne "O") {
103 | $features[$#features-1] = "B-$features[$#features-1]";
104 | }
105 | }
106 | # 20040126 ET code which allows hyphens in the types
107 | if ($features[$#features] =~ /^([^-]*)-(.*)$/) {
108 | $guessed = $1;
109 | $guessedType = $2;
110 | } else {
111 | $guessed = $features[$#features];
112 | $guessedType = "";
113 | }
114 | pop(@features);
115 | if ($features[$#features] =~ /^([^-]*)-(.*)$/) {
116 | $correct = $1;
117 | $correctType = $2;
118 | } else {
119 | $correct = $features[$#features];
120 | $correctType = "";
121 | }
122 | pop(@features);
123 | # ($guessed,$guessedType) = split(/-/,pop(@features));
124 | # ($correct,$correctType) = split(/-/,pop(@features));
125 | $guessedType = $guessedType ? $guessedType : "";
126 | $correctType = $correctType ? $correctType : "";
127 | $firstItem = shift(@features);
128 |
129 | # 1999-06-26 sentence breaks should always be counted as out of chunk
130 | if ( $firstItem eq $boundary ) { $guessed = "O"; }
131 |
132 | if ($inCorrect) {
133 | if ( &endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and
134 | &endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and
135 | $lastGuessedType eq $lastCorrectType) {
136 | $inCorrect=$false;
137 | $correctChunk++;
138 | $correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ?
139 | $correctChunk{$lastCorrectType}+1 : 1;
140 | } elsif (
141 | &endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) !=
142 | &endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) or
143 | $guessedType ne $correctType ) {
144 | $inCorrect=$false;
145 | }
146 | }
147 |
148 | if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and
149 | &startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and
150 | $guessedType eq $correctType) { $inCorrect = $true; }
151 |
152 | if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) ) {
153 | $foundCorrect++;
154 | $foundCorrect{$correctType} = $foundCorrect{$correctType} ?
155 | $foundCorrect{$correctType}+1 : 1;
156 | }
157 | if ( &startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) ) {
158 | $foundGuessed++;
159 | $foundGuessed{$guessedType} = $foundGuessed{$guessedType} ?
160 | $foundGuessed{$guessedType}+1 : 1;
161 | }
162 | if ( $firstItem ne $boundary ) {
163 | if ( $correct eq $guessed and $guessedType eq $correctType ) {
164 | $correctTags++;
165 | }
166 | $tokenCounter++;
167 | }
168 |
169 | $lastGuessed = $guessed;
170 | $lastCorrect = $correct;
171 | $lastGuessedType = $guessedType;
172 | $lastCorrectType = $correctType;
173 | }
174 | if ($inCorrect) {
175 | $correctChunk++;
176 | $correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ?
177 | $correctChunk{$lastCorrectType}+1 : 1;
178 | }
179 |
180 | if (not $latex) {
181 | # compute overall precision, recall and FB1 (default values are 0.0)
182 | $precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0);
183 | $recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0);
184 | $FB1 = 2*$precision*$recall/($precision+$recall)
185 | if ($precision+$recall > 0);
186 |
187 | # print overall performance
188 | printf "processed $tokenCounter tokens with $foundCorrect phrases; ";
189 | printf "found: $foundGuessed phrases; correct: $correctChunk.\n";
190 | if ($tokenCounter>0) {
191 | printf "accuracy: %6.2f%%; ",100*$correctTags/$tokenCounter;
192 | printf "precision: %6.2f%%; ",$precision;
193 | printf "recall: %6.2f%%; ",$recall;
194 | printf "FB1: %6.2f\n",$FB1;
195 | }
196 | }
197 |
198 | # sort chunk type names
199 | undef($lastType);
200 | @sortedTypes = ();
201 | foreach $i (sort (keys %foundCorrect,keys %foundGuessed)) {
202 | if (not($lastType) or $lastType ne $i) {
203 | push(@sortedTypes,($i));
204 | }
205 | $lastType = $i;
206 | }
207 | # print performance per chunk type
208 | if (not $latex) {
209 | for $i (@sortedTypes) {
210 | $correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0;
211 | if (not($foundGuessed{$i})) { $foundGuessed{$i} = 0; $precision = 0.0; }
212 | else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; }
213 | if (not($foundCorrect{$i})) { $recall = 0.0; }
214 | else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; }
215 | if ($precision+$recall == 0.0) { $FB1 = 0.0; }
216 | else { $FB1 = 2*$precision*$recall/($precision+$recall); }
217 | printf "%17s: ",$i;
218 | printf "precision: %6.2f%%; ",$precision;
219 | printf "recall: %6.2f%%; ",$recall;
220 | printf "FB1: %6.2f %d\n",$FB1,$foundGuessed{$i};
221 | }
222 | } else {
223 | print " & Precision & Recall & F\$_{\\beta=1} \\\\\\hline";
224 | for $i (@sortedTypes) {
225 | $correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0;
226 | if (not($foundGuessed{$i})) { $precision = 0.0; }
227 | else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; }
228 | if (not($foundCorrect{$i})) { $recall = 0.0; }
229 | else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; }
230 | if ($precision+$recall == 0.0) { $FB1 = 0.0; }
231 | else { $FB1 = 2*$precision*$recall/($precision+$recall); }
232 | printf "\n%-7s & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\",
233 | $i,$precision,$recall,$FB1;
234 | }
235 | print "\\hline\n";
236 | $precision = 0.0;
237 | $recall = 0;
238 | $FB1 = 0.0;
239 | $precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0);
240 | $recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0);
241 | $FB1 = 2*$precision*$recall/($precision+$recall)
242 | if ($precision+$recall > 0);
243 | printf "Overall & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\\\hline\n",
244 | $precision,$recall,$FB1;
245 | }
246 |
247 | exit 0;
248 |
249 | # endOfChunk: checks if a chunk ended between the previous and current word
250 | # arguments: previous and current chunk tags, previous and current types
251 | # note: this code is capable of handling other chunk representations
252 | # than the default CoNLL-2000 ones, see EACL'99 paper of Tjong
253 | # Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006
254 |
255 | sub endOfChunk {
256 | my $prevTag = shift(@_);
257 | my $tag = shift(@_);
258 | my $prevType = shift(@_);
259 | my $type = shift(@_);
260 | my $chunkEnd = $false;
261 |
262 | if ( $prevTag eq "B" and $tag eq "B" ) { $chunkEnd = $true; }
263 | if ( $prevTag eq "B" and $tag eq "O" ) { $chunkEnd = $true; }
264 | if ( $prevTag eq "I" and $tag eq "B" ) { $chunkEnd = $true; }
265 | if ( $prevTag eq "I" and $tag eq "O" ) { $chunkEnd = $true; }
266 |
267 | if ( $prevTag eq "E" and $tag eq "E" ) { $chunkEnd = $true; }
268 | if ( $prevTag eq "E" and $tag eq "I" ) { $chunkEnd = $true; }
269 | if ( $prevTag eq "E" and $tag eq "O" ) { $chunkEnd = $true; }
270 | if ( $prevTag eq "I" and $tag eq "O" ) { $chunkEnd = $true; }
271 |
272 | if ($prevTag ne "O" and $prevTag ne "." and $prevType ne $type) {
273 | $chunkEnd = $true;
274 | }
275 |
276 | # corrected 1998-12-22: these chunks are assumed to have length 1
277 | if ( $prevTag eq "]" ) { $chunkEnd = $true; }
278 | if ( $prevTag eq "[" ) { $chunkEnd = $true; }
279 |
280 | return($chunkEnd);
281 | }
282 |
283 | # startOfChunk: checks if a chunk started between the previous and current word
284 | # arguments: previous and current chunk tags, previous and current types
285 | # note: this code is capable of handling other chunk representations
286 | # than the default CoNLL-2000 ones, see EACL'99 paper of Tjong
287 | # Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006
288 |
289 | sub startOfChunk {
290 | my $prevTag = shift(@_);
291 | my $tag = shift(@_);
292 | my $prevType = shift(@_);
293 | my $type = shift(@_);
294 | my $chunkStart = $false;
295 |
296 | if ( $prevTag eq "B" and $tag eq "B" ) { $chunkStart = $true; }
297 | if ( $prevTag eq "I" and $tag eq "B" ) { $chunkStart = $true; }
298 | if ( $prevTag eq "O" and $tag eq "B" ) { $chunkStart = $true; }
299 | if ( $prevTag eq "O" and $tag eq "I" ) { $chunkStart = $true; }
300 |
301 | if ( $prevTag eq "E" and $tag eq "E" ) { $chunkStart = $true; }
302 | if ( $prevTag eq "E" and $tag eq "I" ) { $chunkStart = $true; }
303 | if ( $prevTag eq "O" and $tag eq "E" ) { $chunkStart = $true; }
304 | if ( $prevTag eq "O" and $tag eq "I" ) { $chunkStart = $true; }
305 |
306 | if ($tag ne "O" and $tag ne "." and $prevType ne $type) {
307 | $chunkStart = $true;
308 | }
309 |
310 | # corrected 1998-12-22: these chunks are assumed to have length 1
311 | if ( $tag eq "[" ) { $chunkStart = $true; }
312 | if ( $tag eq "]" ) { $chunkStart = $true; }
313 |
314 | return($chunkStart);
315 | }
316 |
--------------------------------------------------------------------------------
/data_prep/bio_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation. All rights reserved.
2 | # Licensed under the MIT License.
3 |
4 | import copy
5 | import os
6 | import random
7 |
8 | import torch
9 | from torch.utils.data import Dataset
10 |
11 | from options import opt
12 | from utils import read_bio_samples
13 |
14 | class BioDataset(Dataset):
15 | """
16 | tag_vocab is updated
17 | vocab is updated if update_vocab is True
18 | self.raw_X, self.raw_Y: list of lists of strings
19 | self.X: list of dict {'words': list of int, 'chars': list of list of int}
20 | self.Y: list of lists of integers
21 | """
22 | def __init__(self, input_file, vocab, char_vocab, tag_vocab,
23 | update_vocab, encoding='utf-8', remove_empty=False):
24 | self.raw_X = []
25 | self.raw_Y = []
26 | with open(input_file, encoding=encoding) as inf:
27 | for lines in read_bio_samples(inf):
28 | x = [l.split()[0] for l in lines]
29 | y = [l.split()[-1] for l in lines]
30 | if remove_empty:
31 | if all([l=='O' for l in y]):
32 | continue
33 | self.raw_X.append(x)
34 | self.raw_Y.append(y)
35 | if vocab and update_vocab:
36 | for w in x:
37 | vocab.add_word(w)
38 | if char_vocab and update_vocab:
39 | for w in x:
40 | for ch in w:
41 | char_vocab.add_word(ch, normalize=opt.lowercase_char)
42 | tag_vocab.add_tags(y)
43 | self.X = []
44 | self.Y = []
45 | for xs, ys in zip(self.raw_X, self.raw_Y):
46 | x = {}
47 | if vocab:
48 | x['words'] = [vocab.lookup(w) for w in xs]
49 | if char_vocab:
50 | x['chars'] = [[char_vocab.lookup(ch, normalize=opt.lowercase_char) for ch in w] for w in xs]
51 | self.X.append(x)
52 | self.Y.append([tag_vocab.lookup(y) for y in ys])
53 | assert len(self.X) == len(self.Y)
54 |
55 | def __len__(self):
56 | return len(self.Y)
57 |
58 | def __getitem__(self, idx):
59 | return (self.X[idx], self.Y[idx])
60 |
61 | def get_subset(self, num_samples, shuffle=True):
62 | subset = copy.copy(self)
63 | if shuffle:
64 | idx = random.sample(range(len(self)), num_samples)
65 | else:
66 | idx = list(range(min(len(self), num_samples)))
67 | subset.raw_X = [subset.raw_X[i] for i in idx]
68 | subset.raw_Y = [subset.raw_Y[i] for i in idx]
69 | subset.X = [subset.X[i] for i in idx]
70 | subset.Y = [subset.Y[i] for i in idx]
71 | return subset
72 |
73 | class ConllDataset(BioDataset):
74 | def __init__(self, input_file, vocab, char_vocab, tag_vocab,
75 | update_vocab, remove_empty=False):
76 | if (input_file[17:20] == 'esp') or (input_file[17:20] == 'ned'):
77 | encoding = 'ISO-8859-1'
78 | else:
79 | encoding = 'utf-8'
80 | super().__init__(input_file, vocab, char_vocab, tag_vocab,
81 | update_vocab, encoding, remove_empty=False)
82 |
83 |
84 | def get_conll_ner_datasets(vocab, char_vocab, tag_vocab, data_dir, lang):
85 | print(f'Loading CoNLL NER data for {lang} Language..')
86 | train_set = ConllDataset(os.path.join(data_dir, f'{lang}.train'),
87 | vocab, char_vocab, tag_vocab, update_vocab=True, remove_empty=opt.remove_empty_samples)
88 | dev_set = ConllDataset(os.path.join(data_dir, f'{lang}.dev'),
89 | vocab, char_vocab, tag_vocab, update_vocab=True)
90 | test_set = ConllDataset(os.path.join(data_dir, f'{lang}.test'),
91 | vocab, char_vocab, tag_vocab, update_vocab=True)
92 | return train_set, dev_set, test_set, train_set
93 |
94 |
95 | def get_train_on_translation_conll_ner_datasets(vocab, char_vocab, tag_vocab, data_dir, lang):
96 | print(f'Loading Train-on-Translation CoNLL NER data for {lang} Language..')
97 | train_set = ConllDataset(os.path.join(data_dir, f'eng2{lang}_{lang}.train'),
98 | vocab, char_vocab, tag_vocab, update_vocab=True)
99 | dev_set = ConllDataset(os.path.join(data_dir, f'{lang}.dev'),
100 | vocab, char_vocab, tag_vocab, update_vocab=True)
101 | test_set = ConllDataset(os.path.join(data_dir, f'{lang}.test'),
102 | vocab, char_vocab, tag_vocab, update_vocab=False)
103 | return train_set, dev_set, test_set, train_set
104 |
105 |
106 | def get_test_on_translation_conll_ner_datasets(vocab, char_vocab, tag_vocab, data_dir, lang):
107 | print(f'Loading Test-on-Translation CoNLL NER data for {lang} Language..')
108 | train_set = ConllDataset(os.path.join(data_dir, f'eng.train'),
109 | vocab, char_vocab, tag_vocab, update_vocab=True)
110 | dev_set = ConllDataset(os.path.join(data_dir, f'eng2{lang}_eng.dev'),
111 | vocab, char_vocab, tag_vocab, update_vocab=True)
112 | test_set = ConllDataset(os.path.join(data_dir, f'eng2{lang}_eng.test'),
113 | vocab, char_vocab, tag_vocab, update_vocab=False)
114 | return train_set, dev_set, test_set, train_set
115 |
--------------------------------------------------------------------------------
/data_prep/multi_lingual_amazon.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation. All rights reserved.
2 | # Licensed under the MIT License.
3 |
4 | import glob
5 | import os
6 | import pickle
7 | import random
8 |
9 | import torch
10 | from torch.utils.data import Dataset
11 |
12 | from options import opt
13 |
14 |
15 | class MultiLangAmazonDataset(Dataset):
16 | """
17 | vocab is updated if update_vocab is True
18 | self.raw_X: list of lists of strings
19 | self.X, self.Y: list of lists of integers
20 | """
21 | num_labels = 2
22 | def __init__(self, input_file=None, vocab=None, char_vocab=None, max_seq_len=0, num_train_lines=0, update_vocab=False):
23 | self.raw_X = []
24 | self.X = []
25 | self.Y = []
26 | if input_file is None:
27 | return
28 | if input_file.endswith('.pkl'):
29 | with open(input_file, 'rb') as inf:
30 | reviews = pickle.load(inf)
31 | self.raw_X = reviews['X']
32 | self.Y = reviews['Y']
33 | if num_train_lines > 0:
34 | self.raw_X = self.raw_X[:num_train_lines]
35 | self.Y = self.Y[:num_train_lines]
36 | for x in self.raw_X:
37 | if max_seq_len > 0:
38 | x = x[:max_seq_len]
39 | if update_vocab:
40 | for w in x:
41 | vocab.add_word(w)
42 | if char_vocab and update_vocab:
43 | for w in x:
44 | for ch in w:
45 | char_vocab.add_word(ch, normalize=opt.lowercase_char)
46 | sample = {}
47 | if vocab:
48 | sample['words'] = [vocab.lookup(w) for w in x]
49 | if char_vocab:
50 | sample['chars'] = [[char_vocab.lookup(ch,
51 | normalize=opt.lowercase_char) for ch in w] for w in x]
52 | self.X.append(sample)
53 | else:
54 | with open(input_file) as inf:
55 | for cnt, line in enumerate(inf):
56 | parts = line.split('\t')
57 | assert len(parts) == 2, f"Incorrect format {line}"
58 | x = parts[1].rstrip().split(' ')
59 | self.raw_X.append(x)
60 | if max_seq_len > 0:
61 | x = x[:max_seq_len]
62 | if update_vocab:
63 | for w in x:
64 | vocab.add_word(w)
65 | if char_vocab and update_vocab:
66 | for w in x:
67 | for ch in w:
68 | char_vocab.add_word(ch, normalize=opt.lowercase_char)
69 | sample = {}
70 | if vocab:
71 | sample['words'] = [vocab.lookup(w) for w in x]
72 | if char_vocab:
73 | sample['chars'] = [[char_vocab.lookup(ch,
74 | normalize=opt.lowercase_char) for ch in w] for w in x]
75 | self.X.append(sample)
76 | self.Y.append(int(parts[0]))
77 | if num_train_lines > 0 and cnt+1 >= num_train_lines:
78 | break
79 | print(f"Loaded {cnt+1} samples")
80 | assert len(self.X) == len(self.Y)
81 |
82 | def __len__(self):
83 | return len(self.Y)
84 |
85 | def __getitem__(self, idx):
86 | return (self.X[idx], self.Y[idx])
87 |
88 | def get_subset(self, num_samples, shuffle=True):
89 | subset = MultiLangAmazonDataset()
90 | if shuffle:
91 | idx = random.sample(range(len(self)), num_samples)
92 | else:
93 | idx = list(range(min(len(self), num_samples)))
94 | subset.raw_X = [subset.raw_X[i] for i in idx]
95 | subset.raw_Y = [subset.raw_Y[i] for i in idx]
96 | subset.X = [subset.X[i] for i in idx]
97 | subset.Y = [subset.Y[i] for i in idx]
98 | return subset
99 |
100 | def train_dev_split(self, dev_samples, shuffle=True):
101 | devset = MultiLangAmazonDataset()
102 | idx = list(range(len(self)))
103 | if shuffle:
104 | random.shuffle(idx)
105 | # dev set
106 | dev_idx = idx[:dev_samples]
107 | devset.raw_X = [self.raw_X[i] for i in dev_idx]
108 | devset.X = [self.X[i] for i in dev_idx]
109 | devset.Y = [self.Y[i] for i in dev_idx]
110 | # update train set
111 | train_idx = idx[dev_samples:]
112 | self.raw_X = [self.raw_X[i] for i in train_idx]
113 | self.X = [self.X[i] for i in train_idx]
114 | self.Y = [self.Y[i] for i in train_idx]
115 | self.max_seq_len = None
116 | return devset
117 |
118 | def get_max_seq_len(self):
119 | if not hasattr(self, 'max_seq_len') or self.max_seq_len is None:
120 | self.max_seq_len = max([len(x) for x in self.X])
121 | return self.max_seq_len
122 |
123 |
124 | def get_multi_lingual_amazon_datasets(vocab, char_vocab, root_dir, domain, lang, max_seq_len):
125 | print(f'Loading Multi-Lingual Amazon Review data for {domain} Domain and {lang} Language..')
126 | data_dir = os.path.join(root_dir, lang, domain)
127 |
128 | train_set = MultiLangAmazonDataset(os.path.join(data_dir, 'train.pkl'),
129 | vocab, char_vocab, max_seq_len, 0, update_vocab=True)
130 | # train-dev split
131 | dev_set = train_set.train_dev_split(400, shuffle=True)
132 | test_set = MultiLangAmazonDataset(os.path.join(data_dir, 'test.pkl'),
133 | vocab, char_vocab, max_seq_len, 0, update_vocab=True)
134 | unlabeled_set = MultiLangAmazonDataset(os.path.join(data_dir, 'unlabeled.pkl'),
135 | vocab, char_vocab, max_seq_len, 50000, update_vocab=True)
136 | return train_set, dev_set, test_set, unlabeled_set
137 |
--------------------------------------------------------------------------------
/data_processing_scripts/amazon/pickle_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import sys
4 |
5 |
6 | def process_all_data(root_dir, out_dir):
7 | for lang in ['en', 'de', 'fr', 'ja']:
8 | for domain in ['books', 'dvd', 'music']:
9 | for split in ['train', 'test', 'unlabeled']:
10 | fn = os.path.join(out_dir, lang, domain, f'{split}.tok.txt')
11 | ofn = os.path.join(root_dir, lang, domain, f'{split}.pkl')
12 | with open(fn) as inf, open(ofn, 'wb') as ouf:
13 | reviews = {'X': [], 'Y': []}
14 | print(f"Processing file: {fn}")
15 | for line in inf:
16 | parts = line.split('\t')
17 | x = parts[1].rstrip().split(' ')
18 | y = int(parts[0])
19 | reviews['X'].append(x)
20 | reviews['Y'].append(y)
21 | pickle.dump(reviews, ouf)
22 |
23 |
24 | if __name__ == '__main__':
25 | process_all_data(sys.argv[1], sys.argv[2])
26 |
--------------------------------------------------------------------------------
/data_processing_scripts/amazon/process_dataset.py:
--------------------------------------------------------------------------------
1 | # Usage: python process_dataset.py dataset_dir output_dir
2 |
3 | import sys
4 | import os
5 | import gzip
6 | from xml.etree.cElementTree import iterparse,tostring
7 | import traceback
8 |
9 | from stanfordcorenlp import StanfordCoreNLP as corenlp
10 | import tinysegmenter
11 |
12 |
13 | # CoreNLP for tokenizing the translation output
14 | CORENLP_DIR = "./stanford-corenlp-full-2018-02-27"
15 | nlp = {}
16 | for lang in ['en', 'de', 'fr']:
17 | nlp[lang] = corenlp(CORENLP_DIR, lang=lang)
18 |
19 |
20 | def parse(itemfile):
21 | for event, elem in iterparse(itemfile):
22 | if elem.tag == "item":
23 | yield processItem(elem)
24 | elem.clear()
25 |
26 |
27 | def processItem(item):
28 | """ Process a review.
29 | Implement custom code here. Use 'item.find('tagname').text' to access the properties of a review.
30 | """
31 | review = {}
32 | # review.category = item.find("category").text
33 | review['rating'] = int(float(item.find("rating").text))
34 | # review.asin = item.find("asin").text
35 | # review.date = item.find("date").text
36 | review['text'] = item.find("text").text
37 | # review.summary = item.find("summary").text
38 | return review
39 |
40 |
41 | def process_all_data(root_dir, out_dir):
42 | for lang in ['en', 'de', 'fr', 'ja']:
43 | for domain in ['books', 'dvd', 'music']:
44 | for split in ['train', 'test', 'unlabeled']:
45 | fn = os.path.join(root_dir, lang, domain, f'{split}.review')
46 | ofn = os.path.join(out_dir, lang, domain, f'{split}.tok.txt')
47 | with open(fn) as inf, open(ofn, 'w') as ouf:
48 | print(f"Processing file: {fn}")
49 | for review in parse(inf):
50 | # binarize label
51 | label = 1 if review['rating'] > 3 else 0
52 | try:
53 | # remove line breaks
54 | raw_text = review['text'].replace('\n', ' ').replace('\t', ' ')
55 | if lang == 'ja':
56 | tok_text = tinysegmenter.tokenize(raw_text)
57 | else:
58 | tok_text = nlp[lang].word_tokenize(raw_text)
59 | except:
60 | print("Exception tokenizing", review)
61 | continue
62 | print(f"{label}\t{' '.join(tok_text)}", file=ouf)
63 |
64 |
65 | if __name__ == "__main__":
66 | process_all_data(sys.argv[1], sys.argv[2])
67 |
68 |
--------------------------------------------------------------------------------
/layers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation. All rights reserved.
2 | # Licensed under the MIT License.
3 |
4 | import torch
5 | from torch import autograd, nn
6 | import torch.nn.functional as functional
7 |
8 | import utils
9 |
10 | class AveragingLayer(nn.Module):
11 | def __init__(self, word_emb):
12 | super(AveragingLayer, self).__init__()
13 | self.word_emb = word_emb
14 |
15 | def forward(self, input):
16 | """
17 | input: (data, lengths): (IntTensor(batch_size, max_sent_len), IntTensor(batch_size))
18 | """
19 | data, lengths = input
20 | data = autograd.Variable(data)
21 | lengths = autograd.Variable(lengths)
22 | embeds = self.word_emb(data)
23 | X = embeds.sum(1).squeeze(1)
24 | lengths = lengths.view(-1, 1).expand_as(X)
25 | return torch.div(X, lengths.float())
26 |
27 |
28 | class SummingLayer(nn.Module):
29 | def __init__(self, word_emb):
30 | super(SummingLayer, self).__init__()
31 | self.word_emb = word_emb
32 |
33 | def forward(self, input):
34 | """
35 | input: (data, lengths): (IntTensor(batch_size, max_sent_len), IntTensor(batch_size))
36 | """
37 | data, _ = input
38 | data = autograd.Variable(data)
39 | embeds = self.word_emb(data)
40 | X = embeds.sum(1).squeeze()
41 | return X
42 |
43 |
44 | class DotAttentionLayer(nn.Module):
45 | def __init__(self, hidden_size):
46 | super(DotAttentionLayer, self).__init__()
47 | self.hidden_size = hidden_size
48 | self.W = nn.Linear(hidden_size, 1, bias=False)
49 |
50 | def forward(self, input):
51 | """
52 | input: (unpacked_padded_output: batch_size x seq_len x hidden_size, lengths: batch_size)
53 | """
54 | inputs, lengths = input
55 | batch_size, max_len, _ = inputs.size()
56 | flat_input = inputs.contiguous().view(-1, self.hidden_size)
57 | logits = self.W(flat_input).view(batch_size, max_len)
58 | alphas = functional.softmax(logits, dim=-1)
59 |
60 | # computing mask
61 | idxes = torch.arange(0, max_len, out=torch.LongTensor(max_len))
62 | mask = (idxes 0, 'Invalid layer numbers'
106 | self.fcnet = nn.Sequential()
107 | for i in range(num_layers):
108 | if dropout > 0:
109 | self.fcnet.add_module('f-dropout-{}'.format(i), nn.Dropout(p=dropout))
110 | if i == 0:
111 | self.fcnet.add_module('f-linear-{}'.format(i),
112 | nn.Linear(len(kernel_sizes)*kernel_num, hidden_size))
113 | else:
114 | self.fcnet.add_module('f-linear-{}'.format(i), nn.Linear(hidden_size, hidden_size))
115 | # if batch_norm:
116 | # self.fcnet.add_module('f-bn-{}'.format(i), nn.LayerNorm(hidden_size))
117 | self.fcnet.add_module('f-relu-{}'.format(i), opt.act_unit)
118 |
119 | def forward(self, input):
120 | data, _ = input
121 | # conv
122 | if self.word_dropout > 0:
123 | data = functional.dropout(data, self.word_dropout, self.training)
124 | data = data.unsqueeze(1) # batch_size, 1, seq_len, input_size
125 | x = [functional.relu(conv(data)).squeeze(3) for conv in self.convs]
126 | x = [functional.max_pool1d(i, i.size(2)).squeeze(2) for i in x]
127 | x = torch.cat(x, 1)
128 | # fcnet
129 | return self.fcnet(x)
130 |
131 |
132 | class LSTMSequencePooling(nn.Module):
133 | def __init__(self,
134 | num_layers,
135 | input_size,
136 | hidden_size,
137 | word_dropout,
138 | dropout,
139 | bdrnn,
140 | attn_type):
141 | super(LSTMSequencePooling, self).__init__()
142 | self.num_layers = num_layers
143 | self.dropout = dropout
144 | self.word_dropout = word_dropout
145 | self.bdrnn = bdrnn
146 | self.attn_type = attn_type
147 | self.input_size = input_size
148 | self.hidden_size = hidden_size//2 if bdrnn else hidden_size
149 | self.n_cells = self.num_layers*2 if bdrnn else self.num_layers
150 | self.rnn = nn.LSTM(input_size=input_size, hidden_size=self.hidden_size,
151 | num_layers=num_layers, dropout=dropout, bidirectional=bdrnn)
152 | if attn_type == 'dot':
153 | self.attn = DotAttentionLayer(hidden_size)
154 |
155 | def forward(self, input):
156 | data, lengths = input
157 | # lengths_list = lengths.tolist()
158 | batch_size = len(data)
159 | if self.word_dropout > 0:
160 | data = functional.dropout(data, self.word_dropout, self.training)
161 | packed = pack_padded_sequence(data, lengths, batch_first=True)
162 | output, (ht, ct) = self.rnn(packed)
163 |
164 | if self.attn_type == 'last':
165 | return ht[-1] if not self.bdrnn \
166 | else ht[-2:].transpose(0, 1).contiguous().view(batch_size, -1)
167 | elif self.attn_type == 'avg':
168 | unpacked_output = pad_packed_sequence(output, batch_first=True)[0]
169 | return torch.sum(unpacked_output, 1) / lengths.float().view(-1, 1)
170 | elif self.attn_type == 'dot':
171 | unpacked_output = pad_packed_sequence(output, batch_first=True)[0]
172 | return self.attn((unpacked_output, lengths))
173 | else:
174 | raise Exception('Please specify valid attention (pooling) mechanism')
175 |
176 |
177 | class LSTMFeatureExtractor(nn.Module):
178 | def __init__(self,
179 | emb_size,
180 | num_layers,
181 | hidden_size,
182 | word_dropout,
183 | dropout,
184 | bdrnn):
185 | super(LSTMFeatureExtractor, self).__init__()
186 | self.num_layers = num_layers
187 | self.bdrnn = bdrnn
188 | self.dropout = dropout
189 | self.word_dropout = word_dropout
190 | self.hidden_size = hidden_size//2 if bdrnn else hidden_size
191 | # self.n_cells = self.num_layers*2 if bdrnn else self.num_layers
192 |
193 | self.rnn = nn.LSTM(input_size=emb_size, hidden_size=self.hidden_size,
194 | num_layers=num_layers, dropout=dropout, bidirectional=bdrnn)
195 |
196 | def forward(self, input):
197 | embeds, lengths = input
198 | # lengths_list = lengths.tolist()
199 | batch_size, seq_len = embeds.size(0), embeds.size(1)
200 | # word_dropout before LSTM
201 | if self.word_dropout > 0:
202 | embeds = functional.dropout(embeds, self.word_dropout, self.training)
203 | packed = pack_padded_sequence(embeds, lengths, batch_first=True)
204 | # state_shape = self.n_cells, batch_size, self.hidden_size
205 | # h0 = c0 = autograd.Variable(embeds.data.new(*state_shape))
206 | output, _ = self.rnn(packed, None)
207 | output, _ = pad_packed_sequence(output, batch_first=True, total_length=seq_len)
208 | return output
209 |
210 |
211 | class SpMlpTagger(nn.Module):
212 | def __init__(self,
213 | num_layers,
214 | shared_input_size,
215 | private_input_size,
216 | concat_sp,
217 | hidden_size,
218 | output_size,
219 | dropout,
220 | batch_norm=False):
221 | super(SpMlpTagger, self).__init__()
222 | assert num_layers >= 0, 'Invalid layer numbers'
223 | self.concat_sp = concat_sp
224 | self.shared_input_size = shared_input_size
225 | self.private_input_size = private_input_size
226 | self.input_size = shared_input_size + private_input_size if concat_sp else shared_input_size
227 | self.hidden_size = hidden_size
228 | if opt.sp_attn:
229 | if concat_sp:
230 | self.sp_attn = nn.Linear(self.input_size, 2)
231 | else:
232 | self.sp_attn = nn.Linear(self.input_size, 1)
233 | if opt.C_input_gate:
234 | self.input_gate = nn.Linear(self.input_size, self.input_size)
235 | self.net = nn.Sequential()
236 | for i in range(num_layers):
237 | if dropout > 0:
238 | self.net.add_module('p-dropout-{}'.format(i), nn.Dropout(p=dropout))
239 | if i == 0:
240 | self.net.add_module('p-linear-{}'.format(i), nn.Linear(self.input_size, hidden_size))
241 | else:
242 | self.net.add_module('p-linear-{}'.format(i), nn.Linear(hidden_size, hidden_size))
243 | if batch_norm:
244 | self.net.add_module('p-bn-{}'.format(i), nn.LayerNorm(hidden_size))
245 | self.net.add_module('p-relu-{}'.format(i), opt.act_unit)
246 |
247 | self.net.add_module('p-linear-final', nn.Linear(hidden_size, output_size))
248 | self.net.add_module('p-logsoftmax', nn.LogSoftmax(dim=-1))
249 |
250 | def forward(self, input):
251 | fs, fp = input
252 | if self.shared_input_size == 0:
253 | features = fp
254 | elif self.private_input_size == 0:
255 | features = fs
256 | else:
257 | if self.concat_sp:
258 | if opt.sp_attn:
259 | features = torch.cat([fs, fp], dim=-1)
260 | # bs x seqlen x 2
261 | if opt.sp_sigmoid_attn:
262 | alphas = functional.sigmoid(self.sp_attn(features))
263 | else:
264 | alphas = functional.softmax(self.sp_attn(features), dim=-1)
265 | fs = fs * alphas[:,:,0].unsqueeze(-1)
266 | fp = fp * alphas[:,:,1].unsqueeze(-1)
267 | features = torch.cat([fs, fp], dim=-1)
268 | else:
269 | features = torch.cat([fs, fp], dim=-1)
270 | else:
271 | if opt.sp_attn:
272 | a1 = self.sp_attn(fs)
273 | a2 = self.sp_attn(fp)
274 | if opt.sp_sigmoid_attn:
275 | alphas = functional.sigmoid(torch.cat([a1, a2], dim=-1))
276 | else:
277 | alphas = functional.softmax(torch.cat([a1, a2], dim=-1), dim=-1)
278 | features = torch.stack([fs, fp], dim=2) # bs x seq_len x 2 x hidden_dim
279 | features = torch.sum(alphas.unsqueeze(-1) * features, dim=2)
280 | else:
281 | features = fs + fp
282 | if opt.C_input_gate:
283 | gates = self.input_gate(features.detach())
284 | gates = functional.sigmoid(gates)
285 | features = gates * features
286 | return self.net(features) # * mask.unsqueeze(2)
287 |
288 |
289 | class SpClassifier(nn.Module):
290 | def __init__(self,
291 | num_layers,
292 | shared_input_size,
293 | private_input_size,
294 | concat_sp,
295 | hidden_size,
296 | output_size,
297 | dropout,
298 | batch_norm=False):
299 | super(SpClassifier, self).__init__()
300 | assert num_layers >= 0, 'Invalid layer numbers'
301 | self.concat_sp = concat_sp
302 | self.shared_input_size = shared_input_size
303 | self.private_input_size = private_input_size
304 | self.input_size = shared_input_size + private_input_size if concat_sp else shared_input_size
305 | self.hidden_size = hidden_size
306 | if opt.sp_attn:
307 | if concat_sp:
308 | self.sp_attn = nn.Linear(self.input_size, 2)
309 | else:
310 | self.sp_attn = nn.Linear(self.input_size, 1)
311 | if opt.C_input_gate:
312 | self.input_gate = nn.Linear(self.input_size, self.input_size)
313 | self.net = nn.Sequential()
314 | for i in range(num_layers):
315 | if dropout > 0:
316 | self.net.add_module('p-dropout-{}'.format(i), nn.Dropout(p=dropout))
317 | if i == 0:
318 | self.net.add_module('p-linear-{}'.format(i), nn.Linear(self.input_size, hidden_size))
319 | else:
320 | self.net.add_module('p-linear-{}'.format(i), nn.Linear(hidden_size, hidden_size))
321 | if batch_norm:
322 | self.net.add_module('p-bn-{}'.format(i), nn.LayerNorm(hidden_size))
323 | self.net.add_module('p-relu-{}'.format(i), opt.act_unit)
324 |
325 | self.net.add_module('p-linear-final', nn.Linear(hidden_size, output_size))
326 | self.net.add_module('p-logsoftmax', nn.LogSoftmax(dim=-1))
327 |
328 | def forward(self, input):
329 | fs, fp = input
330 | if self.shared_input_size == 0:
331 | features = fp
332 | elif self.private_input_size == 0:
333 | features = fs
334 | else:
335 | if self.concat_sp:
336 | if opt.sp_attn:
337 | features = torch.cat([fs, fp], dim=-1)
338 | # bs x 2
339 | if opt.sp_sigmoid_attn:
340 | alphas = functional.sigmoid(self.sp_attn(features))
341 | else:
342 | alphas = functional.softmax(self.sp_attn(features), dim=-1)
343 | fs = fs * alphas[:,0].unsqueeze(-1)
344 | fp = fp * alphas[:,1].unsqueeze(-1)
345 | features = torch.cat([fs, fp], dim=-1)
346 | else:
347 | features = torch.cat([fs, fp], dim=-1)
348 | else:
349 | if opt.sp_attn:
350 | a1 = self.sp_attn(fs)
351 | a2 = self.sp_attn(fp)
352 | if opt.sp_sigmoid_attn:
353 | alphas = functional.sigmoid(torch.cat([a1, a2], dim=-1))
354 | else:
355 | alphas = functional.softmax(torch.cat([a1, a2], dim=-1), dim=-1)
356 | features = torch.stack([fs, fp], dim=2) # bs x 2 x hidden_dim
357 | features = torch.sum(alphas.unsqueeze(-1) * features, dim=-2)
358 | else:
359 | features = fs + fp
360 | if opt.C_input_gate:
361 | gates = self.input_gate(features.detach())
362 | gates = functional.sigmoid(gates)
363 | features = gates * features
364 | return self.net(features) # * mask.unsqueeze(2)
365 |
366 |
367 | class Mlp(nn.Module):
368 | """
369 | Use tanh for the last layer to better work like LSTM output
370 | """
371 | def __init__(self,
372 | num_layers,
373 | input_size,
374 | hidden_size,
375 | output_size,
376 | dropout,
377 | batch_norm=False):
378 | super(Mlp, self).__init__()
379 | assert num_layers >= 0, 'Invalid layer numbers'
380 | self.hidden_size = hidden_size
381 | self.net = nn.Sequential()
382 | for i in range(num_layers):
383 | if dropout > 0:
384 | self.net.add_module('p-dropout-{}'.format(i), nn.Dropout(p=dropout))
385 | if i == 0:
386 | self.net.add_module('p-linear-{}'.format(i), nn.Linear(input_size, hidden_size))
387 | hsize = hidden_size
388 | elif i+1 == num_layers:
389 | self.net.add_module('p-linear-{}'.format(i), nn.Linear(hidden_size, output_size))
390 | hsize = output_size
391 | else:
392 | self.net.add_module('p-linear-{}'.format(i), nn.Linear(hidden_size, hidden_size))
393 | if batch_norm:
394 | self.net.add_module('p-bn-{}'.format(i), nn.LayerNorm(hsize))
395 | if i+1 < num_layers:
396 | self.net.add_module('p-relu-{}'.format(i), opt.act_unit)
397 | else:
398 | if opt.moe_last_act == 'tanh':
399 | self.net.add_module('p-tanh-{}'.format(i), nn.Tanh())
400 | elif opt.moe_last_act == 'relu':
401 | self.net.add_module('p-relu-{}'.format(i), nn.ReLU())
402 | else:
403 | raise NotImplemented(opt.moe_last_act)
404 |
405 | def forward(self, input):
406 | return self.net(input)
407 |
408 |
409 | class MlpTagger(nn.Module):
410 | def __init__(self,
411 | num_layers,
412 | input_size,
413 | hidden_size,
414 | output_size,
415 | dropout,
416 | batch_norm=False):
417 | super(MlpTagger, self).__init__()
418 | assert num_layers >= 0, 'Invalid layer numbers'
419 | self.hidden_size = hidden_size
420 | self.net = nn.Sequential()
421 | for i in range(num_layers):
422 | if dropout > 0:
423 | self.net.add_module('p-dropout-{}'.format(i), nn.Dropout(p=dropout))
424 | if i == 0:
425 | self.net.add_module('p-linear-{}'.format(i), nn.Linear(input_size, hidden_size))
426 | else:
427 | self.net.add_module('p-linear-{}'.format(i), nn.Linear(hidden_size, hidden_size))
428 | if batch_norm:
429 | self.net.add_module('p-bn-{}'.format(i), nn.LayerNorm(hidden_size))
430 | self.net.add_module('p-relu-{}'.format(i), opt.act_unit)
431 |
432 | self.net.add_module('p-linear-final', nn.Linear(hidden_size, output_size))
433 | self.net.add_module('p-logsoftmax', nn.LogSoftmax(dim=-1))
434 |
435 | def forward(self, input):
436 | return self.net(input)
437 |
438 |
439 | class MixtureOfExperts(nn.Module):
440 | def __init__(self,
441 | num_layers,
442 | input_size,
443 | num_experts,
444 | hidden_size,
445 | output_size,
446 | dropout,
447 | bn=False,
448 | is_tagger=True):
449 | super(MixtureOfExperts, self).__init__()
450 | self.num_experts = num_experts
451 | self.gates = nn.Linear(input_size, num_experts)
452 | mlp = MlpTagger if is_tagger else Mlp
453 | self.experts = nn.ModuleList([mlp(num_layers, \
454 | input_size, hidden_size, output_size, dropout, bn) \
455 | for _ in range(num_experts)])
456 |
457 | def forward(self, input):
458 | # input: bs x seqlen x input_size
459 | gate_input = input.detach() if opt.detach_gate_input else input
460 | gate_outs = self.gates(gate_input)
461 | gate_softmax = functional.softmax(gate_outs, dim=-1) # bs x seqlen x #experts
462 | # bs x seqlen x #experts x output_size
463 | expert_outs = torch.stack([exp(input) for exp in self.experts], dim=-2)
464 | # bs x seqlen x output_size
465 | output = torch.sum(gate_softmax.unsqueeze(-1) * expert_outs, dim=-2)
466 | # output logits
467 | return output, gate_outs
468 |
469 |
470 | class SpMixtureOfExperts(nn.Module):
471 | def __init__(self,
472 | num_layers,
473 | shared_input_size,
474 | private_input_size,
475 | concat_sp,
476 | num_experts,
477 | hidden_size,
478 | output_size,
479 | dropout,
480 | bn):
481 | super(SpMixtureOfExperts, self).__init__()
482 | self.shared_input_size = shared_input_size
483 | self.private_input_size = private_input_size
484 | self.concat_sp = concat_sp
485 | input_size = shared_input_size + private_input_size if concat_sp else shared_input_size
486 | if opt.sp_attn:
487 | if concat_sp:
488 | self.sp_attn = nn.Linear(input_size, 2)
489 | else:
490 | self.sp_attn = nn.Linear(input_size, 1)
491 | self.moe = MixtureOfExperts(num_layers, input_size, num_experts,
492 | hidden_size, output_size, dropout, bn, is_tagger=True)
493 |
494 | def forward(self, input):
495 | fs, fp = input
496 | if self.shared_input_size == 0:
497 | features = fp
498 | elif self.private_input_size == 0:
499 | features = fs
500 | else:
501 | if self.concat_sp:
502 | if opt.sp_attn:
503 | features = torch.cat([fs, fp], dim=-1)
504 | # bs x seqlen x 2
505 | if opt.sp_sigmoid_attn:
506 | alphas = functional.sigmoid(self.sp_attn(features))
507 | else:
508 | alphas = functional.softmax(self.sp_attn(features), dim=-1)
509 | fs = fs * alphas[:,:,0].unsqueeze(-1)
510 | fp = fp * alphas[:,:,1].unsqueeze(-1)
511 | features = torch.cat([fs, fp], dim=-1)
512 | else:
513 | features = torch.cat([fs, fp], dim=-1)
514 | else:
515 | if opt.sp_attn:
516 | a1 = self.sp_attn(fs)
517 | a2 = self.sp_attn(fp)
518 | if opt.sp_sigmoid_attn:
519 | alphas = functional.sigmoid(torch.cat([a1, a2], dim=-1))
520 | else:
521 | alphas = functional.softmax(torch.cat([a1, a2], dim=-1), dim=-1)
522 | features = torch.stack([fs, fp], dim=2) # bs x seq_len x 2 x hidden_dim
523 | features = torch.sum(alphas.unsqueeze(-1) * features, dim=2)
524 | else:
525 | features = fs + fp
526 | return self.moe(features)
527 |
528 |
529 | class SpAttnMixtureOfExperts(nn.Module):
530 | def __init__(self,
531 | num_layers,
532 | shared_input_size,
533 | private_input_size,
534 | concat_sp,
535 | num_experts,
536 | hidden_size,
537 | output_size,
538 | dropout,
539 | attn_type,
540 | bn):
541 | super(SpAttnMixtureOfExperts, self).__init__()
542 | self.shared_input_size = shared_input_size
543 | self.private_input_size = private_input_size
544 | self.concat_sp = concat_sp
545 | input_size = shared_input_size + private_input_size if concat_sp else shared_input_size
546 | if opt.sp_attn:
547 | if concat_sp:
548 | self.sp_attn = nn.Linear(input_size, 2)
549 | else:
550 | self.sp_attn = nn.Linear(input_size, 1)
551 | self.moe = MixtureOfExperts(num_layers, input_size, num_experts,
552 | hidden_size, output_size, dropout, bn, is_tagger=True)
553 | if attn_type == 'dot':
554 | self.attn = DotAttentionLayer(hidden_size)
555 |
556 | def forward(self, input):
557 | fs, fp, lengths = input
558 | if self.shared_input_size == 0:
559 | features = fp
560 | elif self.private_input_size == 0:
561 | features = fs
562 | else:
563 | if self.concat_sp:
564 | if opt.sp_attn:
565 | features = torch.cat([fs, fp], dim=-1)
566 | # bs x seqlen x 2
567 | if opt.sp_sigmoid_attn:
568 | alphas = functional.sigmoid(self.sp_attn(features))
569 | else:
570 | alphas = functional.softmax(self.sp_attn(features), dim=-1)
571 | fs = fs * alphas[:,:,0].unsqueeze(-1)
572 | fp = fp * alphas[:,:,1].unsqueeze(-1)
573 | features = torch.cat([fs, fp], dim=-1)
574 | else:
575 | features = torch.cat([fs, fp], dim=-1)
576 | else:
577 | if opt.sp_attn:
578 | a1 = self.sp_attn(fs)
579 | a2 = self.sp_attn(fp)
580 | if opt.sp_sigmoid_attn:
581 | alphas = functional.sigmoid(torch.cat([a1, a2], dim=-1))
582 | else:
583 | alphas = functional.softmax(torch.cat([a1, a2], dim=-1), dim=-1)
584 | features = torch.stack([fs, fp], dim=2) # bs x seq_len x 2 x hidden_dim
585 | features = torch.sum(alphas.unsqueeze(-1) * features, dim=2)
586 | else:
587 | features = fs + fp
588 | features = self.attn((features, lengths))
589 | return self.moe(features)
590 |
591 |
592 | class MLPLanguageDiscriminator(nn.Module):
593 | def __init__(self,
594 | num_layers,
595 | input_size,
596 | hidden_size,
597 | num_langs,
598 | loss_type,
599 | dropout,
600 | batch_norm=False):
601 | super(MLPLanguageDiscriminator, self).__init__()
602 | assert num_layers >= 0, 'Invalid layer numbers'
603 | self.num_langs = num_langs
604 | self.loss_type = loss_type
605 | self.net = nn.Sequential()
606 | for i in range(num_layers):
607 | if dropout > 0:
608 | self.net.add_module('d-dropout-{}'.format(i), nn.Dropout(p=dropout))
609 | if i == 0:
610 | self.net.add_module('d-linear-{}'.format(i), nn.Linear(input_size, hidden_size))
611 | else:
612 | self.net.add_module('d-linear-{}'.format(i), nn.Linear(hidden_size, hidden_size))
613 | if batch_norm:
614 | self.net.add_module('d-bn-{}'.format(i), nn.BatchNorm1d(hidden_size))
615 | self.net.add_module('d-relu-{}'.format(i), opt.act_unit)
616 |
617 | self.net.add_module('d-linear-final', nn.Linear(hidden_size, num_langs))
618 | if loss_type.lower() == 'gr' or loss_type.lower() == 'bs':
619 | self.net.add_module('d-logsoftmax', nn.LogSoftmax(dim=-1))
620 |
621 | def forward(self, input):
622 | scores = self.net(input)
623 | if self.loss_type.lower() == 'l2':
624 | # normalize
625 | scores = functional.relu(scores)
626 | scores /= torch.sum(scores, dim=1, keepdim=True)
627 | return scores
628 |
629 |
630 | class LanguageDiscriminator(nn.Module):
631 | def __init__(self,
632 | model,
633 | mlp_layers,
634 | input_size,
635 | hidden_size,
636 | num_langs,
637 | dropout,
638 | batch_norm=False,
639 | model_args=None):
640 | """
641 | If model is CNN/LSTM, D takes a sequence and predicts a scalar (model_args dictates architecture)
642 | If model is MLP, D only takes a single token (and D should be applied individually to all tokens
643 | If num_langs > 1, D uses a MAN-style model
644 | If num_langs == 1, D discriminates real vs fake for a specific language
645 | """
646 | super(LanguageDiscriminator, self).__init__()
647 | self.input_size = input_size
648 | self.model = model
649 | self.num_langs = num_langs
650 | if self.model.lower() == 'lstm':
651 | self.pool = LSTMSequencePooling(**model_args)
652 | elif self.model.lower() == 'cnn':
653 | self.pool = CNNSequencePooling(**model_args)
654 | elif self.model.lower() == 'mlp':
655 | self.pool = None
656 | else:
657 | raise Exception('Please specify valid model architecture')
658 |
659 | self.net = nn.Sequential()
660 | for i in range(mlp_layers):
661 | if dropout > 0:
662 | self.net.add_module('d-dropout-{}'.format(i), nn.Dropout(p=dropout))
663 | if i == 0:
664 | self.net.add_module('d-linear-{}'.format(i), nn.Linear(input_size, hidden_size))
665 | else:
666 | self.net.add_module('d-linear-{}'.format(i), nn.Linear(hidden_size, hidden_size))
667 | if batch_norm:
668 | self.net.add_module('d-bn-{}'.format(i), nn.LayerNorm(hidden_size))
669 | self.net.add_module('d-relu-{}'.format(i), opt.act_unit)
670 |
671 | self.net.add_module('d-linear-final', nn.Linear(hidden_size, num_langs))
672 | if opt.loss.lower() == 'gr' or opt.loss.lower() == 'bs':
673 | if num_langs > 1:
674 | self.net.add_module('d-logsoftmax', nn.LogSoftmax(dim=-1))
675 |
676 | def forward(self, input):
677 | data, lengths = input
678 | if self.pool:
679 | data = self.pool((data, lengths))
680 | # data: bs x input_size (with pooling)
681 | # bs*seq_len x input_size (without pooling)
682 | scores = self.net(data)
683 | if opt.loss.lower() == 'l2':
684 | # normalize
685 | scores = functional.relu(scores)
686 | scores /= torch.sum(scores, dim=1, keepdim=True)
687 | return scores
688 |
--------------------------------------------------------------------------------
/options.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation. All rights reserved.
2 | # Licensed under the MIT License.
3 |
4 | import argparse
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--random_seed', type=int, default=1)
11 | parser.add_argument('--max_epoch', type=int, default=30)
12 | parser.add_argument('--dataset', default='conll')
13 | parser.add_argument('--amazon_dir', default='./data/amazon/')
14 | parser.add_argument('--conll_dir', default='./data/conll_ner/')
15 | parser.add_argument('--train_on_translation/', dest='train_on_translation', action='store_true')
16 | parser.add_argument('--test_on_translation/', dest='test_on_translation', action='store_true')
17 | # for preprocessed amazon dataset; set to -1 to use 30000
18 | parser.add_argument('--domain', type=str, default='books')
19 | # labeled langs: if not set, will use default langs for the dataset
20 | parser.add_argument('--langs', type=str, nargs='+', default=[])
21 | parser.add_argument('--unlabeled_langs', type=str, nargs='+', default=[])
22 | parser.add_argument('--dev_langs', type=str, nargs='+', default=[])
23 | parser.add_argument('--use_charemb', action='store_true', default=True)
24 | parser.add_argument('--no_charemb', dest='use_charemb', action='store_false')
25 | parser.add_argument('--use_wordemb', dest='use_wordemb', action='store_true', default=True)
26 | parser.add_argument('--no_wordemb', dest='use_wordemb', action='store_false')
27 | # should be (lang, filename) pairs
28 | # will be made available as a dict lang->emb_filename
29 | parser.add_argument('--emb_filenames', type=str, nargs='+', default=[])
30 | # alternatively, choose from a set of default embs (paths hard-coded later in this file)
31 | # muse, muse-idchar, vecmap or mono
32 | parser.add_argument('--default_emb', type=str, default='muse')
33 | parser.add_argument('--emb_size', type=int, default=300)
34 | # char embeddings
35 | parser.add_argument('--charemb_size', type=int, default=128)
36 | parser.add_argument('--charemb_model', type=str, default='cnn') # cnn
37 | parser.add_argument('--charemb_num_layers', type=int, default=1)
38 | parser.add_argument('--fix_charemb', dest='fix_charemb', action='store_true', default=True)
39 | parser.add_argument('--no_fix_charemb', dest='fix_charemb', action='store_false')
40 | # lowercase the characters (useful for German NER)
41 | parser.add_argument('--lowercase_char', action='store_true', default=False)
42 | # for LSTM D model
43 | parser.add_argument('--charemb_lstm_layers', type=int, default=1)
44 | parser.add_argument('--charemb_attn', default='dot') # attention mechanism (for LSTM): avg, last, dot
45 | parser.add_argument('--charemb_no_bdrnn', dest='charemb_bdrnn', action='store_false', default=True) # bi-directional LSTM
46 | # for CNN model
47 | parser.add_argument('--charemb_kernel_num', type=int, default=200)
48 | parser.add_argument('--charemb_kernel_sizes', type=int, nargs='+', default=[3,4,5])
49 |
50 | parser.add_argument('--max_seq_len', type=int, default=0) # set to <=0 to not truncate
51 | # if True, training samples with all Os are removed
52 | parser.add_argument('--remove_empty_samples', action='store_true', default=False)
53 | # which data to be used as unlabeled data: train, unlabeled, or both
54 | parser.add_argument('--unlabeled_data', type=str, default='unlabeled')
55 | parser.add_argument('--model_save_file', default='./save/manmoe')
56 | #parser.add_argument('--output_pred', dest='output_pred', action='store_true')
57 | #parser.add_argument('--dump_gate_weights', dest='dump_gate_weights', action='store_true')
58 | parser.add_argument('--test_only', dest='test_only', action='store_true')
59 | parser.add_argument('--batch_size', type=int, default=16)
60 | parser.add_argument('--char_batch_size', type=int, default=0)
61 | # In PyTorch 0.3, Batch Norm no longer works for size 1 batch,
62 | # so we will skip leftover batch of size < batch_size
63 | parser.add_argument('--no_skip_leftover_batch', dest='skip_leftover_batch', action='store_false', default=True)
64 | parser.add_argument('--learning_rate', type=float, default=0.001)
65 | parser.add_argument('--D_learning_rate', type=float, default=0.001)
66 | parser.add_argument('--weight_decay', type=float, default=1e-8)
67 | parser.add_argument('--D_weight_decay', type=float, default=1e-8)
68 | # decay lr if validation f1 not increasing
69 | parser.add_argument('--lr_decay', type=float, default=1)
70 | parser.add_argument('--lr_decay_epochs', type=int, default=3)
71 | parser.add_argument('--fix_emb', dest='fix_emb', action='store_true', default=True)
72 | parser.add_argument('--no_fix_emb', dest='fix_emb', action='store_false')
73 | parser.add_argument('--random_emb', action='store_true', default=False)
74 | parser.add_argument('--F_layers', type=int, default=2)
75 | parser.add_argument('--model', default='lstm')
76 | # for LSTM model
77 | parser.add_argument('--bdrnn/', dest='bdrnn', action='store_true', default=True) # bi-directional LSTM
78 | parser.add_argument('--no_bdrnn/', dest='bdrnn', action='store_false', default=True) # bi-directional LSTM
79 | # for CNN model
80 | parser.add_argument('--kernel_num', type=int, default=200)
81 | parser.add_argument('--kernel_sizes', type=int, nargs='+', default=[3,4,5])
82 | # D parameters
83 | parser.add_argument('--D_layers', type=int, default=1)
84 | parser.add_argument('--D_model', default='cnn')
85 | # for LSTM D model
86 | parser.add_argument('--D_lstm_layers', type=int, default=1)
87 | parser.add_argument('--D_attn', default='dot') # attention mechanism (for LSTM): avg, last, dot
88 | parser.add_argument('--D_bdrnn/', dest='D_bdrnn', action='store_true', default=True) # bi-directional LSTM
89 | parser.add_argument('--D_no_bdrnn/', dest='D_bdrnn', action='store_false', default=True) # bi-directional LSTM
90 | # for CNN model
91 | parser.add_argument('--D_kernel_num', type=int, default=200)
92 | parser.add_argument('--D_kernel_sizes', type=int, nargs='+', default=[3,4,5])
93 |
94 | parser.add_argument('--C_layers', type=int, default=1)
95 | # gate the C input (the features)
96 | parser.add_argument('--C_input_gate/', dest='C_input_gate', action='store_true', default=False)
97 | # see the MAN paper; in this work, we only implement the GR loss
98 | parser.add_argument('--loss', default='gr')
99 | parser.add_argument('--shared_hidden_size', type=int, default=128)
100 | parser.add_argument('--private_hidden_size', type=int, default=128)
101 | # if concat_sp, the shared and private features are concatenated
102 | # otherwise, they are added (and private_hidden_size must be = shared_hidden_size)
103 | parser.add_argument('--concat_sp', dest='concat_sp', action='store_true', default=True)
104 | parser.add_argument('--add_sp', dest='concat_sp', action='store_false')
105 | parser.add_argument('--sp_attn', dest='sp_attn', action='store_true')
106 | parser.add_argument('--sp_sigmoid_attn', dest='sp_sigmoid_attn', action='store_true')
107 | parser.add_argument('--activation', default='relu') # relu, leaky
108 | parser.add_argument('--wgan_trick/', dest='wgan_trick', action='store_true', default=False)
109 | parser.add_argument('--no_wgan_trick/', dest='wgan_trick', action='store_false')
110 | parser.add_argument('--n_critic', type=int, default=2) # hyperparameter k in the paper
111 | parser.add_argument('--lambd', type=float, default=0.002)
112 | # lambda scheduling: not used
113 | parser.add_argument('--lambd_schedule', dest='lambd_schedule', action='store_true', default=False)
114 | # gradient penalty: not used
115 | parser.add_argument('--grad_penalty', dest='grad_penalty', default='none') #none, wgan or dragan
116 | parser.add_argument('--onesided_gp', dest='onesided_gp', action='store_true')
117 | parser.add_argument('--gp_lambd', type=float, default=0.1)
118 | # orthogality penalty: not used
119 | parser.add_argument('--ortho_penalty', dest='ortho_penalty', type=float, default=0)
120 | # batch normalization
121 | parser.add_argument('--F_bn/', dest='F_bn', action='store_true', default=False)
122 | parser.add_argument('--no_F_bn/', dest='F_bn', action='store_false')
123 | parser.add_argument('--C_bn/', dest='C_bn', action='store_true', default=False)
124 | parser.add_argument('--no_C_bn/', dest='C_bn', action='store_false')
125 | parser.add_argument('--D_bn/', dest='D_bn', action='store_true', default=False)
126 | parser.add_argument('--no_D_bn/', dest='D_bn', action='store_false')
127 | parser.add_argument('--word_dropout', type=float, default=0.5)
128 | parser.add_argument('--char_dropout', type=float, default=0)
129 | parser.add_argument('--dropout', type=float, default=0.5)
130 | parser.add_argument('--mlp_dropout', type=float, default=0.5)
131 | parser.add_argument('--D_word_dropout', type=float, default=0)
132 | parser.add_argument('--D_dropout', type=float, default=0.5)
133 | parser.add_argument('--device/', dest='device', type=str, default='cuda')
134 | parser.add_argument('--use_data_parallel/', dest='use_data_parallel', action='store_true', default=False)
135 | parser.add_argument('--debug/', dest='debug', action='store_true')
136 | ###### MoE options ######
137 | # a shared LSTM is used, and each expert is a MLP
138 | parser.add_argument('--F_hidden_size', type=int, default=128)
139 | parser.add_argument('--expert_hidden_size', type=int, default=128)
140 | # if True, a shared expert is added (not used in paper)
141 | parser.add_argument('--expert_sp/', dest='expert_sp', action='store_true', default=False)
142 | parser.add_argument('--gate_loss_weight', type=float, default=0.01)
143 | parser.add_argument('--C_gate_loss_weight', type=float, default=0.01)
144 | parser.add_argument('--moe_last_act', type=str, default='tanh') # tanh or relu
145 | parser.add_argument('--detach_gate_input/', dest='detach_gate_input', action='store_true', default=True)
146 | parser.add_argument('--no_detach_gate_input/', dest='detach_gate_input', action='store_false')
147 | # only effective if MoE is added between LSTMs (otherwise, specify C_layers)
148 | parser.add_argument('--MoE_layers', type=int, default=2)
149 | parser.add_argument('--MoE_bn/', dest='MoE_bn', action='store_true', default=False)
150 | parser.add_argument('--no_MoE_bn/', dest='MoE_bn', action='store_false')
151 | ### Cross-Lingual Text Classification Options ###
152 | parser.add_argument('--F_attn', default='dot') # attention mechanism (for LSTM): avg, last, dot
153 | parser.add_argument('--Fp_MoE', dest='Fp_MoE', action='store_true', default=True)
154 | parser.add_argument('--no_Fp_MoE/', dest='Fp_MoE', action='store_false')
155 | parser.add_argument('--C_MoE', dest='C_MoE', action='store_true', default=True)
156 | parser.add_argument('--no_C_MoE/', dest='C_MoE', action='store_false')
157 | opt = parser.parse_args()
158 |
159 | # automatically prepared options
160 | if not torch.cuda.is_available():
161 | opt.device = 'cpu'
162 |
163 | opt.all_langs = opt.langs + opt.unlabeled_langs
164 | if len(opt.dev_langs) == 0:
165 | opt.dev_langs = opt.all_langs
166 |
167 | DEFAULT_EMB = {
168 | 'en-us': './data/embeddings/muse/wiki.200k.en.vec',
169 | 'de-de': './data/embeddings/muse/muse.200k.de-en.de.vec',
170 | 'es-es': './data/embeddings/muse/muse.200k.es-en.es.vec',
171 | 'zh-cn': './data/embeddings/muse/muse.200k.zh-en.zh.vec',
172 | 'en': './data/embeddings/muse/wiki.200k.en.vec',
173 | 'de': './data/embeddings/muse/muse.200k.de-en.de.vec',
174 | 'es': './data/embeddings/muse/muse.200k.es-en.es.vec',
175 | 'zh': './data/embeddings/muse/muse.200k.zh-en.zh.vec',
176 | 'fr': './data/embeddings/muse/muse.200k.fr-en.fr.vec',
177 | 'ja': './data/embeddings/muse/muse.200k.ja-en.ja.vec',
178 | 'eng': './data/embeddings/muse/wiki.200k.en.vec',
179 | 'deu': './data/embeddings/muse/muse.200k.de-en.de.vec',
180 | 'esp': './data/embeddings/muse/muse.200k.es-en.es.vec',
181 | 'ned': './data/embeddings/muse/muse.200k.nl-en.nl.vec'
182 | }
183 | UMWE_EMB = {
184 | 'en': './data/embeddings/wiki.200k.en.vec',
185 | 'de': './data/embeddings/umwe-fullexport/umwe.200k.de-en.de.vec',
186 | 'es': './data/embeddings/umwe-fullexport/umwe.200k.es-en.es.vec',
187 | 'zh': './data/embeddings/umwe-fullexport/umwe.200k.zh-en.zh.vec',
188 | 'fr': './data/embeddings/umwe-fullexport/umwe.200k.fr-en.fr.vec',
189 | 'ja': './data/embeddings/umwe-fullexport/umwe.200k.ja-en.ja.vec',
190 | 'eng': './data/embeddings/wiki.200k.en.vec',
191 | 'deu': './data/embeddings/umwe-fullexport/umwe.200k.deesnl2en.de.vec',
192 | 'esp': './data/embeddings/umwe-fullexport/umwe.200k.deesnl2en.es.vec',
193 | 'ned': './data/embeddings/umwe-fullexport/umwe.200k.deesnl2en.nl.vec'
194 | }
195 | DEFAULT_VECMAP_EMB = {
196 | 'en-us': './data/embeddings/vecmap/vecmap_idchar_ortho.200k.de-en.en.vec',
197 | 'de-de': './data/embeddings/vecmap/vecmap_idchar_ortho.200k.de-en.de.vec',
198 | 'es-es': './data/embeddings/vecmap/vecmap_idchar_ortho.200k.es-en.es.vec',
199 | 'zh-cn': './data/embeddings/vecmap/vecmap_idchar_ortho.200k.zh-en.zh.vec',
200 | 'en': './data/embeddings/vecmap/vecmap_idchar_ortho.200k.de-en.en.vec',
201 | 'de': './data/embeddings/vecmap/vecmap_idchar_ortho.200k.de-en.de.vec',
202 | 'fr': './data/embeddings/vecmap/vecmap_idchar_ortho.200k.fr-en.fr.vec',
203 | 'ja': './data/embeddings/vecmap/vecmap_idchar_ortho.200k.ja-en.ja.vec',
204 | 'eng': './data/embeddings/vecmap/vecmap_idchar_ortho.200k.de-en.en.vec',
205 | 'deu': './data/embeddings/vecmap/vecmap_idchar_ortho.200k.de-en.de.vec',
206 | 'esp': './data/embeddings/vecmap/vecmap_idchar_ortho.200k.es-en.es.vec',
207 | 'ned': './data/embeddings/vecmap/vecmap_idchar_ortho.200k.nl-en.nl.vec',
208 | }
209 |
210 | # init emb filenames
211 | assert len(opt.emb_filenames) % 2 == 0, "emb_filenames should be in (lang, filename) pairs"
212 | emb_fns = dict(zip(opt.emb_filenames[::2], opt.emb_filenames[1::2]))
213 | opt.emb_filenames = {}
214 | for i, lang in enumerate(opt.all_langs):
215 | if opt.default_emb == 'muse':
216 | def_fn = DEFAULT_EMB[lang]
217 | elif opt.default_emb == 'vecmap':
218 | def_fn = DEFAULT_VECMAP_EMB[lang]
219 | elif opt.default_emb == 'umwe':
220 | def_fn = UMWE_EMB[lang]
221 | else:
222 | def_fn = ''
223 | fn = emb_fns[lang] if lang in emb_fns else def_fn
224 | opt.emb_filenames[lang] = fn
225 |
226 | opt.max_kernel_size = max(opt.kernel_sizes + opt.D_kernel_sizes)
227 | if opt.activation.lower() == 'relu':
228 | opt.act_unit = nn.ReLU()
229 | elif opt.activation.lower() == 'leaky':
230 | opt.act_unit = nn.LeakyReLU()
231 | else:
232 | raise Exception(f'Unknown activation function {opt.activation}')
233 |
234 | opt.total_emb_size = 0
235 | if opt.use_wordemb:
236 | opt.total_emb_size += opt.emb_size
237 | if opt.use_charemb:
238 | opt.total_emb_size += opt.charemb_size
239 |
--------------------------------------------------------------------------------
/scripts/get_overall_perf_amazon.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation. All rights reserved.
2 | # Licensed under the MIT License.
3 |
4 | # python get_overall_perf.py model_dir suffix src_lang
5 | from collections import defaultdict
6 | import os
7 | import re
8 | import sys
9 |
10 |
11 | domains = ['books', 'dvd', 'music']
12 | langs = ['en', 'de', 'fr', 'ja']
13 |
14 |
15 | def get_overall_perf(folder, suffix, source_lang=None):
16 | devperf = defaultdict(lambda: {})
17 | testperf = defaultdict(lambda: {})
18 | for domain in domains:
19 | for i, lang in enumerate(langs):
20 | if source_lang:
21 | if source_lang == 1:
22 | lang = f'en2{langs[i]}'
23 | if lang == 'en':
24 | continue
25 | elif source_lang == 3:
26 | srcs = [l for l in langs if l != langs[i]]
27 | lang = ''.join(srcs)+'2'+langs[i]
28 | logfile = os.path.join(folder, f"{domain}_{lang}_{suffix}", 'log.txt')
29 | if not os.path.exists(logfile):
30 | print('File not found:', logfile)
31 | continue
32 | else:
33 | print('Processing file:', logfile)
34 | with open(logfile) as inf:
35 | lines = inf.readlines()[-2:]
36 | try:
37 | devperf[domain][lang] = float(lines[0].split()[-1])
38 | testperf[domain][lang] = float(lines[1].split()[-1])
39 | except:
40 | print('Errors in ', logfile)
41 |
42 | rowtemp = "{0:8}{1:8}\t{2:8}"
43 | for domain in domains:
44 | if len(devperf[domain]) > 0:
45 | print(f'Domain: {domain}, Dev and Test')
46 | print(rowtemp.format('Lang', 'F1', 'F1'))
47 | for lang in devperf[domain]:
48 | row = [lang] + [devperf[domain][lang]] + [testperf[domain][lang]]
49 | print(rowtemp.format(*row))
50 | print(rowtemp.format(*['Avg',
51 | sum(devperf[domain].values())/len(devperf[domain]),
52 | sum(testperf[domain].values())/len(testperf[domain])]))
53 | print()
54 |
55 |
56 | if __name__ == '__main__':
57 | assert len(sys.argv) > 1, 'Model dir is required.'
58 | suffix = sys.argv[2]
59 | src = 3
60 | if len(sys.argv) > 3:
61 | src = int(sys.argv[3])
62 |
63 | get_overall_perf(sys.argv[1], suffix, src)
64 | for seed in range(1, 6):
65 | new_suffix = suffix + '_seed' + str(seed)
66 | print(f"Results for seed {seed}:")
67 | get_overall_perf(sys.argv[1], new_suffix, src)
68 |
--------------------------------------------------------------------------------
/scripts/get_overall_perf_ner.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation. All rights reserved.
2 | # Licensed under the MIT License.
3 |
4 | # python get_overall_perf.py model_dir suffix src_lang
5 | from collections import defaultdict
6 | import os
7 | import re
8 | import sys
9 |
10 |
11 | langs = ['eng', 'deu', 'esp', 'ned']
12 | short_langs = ['en', 'de', 'es', 'nl']
13 | # prog = re.compile(r".*, p: (\d\.\d+), r: (\d\.\d+), f: (\d\.\d+).*")
14 |
15 |
16 | def get_overall_perf(folder, suffix, source_lang=None):
17 | devperf = {}
18 | testperf = {}
19 | for i, lang in enumerate(langs):
20 | if source_lang:
21 | if lang.startswith('en'):
22 | continue
23 | if source_lang == 1:
24 | lang = f'en2{short_langs[i]}'
25 | elif source_lang == 3:
26 | srcs = [l for l in short_langs if l != short_langs[i]]
27 | lang = ''.join(srcs)+'2'+short_langs[i]
28 | logfile = os.path.join(folder, f"conll_ner_{lang}_{suffix}", 'log.txt')
29 | # logfile = os.path.join(folder, f"{domain}_{suffix}_{lang}", 'log.txt')
30 | if not os.path.exists(logfile):
31 | print('File not found:', logfile)
32 | continue
33 | else:
34 | print('Processing file:', logfile)
35 | with open(logfile) as inf:
36 | lines = inf.readlines()[-2:]
37 | try:
38 | devperf[lang] = float(lines[0].split()[-1])
39 | testperf[lang] = float(lines[1].split()[-1])
40 | except:
41 | print('Errors in ', logfile)
42 |
43 | rowtemp = "{0:8}{1:8}\t{2:8}"
44 | if len(devperf) > 0:
45 | print(f'Dev and Test')
46 | print(rowtemp.format('Lang', 'F1', 'F1'))
47 | for lang in devperf:
48 | row = [lang] + [devperf[lang]] + [testperf[lang]]
49 | print(rowtemp.format(*row))
50 | print(rowtemp.format(*['Avg',
51 | sum(devperf.values())/len(devperf),
52 | sum(testperf.values())/len(testperf)]))
53 | print()
54 |
55 |
56 | if __name__ == '__main__':
57 | assert len(sys.argv) > 1, 'Model dir is required.'
58 | suffix = sys.argv[2]
59 | src = 3 # number of source languages
60 | if len(sys.argv) > 3:
61 | src = int(sys.argv[3])
62 |
63 | get_overall_perf(sys.argv[1], suffix, src)
64 | for seed in range(1, 6):
65 | new_suffix = suffix + '_seed' + str(seed)
66 | print(f"Results for seed {seed}:")
67 | get_overall_perf(sys.argv[1], new_suffix, src)
68 |
--------------------------------------------------------------------------------
/scripts/train_amazon_3to1.sh:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation. All rights reserved.
2 | # Licensed under the MIT License.
3 |
4 | domains=("books" "dvd" "music")
5 | langs=("fr" "ja" "de")
6 | for domain in ${domains[@]}
7 | do
8 | python train_cls_man_moe.py --langs en de fr --unlabeled_langs ja --dev_langs ja --domain "$domain" --model_save_file "save/${domain}_endefr2ja_$1/" --fix_emb --model lstm --batch_size 16 --default_emb vecmap --no_charemb --n_critic 1 --lambd 0.002 "${@:2}"
9 | python train_cls_man_moe.py --langs en de ja --unlabeled_langs fr --dev_langs fr --domain "$domain" --model_save_file "save/${domain}_endeja2fr_$1/" --fix_emb --model lstm --batch_size 16 --default_emb vecmap --no_charemb --n_critic 1 --lambd 0.002 "${@:2}"
10 | python train_cls_man_moe.py --langs en fr ja --unlabeled_langs de --dev_langs de --domain "$domain" --model_save_file "save/${domain}_enfrja2de_$1/" --fix_emb --model lstm --batch_size 16 --default_emb vecmap --no_charemb --n_critic 1 --lambd 0.002 "${@:2}"
11 | done
12 |
--------------------------------------------------------------------------------
/scripts/train_conll_ner_3to1.sh:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation. All rights reserved.
2 | # Licensed under the MIT License.
3 |
4 | python3 train_tagging_man_moe.py --dataset conll --langs eng deu esp --unlabeled_langs ned --dev_langs ned --model_save_file "save/conll_ner_endees2nl_$1/" --fix_emb --default_emb umwe --private_hidden_size 200 --shared_hidden_size 200 --n_critic 1 "${@:2}"
5 | python3 train_tagging_man_moe.py --dataset conll --langs eng deu ned --unlabeled_langs esp --dev_langs esp --model_save_file "save/conll_ner_endenl2es_$1/" --fix_emb --default_emb umwe --private_hidden_size 200 --shared_hidden_size 200 --n_critic 1 "${@:2}"
6 | python3 train_tagging_man_moe.py --dataset conll --langs eng esp ned --unlabeled_langs deu --dev_langs deu --model_save_file "save/conll_ner_enesnl2de_$1/" --fix_emb --lowercase_char --default_emb umwe --private_hidden_size 200 --shared_hidden_size 200 --n_critic 1 "${@:2}"
7 |
--------------------------------------------------------------------------------
/train_cls_man_moe.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation. All rights reserved.
2 | # Licensed under the MIT License.
3 |
4 | # A shared LSTM + MAN for learning domain-invariant features
5 | # Another shared LSTM with MoE on top for learning domain-specific features
6 | # MoE MLP tagger
7 | from collections import defaultdict
8 | import io
9 | import itertools
10 | import logging
11 | import math
12 | import os
13 | import pickle
14 | import random
15 | import sys
16 | from tqdm import tqdm
17 |
18 | import numpy as np
19 | import torch
20 | import torch.nn as nn
21 | import torch.nn.functional as functional
22 | import torch.optim as optim
23 | from torch.utils.data import ConcatDataset, DataLoader
24 | from torchnet.meter import ConfusionMeter
25 |
26 | from options import opt
27 | random.seed(opt.random_seed)
28 | np.random.seed(opt.random_seed)
29 | torch.manual_seed(opt.random_seed)
30 | torch.cuda.manual_seed_all(opt.random_seed)
31 |
32 | from data_prep.multi_lingual_amazon import *
33 | from models import *
34 | import utils
35 | from vocab import Vocab, TagVocab
36 |
37 | # save models and logging
38 | if not os.path.exists(opt.model_save_file):
39 | os.makedirs(opt.model_save_file)
40 | logging.basicConfig(stream=sys.stderr, level=logging.DEBUG if opt.debug else logging.INFO)
41 | log = logging.getLogger(__name__)
42 | fh = logging.FileHandler(os.path.join(opt.model_save_file, 'log.txt'))
43 | log.addHandler(fh)
44 | # output options
45 | log.info(opt)
46 |
47 |
48 | def train(vocabs, char_vocab, train_sets, dev_sets, test_sets, unlabeled_sets):
49 | """
50 | train_sets, dev_sets, test_sets: dict[lang] -> AmazonDataset
51 | For unlabeled langs, no train_sets are available
52 | """
53 | # dataset loaders
54 | train_loaders, unlabeled_loaders = {}, {}
55 | train_iters, unlabeled_iters, d_unlabeled_iters = {}, {}, {}
56 | dev_loaders, test_loaders = {}, {}
57 | my_collate = utils.sorted_cls_collate if opt.model=='lstm' else utils.unsorted_cls_collate
58 | for lang in opt.langs:
59 | train_loaders[lang] = DataLoader(train_sets[lang],
60 | opt.batch_size, shuffle=True, collate_fn = my_collate)
61 | train_iters[lang] = iter(train_loaders[lang])
62 | for lang in opt.dev_langs:
63 | dev_loaders[lang] = DataLoader(dev_sets[lang],
64 | opt.batch_size, shuffle=False, collate_fn = my_collate)
65 | test_loaders[lang] = DataLoader(test_sets[lang],
66 | opt.batch_size, shuffle=False, collate_fn = my_collate)
67 | for lang in opt.all_langs:
68 | if lang in opt.unlabeled_langs:
69 | uset = unlabeled_sets[lang]
70 | else:
71 | # for labeled langs, consider which data to use as unlabeled set
72 | if opt.unlabeled_data == 'both':
73 | uset = ConcatDataset([train_sets[lang], unlabeled_sets[lang]])
74 | elif opt.unlabeled_data == 'unlabeled':
75 | uset = unlabeled_sets[lang]
76 | elif opt.unlabeled_data == 'train':
77 | uset = train_sets[lang]
78 | else:
79 | raise Exception(f'Unknown options for the unlabeled data usage: {opt.unlabeled_data}')
80 | unlabeled_loaders[lang] = DataLoader(uset,
81 | opt.batch_size, shuffle=True, collate_fn = my_collate)
82 | unlabeled_iters[lang] = iter(unlabeled_loaders[lang])
83 | d_unlabeled_iters[lang] = iter(unlabeled_loaders[lang])
84 |
85 | # embeddings
86 | emb = MultiLangWordEmb(vocabs, char_vocab, opt.use_wordemb, opt.use_charemb).to(opt.device)
87 | # models
88 | F_s = None
89 | F_p = None
90 | C, D = None, None
91 | num_experts = len(opt.langs)+1 if opt.expert_sp else len(opt.langs)
92 | if opt.model.lower() == 'lstm':
93 | if opt.shared_hidden_size > 0:
94 | F_s = LSTMFeatureExtractor(opt.total_emb_size, opt.F_layers, opt.shared_hidden_size,
95 | opt.word_dropout, opt.dropout, opt.bdrnn)
96 | if opt.private_hidden_size > 0:
97 | if not opt.concat_sp:
98 | assert opt.shared_hidden_size == opt.private_hidden_size, "shared dim != private dim when using add_sp!"
99 | if opt.Fp_MoE:
100 | F_p = nn.Sequential(
101 | LSTMFeatureExtractor(opt.total_emb_size, opt.F_layers, opt.private_hidden_size,
102 | opt.word_dropout, opt.dropout, opt.bdrnn),
103 | MixtureOfExperts(opt.MoE_layers, opt.private_hidden_size,
104 | len(opt.langs), opt.private_hidden_size,
105 | opt.private_hidden_size, opt.dropout, opt.MoE_bn, False)
106 | )
107 | else:
108 | raise NotImplemented()
109 | else:
110 | raise Exception(f'Unknown model architecture {opt.model}')
111 |
112 | if opt.C_MoE:
113 | C = SpAttnMixtureOfExperts(opt.C_layers, opt.shared_hidden_size, opt.private_hidden_size, opt.concat_sp,
114 | num_experts, opt.shared_hidden_size + opt.private_hidden_size, opt.num_labels,
115 | opt.mlp_dropout, opt.F_attn, opt.C_bn)
116 | else:
117 | raise NotImplemented()
118 |
119 | if opt.shared_hidden_size > 0 and opt.n_critic > 0:
120 | if opt.D_model.lower() == 'lstm':
121 | d_args = {
122 | 'num_layers': opt.D_lstm_layers,
123 | 'input_size': opt.shared_hidden_size,
124 | 'hidden_size': opt.shared_hidden_size,
125 | 'word_dropout': opt.D_word_dropout,
126 | 'dropout': opt.D_dropout,
127 | 'bdrnn': opt.D_bdrnn,
128 | 'attn_type': opt.D_attn
129 | }
130 | elif opt.D_model.lower() == 'cnn':
131 | d_args = {
132 | 'num_layers': 1,
133 | 'input_size': opt.shared_hidden_size,
134 | 'hidden_size': opt.shared_hidden_size,
135 | 'kernel_num': opt.D_kernel_num,
136 | 'kernel_sizes': opt.D_kernel_sizes,
137 | 'word_dropout': opt.D_word_dropout,
138 | 'dropout': opt.D_dropout
139 | }
140 | else:
141 | d_args = None
142 | if opt.D_model.lower() == 'mlp':
143 | D = MLPLanguageDiscriminator(opt.D_layers, opt.shared_hidden_size,
144 | opt.shared_hidden_size, len(opt.all_langs), opt.loss, opt.D_dropout, opt.D_bn)
145 | else:
146 | D = LanguageDiscriminator(opt.D_model, opt.D_layers,
147 | opt.shared_hidden_size, opt.shared_hidden_size,
148 | len(opt.all_langs), opt.D_dropout, opt.D_bn, d_args)
149 |
150 | F_s, C, D = F_s.to(opt.device) if F_s else None, C.to(opt.device), D.to(opt.device) if D else None
151 | if F_p:
152 | F_p = F_p.to(opt.device)
153 | # optimizers
154 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, itertools.chain(*map(list,
155 | [emb.parameters(), F_s.parameters() if F_s else [], \
156 | C.parameters(), F_p.parameters() if F_p else []]))),
157 | lr=opt.learning_rate,
158 | weight_decay=opt.weight_decay)
159 | if D:
160 | optimizerD = optim.Adam(D.parameters(), lr=opt.D_learning_rate, weight_decay=opt.D_weight_decay)
161 |
162 | # testing
163 | if opt.test_only:
164 | log.info(f'Loading model from {opt.model_save_file}...')
165 | if F_s:
166 | F_s.load_state_dict(torch.load(os.path.join(opt.model_save_file,
167 | f'netF_s.pth')))
168 | for lang in opt.all_langs:
169 | F_p.load_state_dict(torch.load(os.path.join(opt.model_save_file,
170 | f'net_F_p.pth')))
171 | C.load_state_dict(torch.load(os.path.join(opt.model_save_file,
172 | f'netC.pth')))
173 | if D:
174 | D.load_state_dict(torch.load(os.path.join(opt.model_save_file,
175 | f'netD.pth')))
176 |
177 | log.info('Evaluating validation sets:')
178 | acc = {}
179 | for lang in opt.all_langs:
180 | acc[lang] = evaluate_acc(f'{lang}_dev', dev_loaders[lang], vocabs[lang],
181 | emb, lang, F_s, F_p, C)
182 | avg_acc = sum([acc[d] for d in opt.dev_langs]) / len(opt.dev_langs)
183 | log.info(f'Average validation accuracy: {avg_acc}')
184 | log.info('Evaluating test sets:')
185 | test_acc = {}
186 | for lang in opt.all_langs:
187 | test_acc[lang] = evaluate_acc(f'{lang}_test', test_loaders[lang], vocabs[lang],
188 | emb, lang, F_s, F_p, C)
189 | avg_test_acc = sum([test_acc[d] for d in opt.dev_langs]) / len(opt.dev_langs)
190 | log.info(f'Average test accuracy: {avg_test_acc}')
191 | return {'valid': acc, 'test': test_acc}
192 |
193 | # training
194 | best_acc, best_avg_acc = defaultdict(float), 0.0
195 | # lambda scheduling
196 | if opt.lambd > 0 and opt.lambd_schedule:
197 | opt.lambd_orig = opt.lambd
198 | num_iter = int(utils.gmean([len(train_loaders[l]) for l in opt.langs]))
199 | for epoch in range(opt.max_epoch):
200 | emb.train()
201 | if F_s:
202 | F_s.train()
203 | C.train()
204 | if D:
205 | D.train()
206 | if F_p:
207 | F_p.train()
208 | # lambda scheduling
209 | if hasattr(opt, 'lambd_orig') and opt.lambd_schedule:
210 | if epoch == 0:
211 | opt.lambd = opt.lambd_orig
212 | elif epoch == 5:
213 | opt.lambd = 10 * opt.lambd_orig
214 | elif epoch == 15:
215 | opt.lambd = 100 * opt.lambd_orig
216 | log.info(f'Scheduling lambda = {opt.lambd}')
217 | # training accuracy
218 | correct, total = defaultdict(int), defaultdict(int)
219 | gate_correct = defaultdict(int)
220 | gate_total = defaultdict(int)
221 | c_gate_correct = defaultdict(int)
222 | # D accuracy
223 | d_correct, d_total = 0, 0
224 | for i in tqdm(range(num_iter), ascii=True):
225 | # D iterations
226 | if opt.shared_hidden_size > 0:
227 | utils.freeze_net(emb)
228 | utils.freeze_net(F_s)
229 | utils.freeze_net(F_p)
230 | utils.freeze_net(C)
231 | utils.unfreeze_net(D)
232 | # WGAN n_critic trick since D trains slower
233 | n_critic = opt.n_critic
234 | if opt.wgan_trick:
235 | if opt.n_critic>0 and ((epoch==0 and i<25) or i%500==0):
236 | n_critic = 100
237 |
238 | for _ in range(n_critic):
239 | D.zero_grad()
240 | loss_d = {}
241 | lang_features = {}
242 | # train on both labeled and unlabeled langs
243 | for lang in opt.all_langs:
244 | # targets not used
245 | d_inputs, _ = utils.endless_get_next_batch(
246 | unlabeled_loaders, d_unlabeled_iters, lang)
247 | d_inputs, d_lengths, mask, d_chars, d_char_lengths = d_inputs
248 | d_embeds = emb(lang, d_inputs, d_chars, d_char_lengths)
249 | shared_feat = F_s((d_embeds, d_lengths))
250 | if opt.grad_penalty != 'none':
251 | lang_features[lang] = shared_feat.detach()
252 | if opt.D_model.lower() == 'mlp':
253 | d_outputs = D(shared_feat)
254 | # if token-level D, we can reuse the gate label generator
255 | d_targets = utils.get_gate_label(d_outputs, lang, mask, False, all_langs=True)
256 | d_total += torch.sum(d_lengths).item()
257 | else:
258 | d_outputs = D((shared_feat, d_lengths))
259 | d_targets = utils.get_lang_label(opt.loss, lang, len(d_lengths))
260 | d_total += len(d_lengths)
261 | # D accuracy
262 | _, pred = torch.max(d_outputs, -1)
263 | d_correct += (pred==d_targets).sum().item()
264 | l_d = functional.nll_loss(d_outputs.view(-1, D.num_langs),
265 | d_targets.view(-1), ignore_index=-1)
266 | l_d.backward()
267 | loss_d[lang] = l_d.item()
268 | # gradient penalty
269 | if opt.grad_penalty != 'none':
270 | gp = utils.calc_gradient_penalty(D, lang_features,
271 | onesided=opt.onesided_gp, interpolate=(opt.grad_penalty=='wgan'))
272 | gp.backward()
273 | optimizerD.step()
274 |
275 | # F&C iteration
276 | utils.unfreeze_net(emb)
277 | if opt.use_wordemb and opt.fix_emb:
278 | for lang in emb.langs:
279 | emb.wordembs[lang].weight.requires_grad = False
280 | if opt.use_charemb and opt.fix_charemb:
281 | emb.charemb.weight.requires_grad = False
282 | utils.unfreeze_net(F_s)
283 | utils.unfreeze_net(F_p)
284 | utils.unfreeze_net(C)
285 | utils.freeze_net(D)
286 | emb.zero_grad()
287 | if F_s:
288 | F_s.zero_grad()
289 | if F_p:
290 | F_p.zero_grad()
291 | C.zero_grad()
292 | # optimizer.zero_grad()
293 | for lang in opt.langs:
294 | inputs, targets = utils.endless_get_next_batch(
295 | train_loaders, train_iters, lang)
296 | inputs, lengths, mask, chars, char_lengths = inputs
297 | # bs, seq_len = inputs.size()
298 | embeds = emb(lang, inputs, chars, char_lengths)
299 | shared_feat, private_feat = None, None
300 | if opt.shared_hidden_size > 0:
301 | shared_feat = F_s((embeds, lengths))
302 | if opt.private_hidden_size > 0:
303 | if opt.Fp_MoE:
304 | private_feat, gate_outputs = F_p((embeds, lengths))
305 | else:
306 | private_feat = F_p((embeds, lengths))
307 | if opt.C_MoE:
308 | c_outputs, c_gate_outputs = C((shared_feat, private_feat, lengths))
309 | else:
310 | c_outputs = C((shared_feat, private_feat))
311 | # targets are padded with -1
312 | l_c = functional.nll_loss(c_outputs, targets)
313 | if F_p and opt.Fp_MoE:
314 | # token-level gate loss
315 | gate_targets = utils.get_gate_label(gate_outputs, lang, mask, False)
316 | l_gate = functional.cross_entropy(gate_outputs.view(-1, gate_outputs.size(-1)),
317 | gate_targets.view(-1), ignore_index=-1)
318 | l_c += opt.gate_loss_weight * l_gate
319 | _, gate_pred = torch.max(gate_outputs.view(-1, gate_outputs.size(-1)), -1)
320 | gate_correct[lang] += (gate_pred == gate_targets.view(-1)).sum().item()
321 | gate_total[lang] += torch.sum(lengths).item()
322 | if opt.C_MoE and opt.C_gate_loss_weight > 0:
323 | c_gate_targets = utils.get_cls_gate_label(c_gate_outputs, lang, opt.expert_sp)
324 | _, c_gate_pred = torch.max(c_gate_outputs, -1)
325 | if opt.expert_sp:
326 | l_c_gate = functional.binary_cross_entropy_with_logits(c_gate_outputs, gate_targets)
327 | c_gate_correct[lang] += torch.index_select(c_gate_targets,
328 | -1, c_gate_pred).sum().item()
329 | else:
330 | l_c_gate = functional.cross_entropy(c_gate_outputs, c_gate_targets)
331 | c_gate_correct[lang] += (c_gate_pred == c_gate_targets).sum().item()
332 | l_c += opt.C_gate_loss_weight * l_c_gate
333 | l_c.backward()
334 | _, pred = torch.max(c_outputs, -1)
335 | total[lang] += targets.size(0)
336 | correct[lang] += (pred == targets).sum().item()
337 |
338 | # update F with D gradients on all langs
339 | if D:
340 | for lang in opt.all_langs:
341 | inputs, _ = utils.endless_get_next_batch(
342 | unlabeled_loaders, unlabeled_iters, lang)
343 | inputs, lengths, mask, chars, char_lengths = inputs
344 | embeds = emb(lang, inputs, chars, char_lengths)
345 | shared_feat = F_s((embeds, lengths))
346 | if opt.D_model.lower() == 'mlp':
347 | d_outputs = D(shared_feat)
348 | # if token-level D, we can reuse the gate label generator
349 | d_targets = utils.get_gate_label(d_outputs, lang, mask, False, all_langs=True)
350 | else:
351 | d_outputs = D((shared_feat, lengths))
352 | d_targets = utils.get_lang_label(opt.loss, lang, len(lengths))
353 | l_d = functional.nll_loss(d_outputs.view(-1, D.num_langs),
354 | d_targets.view(-1), ignore_index=-1)
355 | if opt.lambd > 0:
356 | l_d *= -opt.lambd
357 | l_d.backward()
358 |
359 | optimizer.step()
360 |
361 | # end of epoch
362 | log.info('Ending epoch {}'.format(epoch+1))
363 | if d_total > 0:
364 | log.info('D Training Accuracy: {}%'.format(100.0*d_correct/d_total))
365 | log.info('Training accuracy:')
366 | log.info('\t'.join(opt.langs))
367 | log.info('\t'.join([str(100.0*correct[d]/total[d]) for d in opt.langs]))
368 | if F_p:
369 | log.info('Gate accuracy:')
370 | log.info('\t'.join([str(100.0*gate_correct[d]/gate_total[d]) for d in opt.langs]))
371 | if opt.C_MoE:
372 | log.info('Tagger Gate accuracy:')
373 | log.info('\t'.join([str(100.0*c_gate_correct[d]/total[d]) for d in opt.langs]))
374 | log.info('Evaluating validation sets:')
375 | acc = {}
376 | for lang in opt.dev_langs:
377 | acc[lang] = evaluate_acc(f'{lang}_dev', dev_loaders[lang], vocabs[lang],
378 | emb, lang, F_s, F_p, C)
379 | avg_acc = sum([acc[d] for d in opt.dev_langs]) / len(opt.dev_langs)
380 | log.info(f'Average validation accuracy: {avg_acc}')
381 | log.info('Evaluating test sets:')
382 | test_acc = {}
383 | for lang in opt.dev_langs:
384 | test_acc[lang] = evaluate_acc(f'{lang}_test', test_loaders[lang], vocabs[lang],
385 | emb, lang, F_s, F_p, C)
386 | avg_test_acc = sum([test_acc[d] for d in opt.dev_langs]) / len(opt.dev_langs)
387 | log.info(f'Average test accuracy: {avg_test_acc}')
388 |
389 | if avg_acc > best_avg_acc:
390 | log.info(f'New best average validation accuracy: {avg_acc}')
391 | best_acc['valid'] = acc
392 | best_acc['test'] = test_acc
393 | best_avg_acc = avg_acc
394 | with open(os.path.join(opt.model_save_file, 'options.pkl'), 'wb') as ouf:
395 | pickle.dump(opt, ouf)
396 | if F_s:
397 | torch.save(F_s.state_dict(),
398 | '{}/netF_s.pth'.format(opt.model_save_file))
399 | torch.save(emb.state_dict(),
400 | '{}/net_emb.pth'.format(opt.model_save_file))
401 | if F_p:
402 | torch.save(F_p.state_dict(),
403 | '{}/net_F_p.pth'.format(opt.model_save_file))
404 | torch.save(C.state_dict(),
405 | '{}/netC.pth'.format(opt.model_save_file))
406 | if D:
407 | torch.save(D.state_dict(),
408 | '{}/netD.pth'.format(opt.model_save_file))
409 |
410 | # end of training
411 | log.info(f'Best average validation accuracy: {best_avg_acc}')
412 | return best_acc
413 |
414 |
415 | def evaluate_acc(name, loader, vcoab, emb, lang, F_s, F_p, C):
416 | emb.eval()
417 | if F_s:
418 | F_s.eval()
419 | if F_p:
420 | F_p.eval()
421 | C.eval()
422 | it = iter(loader)
423 | correct = 0
424 | total = 0
425 | confusion = ConfusionMeter(opt.num_labels)
426 |
427 | with torch.no_grad():
428 | for inputs, targets in tqdm(it, ascii=True):
429 | inputs, lengths, mask, chars, char_lengths = inputs
430 | embeds = (emb(lang, inputs, chars, char_lengths), lengths)
431 | shared_features, lang_features = None, None
432 | if opt.shared_hidden_size > 0:
433 | shared_features = F_s(embeds)
434 | if opt.private_hidden_size > 0:
435 | if not F_p:
436 | # unlabeled lang
437 | lang_features = torch.zeros(targets.size(0),
438 | targets.size(1), opt.private_hidden_size).to(opt.device)
439 | else:
440 | if opt.Fp_MoE:
441 | lang_features, gate_outputs = F_p(embeds)
442 | else:
443 | lang_features = F_p(embeds)
444 | if opt.C_MoE:
445 | outputs, _ = C((shared_features, lang_features, lengths))
446 | else:
447 | outputs = C((shared_features, lang_features))
448 | _, pred = torch.max(outputs, -1)
449 | confusion.add(pred.detach(), targets.detach())
450 | total += targets.size(0)
451 | correct += (pred == targets).sum().item()
452 | acc = correct / total
453 | log.info('{}: Accuracy on {} samples: {}%'.format(name, total, 100.0*acc))
454 | log.debug(confusion.conf)
455 | return acc
456 |
457 |
458 | def main():
459 | if not os.path.exists(opt.model_save_file):
460 | os.makedirs(opt.model_save_file)
461 | log.info('Running the S-MAN + P-MoE + C-MoE model...')
462 | vocabs = {}
463 | assert opt.use_wordemb or opt.use_charemb, "At least one of word or char embeddings must be used!"
464 | char_vocab = Vocab(opt.charemb_size) if opt.use_charemb else None
465 | log.info(f'Loading Datasets...')
466 | log.info(f'Domain: {opt.domain}')
467 | log.info(f'Languages {opt.langs}')
468 |
469 | log.info('Loading Embeddings...')
470 | train_sets, dev_sets, test_sets, unlabeled_sets = {}, {}, {}, {}
471 | for lang in opt.all_langs:
472 | log.info(f'Building Vocab for {lang}...')
473 | vocabs[lang] = Vocab(opt.emb_size, opt.emb_filenames[lang])
474 | assert not opt.train_on_translation or not opt.test_on_translation
475 | train_sets[lang], dev_sets[lang], test_sets[lang], unlabeled_sets[lang] = \
476 | get_multi_lingual_amazon_datasets(vocabs[lang], char_vocab, opt.amazon_dir,
477 | opt.domain, lang, opt.max_seq_len)
478 | opt.num_labels = MultiLangAmazonDataset.num_labels
479 | log.info(f'Done Loading Datasets.')
480 |
481 | cv = train(vocabs, char_vocab, train_sets, dev_sets, test_sets, unlabeled_sets)
482 | log.info(f'Training done...')
483 | acc = sum(cv['valid'].values()) / len(cv['valid'])
484 | log.info(f'Validation Set Domain Average\t{acc}')
485 | test_acc = sum(cv['test'].values()) / len(cv['test'])
486 | log.info(f'Test Set Domain Average\t{test_acc}')
487 | return cv
488 |
489 |
490 | if __name__ == '__main__':
491 | main()
492 |
--------------------------------------------------------------------------------
/train_tagging_man_moe.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation. All rights reserved.
2 | # Licensed under the MIT License.
3 |
4 | # A shared LSTM + MAN for learning language-invariant features (F_s)
5 | # Another shared LSTM with MoE on top for learning language-specific features (F_p)
6 | # MoE MLP tagger (C)
7 | from collections import defaultdict
8 | import io
9 | import itertools
10 | import logging
11 | import os
12 | import pickle
13 | import random
14 | import shutil
15 | import sys
16 | from tqdm import tqdm
17 |
18 | import numpy as np
19 | import torch
20 | import torch.nn as nn
21 | import torch.nn.functional as functional
22 | import torch.optim as optim
23 | from torch.utils.data import ConcatDataset, DataLoader
24 |
25 | from options import opt
26 | random.seed(opt.random_seed)
27 | np.random.seed(opt.random_seed)
28 | torch.manual_seed(opt.random_seed)
29 | torch.cuda.manual_seed_all(opt.random_seed)
30 |
31 | from data_prep.bio_dataset import *
32 | from models import *
33 | import utils
34 | from vocab import Vocab, TagVocab
35 |
36 | # save models and logging
37 | if not os.path.exists(opt.model_save_file):
38 | os.makedirs(opt.model_save_file)
39 | logging.basicConfig(stream=sys.stderr, level=logging.DEBUG if opt.debug else logging.INFO)
40 | log = logging.getLogger(__name__)
41 | fh = logging.FileHandler(os.path.join(opt.model_save_file, 'log.txt'))
42 | log.addHandler(fh)
43 | # output options
44 | log.info(opt)
45 |
46 |
47 | def train(vocabs, char_vocab, tag_vocab, train_sets, dev_sets, test_sets, unlabeled_sets):
48 | """
49 | train_sets, dev_sets, test_sets: dict[lang] -> AmazonDataset
50 | For unlabeled langs, no train_sets are available
51 | """
52 | # dataset loaders
53 | train_loaders, unlabeled_loaders = {}, {}
54 | train_iters, unlabeled_iters, d_unlabeled_iters = {}, {}, {}
55 | dev_loaders, test_loaders = {}, {}
56 | my_collate = utils.sorted_collate if opt.model=='lstm' else utils.unsorted_collate
57 | for lang in opt.langs:
58 | train_loaders[lang] = DataLoader(train_sets[lang],
59 | opt.batch_size, shuffle=True, collate_fn = my_collate)
60 | train_iters[lang] = iter(train_loaders[lang])
61 | for lang in opt.dev_langs:
62 | dev_loaders[lang] = DataLoader(dev_sets[lang],
63 | opt.batch_size, shuffle=False, collate_fn = my_collate)
64 | test_loaders[lang] = DataLoader(test_sets[lang],
65 | opt.batch_size, shuffle=False, collate_fn = my_collate)
66 | for lang in opt.all_langs:
67 | if lang in opt.unlabeled_langs:
68 | uset = unlabeled_sets[lang]
69 | else:
70 | # for labeled langs, consider which data to use as unlabeled set
71 | if opt.unlabeled_data == 'both':
72 | uset = ConcatDataset([train_sets[lang], unlabeled_sets[lang]])
73 | elif opt.unlabeled_data == 'unlabeled':
74 | uset = unlabeled_sets[lang]
75 | elif opt.unlabeled_data == 'train':
76 | uset = train_sets[lang]
77 | else:
78 | raise Exception(f'Unknown options for the unlabeled data usage: {opt.unlabeled_data}')
79 | unlabeled_loaders[lang] = DataLoader(uset,
80 | opt.batch_size, shuffle=True, collate_fn = my_collate)
81 | unlabeled_iters[lang] = iter(unlabeled_loaders[lang])
82 | d_unlabeled_iters[lang] = iter(unlabeled_loaders[lang])
83 |
84 | # embeddings
85 | emb = MultiLangWordEmb(vocabs, char_vocab, opt.use_wordemb, opt.use_charemb).to(opt.device)
86 | # models
87 | F_s = None
88 | F_p = None
89 | C, D = None, None
90 | num_experts = len(opt.langs)+1 if opt.expert_sp else len(opt.langs)
91 | if opt.model.lower() == 'lstm':
92 | if opt.shared_hidden_size > 0:
93 | F_s = LSTMFeatureExtractor(opt.total_emb_size, opt.F_layers, opt.shared_hidden_size,
94 | opt.word_dropout, opt.dropout, opt.bdrnn)
95 | if opt.private_hidden_size > 0:
96 | if not opt.concat_sp:
97 | assert opt.shared_hidden_size == opt.private_hidden_size, "shared dim != private dim when using add_sp!"
98 | F_p = nn.Sequential(
99 | LSTMFeatureExtractor(opt.total_emb_size, opt.F_layers, opt.private_hidden_size,
100 | opt.word_dropout, opt.dropout, opt.bdrnn),
101 | MixtureOfExperts(opt.MoE_layers, opt.private_hidden_size,
102 | len(opt.langs), opt.private_hidden_size,
103 | opt.private_hidden_size, opt.dropout, opt.MoE_bn, False)
104 | )
105 | else:
106 | raise Exception(f'Unknown model architecture {opt.model}')
107 |
108 | if opt.C_MoE:
109 | C = SpMixtureOfExperts(opt.C_layers, opt.shared_hidden_size, opt.private_hidden_size, opt.concat_sp,
110 | num_experts, opt.shared_hidden_size + opt.private_hidden_size, len(tag_vocab),
111 | opt.mlp_dropout, opt.C_bn)
112 | else:
113 | C = SpMlpTagger(opt.C_layers, opt.shared_hidden_size, opt.private_hidden_size, opt.concat_sp,
114 | opt.shared_hidden_size + opt.private_hidden_size, len(tag_vocab),
115 | opt.mlp_dropout, opt.C_bn)
116 | if opt.shared_hidden_size > 0 and opt.n_critic > 0:
117 | if opt.D_model.lower() == 'lstm':
118 | d_args = {
119 | 'num_layers': opt.D_lstm_layers,
120 | 'input_size': opt.shared_hidden_size,
121 | 'hidden_size': opt.shared_hidden_size,
122 | 'word_dropout': opt.D_word_dropout,
123 | 'dropout': opt.D_dropout,
124 | 'bdrnn': opt.D_bdrnn,
125 | 'attn_type': opt.D_attn
126 | }
127 | elif opt.D_model.lower() == 'cnn':
128 | d_args = {
129 | 'num_layers': 1,
130 | 'input_size': opt.shared_hidden_size,
131 | 'hidden_size': opt.shared_hidden_size,
132 | 'kernel_num': opt.D_kernel_num,
133 | 'kernel_sizes': opt.D_kernel_sizes,
134 | 'word_dropout': opt.D_word_dropout,
135 | 'dropout': opt.D_dropout
136 | }
137 | else:
138 | d_args = None
139 |
140 | if opt.D_model.lower() == 'mlp':
141 | D = MLPLanguageDiscriminator(opt.D_layers, opt.shared_hidden_size,
142 | opt.shared_hidden_size, len(opt.all_langs), opt.loss, opt.D_dropout, opt.D_bn)
143 | else:
144 | D = LanguageDiscriminator(opt.D_model, opt.D_layers,
145 | opt.shared_hidden_size, opt.shared_hidden_size,
146 | len(opt.all_langs), opt.D_dropout, opt.D_bn, d_args)
147 | if opt.use_data_parallel:
148 | F_s, C, D = nn.DataParallel(F_s).to(opt.device) if F_s else None, nn.DataParallel(C).to(opt.device), nn.DataParallel(D).to(opt.device) if D else None
149 | else:
150 | F_s, C, D = F_s.to(opt.device) if F_s else None, C.to(opt.device), D.to(opt.device) if D else None
151 | if F_p:
152 | if opt.use_data_parallel:
153 | F_p = nn.DataParallel(F_p).to(opt.device)
154 | else:
155 | F_p = F_p.to(opt.device)
156 | # optimizers
157 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, itertools.chain(*map(list,
158 | [emb.parameters(), F_s.parameters() if F_s else [], \
159 | C.parameters(), F_p.parameters() if F_p else []]))),
160 | lr=opt.learning_rate,
161 | weight_decay=opt.weight_decay)
162 | if D:
163 | optimizerD = optim.Adam(D.parameters(), lr=opt.D_learning_rate, weight_decay=opt.D_weight_decay)
164 |
165 | # testing
166 | if opt.test_only:
167 | log.info(f'Loading model from {opt.model_save_file}...')
168 | if F_s:
169 | F_s.load_state_dict(torch.load(os.path.join(opt.model_save_file,
170 | f'netF_s.pth')))
171 | for lang in opt.all_langs:
172 | F_p.load_state_dict(torch.load(os.path.join(opt.model_save_file,
173 | f'net_F_p.pth')))
174 | C.load_state_dict(torch.load(os.path.join(opt.model_save_file,
175 | f'netC.pth')))
176 | if D:
177 | D.load_state_dict(torch.load(os.path.join(opt.model_save_file,
178 | f'netD.pth')))
179 |
180 | log.info('Evaluating validation sets:')
181 | acc = {}
182 | log.info(dev_loaders)
183 | log.info(vocabs)
184 | for lang in opt.all_langs:
185 | acc[lang] = evaluate(f'{lang}_dev', dev_loaders[lang], vocabs[lang], tag_vocab,
186 | emb, lang, F_s, F_p, C)
187 | avg_acc = sum([acc[d] for d in opt.dev_langs]) / len(opt.dev_langs)
188 | log.info(f'Average validation accuracy: {avg_acc}')
189 | log.info('Evaluating test sets:')
190 | test_acc = {}
191 | for lang in opt.all_langs:
192 | test_acc[lang] = evaluate(f'{lang}_test', test_loaders[lang], vocabs[lang], tag_vocab,
193 | emb, lang, F_s, F_p, C)
194 | avg_test_acc = sum([test_acc[d] for d in opt.dev_langs]) / len(opt.dev_langs)
195 | log.info(f'Average test accuracy: {avg_test_acc}')
196 | return {'valid': acc, 'test': test_acc}
197 |
198 | # training
199 | best_acc, best_avg_acc = defaultdict(float), 0.0
200 | epochs_since_decay = 0
201 | # lambda scheduling
202 | if opt.lambd > 0 and opt.lambd_schedule:
203 | opt.lambd_orig = opt.lambd
204 | num_iter = int(utils.gmean([len(train_loaders[l]) for l in opt.langs]))
205 | # adapt max_epoch
206 | if opt.max_epoch > 0 and num_iter * opt.max_epoch < 15000:
207 | opt.max_epoch = 15000 // num_iter
208 | log.info(f"Setting max_epoch to {opt.max_epoch}")
209 | for epoch in range(opt.max_epoch):
210 | emb.train()
211 | if F_s:
212 | F_s.train()
213 | C.train()
214 | if D:
215 | D.train()
216 | if F_p:
217 | F_p.train()
218 |
219 | # lambda scheduling
220 | if hasattr(opt, 'lambd_orig') and opt.lambd_schedule:
221 | if epoch == 0:
222 | opt.lambd = opt.lambd_orig
223 | elif epoch == 5:
224 | opt.lambd = 10 * opt.lambd_orig
225 | elif epoch == 15:
226 | opt.lambd = 100 * opt.lambd_orig
227 | log.info(f'Scheduling lambda = {opt.lambd}')
228 |
229 | # training accuracy
230 | correct, total = defaultdict(int), defaultdict(int)
231 | gate_correct = defaultdict(int)
232 | c_gate_correct = defaultdict(int)
233 | # D accuracy
234 | d_correct, d_total = 0, 0
235 | for i in tqdm(range(num_iter), ascii=True):
236 | # D iterations
237 | if opt.shared_hidden_size > 0:
238 | utils.freeze_net(emb)
239 | utils.freeze_net(F_s)
240 | utils.freeze_net(F_p)
241 | utils.freeze_net(C)
242 | utils.unfreeze_net(D)
243 | # WGAN n_critic trick since D trains slower
244 | n_critic = opt.n_critic
245 | if opt.wgan_trick:
246 | if opt.n_critic>0 and ((epoch==0 and i<25) or i%500==0):
247 | n_critic = 100
248 |
249 | for _ in range(n_critic):
250 | D.zero_grad()
251 | loss_d = {}
252 | lang_features = {}
253 | # train on both labeled and unlabeled langs
254 | for lang in opt.all_langs:
255 | # targets not used
256 | d_inputs, _ = utils.endless_get_next_batch(
257 | unlabeled_loaders, d_unlabeled_iters, lang)
258 | d_inputs, d_lengths, mask, d_chars, d_char_lengths = d_inputs
259 | d_embeds = emb(lang, d_inputs, d_chars, d_char_lengths)
260 | shared_feat = F_s((d_embeds, d_lengths))
261 | if opt.grad_penalty != 'none':
262 | lang_features[lang] = shared_feat.detach()
263 | if opt.D_model.lower() == 'mlp':
264 | d_outputs = D(shared_feat)
265 | # if token-level D, we can reuse the gate label generator
266 | d_targets = utils.get_gate_label(d_outputs, lang, mask, False, all_langs=True)
267 | d_total += torch.sum(d_lengths).item()
268 | else:
269 | d_outputs = D((shared_feat, d_lengths))
270 | d_targets = utils.get_lang_label(opt.loss, lang, len(d_lengths))
271 | d_total += len(d_lengths)
272 | # D accuracy
273 | _, pred = torch.max(d_outputs, -1)
274 | # d_total += len(d_lengths)
275 | d_correct += (pred==d_targets).sum().item()
276 | if opt.use_data_parallel:
277 | l_d = functional.nll_loss(d_outputs.view(-1, D.module.num_langs),
278 | d_targets.view(-1), ignore_index=-1)
279 | else:
280 | l_d = functional.nll_loss(d_outputs.view(-1, D.num_langs),
281 | d_targets.view(-1), ignore_index=-1)
282 |
283 | l_d.backward()
284 | loss_d[lang] = l_d.item()
285 | # gradient penalty
286 | if opt.grad_penalty != 'none':
287 | gp = utils.calc_gradient_penalty(D, lang_features,
288 | onesided=opt.onesided_gp, interpolate=(opt.grad_penalty=='wgan'))
289 | gp.backward()
290 | optimizerD.step()
291 |
292 | # F&C iteration
293 | utils.unfreeze_net(emb)
294 | if opt.use_wordemb and opt.fix_emb:
295 | for lang in emb.langs:
296 | emb.wordembs[lang].weight.requires_grad = False
297 | if opt.use_charemb and opt.fix_charemb:
298 | emb.charemb.weight.requires_grad = False
299 | utils.unfreeze_net(F_s)
300 | utils.unfreeze_net(F_p)
301 | utils.unfreeze_net(C)
302 | utils.freeze_net(D)
303 | emb.zero_grad()
304 | if F_s:
305 | F_s.zero_grad()
306 | if F_p:
307 | F_p.zero_grad()
308 | C.zero_grad()
309 | # optimizer.zero_grad()
310 | for lang in opt.langs:
311 | inputs, targets = utils.endless_get_next_batch(
312 | train_loaders, train_iters, lang)
313 | inputs, lengths, mask, chars, char_lengths = inputs
314 | bs, seq_len = inputs.size()
315 | embeds = emb(lang, inputs, chars, char_lengths)
316 | shared_feat, private_feat = None, None
317 | if opt.shared_hidden_size > 0:
318 | shared_feat = F_s((embeds, lengths))
319 | if opt.private_hidden_size > 0:
320 | private_feat, gate_outputs = F_p((embeds, lengths))
321 | if opt.C_MoE:
322 | c_outputs, c_gate_outputs = C((shared_feat, private_feat))
323 | else:
324 | c_outputs = C((shared_feat, private_feat))
325 | # targets are padded with -1
326 | l_c = functional.nll_loss(c_outputs.view(bs*seq_len, -1),
327 | targets.view(-1), ignore_index=-1)
328 | # gate loss
329 | if F_p:
330 | gate_targets = utils.get_gate_label(gate_outputs, lang, mask, False)
331 | l_gate = functional.cross_entropy(gate_outputs.view(bs*seq_len, -1),
332 | gate_targets.view(-1), ignore_index=-1)
333 | l_c += opt.gate_loss_weight * l_gate
334 | _, gate_pred = torch.max(gate_outputs.view(bs*seq_len, -1), -1)
335 | gate_correct[lang] += (gate_pred == gate_targets.view(-1)).sum().item()
336 | if opt.C_MoE and opt.C_gate_loss_weight > 0:
337 | c_gate_targets = utils.get_gate_label(c_gate_outputs, lang, mask, opt.expert_sp)
338 | _, c_gate_pred = torch.max(c_gate_outputs.view(bs*seq_len, -1), -1)
339 | if opt.expert_sp:
340 | l_c_gate = functional.binary_cross_entropy_with_logits(
341 | mask.unsqueeze(-1) * c_gate_outputs, c_gate_targets)
342 | c_gate_correct[lang] += torch.index_select(c_gate_targets.view(bs*seq_len, -1),
343 | -1, c_gate_pred.view(bs*seq_len)).sum().item()
344 | else:
345 | l_c_gate = functional.cross_entropy(c_gate_outputs.view(bs*seq_len, -1),
346 | c_gate_targets.view(-1), ignore_index=-1)
347 | c_gate_correct[lang] += (c_gate_pred == c_gate_targets.view(-1)).sum().item()
348 | l_c += opt.C_gate_loss_weight * l_c_gate
349 | l_c.backward()
350 | _, pred = torch.max(c_outputs, -1)
351 | total[lang] += torch.sum(lengths).item()
352 | correct[lang] += (pred == targets).sum().item()
353 |
354 | # update F with D gradients on all langs
355 | if D:
356 | for lang in opt.all_langs:
357 | inputs, _ = utils.endless_get_next_batch(
358 | unlabeled_loaders, unlabeled_iters, lang)
359 | inputs, lengths, mask, chars, char_lengths = inputs
360 | embeds = emb(lang, inputs, chars, char_lengths)
361 | shared_feat = F_s((embeds, lengths))
362 | # d_outputs = D((shared_feat, lengths))
363 | if opt.D_model.lower() == 'mlp':
364 | d_outputs = D(shared_feat)
365 | # if token-level D, we can reuse the gate label generator
366 | d_targets = utils.get_gate_label(d_outputs, lang, mask, False, all_langs=True)
367 | else:
368 | d_outputs = D((shared_feat, lengths))
369 | d_targets = utils.get_lang_label(opt.loss, lang, len(lengths))
370 | if opt.use_data_parallel:
371 | l_d = functional.nll_loss(d_outputs.view(-1, D.module.num_langs),
372 | d_targets.view(-1), ignore_index=-1)
373 | else:
374 | l_d = functional.nll_loss(d_outputs.view(-1, D.num_langs),
375 | d_targets.view(-1), ignore_index=-1)
376 | if opt.lambd > 0:
377 | l_d *= -opt.lambd
378 | l_d.backward()
379 |
380 | optimizer.step()
381 |
382 | # end of epoch
383 | log.info('Ending epoch {}'.format(epoch+1))
384 | if d_total > 0:
385 | log.info('D Training Accuracy: {}%'.format(100.0*d_correct/d_total))
386 | log.info('Training accuracy:')
387 | log.info('\t'.join(opt.langs))
388 | log.info('\t'.join([str(100.0*correct[d]/total[d]) for d in opt.langs]))
389 | log.info('Gate accuracy:')
390 | log.info('\t'.join([str(100.0*gate_correct[d]/total[d]) for d in opt.langs]))
391 | log.info('Tagger Gate accuracy:')
392 | log.info('\t'.join([str(100.0*c_gate_correct[d]/total[d]) for d in opt.langs]))
393 | log.info('Evaluating validation sets:')
394 | acc = {}
395 | for lang in opt.dev_langs:
396 | acc[lang] = evaluate(f'{lang}_dev', dev_loaders[lang], vocabs[lang], tag_vocab,
397 | emb, lang, F_s, F_p, C)
398 | avg_acc = sum([acc[d] for d in opt.dev_langs]) / len(opt.dev_langs)
399 | log.info(f'Average validation accuracy: {avg_acc}')
400 | log.info('Evaluating test sets:')
401 | test_acc = {}
402 | for lang in opt.dev_langs:
403 | test_acc[lang] = evaluate(f'{lang}_test', test_loaders[lang], vocabs[lang], tag_vocab,
404 | emb, lang, F_s, F_p, C)
405 | avg_test_acc = sum([test_acc[d] for d in opt.dev_langs]) / len(opt.dev_langs)
406 | log.info(f'Average test accuracy: {avg_test_acc}')
407 |
408 | if avg_acc > best_avg_acc:
409 | epochs_since_decay = 0
410 | log.info(f'New best average validation accuracy: {avg_acc}')
411 | best_acc['valid'] = acc
412 | best_acc['test'] = test_acc
413 | best_avg_acc = avg_acc
414 | with open(os.path.join(opt.model_save_file, 'options.pkl'), 'wb') as ouf:
415 | pickle.dump(opt, ouf)
416 | if F_s:
417 | torch.save(F_s.state_dict(),
418 | '{}/netF_s.pth'.format(opt.model_save_file))
419 | torch.save(emb.state_dict(),
420 | '{}/net_emb.pth'.format(opt.model_save_file))
421 | if F_p:
422 | torch.save(F_p.state_dict(),
423 | '{}/net_F_p.pth'.format(opt.model_save_file))
424 | torch.save(C.state_dict(),
425 | '{}/netC.pth'.format(opt.model_save_file))
426 | if D:
427 | torch.save(D.state_dict(),
428 | '{}/netD.pth'.format(opt.model_save_file))
429 | else:
430 | epochs_since_decay += 1
431 | if opt.lr_decay < 1 and epochs_since_decay >= opt.lr_decay_epochs:
432 | epochs_since_decay = 0
433 | old_lr = optimizer.param_groups[0]['lr']
434 | optimizer.param_groups[0]['lr'] = old_lr * opt.lr_decay
435 | log.info(f'Decreasing LR to {old_lr * opt.lr_decay}')
436 |
437 | # end of training
438 | log.info(f'Best average validation accuracy: {best_avg_acc}')
439 | return best_acc
440 |
441 |
442 | def evaluate(name, loader, vocab, tag_vocab, emb, lang, F_s, F_p, C):
443 | emb.eval()
444 | if F_s:
445 | F_s.eval()
446 | if F_p:
447 | F_p.eval()
448 | C.eval()
449 | it = iter(loader)
450 | conll = io.StringIO()
451 | with torch.no_grad():
452 | for inputs, targets in tqdm(it, ascii=True):
453 | inputs, lengths, mask, chars, char_lengths = inputs
454 | embeds = (emb(lang, inputs, chars, char_lengths), lengths)
455 | shared_features, lang_features = None, None
456 | if opt.shared_hidden_size > 0:
457 | shared_features = F_s(embeds)
458 | if opt.private_hidden_size > 0:
459 | if not F_p:
460 | # unlabeled lang
461 | if opt.use_data_parallel:
462 | lang_features = torch.zeros(target.size(0),
463 | targets.size(1), opt.private_hidden_size)
464 | lang_features = nn.DataParallel(lang_features).to(opt.device)
465 | else:
466 | lang_features = torch.zeros(targets.size(0),
467 | targets.size(1), opt.private_hidden_size).to(opt.device)
468 | else:
469 | lang_features, gate_outputs = F_p(embeds)
470 | if opt.C_MoE:
471 | outputs, _ = C((shared_features, lang_features))
472 | else:
473 | outputs = C((shared_features, lang_features))
474 | _, pred = torch.max(outputs, -1)
475 | bs, seq_len = pred.size()
476 | for i in range(bs):
477 | for j in range(lengths[i]):
478 | word = vocab.get_word(inputs[i][j])
479 | gold_tag = tag_vocab.get_tag(targets[i][j])
480 | pred_tag = tag_vocab.get_tag(pred[i][j])
481 | conll.write(f'{word} {gold_tag} {pred_tag}\n')
482 | conll.write('\n')
483 | f1 = utils.conllF1(conll, log)
484 | conll.close()
485 | log.info('{}: F1 score: {}%'.format(name, f1))
486 | return f1
487 |
488 |
489 | def main():
490 | if not os.path.exists(opt.model_save_file):
491 | os.makedirs(opt.model_save_file)
492 | log.info('Running the S-MAN + P-MoE + C-MoE model...')
493 | vocabs = {}
494 | tag_vocab = TagVocab()
495 | assert opt.use_wordemb or opt.use_charemb, "At least one of word or char embeddings must be used!"
496 | char_vocab = Vocab(opt.charemb_size)
497 | log.info(f'Loading Datasets...')
498 | log.info(f'Languages {opt.langs}')
499 |
500 | log.info('Loading Embeddings...')
501 | train_sets, dev_sets, test_sets, unlabeled_sets = {}, {}, {}, {}
502 | for lang in opt.all_langs:
503 | log.info(f'Building Vocab for {lang}...')
504 | vocabs[lang] = Vocab(opt.emb_size, opt.emb_filenames[lang])
505 | assert not opt.train_on_translation or not opt.test_on_translation
506 | if opt.dataset.lower() == 'conll':
507 | get_dataset_fn = get_conll_ner_datasets
508 | if opt.train_on_translation:
509 | get_dataset_fn = get_train_on_translation_conll_ner_datasets
510 | if opt.test_on_translation:
511 | get_dataset_fn = get_test_on_translation_conll_ner_datasets
512 | train_sets[lang], dev_sets[lang], test_sets[lang], unlabeled_sets[lang] = \
513 | get_dataset_fn(vocabs[lang], char_vocab, tag_vocab, opt.conll_dir, lang)
514 | else:
515 | raise Exception(f"Unknown dataset {opt.dataset}")
516 |
517 | opt.num_labels = len(tag_vocab)
518 | log.info(f'Tagset: {tag_vocab.id2tag}')
519 | log.info(f'Done Loading Datasets.')
520 |
521 | cv = train(vocabs, char_vocab, tag_vocab, train_sets, dev_sets, test_sets, unlabeled_sets)
522 | log.info(f'Training done...')
523 | acc = sum(cv['valid'].values()) / len(cv['valid'])
524 | log.info(f'Validation Set Domain Average\t{acc}')
525 | test_acc = sum(cv['test'].values()) / len(cv['test'])
526 | log.info(f'Test Set Domain Average\t{test_acc}')
527 | return cv
528 |
529 |
530 | if __name__ == '__main__':
531 | main()
532 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation. All rights reserved.
2 | # Licensed under the MIT License.
3 |
4 | from collections import defaultdict
5 | import subprocess
6 |
7 | import numpy as np
8 | import torch
9 | from torch import autograd
10 | from options import opt
11 |
12 |
13 | def read_bio_samples(f):
14 | lines = []
15 | for line in f:
16 | if not line.rstrip():
17 | if len(lines) > 0:
18 | yield lines
19 | lines = []
20 | else:
21 | lines.append(line.rstrip())
22 | if len(lines) > 0:
23 | yield lines
24 | lines = []
25 |
26 | def freeze_net(net):
27 | if not net:
28 | return
29 | for p in net.parameters():
30 | p.requires_grad = False
31 |
32 |
33 | def unfreeze_net(net):
34 | if not net:
35 | return
36 | for p in net.parameters():
37 | p.requires_grad = True
38 |
39 |
40 | def sorted_collate(batch):
41 | return my_collate(batch, sort=True)
42 |
43 |
44 | def unsorted_collate(batch):
45 | return my_collate(batch, sort=False)
46 |
47 |
48 | def my_collate(batch, sort):
49 | x, y = zip(*batch)
50 | # extract input indices
51 | if opt.default_emb == 'elmo':
52 | x, y = raw_pad(x, y, sort)
53 | else:
54 | x, y = pad(x, y, opt.eos_idx, sort)
55 | return (x, y)
56 |
57 |
58 | def sorted_cls_collate(batch):
59 | return cls_my_collate(batch, sort=True)
60 |
61 |
62 | def unsorted_cls_collate(batch):
63 | return cls_my_collate(batch, sort=False)
64 |
65 |
66 | def cls_my_collate(batch, sort):
67 | x, y = zip(*batch)
68 | # extract input indices
69 | x, y = cls_pad(x, y, opt.eos_idx, sort)
70 | return (x, y)
71 |
72 |
73 | # TODO do not actually prepare input if char emb is not used to speed up
74 | def pad(x, y, eos_idx, sort):
75 | bs = len(x)
76 | chars = [sample['chars'] for sample in x]
77 | x = [sample['words'] for sample in x]
78 | lengths = [len(row) for row in x] if x else [len(row) for row in chars]
79 | max_len = max(lengths)
80 | # if using CNN, pad to at least the largest kernel size
81 | if opt.model.lower() == 'cnn' or opt.D_model.lower() == 'cnn':
82 | max_len = max(max_len, opt.max_kernel_size)
83 | if chars:
84 | char_lengths = [[len(w) for w in sample] for sample in chars]
85 | max_char_len = max([l for sample in char_lengths for l in sample])
86 | if opt.charemb_model.lower() == 'cnn':
87 | max_char_len = max(max_char_len, max(opt.charemb_kernel_sizes))
88 | # pad sequences
89 | lengths = torch.tensor(lengths)
90 | padded_x = torch.full((len(x), max_len), eos_idx, dtype=torch.long)
91 | padded_y = torch.full((len(y), max_len), -1, dtype=torch.long)
92 | if chars:
93 | padded_chars = torch.full((bs, max_len, max_char_len), eos_idx, dtype=torch.long)
94 | padded_char_lengths = torch.zeros((len(x), max_len), dtype=torch.long)
95 | for i, (row, tag, char_row, cl_row) in enumerate(zip(x, y, chars, char_lengths)):
96 | assert len(row) == len(tag)
97 | assert eos_idx not in row, f'EOS in sequence {row}'
98 | padded_x[i][:len(row)] = torch.tensor(row)
99 | padded_y[i][:len(tag)] = torch.tensor(tag)
100 | if chars:
101 | padded_char_lengths[i][:len(cl_row)] = torch.tensor(cl_row)
102 | for j, ch_word in enumerate(char_row):
103 | padded_chars[i][j][:len(ch_word)] = torch.tensor(ch_word)
104 | # create mask
105 | idxes = torch.arange(0, max_len, dtype=torch.long).unsqueeze(0) # some day, you'll be able to directly do this on cuda
106 | mask = (idxes
45 | self.unk_tok = ''
46 | self.add_word(self.unk_tok)
47 | self.unk_idx = self.w2vvocab[self.unk_tok]
48 |
49 | def base_form(word):
50 | return word.strip().lower()
51 |
52 | def new_rand_emb(self):
53 | vec = np.random.normal(0, 1, size=self.emb_size)
54 | vec /= sum(x*x for x in vec) ** .5
55 | return vec
56 |
57 | def init_embed_layer(self, fix_emb):
58 | # free some memory
59 | self.clear_pretrained_vectors()
60 | emb = nn.Embedding.from_pretrained(torch.tensor(self.embeddings, dtype=torch.float),
61 | freeze=fix_emb)
62 | emb.padding_idx = self.eos_idx
63 | assert len(emb.weight) == self.vocab_size
64 | return emb
65 |
66 | def add_word(self, word, normalize=True):
67 | if normalize:
68 | word = Vocab.base_form(word)
69 | if word not in self.w2vvocab:
70 | if not opt.random_emb and hasattr(self, 'pt_w2vvocab') and word in self.pt_w2vvocab:
71 | vector = self.pretrained[self.pt_w2vvocab[word]].copy()
72 | else:
73 | vector = self.new_rand_emb()
74 | self.v2wvocab.append(word)
75 | self.w2vvocab[word] = self.vocab_size
76 | self.embeddings.append(vector)
77 | self.vocab_size += 1
78 |
79 | def clear_pretrained_vectors(self):
80 | if hasattr(self, 'pretrained'):
81 | del self.pretrained
82 | del self.pt_w2vvocab
83 | del self.pt_v2wvocab
84 |
85 | def lookup(self, word, normalize=True):
86 | if normalize:
87 | word = Vocab.base_form(word)
88 | if word in self.w2vvocab:
89 | return self.w2vvocab[word]
90 | return self.unk_idx
91 |
92 | def get_word(self, i):
93 | return self.v2wvocab[i]
94 |
95 |
96 | class TagVocab:
97 | def __init__(self, bio=True):
98 | self.tag2id = {}
99 | self.id2tag = []
100 | self.bio = bio
101 |
102 | def __len__(self):
103 | return len(self.tag2id)
104 |
105 | def _add_tag(self, tag):
106 | if tag not in self.tag2id:
107 | self.tag2id[tag] = len(self)
108 | self.id2tag.append(tag)
109 |
110 | def add_tag(self, tag):
111 | self._add_tag(tag)
112 | if self.bio:
113 | if tag.startswith('B-') or tag.startswith('I-'):
114 | self._add_tag(f'B-{tag[2:]}')
115 | self._add_tag(f'I-{tag[2:]}')
116 |
117 | def add_tags(self, tags):
118 | for tag in tags:
119 | self.add_tag(tag)
120 |
121 | def lookup(self, raw_y):
122 | return self.tag2id[raw_y]
123 |
124 | def get_tag(self, i):
125 | return self.id2tag[i]
126 |
127 |
--------------------------------------------------------------------------------