├── .gitignore
├── LICENSE
├── Pipfile
├── Pipfile.lock
├── README.md
├── examples
└── die_vs_data_rest_api
│ ├── .gitignore
│ ├── README.md
│ ├── app.py
│ ├── app
│ ├── __init__.py
│ ├── __pycache__
│ │ └── __init__.cpython-37.pyc
│ └── config.json
│ └── requirements.txt
├── model_cards
└── README.md
├── notebooks
├── demo_RobBERT_for_conll_ner.ipynb
├── demo_RobBERT_for_masked_LM.ipynb
├── die_dat_demo.ipynb
├── evaluate_zeroshot_wordlists.ipynb
├── evaluate_zeroshot_wordlists_v2.ipynb
└── finetune_dbrd.ipynb
├── requirements.txt
├── res
├── dbrd.png
├── gender_diff.png
├── robbert_2022_logo.png
├── robbert_2022_logo_with_name.png
├── robbert_2023_logo.png
├── robbert_logo.png
├── robbert_logo_with_name.png
└── robbert_pos_accuracy.png
├── src
├── __init__.py
├── bert_masked_lm_adapter.py
├── convert_roberta_dict.py
├── evaluate_zeroshot_wordlist.py
├── multiprocessing_bpe_encoder.py
├── preprocess_conll2002_ner.py
├── preprocess_dbrd.py
├── preprocess_diedat.py
├── preprocess_diedat.sh
├── preprocess_lassy_ud.py
├── preprocess_util.py
├── preprocess_wordlist_mask.py
├── pretrain.pbs
├── run_lm.py
├── split_dbrd_training.sh
├── textdataset.py
├── train.py
├── train_config.py
├── train_diedat.sh
└── wordlistfiller.py
└── tests
├── __init__.py
└── test_convert_roberta_dict.py
/.gitignore:
--------------------------------------------------------------------------------
1 | data/
2 | .idea/
3 | models/
4 | src/__pycache__/
5 |
6 | venv/
7 | .env/
8 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Pieter Delobelle, Thomas Winters, Bettina Berendt
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Pipfile:
--------------------------------------------------------------------------------
1 | [[source]]
2 | name = "pypi"
3 | url = "https://pypi.org/simple"
4 | verify_ssl = true
5 |
6 | [dev-packages]
7 |
8 | [packages]
9 | fairseq = "*"
10 | nltk = "*"
11 | transformers = "*"
12 | tensorboardx = "*"
13 | tokenizers = "~=0.4.2"
14 | jupyter = "*"
15 |
16 | [requires]
17 | python_version = "3.7"
18 |
--------------------------------------------------------------------------------
/Pipfile.lock:
--------------------------------------------------------------------------------
1 | {
2 | "_meta": {
3 | "hash": {
4 | "sha256": "de30df368a36e74d2c9c8d1d71d6017149e55b4570e2ae1d31539a99ee42feb2"
5 | },
6 | "pipfile-spec": 6,
7 | "requires": {
8 | "python_version": "3.7"
9 | },
10 | "sources": [
11 | {
12 | "name": "pypi",
13 | "url": "https://pypi.org/simple",
14 | "verify_ssl": true
15 | }
16 | ]
17 | },
18 | "default": {
19 | "boto3": {
20 | "hashes": [
21 | "sha256:2cb3063d347c0d06f96dcd10d69247d3e3d12eda4b1ed9fdec62ac25c7fee8b6",
22 | "sha256:5bffdfc7d3465f053db38a64821fbeb1f8044ffafb71c90587f1c293d334305f"
23 | ],
24 | "version": "==1.11.16"
25 | },
26 | "botocore": {
27 | "hashes": [
28 | "sha256:63dc203046e47573add16c348f00a8e2fe3bb4975323458787e7b2bebeec5923",
29 | "sha256:d31d7e0e12fb64d6bd331c6b853750c71423d560cee01a989f607dd45407ef65"
30 | ],
31 | "version": "==1.14.16"
32 | },
33 | "certifi": {
34 | "hashes": [
35 | "sha256:017c25db2a153ce562900032d5bc68e9f191e44e9a0f762f373977de9df1fbb3",
36 | "sha256:25b64c7da4cd7479594d035c08c2d809eb4aab3a26e5a990ea98cc450c320f1f"
37 | ],
38 | "version": "==2019.11.28"
39 | },
40 | "cffi": {
41 | "hashes": [
42 | "sha256:001bf3242a1bb04d985d63e138230802c6c8d4db3668fb545fb5005ddf5bb5ff",
43 | "sha256:00789914be39dffba161cfc5be31b55775de5ba2235fe49aa28c148236c4e06b",
44 | "sha256:028a579fc9aed3af38f4892bdcc7390508adabc30c6af4a6e4f611b0c680e6ac",
45 | "sha256:14491a910663bf9f13ddf2bc8f60562d6bc5315c1f09c704937ef17293fb85b0",
46 | "sha256:1cae98a7054b5c9391eb3249b86e0e99ab1e02bb0cc0575da191aedadbdf4384",
47 | "sha256:2089ed025da3919d2e75a4d963d008330c96751127dd6f73c8dc0c65041b4c26",
48 | "sha256:2d384f4a127a15ba701207f7639d94106693b6cd64173d6c8988e2c25f3ac2b6",
49 | "sha256:337d448e5a725bba2d8293c48d9353fc68d0e9e4088d62a9571def317797522b",
50 | "sha256:399aed636c7d3749bbed55bc907c3288cb43c65c4389964ad5ff849b6370603e",
51 | "sha256:3b911c2dbd4f423b4c4fcca138cadde747abdb20d196c4a48708b8a2d32b16dd",
52 | "sha256:3d311bcc4a41408cf5854f06ef2c5cab88f9fded37a3b95936c9879c1640d4c2",
53 | "sha256:62ae9af2d069ea2698bf536dcfe1e4eed9090211dbaafeeedf5cb6c41b352f66",
54 | "sha256:66e41db66b47d0d8672d8ed2708ba91b2f2524ece3dee48b5dfb36be8c2f21dc",
55 | "sha256:675686925a9fb403edba0114db74e741d8181683dcf216be697d208857e04ca8",
56 | "sha256:7e63cbcf2429a8dbfe48dcc2322d5f2220b77b2e17b7ba023d6166d84655da55",
57 | "sha256:8a6c688fefb4e1cd56feb6c511984a6c4f7ec7d2a1ff31a10254f3c817054ae4",
58 | "sha256:8c0ffc886aea5df6a1762d0019e9cb05f825d0eec1f520c51be9d198701daee5",
59 | "sha256:95cd16d3dee553f882540c1ffe331d085c9e629499ceadfbda4d4fde635f4b7d",
60 | "sha256:99f748a7e71ff382613b4e1acc0ac83bf7ad167fb3802e35e90d9763daba4d78",
61 | "sha256:b8c78301cefcf5fd914aad35d3c04c2b21ce8629b5e4f4e45ae6812e461910fa",
62 | "sha256:c420917b188a5582a56d8b93bdd8e0f6eca08c84ff623a4c16e809152cd35793",
63 | "sha256:c43866529f2f06fe0edc6246eb4faa34f03fe88b64a0a9a942561c8e22f4b71f",
64 | "sha256:cab50b8c2250b46fe738c77dbd25ce017d5e6fb35d3407606e7a4180656a5a6a",
65 | "sha256:cef128cb4d5e0b3493f058f10ce32365972c554572ff821e175dbc6f8ff6924f",
66 | "sha256:cf16e3cf6c0a5fdd9bc10c21687e19d29ad1fe863372b5543deaec1039581a30",
67 | "sha256:e56c744aa6ff427a607763346e4170629caf7e48ead6921745986db3692f987f",
68 | "sha256:e577934fc5f8779c554639376beeaa5657d54349096ef24abe8c74c5d9c117c3",
69 | "sha256:f2b0fa0c01d8a0c7483afd9f31d7ecf2d71760ca24499c8697aeb5ca37dc090c"
70 | ],
71 | "version": "==1.14.0"
72 | },
73 | "chardet": {
74 | "hashes": [
75 | "sha256:84ab92ed1c4d4f16916e05906b6b75a6c0fb5db821cc65e70cbd64a3e2a5eaae",
76 | "sha256:fc323ffcaeaed0e0a02bf4d117757b98aed530d9ed4531e3e15460124c106691"
77 | ],
78 | "version": "==3.0.4"
79 | },
80 | "click": {
81 | "hashes": [
82 | "sha256:2335065e6395b9e67ca716de5f7526736bfa6ceead690adf616d925bdc622b13",
83 | "sha256:5b94b49521f6456670fdb30cd82a4eca9412788a93fa6dd6df72c94d5a8ff2d7"
84 | ],
85 | "version": "==7.0"
86 | },
87 | "cython": {
88 | "hashes": [
89 | "sha256:01d566750e7c08e5f094419f8d1ee90e7fa286d8d77c4569748263ed5f05280a",
90 | "sha256:072cb90e2fe4b5cc27d56de12ec5a00311eee781c2d2e3f7c98a82319103c7ed",
91 | "sha256:0e078e793a9882bf48194b8b5c9b40c75769db1859cd90b210a4d7bf33cda2b1",
92 | "sha256:1a3842be21d1e25b7f3440a0c881ef44161937273ea386c30c0e253e30c63740",
93 | "sha256:1dc973bdea03c65f03f41517e4f0fc2b717d71cfbcf4ec34adac7e5bee71303e",
94 | "sha256:214a53257c100e93e7673e95ab448d287a37626a3902e498025993cc633647ae",
95 | "sha256:30462d61e7e290229a64e1c3682b4cc758ffc441e59cc6ce6fae059a05df305b",
96 | "sha256:34004f60b1e79033b0ca29b9ab53a86c12bcaab56648b82fbe21c007cd73d867",
97 | "sha256:34c888a57f419c63bef63bc0911c5bb407b93ed5d6bdeb1587dca2cd1dd56ad1",
98 | "sha256:3dd0cba13b36ff969232930bd6db08d3da0798f1fac376bd1fa4458f4b55d802",
99 | "sha256:4e5acf3b856a50d0aaf385f06a7b56a128a296322a9740f5f279c96619244308",
100 | "sha256:60d859e1efa5cc80436d58aecd3718ff2e74b987db0518376046adedba97ac30",
101 | "sha256:61e505379497b624d6316dd67ef8100aaadca0451f48f8c6fff8d622281cd121",
102 | "sha256:6f6de0bee19c70cb01e519634f0c35770de623006e4876e649ee4a960a147fec",
103 | "sha256:77ac051b7caf02938a32ea0925f558534ab2a99c0c98c681cc905e3e8cba506e",
104 | "sha256:7e4d74515d92c4e2be7201aaef7a51705bd3d5957df4994ddfe1b252195b5e27",
105 | "sha256:993837bbf0849e3b176e1ef6a50e9b8c2225e895501b85d56f4bb65a67f5ea25",
106 | "sha256:9a5f0cf8b95c0c058e413679a650f70dcc97764ccb2a6d5ccc6b08d44c9b334c",
107 | "sha256:9f2839396d21d5537bc9ff53772d44db39b0efb6bf8b6cac709170483df53a5b",
108 | "sha256:b8ba4b4ee3addc26bc595a51b6240b05a80e254b946d624fff6506439bc323d1",
109 | "sha256:bb6d90180eff72fc5a30099c442b8b0b5a620e84bf03ef32a55e3f7bd543f32e",
110 | "sha256:c3d778304209cc39f8287da22f2180f34d2c2ee46cd55abd82e48178841b37b1",
111 | "sha256:c562bc316040097e21357e783286e5eca056a5b2750e89d9d75f9541c156b6dc",
112 | "sha256:d114f9c0164df8fcd2880e4ba96986d7b0e7218f6984acc4989ff384c5d3d512",
113 | "sha256:d282b030ed5c736e4cdb1713a0c4fad7027f4e3959dc4b8fdb7c75042d83ed1b",
114 | "sha256:d8c73fe0ec57a0e4fdf5d2728b5e18b63980f55f1baf51b6bac6a73e8cbb7186",
115 | "sha256:e5c8f4198e25bc4b0e4a884377e0c0e46ca273993679e3bcc212ef96d4211b83",
116 | "sha256:e7f1dcc0e8c3e18fa2fddca4aecdf71c5651555a8dc9a0cd3a1d164cbce6cb35",
117 | "sha256:ea3b61bff995de49b07331d1081e0056ea29901d3e995aa989073fe2b1be0cb7",
118 | "sha256:ea5f987b4da530822fa797cf2f010193be77ea9e232d07454e3194531edd8e58",
119 | "sha256:f91b16e73eca996f86d1943be3b2c2b679b03e068ed8c82a5506c1e65766e4a6"
120 | ],
121 | "version": "==0.29.15"
122 | },
123 | "docutils": {
124 | "hashes": [
125 | "sha256:6c4f696463b79f1fb8ba0c594b63840ebd41f059e92b31957c46b74a4599b6d0",
126 | "sha256:9e4d7ecfc600058e07ba661411a2b7de2fd0fafa17d1a7f7361cd47b1175c827",
127 | "sha256:a2aeea129088da402665e92e0b25b04b073c04b2dce4ab65caaa38b7ce2e1a99"
128 | ],
129 | "version": "==0.15.2"
130 | },
131 | "fairseq": {
132 | "hashes": [
133 | "sha256:61206358b79f325ea0b46cfd8c95cdb81bfbcfb43cf12b47d1d5124ce7321d3b"
134 | ],
135 | "index": "pypi",
136 | "version": "==0.9.0"
137 | },
138 | "filelock": {
139 | "hashes": [
140 | "sha256:18d82244ee114f543149c66a6e0c14e9c4f8a1044b5cdaadd0f82159d6a6ff59",
141 | "sha256:929b7d63ec5b7d6b71b0fa5ac14e030b3f70b75747cef1b10da9b879fef15836"
142 | ],
143 | "version": "==3.0.12"
144 | },
145 | "idna": {
146 | "hashes": [
147 | "sha256:c357b3f628cf53ae2c4c05627ecc484553142ca23264e593d327bcde5e9c3407",
148 | "sha256:ea8b7f6188e6fa117537c3df7da9fc686d485087abf6ac197f9c46432f7e4a3c"
149 | ],
150 | "version": "==2.8"
151 | },
152 | "jmespath": {
153 | "hashes": [
154 | "sha256:3720a4b1bd659dd2eecad0666459b9788813e032b83e7ba58578e48254e0a0e6",
155 | "sha256:bde2aef6f44302dfb30320115b17d030798de8c4110e28d5cf6cf91a7a31074c"
156 | ],
157 | "version": "==0.9.4"
158 | },
159 | "joblib": {
160 | "hashes": [
161 | "sha256:0630eea4f5664c463f23fbf5dcfc54a2bc6168902719fa8e19daf033022786c8",
162 | "sha256:bdb4fd9b72915ffb49fde2229ce482dd7ae79d842ed8c2b4c932441495af1403"
163 | ],
164 | "version": "==0.14.1"
165 | },
166 | "nltk": {
167 | "hashes": [
168 | "sha256:bed45551259aa2101381bbdd5df37d44ca2669c5c3dad72439fa459b29137d94"
169 | ],
170 | "index": "pypi",
171 | "version": "==3.4.5"
172 | },
173 | "numpy": {
174 | "hashes": [
175 | "sha256:1786a08236f2c92ae0e70423c45e1e62788ed33028f94ca99c4df03f5be6b3c6",
176 | "sha256:17aa7a81fe7599a10f2b7d95856dc5cf84a4eefa45bc96123cbbc3ebc568994e",
177 | "sha256:20b26aaa5b3da029942cdcce719b363dbe58696ad182aff0e5dcb1687ec946dc",
178 | "sha256:2d75908ab3ced4223ccba595b48e538afa5ecc37405923d1fea6906d7c3a50bc",
179 | "sha256:39d2c685af15d3ce682c99ce5925cc66efc824652e10990d2462dfe9b8918c6a",
180 | "sha256:56bc8ded6fcd9adea90f65377438f9fea8c05fcf7c5ba766bef258d0da1554aa",
181 | "sha256:590355aeade1a2eaba17617c19edccb7db8d78760175256e3cf94590a1a964f3",
182 | "sha256:70a840a26f4e61defa7bdf811d7498a284ced303dfbc35acb7be12a39b2aa121",
183 | "sha256:77c3bfe65d8560487052ad55c6998a04b654c2fbc36d546aef2b2e511e760971",
184 | "sha256:9537eecf179f566fd1c160a2e912ca0b8e02d773af0a7a1120ad4f7507cd0d26",
185 | "sha256:9acdf933c1fd263c513a2df3dceecea6f3ff4419d80bf238510976bf9bcb26cd",
186 | "sha256:ae0975f42ab1f28364dcda3dde3cf6c1ddab3e1d4b2909da0cb0191fa9ca0480",
187 | "sha256:b3af02ecc999c8003e538e60c89a2b37646b39b688d4e44d7373e11c2debabec",
188 | "sha256:b6ff59cee96b454516e47e7721098e6ceebef435e3e21ac2d6c3b8b02628eb77",
189 | "sha256:b765ed3930b92812aa698a455847141869ef755a87e099fddd4ccf9d81fffb57",
190 | "sha256:c98c5ffd7d41611407a1103ae11c8b634ad6a43606eca3e2a5a269e5d6e8eb07",
191 | "sha256:cf7eb6b1025d3e169989416b1adcd676624c2dbed9e3bcb7137f51bfc8cc2572",
192 | "sha256:d92350c22b150c1cae7ebb0ee8b5670cc84848f6359cf6b5d8f86617098a9b73",
193 | "sha256:e422c3152921cece8b6a2fb6b0b4d73b6579bd20ae075e7d15143e711f3ca2ca",
194 | "sha256:e840f552a509e3380b0f0ec977e8124d0dc34dc0e68289ca28f4d7c1d0d79474",
195 | "sha256:f3d0a94ad151870978fb93538e95411c83899c9dc63e6fb65542f769568ecfa5"
196 | ],
197 | "version": "==1.18.1"
198 | },
199 | "portalocker": {
200 | "hashes": [
201 | "sha256:6f57aabb25ba176462dc7c63b86c42ad6a9b5bd3d679a9d776d0536bfb803d54",
202 | "sha256:dac62e53e5670cb40d2ee4cdc785e6b829665932c3ee75307ad677cf5f7d2e9f"
203 | ],
204 | "version": "==1.5.2"
205 | },
206 | "protobuf": {
207 | "hashes": [
208 | "sha256:0bae429443cc4748be2aadfdaf9633297cfaeb24a9a02d0ab15849175ce90fab",
209 | "sha256:24e3b6ad259544d717902777b33966a1a069208c885576254c112663e6a5bb0f",
210 | "sha256:310a7aca6e7f257510d0c750364774034272538d51796ca31d42c3925d12a52a",
211 | "sha256:52e586072612c1eec18e1174f8e3bb19d08f075fc2e3f91d3b16c919078469d0",
212 | "sha256:73152776dc75f335c476d11d52ec6f0f6925774802cd48d6189f4d5d7fe753f4",
213 | "sha256:7774bbbaac81d3ba86de646c39f154afc8156717972bf0450c9dbfa1dc8dbea2",
214 | "sha256:82d7ac987715d8d1eb4068bf997f3053468e0ce0287e2729c30601feb6602fee",
215 | "sha256:8eb9c93798b904f141d9de36a0ba9f9b73cc382869e67c9e642c0aba53b0fc07",
216 | "sha256:adf0e4d57b33881d0c63bb11e7f9038f98ee0c3e334c221f0858f826e8fb0151",
217 | "sha256:c40973a0aee65422d8cb4e7d7cbded95dfeee0199caab54d5ab25b63bce8135a",
218 | "sha256:c77c974d1dadf246d789f6dad1c24426137c9091e930dbf50e0a29c1fcf00b1f",
219 | "sha256:dd9aa4401c36785ea1b6fff0552c674bdd1b641319cb07ed1fe2392388e9b0d7",
220 | "sha256:e11df1ac6905e81b815ab6fd518e79be0a58b5dc427a2cf7208980f30694b956",
221 | "sha256:e2f8a75261c26b2f5f3442b0525d50fd79a71aeca04b5ec270fc123536188306",
222 | "sha256:e512b7f3a4dd780f59f1bf22c302740e27b10b5c97e858a6061772668cd6f961",
223 | "sha256:ef2c2e56aaf9ee914d3dccc3408d42661aaf7d9bb78eaa8f17b2e6282f214481",
224 | "sha256:fac513a9dc2a74b99abd2e17109b53945e364649ca03d9f7a0b96aa8d1807d0a",
225 | "sha256:fdfb6ad138dbbf92b5dbea3576d7c8ba7463173f7d2cb0ca1bd336ec88ddbd80"
226 | ],
227 | "version": "==3.11.3"
228 | },
229 | "pycparser": {
230 | "hashes": [
231 | "sha256:a988718abfad80b6b157acce7bf130a30876d27603738ac39f140993246b25b3"
232 | ],
233 | "version": "==2.19"
234 | },
235 | "python-dateutil": {
236 | "hashes": [
237 | "sha256:73ebfe9dbf22e832286dafa60473e4cd239f8592f699aa5adaf10050e6e1823c",
238 | "sha256:75bb3f31ea686f1197762692a9ee6a7550b59fc6ca3a1f4b5d7e32fb98e2da2a"
239 | ],
240 | "version": "==2.8.1"
241 | },
242 | "regex": {
243 | "hashes": [
244 | "sha256:07b39bf943d3d2fe63d46281d8504f8df0ff3fe4c57e13d1656737950e53e525",
245 | "sha256:0932941cdfb3afcbc26cc3bcf7c3f3d73d5a9b9c56955d432dbf8bbc147d4c5b",
246 | "sha256:0e182d2f097ea8549a249040922fa2b92ae28be4be4895933e369a525ba36576",
247 | "sha256:10671601ee06cf4dc1bc0b4805309040bb34c9af423c12c379c83d7895622bb5",
248 | "sha256:23e2c2c0ff50f44877f64780b815b8fd2e003cda9ce817a7fd00dea5600c84a0",
249 | "sha256:26ff99c980f53b3191d8931b199b29d6787c059f2e029b2b0c694343b1708c35",
250 | "sha256:27429b8d74ba683484a06b260b7bb00f312e7c757792628ea251afdbf1434003",
251 | "sha256:3e77409b678b21a056415da3a56abfd7c3ad03da71f3051bbcdb68cf44d3c34d",
252 | "sha256:4e8f02d3d72ca94efc8396f8036c0d3bcc812aefc28ec70f35bb888c74a25161",
253 | "sha256:4eae742636aec40cf7ab98171ab9400393360b97e8f9da67b1867a9ee0889b26",
254 | "sha256:6a6ae17bf8f2d82d1e8858a47757ce389b880083c4ff2498dba17c56e6c103b9",
255 | "sha256:6a6ba91b94427cd49cd27764679024b14a96874e0dc638ae6bdd4b1a3ce97be1",
256 | "sha256:7bcd322935377abcc79bfe5b63c44abd0b29387f267791d566bbb566edfdd146",
257 | "sha256:98b8ed7bb2155e2cbb8b76f627b2fd12cf4b22ab6e14873e8641f266e0fb6d8f",
258 | "sha256:bd25bb7980917e4e70ccccd7e3b5740614f1c408a642c245019cff9d7d1b6149",
259 | "sha256:d0f424328f9822b0323b3b6f2e4b9c90960b24743d220763c7f07071e0778351",
260 | "sha256:d58e4606da2a41659c84baeb3cfa2e4c87a74cec89a1e7c56bee4b956f9d7461",
261 | "sha256:e3cd21cc2840ca67de0bbe4071f79f031c81418deb544ceda93ad75ca1ee9f7b",
262 | "sha256:e6c02171d62ed6972ca8631f6f34fa3281d51db8b326ee397b9c83093a6b7242",
263 | "sha256:e7c7661f7276507bce416eaae22040fd91ca471b5b33c13f8ff21137ed6f248c",
264 | "sha256:ecc6de77df3ef68fee966bb8cb4e067e84d4d1f397d0ef6fce46913663540d77"
265 | ],
266 | "version": "==2020.1.8"
267 | },
268 | "requests": {
269 | "hashes": [
270 | "sha256:11e007a8a2aa0323f5a921e9e6a2d7e4e67d9877e85773fba9ba6419025cbeb4",
271 | "sha256:9cf5292fcd0f598c671cfc1e0d7d1a7f13bb8085e9a590f48c010551dc6c4b31"
272 | ],
273 | "version": "==2.22.0"
274 | },
275 | "s3transfer": {
276 | "hashes": [
277 | "sha256:2482b4259524933a022d59da830f51bd746db62f047d6eb213f2f8855dcb8a13",
278 | "sha256:921a37e2aefc64145e7b73d50c71bb4f26f46e4c9f414dc648c6245ff92cf7db"
279 | ],
280 | "version": "==0.3.3"
281 | },
282 | "sacrebleu": {
283 | "hashes": [
284 | "sha256:0a4b9e53b742d95fcd2f32e4aaa42aadcf94121d998ca19c66c05e7037d1eeee",
285 | "sha256:6ca2418e4474120537155d77695aa5303b2cf1a4227aec32074b4b4d8c107e2b"
286 | ],
287 | "version": "==1.4.3"
288 | },
289 | "sacremoses": {
290 | "hashes": [
291 | "sha256:34dcfaacf9fa34a6353424431f0e4fcc60e8ebb27ffee320d57396690b712a3b"
292 | ],
293 | "version": "==0.0.38"
294 | },
295 | "sentencepiece": {
296 | "hashes": [
297 | "sha256:0a98ec863e541304df23a37787033001b62cb089f4ed9307911791d7e210c0b1",
298 | "sha256:0ad221ea7914d65f57d3e3af7ae48852b5035166493312b5025367585b43ac41",
299 | "sha256:0f72c4151791de7242e7184a9b7ef12503cef42e9a5a0c1b3510f2c68874e810",
300 | "sha256:22fe7d92203fadbb6a0dc7d767430d37cdf3a9da4a0f2c5302c7bf294f7bfd8f",
301 | "sha256:2a72d4c3d0dbb1e099ddd2dc6b724376d3d7ff77ba494756b894254485bec4b4",
302 | "sha256:30791ce80a557339e17f1290c68dccd3f661612fdc6b689b4e4f21d805b64952",
303 | "sha256:39904713b81869db10de53fe8b3719f35acf77f49351f28ceaad0d360f2f6305",
304 | "sha256:3d5a2163deea95271ce8e38dfd0c3c924bea92aaf63bdda69b5458628dacc8bd",
305 | "sha256:3f3dee204635c33ca2e450e17ee9e0e92f114a47f853c2e44e7f0f0ab444d8d0",
306 | "sha256:4dcea889af53f669dc39d1ca870c37c52bb3110fcd96a2e7330d288400958281",
307 | "sha256:4e36a92558ad9e2f91b311c5bcea90b7a63c567c0e7e20da44d6a6f01031b57e",
308 | "sha256:576bf820eb963e6f275d4005ed5334fbed59eb54bed508e5cae6d16c7179710f",
309 | "sha256:6d2bbdbf296d96304c6345675749981bb17dcf2a7163d2fec38f70a704b75669",
310 | "sha256:76fdce3e7e614e24b35167c22c9c388e0c843be53d99afb5e1f25f6bfe04e228",
311 | "sha256:97b8ee26892d236b2620af8ddae11713fbbb2dae9adf4ad5e988e5a82ce50a90",
312 | "sha256:b3b6fe02af7ea4823c19e0d8efddc10ff59b8449bc1ae9921f9dd8ad33802c33",
313 | "sha256:b416f514fff8785a1113e6c07f696e52967fc979d6cd946e454a8660cca72ef8",
314 | "sha256:bf0bad6ba01ace3e938ffdf05c42b24d8fd3740487ba865504795a0bb9b1f2b3",
315 | "sha256:c00387970360ec0369b5e7c75f3977fb14330df75465200c13bafb7a632d2e6b",
316 | "sha256:c23fb7bb949934998375d41dbe54d4df1778a3b9dcb24bc2ddaaa595819ed1da",
317 | "sha256:dfdcf48678656592b11d11e2102c52c38122e309f7a1a5272305d397cfe21ce0",
318 | "sha256:fb69c5ba325b900cf2b91f517b46eec8ce3c50995955e293b46681d832021c0e",
319 | "sha256:fba83bef6c7a7899cd811d9b1195e748722eb2a9737c3f3890160f0e01e3ad08",
320 | "sha256:fe115aee209197839b2a357e34523e23768d553e8a69eac2b558499ccda56f80",
321 | "sha256:ffdf51218a3d7e0dad79bdffd21ad15a23cbb9c572d2300c3295c6efc6c2357e"
322 | ],
323 | "version": "==0.1.85"
324 | },
325 | "six": {
326 | "hashes": [
327 | "sha256:236bdbdce46e6e6a3d61a337c0f8b763ca1e8717c03b369e87a7ec7ce1319c0a",
328 | "sha256:8f3cd2e254d8f793e7f3d6d9df77b92252b52637291d0f0da013c76ea2724b6c"
329 | ],
330 | "version": "==1.14.0"
331 | },
332 | "tensorboardx": {
333 | "hashes": [
334 | "sha256:835d85db0aef2c6768f07c35e69a74e3dcb122d6afceaf2b8504d7d16c7209a5",
335 | "sha256:8e336103a66b1b97a663057cc13d1db4419f7a12f332b8364386dbf8635031d9"
336 | ],
337 | "index": "pypi",
338 | "version": "==2.0"
339 | },
340 | "tokenizers": {
341 | "hashes": [
342 | "sha256:08e08027564194e16aa647d180837d292b2c9c5ef772fed15badcc88e2474a8f",
343 | "sha256:1385deb90ec76cbee59b50298c8d2dc5909cda080a706d263e4f81c8474ba53d",
344 | "sha256:3ebe7f0bff9e30ab15dec4846c54c9085e02e47711eb7253d36a6777eadc2948",
345 | "sha256:4b7c42b644a1c5705a59b14c53c84b50b8f0b9c0f5f952a8a91a350403e7615f",
346 | "sha256:503418d5195ae1a483ced0257a0d2f4583456aa49bdfe0014c8605babf244ac5",
347 | "sha256:5ba2c6eaac2e8e0a2d839c0420d16707496b5e93b1454029d19487c5dd8c9b62",
348 | "sha256:7de28e0bebd0904b990560a1f14c3c5600da29be287e544bdf19e6970ea11d54",
349 | "sha256:82e8c3b13a66410358753b7e48776749935851cdb49a3d0c139a046178ec4f49",
350 | "sha256:a4d1ef6ee9221e7f9c1a4c122a15e93f0961977aaae2813b7b405c778728dcee",
351 | "sha256:a66ff87c32a221a126904d7ec972e7c8e0033486b24f8777c0f056aedbc09011",
352 | "sha256:a7f5e43674dd5b012ad29b79a32f0652ecfff3a3ed1c04f9073038c4bf63829d",
353 | "sha256:bb44fa1b268d1bbdf2bb14cd82da6ffb93d19638157c77f9e17e246928f0233f",
354 | "sha256:ce75c75430a3dfc33a10c90c1607d44b172c6d2ea19d586692b6cc9ba6ec5e14"
355 | ],
356 | "index": "pypi",
357 | "version": "==0.0.11"
358 | },
359 | "torch": {
360 | "hashes": [
361 | "sha256:271d4d1e44df6ed57c530f8849b028447c62b8a19b8e8740dd9baa56e7f682c1",
362 | "sha256:30ce089475b287a37d6fbb8d71853e672edaf66699e3dd2eb19be6ce6296732a",
363 | "sha256:405b9eb40e44037d2525b3ddb5bc4c66b519cd742bff249d4207d23f83e88ea5",
364 | "sha256:504915c6bc6051ba6a4c2a43c446463dff04411e352f1e26fe13debeae431778",
365 | "sha256:54d06a0e8ee85e5a437c24f4af9f4196c819294c23ffb5914e177756f55f1829",
366 | "sha256:6f2fd9eb8c7eaf38a982ab266dbbfba0f29fb643bc74e677d045d6f2595e4692",
367 | "sha256:8856f334aa9ecb742c1504bd2563d0ffb8dceb97149c8d72a04afa357f667dbc",
368 | "sha256:8fff03bf7b474c16e4b50da65ea14200cc64553b67b9b2307f9dc7e8c69b9d28",
369 | "sha256:9a1b1db73d8dcfd94b2eee24b939368742aa85f1217c55b8f5681e76c581e99a",
370 | "sha256:bb1e87063661414e1149bef2e3a2499ce0b5060290799d7e26bc5578037075ba",
371 | "sha256:d7b34a78f021935ad727a3bede56a8a8d4fda0b0272314a04c5f6890bbe7bb29"
372 | ],
373 | "version": "==1.4.0"
374 | },
375 | "tqdm": {
376 | "hashes": [
377 | "sha256:251ee8440dbda126b8dfa8a7c028eb3f13704898caaef7caa699b35e119301e2",
378 | "sha256:fe231261cfcbc6f4a99165455f8f6b9ef4e1032a6e29bccf168b4bf42012f09c"
379 | ],
380 | "version": "==4.42.1"
381 | },
382 | "transformers": {
383 | "hashes": [
384 | "sha256:995393b9ce764044287847792476101cd4e8377c756874df1116b221980749ad",
385 | "sha256:c5e765d3fd1a654e27b0f675e1cc1e210139429cc00d1b816c3c21502ace941e"
386 | ],
387 | "index": "pypi",
388 | "version": "==2.4.1"
389 | },
390 | "typing": {
391 | "hashes": [
392 | "sha256:91dfe6f3f706ee8cc32d38edbbf304e9b7583fb37108fef38229617f8b3eba23",
393 | "sha256:c8cabb5ab8945cd2f54917be357d134db9cc1eb039e59d1606dc1e60cb1d9d36",
394 | "sha256:f38d83c5a7a7086543a0f649564d661859c5146a85775ab90c0d2f93ffaa9714"
395 | ],
396 | "version": "==3.7.4.1"
397 | },
398 | "urllib3": {
399 | "hashes": [
400 | "sha256:2f3db8b19923a873b3e5256dc9c2dedfa883e33d87c690d9c7913e1f40673cdc",
401 | "sha256:87716c2d2a7121198ebcb7ce7cccf6ce5e9ba539041cfbaeecfb641dc0bf6acc"
402 | ],
403 | "markers": "python_version != '3.4'",
404 | "version": "==1.25.8"
405 | }
406 | },
407 | "develop": {}
408 | }
409 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | 
7 | 
8 | 
9 | [](https://huggingface.co/pdelobelle/robbert-v2-dutch-base)
10 |
11 |
12 | # RobBERT: Dutch RoBERTa-based Language Model.
13 |
14 | RobBERT is the state-of-the-art Dutch BERT model.
15 | It is a large pre-trained general Dutch language model that can be fine-tuned on a given dataset to perform any text classification, regression or token-tagging task.
16 | As such, it has been successfully used by many [researchers](https://scholar.google.com/scholar?oi=bibs&hl=en&cites=7180110604335112086) and [practitioners](https://huggingface.co/models?search=robbert) for achieving state-of-the-art performance for a wide range of Dutch natural language processing tasks, including:
17 |
18 | - [Emotion detection](https://www.aclweb.org/anthology/2021.wassa-1.27/)
19 | - Sentiment analysis ([book reviews](https://arxiv.org/pdf/2001.06286.pdf), [news articles](https://biblio.ugent.be/publication/8704637/file/8704638.pdf)*)
20 | - [Coreference resolution](https://arxiv.org/pdf/2001.06286.pdf)
21 | - Named entity recognition ([CoNLL](https://arxiv.org/pdf/2001.06286.pdf), [job titles](https://arxiv.org/pdf/2004.02814.pdf)*, [SoNaR](https://github.com/proycon/deepfrog))
22 | - Part-of-speech tagging ([Small UD Lassy](https://arxiv.org/pdf/2001.06286.pdf), [CGN](https://github.com/proycon/deepfrog))
23 | - [Zero-shot word prediction](https://arxiv.org/pdf/2001.06286.pdf)
24 | - [Humor detection](https://arxiv.org/pdf/2010.13652.pdf)
25 | - [Cyberbullying detection](https://www.cambridge.org/core/journals/natural-language-engineering/article/abs/automatic-classification-of-participant-roles-in-cyberbullying-can-we-detect-victims-bullies-and-bystanders-in-social-media-text/A2079C2C738C29428E666810B8903342)
26 | - [Correcting dt-spelling mistakes](https://gitlab.com/spelfouten/dutch-simpletransformers/)*
27 |
28 | and also achieved outstanding, near-sota results for:
29 |
30 | - [Natural language inference](https://arxiv.org/pdf/2101.05716.pdf)*
31 | - [Review classification](https://medium.com/broadhorizon-cmotions/nlp-with-r-part-5-state-of-the-art-in-nlp-transformers-bert-3449e3cd7494)*
32 |
33 | \* *Note that several evaluations use RobBERT-v1, and that the second and improved RobBERT-v2 outperforms this first model on everything we tested*
34 |
35 | *(Also note that this list is not exhaustive. If you used RobBERT for your application, we are happy to know about it! Send us a mail, or add it yourself to this list by sending a pull request with the edit!)*
36 |
37 | To use the RobBERT model using [HuggingFace transformers](https://huggingface.co/transformers/), use the name [`pdelobelle/robbert-v2-dutch-base`](https://huggingface.co/pdelobelle/robbert-v2-dutch-base).
38 |
39 | More in-depth information about RobBERT can be found in our [blog post](https://pieter.ai/robbert/) and in [our paper](https://arxiv.org/abs/2001.06286).
40 |
41 | ## Table of contents
42 | * [How To Use](#how-to-use)
43 | + [Using Huggingface Transformers (easiest)](#using-huggingface-transformers-easiest)
44 | + [Using Fairseq (harder)](#using-fairseq-harder)
45 | * [Technical Details From The Paper](#technical-details-from-the-paper)
46 | + [Our Performance Evaluation Results](#our-performance-evaluation-results)
47 | + [Sentiment analysis](#sentiment-analysis)
48 | + [Die/Dat (coreference resolution)](#diedat-coreference-resolution)
49 | - [Finetuning on whole dataset](#finetuning-on-whole-dataset)
50 | - [Finetuning on 10K examples](#finetuning-on-10k-examples)
51 | - [Using zero-shot word masking task](#using-zero-shot-word-masking-task)
52 | + [Part-of-Speech Tagging.](#part-of-speech-tagging)
53 | + [Named Entity Recognition](#named-entity-recognition)
54 | * [Pre-Training Procedure Details](#pre-training-procedure-details)
55 | * [Investigating Limitations and Bias](#investigating-limitations-and-bias)
56 | * [How to Replicate Our Paper Experiments](#how-to-replicate-our-paper-experiments)
57 | + [Classification](#classification)
58 | - [Sentiment analysis using the Dutch Book Review Dataset](#sentiment-analysis-using-the-dutch-book-review-dataset)
59 | - [Predicting the Dutch pronouns _die_ and _dat_](#predicting-the-dutch-pronouns-die-and-dat)
60 | * [Name Origin of RobBERT](#name-origin-of-robbert)
61 | * [Credits and citation](#credits-and-citation)
62 |
63 |
64 |
65 | ## How To Use
66 |
67 | RobBERT uses the [RoBERTa](https://arxiv.org/abs/1907.11692) architecture and pre-training but with a Dutch tokenizer and training data. RoBERTa is the robustly optimized English BERT model, making it even more powerful than the original BERT model. Given this same architecture, RobBERT can easily be finetuned and inferenced using [code to finetune RoBERTa](https://huggingface.co/transformers/model_doc/roberta.html) models and most code used for BERT models, e.g. as provided by [HuggingFace Transformers](https://huggingface.co/transformers/) library.
68 |
69 | RobBERT can easily be used in two different ways, namely either using [Fairseq RoBERTa](https://github.com/pytorch/fairseq/tree/master/examples/roberta) code or using [HuggingFace Transformers](https://github.com/huggingface/transformers)
70 |
71 | By default, RobBERT has the masked language model head used in training. This can be used as a zero-shot way to fill masks in sentences. It can be tested out for free on [RobBERT's Hosted infererence API of Huggingface](https://huggingface.co/pdelobelle/robbert-v2-dutch-base?text=De+hoofdstad+van+Belgi%C3%AB+is+%3Cmask%3E.). You can also create a new prediction head for your own task by using any of HuggingFace's [RoBERTa-runners](https://huggingface.co/transformers/v2.7.0/examples.html#language-model-training), [their fine-tuning notebooks](https://huggingface.co/transformers/v4.1.1/notebooks.html) by changing the model name to `pdelobelle/robbert-v2-dutch-base`, or use the original fairseq [RoBERTa](https://github.com/pytorch/fairseq/tree/master/examples/roberta) training regimes.
72 |
73 | ### Using Huggingface Transformers (easiest)
74 |
75 | You can easily [download RobBERT v2](https://huggingface.co/pdelobelle/robbert-v2-dutch-base) using [🤗 Transformers](https://github.com/huggingface/transformers).
76 | Use the following code to download the base model and finetune it yourself, or use one of our finetuned models (documented on [our project site](https://pieter.ai/robbert/)).
77 |
78 | ```python
79 | from transformers import RobertaTokenizer, RobertaForSequenceClassification
80 | tokenizer = RobertaTokenizer.from_pretrained("pdelobelle/robbert-v2-dutch-base")
81 | model = RobertaForSequenceClassification.from_pretrained("pdelobelle/robbert-v2-dutch-base")
82 | ```
83 |
84 | Starting with `transformers v2.4.0` (or installing from source), you can use AutoTokenizer and AutoModel.
85 | You can then use most of [HuggingFace's BERT-based notebooks](https://huggingface.co/transformers/v4.1.1/notebooks.html) for finetuning RobBERT on your type of Dutch language dataset.
86 |
87 | ### Using Fairseq (harder)
88 |
89 | Alternatively, you can also use RobBERT using the [RoBERTa architecture code]((https://github.com/pytorch/fairseq/tree/master/examples/roberta)).
90 | You can download RobBERT v2's Fairseq model here: [(RobBERT-base, 1.5 GB)](https://github.com/iPieter/BERDT/releases/download/v1.0/RobBERT-base.pt).
91 | Using RobBERT's `model.pt`, this method allows you to use all other functionalities of [RoBERTa](https://github.com/pytorch/fairseq/tree/master/examples/roberta).
92 |
93 |
94 | ## Technical Details From The Paper
95 |
96 |
97 | ### Our Performance Evaluation Results
98 |
99 | All experiments are described in more detail in our [paper](https://arxiv.org/abs/2001.06286), with the code in [our GitHub repository](https://github.com/iPieter/RobBERT).
100 |
101 | ### Sentiment analysis
102 | Predicting whether a review is positive or negative using the [Dutch Book Reviews Dataset](https://github.com/benjaminvdb/DBRD).
103 |
104 | | Model | Accuracy [%] |
105 | |-------------------|--------------------------|
106 | | ULMFiT | 93.8 |
107 | | BERTje | 93.0 |
108 | | RobBERT v2 | **95.1** |
109 |
110 | ### Die/Dat (coreference resolution)
111 |
112 | We measured how well the models are able to do coreference resolution by predicting whether "die" or "dat" should be filled into a sentence.
113 | For this, we used the [EuroParl corpus](https://www.statmt.org/europarl/).
114 |
115 | #### Finetuning on whole dataset
116 |
117 | | Model | Accuracy [%] | F1 [%] |
118 | |-------------------|--------------------------|--------------|
119 | | [Baseline](https://arxiv.org/abs/2001.02943) (LSTM) | | 75.03 |
120 | | mBERT | 98.285 | 98.033 |
121 | | BERTje | 98.268 | 98.014 |
122 | | RobBERT v2 | **99.232** | **99.121** |
123 |
124 | #### Finetuning on 10K examples
125 |
126 | We also measured the performance using only 10K training examples.
127 | This experiment clearly illustrates that RobBERT outperforms other models when there is little data available.
128 |
129 | | Model | Accuracy [%] | F1 [%] |
130 | |-------------------|--------------------------|--------------|
131 | | mBERT | 92.157 | 90.898 |
132 | | BERTje | 93.096 | 91.279 |
133 | | RobBERT v2 | **97.816** | **97.514** |
134 |
135 | #### Using zero-shot word masking task
136 |
137 | Since BERT models are pre-trained using the word masking task, we can use this to predict whether "die" or "dat" is more likely.
138 | This experiment shows that RobBERT has internalised more information about Dutch than other models.
139 |
140 | | Model | Accuracy [%] |
141 | |-------------------|--------------------------|
142 | | ZeroR | 66.70 |
143 | | mBERT | 90.21 |
144 | | BERTje | 94.94 |
145 | | RobBERT v2 | **98.75** |
146 |
147 | ### Part-of-Speech Tagging.
148 |
149 | Using the [Lassy UD dataset](https://universaldependencies.org/treebanks/nl_lassysmall/index.html).
150 |
151 |
152 | | Model | Accuracy [%] |
153 | |-------------------|--------------------------|
154 | | Frog | 91.7 |
155 | | mBERT | **96.5** |
156 | | BERTje | 96.3 |
157 | | RobBERT v2 | 96.4 |
158 |
159 | Interestingly, we found that when dealing with **small data sets**, RobBERT v2 **significantly outperforms** other models.
160 |
161 |
162 |
163 |
164 |
165 | ### Named Entity Recognition
166 |
167 | Using the [CoNLL 2002 evaluation script](https://www.clips.uantwerpen.be/conll2002/ner/).
168 |
169 |
170 | | Model | Accuracy [%] |
171 | |-------------------|--------------------------|
172 | | Frog | 57.31 |
173 | | mBERT | **90.94** |
174 | | BERT-NL | 89.7 |
175 | | BERTje | 88.3 |
176 | | RobBERT v2 | 89.08 |
177 |
178 |
179 | ## Pre-Training Procedure Details
180 |
181 | We pre-trained RobBERT using the RoBERTa training regime.
182 | We pre-trained our model on the Dutch section of the [OSCAR corpus](https://oscar-corpus.com/), a large multilingual corpus which was obtained by language classification in the Common Crawl corpus.
183 | This Dutch corpus is 39GB large, with 6.6 billion words spread over 126 million lines of text, where each line could contain multiple sentences, thus using more data than concurrently developed Dutch BERT models.
184 |
185 |
186 | RobBERT shares its architecture with [RoBERTa's base model](https://github.com/pytorch/fairseq/tree/master/examples/roberta), which itself is a replication and improvement over BERT.
187 | Like BERT, it's architecture consists of 12 self-attention layers with 12 heads with 117M trainable parameters.
188 | One difference with the original BERT model is due to the different pre-training task specified by RoBERTa, using only the MLM task and not the NSP task.
189 | During pre-training, it thus only predicts which words are masked in certain positions of given sentences.
190 | The training process uses the Adam optimizer with polynomial decay of the learning rate l_r=10^-6 and a ramp-up period of 1000 iterations, with hyperparameters beta_1=0.9
191 | and RoBERTa's default beta_2=0.98.
192 | Additionally, a weight decay of 0.1 and a small dropout of 0.1 helps prevent the model from overfitting.
193 |
194 |
195 | RobBERT was trained on a computing cluster with 4 Nvidia P100 GPUs per node, where the number of nodes was dynamically adjusted while keeping a fixed batch size of 8192 sentences.
196 | At most 20 nodes were used (i.e. 80 GPUs), and the median was 5 nodes.
197 | By using gradient accumulation, the batch size could be set independently of the number of GPUs available, in order to maximally utilize the cluster.
198 | Using the [Fairseq library](https://github.com/pytorch/fairseq/tree/master/examples/roberta), the model trained for two epochs, which equals over 16k batches in total, which took about three days on the computing cluster.
199 | In between training jobs on the computing cluster, 2 Nvidia 1080 Ti's also covered some parameter updates for RobBERT v2.
200 |
201 |
202 | ## Investigating Limitations and Bias
203 |
204 | In the [RobBERT paper](https://arxiv.org/abs/2001.06286), we also investigated potential sources of bias in RobBERT.
205 |
206 | We found that the zeroshot model estimates the probability of *hij* (he) to be higher than *zij* (she) for most occupations in bleached template sentences, regardless of their actual job gender ratio in reality.
207 |
208 |
209 |
210 |
211 |
212 | By augmenting the DBRB Dutch Book sentiment analysis dataset with the stated gender of the author of the review, we found that highly positive reviews written by women were generally more accurately detected by RobBERT as being positive than those written by men.
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 | ## How to Replicate Our Paper Experiments
225 |
226 | You can replicate the experiments done in our paper by following the following steps.
227 | You can install the required dependencies either the requirements.txt or pipenv:
228 | - Installing the dependencies from the requirements.txt file using `pip install -r requirements.txt`
229 | - OR install using [Pipenv](https://pipenv.readthedocs.io/en/latest/) *(install by running `pip install pipenv` in your terminal)* by running `pipenv install`.
230 |
231 |
232 | ### Classification
233 | In this section we describe how to use the scripts we provide to fine-tune models, which should be general enough to reuse for other desired textual classification tasks.
234 |
235 | #### Sentiment analysis using the Dutch Book Review Dataset
236 |
237 | - Download the Dutch book review dataset from [https://github.com/benjaminvdb/DBRD](https://github.com/benjaminvdb/DBRD), and save it to `data/raw/DBRD`
238 | - Run `src/preprocess_dbrd.py` to prepare the dataset.
239 | - To not be blind during training, we recommend to keep aside a small evaluation set from the training set. For this run `src/split_dbrd_training.sh`.
240 | - Follow the notebook `notebooks/finetune_dbrd.ipynb` to finetune the model.
241 |
242 | #### Predicting the Dutch pronouns _die_ and _dat_
243 | We fine-tune our model on the Dutch [Europarl corpus](http://www.statmt.org/europarl/). You can download it first with:
244 |
245 | ```
246 | cd data\raw\europarl\
247 | wget -N 'http://www.statmt.org/europarl/v7/nl-en.tgz'
248 | tar zxvf nl-en.tgz
249 | ```
250 | As a sanity check, now you should have the following files in your `data/raw/europarl` folder:
251 |
252 | ```
253 | europarl-v7.nl-en.en
254 | europarl-v7.nl-en.nl
255 | nl-en.tgz
256 | ```
257 |
258 | Then you can run the preprocessing with the following script, which fill first process the Europarl corpus to remove sentences without any _die_ or _dat_.
259 | Afterwards, it will flip the pronoun and join both sentences together with a `` token.
260 |
261 | ```
262 | python src/preprocess_diedat.py
263 | . src/preprocess_diedat.sh
264 | ```
265 |
266 | note: You can monitor the progress of the first preprocessing step with `watch -n 2 wc -l data/europarl-v7.nl-en.nl.sentences`. This will take a while, but it's certainly not needed to use all inputs. This is after all why you want to use a pre-trained language model. You can terminate the python script at any time and the second step will only use those._
267 |
268 | ## Name Origin of RobBERT
269 |
270 | Most BERT-like models have the word *BERT* in their name (e.g. [RoBERTa](https://huggingface.co/transformers/model_doc/roberta.html), [ALBERT](https://arxiv.org/abs/1909.11942), [CamemBERT](https://camembert-model.fr/), and [many, many others](https://huggingface.co/models?search=bert)).
271 | As such, we queried our newly trained model using its masked language model to name itself *\bert* using [all](https://huggingface.co/pdelobelle/robbert-v2-dutch-base?text=Mijn+naam+is+%3Cmask%3Ebert.) [kinds](https://huggingface.co/pdelobelle/robbert-v2-dutch-base?text=Hallo%2C+ik+ben+%3Cmask%3Ebert.) [of](https://huggingface.co/pdelobelle/robbert-v2-dutch-base?text=Leuk+je+te+ontmoeten%2C+ik+heet+%3Cmask%3Ebert.) [prompts](https://huggingface.co/pdelobelle/robbert-v2-dutch-base?text=Niemand+weet%2C+niemand+weet%2C+dat+ik+%3Cmask%3Ebert+heet.), and it consistently called itself RobBERT.
272 | We thought it was really quite fitting, given that RobBERT is a [*very* Dutch name](https://en.wikipedia.org/wiki/Robbert) *(and thus clearly a Dutch language model)*, and additionally has a high similarity to its root architecture, namely [RoBERTa](https://huggingface.co/transformers/model_doc/roberta.html).
273 |
274 | Since *"rob"* is a Dutch words to denote a seal, we decided to draw a seal and dress it up like [Bert from Sesame Street](https://muppet.fandom.com/wiki/Bert) for the [RobBERT logo](https://github.com/iPieter/RobBERT/blob/master/res/robbert_logo.png).
275 |
276 | ## Credits and citation
277 |
278 | This project is created by [Pieter Delobelle](https://github.com/iPieter), [Thomas Winters](https://github.com/twinters) and Bettina Berendt.
279 |
280 | We are grateful to Liesbeth Allein, for her work on die-dat disambiguation, Huggingface for their transformer package, Facebook for their Fairseq package and all other people whose work we could use.
281 |
282 | We release our models and this code under MIT.
283 |
284 | If you would like to cite our paper or model, you can use the following BibTeX code:
285 |
286 | ```
287 | @inproceedings{delobelle2020robbert,
288 | title = "{R}ob{BERT}: a {D}utch {R}o{BERT}a-based {L}anguage {M}odel",
289 | author = "Delobelle, Pieter and
290 | Winters, Thomas and
291 | Berendt, Bettina",
292 | booktitle = "Findings of the Association for Computational Linguistics: EMNLP 2020",
293 | month = nov,
294 | year = "2020",
295 | address = "Online",
296 | publisher = "Association for Computational Linguistics",
297 | url = "https://www.aclweb.org/anthology/2020.findings-emnlp.292",
298 | doi = "10.18653/v1/2020.findings-emnlp.292",
299 | pages = "3255--3265"
300 | }
301 | ```
302 |
--------------------------------------------------------------------------------
/examples/die_vs_data_rest_api/.gitignore:
--------------------------------------------------------------------------------
1 | venv
2 | app/__pycache__
3 |
--------------------------------------------------------------------------------
/examples/die_vs_data_rest_api/README.md:
--------------------------------------------------------------------------------
1 | # Demo: 'die' vs 'dat' as a rest endpoint with Flask
2 |
3 | As a demo, we release a small Flask example for a rest endpoint to analyse sentences.
4 | It will return whether a sentence contains the correct—according to RobBERT—use of 'die' or 'dat'.
5 |
6 | By default, A Flask server will listen to port 5000. The endpoint is `/`.
7 |
8 | ## Get started
9 | First install the dependencies from the requirements.txt file using `pip install -r requirements.txt`
10 |
11 | ```shell script
12 | $ python app.py --model-path DTAI-KULeuven/robbertje-shuffled-dutch-die-vs-dat --fast-model-path pdelobelle/robbert-v2-dutch-base
13 | ```
14 |
15 | ## Classification model
16 | Simply make a http POST request to `/` with the parameter `sentence` filled in:
17 |
18 | ```shell script
19 | $ curl --data "sentence=Daar loopt _die_ meisje." localhost:5000
20 | ```
21 |
22 | This should give you the following response:
23 |
24 | ```json
25 | {
26 | "rating": 1,
27 | "interpretation": "incorrect",
28 | "confidence": 5.222124099731445,
29 | "sentence": "Daar loopt _die_ meisje."
30 | }
31 | ```
32 |
33 | ## Zero-shot model
34 | We also have a faster zero-shot model (using RobBERT base), which might be faster and easier to use. There is a small drop in accuracy, but that should be quite limited.
35 |
36 | To use the faster zero-shot model, just make a http POST request to `/fast` with the same parameter `sentence` filled in:
37 |
38 | ```shell script
39 | $ curl --data "sentence=Daar loopt _die_ meisje." localhost:5000/fast
40 | ```
41 |
42 | This should give you the following response, which is similar, but also provides `die`, `dat`, `Die` or `Dat` as the rating value:
43 |
44 | ```json
45 | {
46 | "rating": "dat",
47 | "interpretation": "incorrect",
48 | "confidence": 2.567270278930664,
49 | "sentence": "Daar loopt _die_ meisie."
50 | }
51 | ```
--------------------------------------------------------------------------------
/examples/die_vs_data_rest_api/app.py:
--------------------------------------------------------------------------------
1 | from app import create_app
2 | import argparse
3 |
4 | def create_parser():
5 | "Utility function to create the CLI argument parser"
6 |
7 | parser = argparse.ArgumentParser(
8 | description="Create a REST endpoint for for 'die' vs 'dat' disambiguation."
9 | )
10 |
11 | parser.add_argument("--model-path", help="Path to the finetuned RobBERT identifier.", required=False)
12 | parser.add_argument("--fast-model-path", help="Path to the mlm RobBERT identifier.", required=False)
13 |
14 | return parser
15 |
16 | if __name__ == "__main__":
17 | arg_parser = create_parser()
18 | args = arg_parser.parse_args()
19 |
20 | create_parser()
21 | create_app(args.model_path, args.fast_model_path).run()
--------------------------------------------------------------------------------
/examples/die_vs_data_rest_api/app/__init__.py:
--------------------------------------------------------------------------------
1 | from flask import Flask, request
2 | import os
3 | from transformers import (
4 | RobertaForSequenceClassification,
5 | RobertaForMaskedLM,
6 | RobertaTokenizer,
7 | )
8 | import torch
9 | import nltk
10 | from nltk.tokenize.treebank import TreebankWordDetokenizer
11 | import json
12 | import re
13 |
14 |
15 | def replace_query_token(sentence):
16 | "Small utility function to replace a sentence with `_die_` or `_dat_` with the proper RobBERT input."
17 | tokens = nltk.word_tokenize(sentence)
18 | tokens_swapped = nltk.word_tokenize(sentence)
19 | for i, word in enumerate(tokens):
20 | if word == "_die_":
21 | tokens[i] = "die"
22 | tokens_swapped[i] = "dat"
23 |
24 | elif word == "_dat_":
25 | tokens[i] = "dat"
26 | tokens_swapped[i] = "die"
27 |
28 | elif word == "_Dat_":
29 | tokens[i] = "Dat"
30 | tokens_swapped[i] = "Die"
31 |
32 | elif word == "_Die_":
33 | tokens[i] = "Die"
34 |
35 | tokens_swapped[i] = "Dat"
36 |
37 | if word.lower() == "_die_" or word.lower() == "_dat_":
38 | results = TreebankWordDetokenizer().detokenize(tokens)
39 | results_swapped = TreebankWordDetokenizer().detokenize(tokens_swapped)
40 |
41 | return "{} {}".format(results, results_swapped)
42 |
43 | # If we reach the end of the for loop, it means no query token was present
44 | raise ValueError("'die' or 'dat' should be surrounded by underscores.")
45 |
46 |
47 | def create_app(model_path: str, fast_model_path: str, device="cpu"):
48 | """
49 | Create the flask app.
50 |
51 | :param model_path: Path to the finetuned model.
52 | :param device: Pytorch device, default CPU (switch to 'cuda' if a GPU is present)
53 | :return: the flask app
54 | """
55 | app = Flask(__name__, instance_relative_config=True)
56 |
57 | print("initializing tokenizer and RobBERT.")
58 | if model_path:
59 | tokenizer: RobertaTokenizer = RobertaTokenizer.from_pretrained(
60 | model_path, use_auth_token=True
61 | )
62 | robbert = RobertaForSequenceClassification.from_pretrained(
63 | model_path, use_auth_token=True
64 | )
65 | robbert.eval()
66 | print("Loaded finetuned model")
67 |
68 | if fast_model_path:
69 | fast_tokenizer: RobertaTokenizer = RobertaTokenizer.from_pretrained(
70 | fast_model_path, use_auth_token=True
71 | )
72 | fast_robbert = RobertaForMaskedLM.from_pretrained(
73 | fast_model_path, use_auth_token=True
74 | )
75 | fast_robbert.eval()
76 |
77 | print("Loaded MLM model")
78 |
79 | possible_tokens = ["die", "dat", "Die", "Dat"]
80 |
81 | ids = fast_tokenizer.convert_tokens_to_ids(possible_tokens)
82 |
83 | mask_padding_with_zero = True
84 | block_size = 512
85 |
86 | # Disable dropout
87 |
88 | nltk.download("punkt")
89 |
90 | if fast_model_path:
91 |
92 | @app.route("/disambiguation/mlm/all", methods=["POST"])
93 | def split():
94 | sentence = request.form["sentence"]
95 |
96 | response = []
97 | old_pos = 0
98 | for match in re.finditer(r"(die|dat|Die|Dat)+", sentence):
99 | print(
100 | "match",
101 | match.group(),
102 | "start index",
103 | match.start(),
104 | "End index",
105 | match.end(),
106 | )
107 | with torch.no_grad():
108 | query = (
109 | sentence[: match.start()] + "" + sentence[match.end() :]
110 | )
111 | print(query)
112 | if match.start() > 0:
113 | response.append({"part": sentence[old_pos: match.start()]})
114 |
115 | old_pos = match.end()
116 | inputs = fast_tokenizer.encode_plus(query, return_tensors="pt")
117 |
118 | outputs = fast_robbert(**inputs)
119 | masked_position = torch.where(
120 | inputs["input_ids"] == fast_tokenizer.mask_token_id
121 | )[1]
122 | if len(masked_position) > 1:
123 | return "No two queries allowed in one sentence.", 400
124 |
125 | print(outputs.logits[0, masked_position, ids])
126 | token = outputs.logits[0, masked_position, ids].argmax()
127 |
128 | confidence = float(outputs.logits[0, masked_position, ids].max())
129 |
130 | response.append({
131 | "predicted": possible_tokens[token],
132 | "input": match.group(),
133 | "interpretation": "correct"
134 | if possible_tokens[token] == match.group()
135 | else "incorrect",
136 | "confidence": confidence,
137 | "sentence": sentence,
138 | })
139 |
140 | # This would be a good place for logging/storing queries + results
141 | print(response)
142 |
143 | # inputs = fast_tokenizer.encode_plus(query, return_tensors="pt")
144 | response.append({"part": sentence[match.end():]})
145 | return json.dumps(response)
146 |
147 | @app.route("/disambiguation/mlm", methods=["POST"])
148 | def fast():
149 | sentence = request.form["sentence"]
150 | for i, x in enumerate(possible_tokens):
151 | if f"_{x}_" in sentence:
152 | masked_id = i
153 | query = sentence.replace(f"_{x}_", fast_tokenizer.mask_token)
154 |
155 | inputs = fast_tokenizer.encode_plus(query, return_tensors="pt")
156 |
157 | masked_position = torch.where(
158 | inputs["input_ids"] == fast_tokenizer.mask_token_id
159 | )[1]
160 | if len(masked_position) > 1:
161 | return "No two queries allowed in one sentence.", 400
162 |
163 | # self.examples.append([tokenizer.build_inputs_with_special_tokens(tokenized_text[0 : block_size]), [0], [0]])
164 | with torch.no_grad():
165 | outputs = fast_robbert(**inputs)
166 |
167 | print(outputs.logits[0, masked_position, ids])
168 | token = outputs.logits[0, masked_position, ids].argmax()
169 |
170 | confidence = float(outputs.logits[0, masked_position, ids].max())
171 |
172 | response = {
173 | "rating": possible_tokens[token],
174 | "interpretation": "correct" if token == masked_id else "incorrect",
175 | "confidence": confidence,
176 | "sentence": sentence,
177 | }
178 |
179 | # This would be a good place for logging/storing queries + results
180 | print(response)
181 |
182 | return json.dumps(response)
183 |
184 | if model_path:
185 |
186 | @app.route("/disambiguation/classifier", methods=["POST"])
187 | def main():
188 | sentence = request.form["sentence"]
189 | query = replace_query_token(sentence)
190 |
191 | tokenized_text = tokenizer.encode(
192 | tokenizer.tokenize(query)[-block_size + 3 : -1]
193 | )
194 |
195 | input_mask = [1 if mask_padding_with_zero else 0] * len(tokenized_text)
196 |
197 | pad_token = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
198 | while len(tokenized_text) < block_size:
199 | tokenized_text.append(pad_token)
200 | input_mask.append(0 if mask_padding_with_zero else 1)
201 | # segment_ids.append(pad_token_segment_id)
202 | # p_mask.append(1)
203 |
204 | # self.examples.append([tokenizer.build_inputs_with_special_tokens(tokenized_text[0 : block_size]), [0], [0]])
205 | batch = tuple(
206 | torch.tensor(t).to(torch.device(device))
207 | for t in [
208 | tokenized_text[0 : block_size - 3],
209 | input_mask[0 : block_size - 3],
210 | [0],
211 | [1][0],
212 | ]
213 | )
214 | inputs = {
215 | "input_ids": batch[0].unsqueeze(0),
216 | "attention_mask": batch[1].unsqueeze(0),
217 | "labels": batch[3].unsqueeze(0),
218 | }
219 | with torch.no_grad():
220 | outputs = robbert(**inputs)
221 |
222 | rating = outputs[1].argmax().item()
223 | confidence = outputs[1][0, rating].item()
224 |
225 | response = {
226 | "rating": rating,
227 | "interpretation": "incorrect" if rating == 1 else "correct",
228 | "confidence": confidence,
229 | "sentence": sentence,
230 | }
231 |
232 | # This would be a good place for logging/storing queries + results
233 | print(response)
234 |
235 | return json.dumps(response)
236 |
237 | return app
238 |
--------------------------------------------------------------------------------
/examples/die_vs_data_rest_api/app/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iPieter/RobBERT/8f562fe3e79ec8c0ea04051277b6ae86e7e382e9/examples/die_vs_data_rest_api/app/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/examples/die_vs_data_rest_api/app/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "port": 5000
3 | }
--------------------------------------------------------------------------------
/examples/die_vs_data_rest_api/requirements.txt:
--------------------------------------------------------------------------------
1 | boto3==1.12.36
2 | botocore==1.15.36
3 | certifi==2020.4.5.1
4 | chardet==3.0.4
5 | click==7.1.1
6 | docutils==0.15.2
7 | filelock==3.0.12
8 | Flask==1.1.2
9 | idna==2.9
10 | itsdangerous==1.1.0
11 | Jinja2==2.11.1
12 | jmespath==0.9.5
13 | joblib==0.14.1
14 | MarkupSafe==1.1.1
15 | nltk==3.4.5
16 | numpy==1.18.2
17 | python-dateutil==2.8.1
18 | regex==2020.4.4
19 | requests==2.23.0
20 | s3transfer==0.3.3
21 | sacremoses==0.0.38
22 | sentencepiece==0.1.85
23 | six==1.14.0
24 | tokenizers==0.5.2
25 | torch==1.4.0
26 | tqdm==4.45.0
27 | transformers==2.7.0
28 | urllib3==1.25.8
29 | Werkzeug==1.0.1
30 |
--------------------------------------------------------------------------------
/model_cards/README.md:
--------------------------------------------------------------------------------
1 | ---
2 | language: "nl"
3 | thumbnail: "https://github.com/iPieter/RobBERT/raw/master/res/robbert_logo.png"
4 | tags:
5 | - Dutch
6 | - Flemish
7 | - RoBERTa
8 | - RobBERT
9 | license: mit
10 | datasets:
11 | - oscar
12 | - oscar (NL)
13 | - dbrd
14 | - lassy-ud
15 | - europarl-mono
16 | - conll2002
17 | widget:
18 | - text: "Hallo, ik ben RobBERT, een taalmodel van de KU Leuven."
19 | ---
20 |
21 |
22 |
23 |
24 |
25 | # RobBERT: Dutch RoBERTa-based Language Model.
26 |
27 | [RobBERT](https://github.com/iPieter/RobBERT) is the state-of-the-art Dutch BERT model. It is a large pre-trained general Dutch language model that can be fine-tuned on a given dataset to perform any text classification, regression or token-tagging task. As such, it has been successfully used by many [researchers](https://scholar.google.com/scholar?oi=bibs&hl=en&cites=7180110604335112086) and [practitioners](https://huggingface.co/models?search=robbert) for achieving state-of-the-art performance for a wide range of Dutch natural language processing tasks, including:
28 |
29 | - [Emotion detection](https://www.aclweb.org/anthology/2021.wassa-1.27/)
30 | - Sentiment analysis ([book reviews](https://arxiv.org/pdf/2001.06286.pdf), [news articles](https://biblio.ugent.be/publication/8704637/file/8704638.pdf)*)
31 | - [Coreference resolution](https://arxiv.org/pdf/2001.06286.pdf)
32 | - Named entity recognition ([CoNLL](https://arxiv.org/pdf/2001.06286.pdf), [job titles](https://arxiv.org/pdf/2004.02814.pdf)*, [SoNaR](https://github.com/proycon/deepfrog))
33 | - Part-of-speech tagging ([Small UD Lassy](https://arxiv.org/pdf/2001.06286.pdf), [CGN](https://github.com/proycon/deepfrog))
34 | - [Zero-shot word prediction](https://arxiv.org/pdf/2001.06286.pdf)
35 | - [Humor detection](https://arxiv.org/pdf/2010.13652.pdf)
36 | - [Cyberbulling detection](https://www.cambridge.org/core/journals/natural-language-engineering/article/abs/automatic-classification-of-participant-roles-in-cyberbullying-can-we-detect-victims-bullies-and-bystanders-in-social-media-text/A2079C2C738C29428E666810B8903342)
37 | - [Correcting dt-spelling mistakes](https://gitlab.com/spelfouten/dutch-simpletransformers/)*
38 |
39 | and also achieved outstanding, near-sota results for:
40 |
41 | - [Natural language inference](https://arxiv.org/pdf/2101.05716.pdf)*
42 | - [Review classification](https://medium.com/broadhorizon-cmotions/nlp-with-r-part-5-state-of-the-art-in-nlp-transformers-bert-3449e3cd7494)*
43 |
44 | \* *Note that several evaluations use RobBERT-v1, and that the second and improved RobBERT-v2 outperforms this first model on everything we tested*
45 |
46 | *(Also note that this list is not exhaustive. If you used RobBERT for your application, we are happy to know about it! Send us a mail, or add it yourself to this list by sending a pull request with the edit!)*
47 |
48 | More in-depth information about RobBERT can be found in our [blog post](https://people.cs.kuleuven.be/~pieter.delobelle/robbert/), [our paper](https://arxiv.org/abs/2001.06286) and [the RobBERT Github repository](https://github.com/iPieter/RobBERT)
49 |
50 |
51 | ## How to use
52 |
53 | RobBERT uses the [RoBERTa](https://arxiv.org/abs/1907.11692) architecture and pre-training but with a Dutch tokenizer and training data. RoBERTa is the robustly optimized English BERT model, making it even more powerful than the original BERT model. Given this same architecture, RobBERT can easily be finetuned and inferenced using [code to finetune RoBERTa](https://huggingface.co/transformers/model_doc/roberta.html) models and most code used for BERT models, e.g. as provided by [HuggingFace Transformers](https://huggingface.co/transformers/) library.
54 |
55 | By default, RobBERT has the masked language model head used in training. This can be used as a zero-shot way to fill masks in sentences. It can be tested out for free on [RobBERT's Hosted infererence API of Huggingface](https://huggingface.co/pdelobelle/robbert-v2-dutch-base?text=De+hoofdstad+van+Belgi%C3%AB+is+%3Cmask%3E.). You can also create a new prediction head for your own task by using any of HuggingFace's [RoBERTa-runners](https://huggingface.co/transformers/v2.7.0/examples.html#language-model-training), [their fine-tuning notebooks](https://huggingface.co/transformers/v4.1.1/notebooks.html) by changing the model name to `pdelobelle/robbert-v2-dutch-base`, or use the original fairseq [RoBERTa](https://github.com/pytorch/fairseq/tree/master/examples/roberta) training regimes.
56 |
57 | Use the following code to download the base model and finetune it yourself, or use one of our finetuned models (documented on [our project site](https://people.cs.kuleuven.be/~pieter.delobelle/robbert/)).
58 |
59 | ```python
60 | from transformers import RobertaTokenizer, RobertaForSequenceClassification
61 | tokenizer = RobertaTokenizer.from_pretrained("pdelobelle/robbert-v2-dutch-base")
62 | model = RobertaForSequenceClassification.from_pretrained("pdelobelle/robbert-v2-dutch-base")
63 | ```
64 |
65 | Starting with `transformers v2.4.0` (or installing from source), you can use AutoTokenizer and AutoModel.
66 | You can then use most of [HuggingFace's BERT-based notebooks](https://huggingface.co/transformers/v4.1.1/notebooks.html) for finetuning RobBERT on your type of Dutch language dataset.
67 |
68 |
69 | ## Technical Details From The Paper
70 |
71 |
72 | ### Our Performance Evaluation Results
73 |
74 | All experiments are described in more detail in our [paper](https://arxiv.org/abs/2001.06286), with the code in [our GitHub repository](https://github.com/iPieter/RobBERT).
75 |
76 | ### Sentiment analysis
77 | Predicting whether a review is positive or negative using the [Dutch Book Reviews Dataset](https://github.com/benjaminvdb/DBRD).
78 |
79 | | Model | Accuracy [%] |
80 | |-------------------|--------------------------|
81 | | ULMFiT | 93.8 |
82 | | BERTje | 93.0 |
83 | | RobBERT v2 | **95.1** |
84 |
85 | ### Die/Dat (coreference resolution)
86 |
87 | We measured how well the models are able to do coreference resolution by predicting whether "die" or "dat" should be filled into a sentence.
88 | For this, we used the [EuroParl corpus](https://www.statmt.org/europarl/).
89 |
90 | #### Finetuning on whole dataset
91 |
92 | | Model | Accuracy [%] | F1 [%] |
93 | |-------------------|--------------------------|--------------|
94 | | [Baseline](https://arxiv.org/abs/2001.02943) (LSTM) | | 75.03 |
95 | | mBERT | 98.285 | 98.033 |
96 | | BERTje | 98.268 | 98.014 |
97 | | RobBERT v2 | **99.232** | **99.121** |
98 |
99 | #### Finetuning on 10K examples
100 |
101 | We also measured the performance using only 10K training examples.
102 | This experiment clearly illustrates that RobBERT outperforms other models when there is little data available.
103 |
104 | | Model | Accuracy [%] | F1 [%] |
105 | |-------------------|--------------------------|--------------|
106 | | mBERT | 92.157 | 90.898 |
107 | | BERTje | 93.096 | 91.279 |
108 | | RobBERT v2 | **97.816** | **97.514** |
109 |
110 | #### Using zero-shot word masking task
111 |
112 | Since BERT models are pre-trained using the word masking task, we can use this to predict whether "die" or "dat" is more likely.
113 | This experiment shows that RobBERT has internalised more information about Dutch than other models.
114 |
115 | | Model | Accuracy [%] |
116 | |-------------------|--------------------------|
117 | | ZeroR | 66.70 |
118 | | mBERT | 90.21 |
119 | | BERTje | 94.94 |
120 | | RobBERT v2 | **98.75** |
121 |
122 | ### Part-of-Speech Tagging.
123 |
124 | Using the [Lassy UD dataset](https://universaldependencies.org/treebanks/nl_lassysmall/index.html).
125 |
126 |
127 | | Model | Accuracy [%] |
128 | |-------------------|--------------------------|
129 | | Frog | 91.7 |
130 | | mBERT | **96.5** |
131 | | BERTje | 96.3 |
132 | | RobBERT v2 | 96.4 |
133 |
134 | Interestingly, we found that when dealing with **small data sets**, RobBERT v2 **significantly outperforms** other models.
135 |
136 |
137 |
138 |
139 |
140 | ### Named Entity Recognition
141 |
142 | Using the [CoNLL 2002 evaluation script](https://www.clips.uantwerpen.be/conll2002/ner/).
143 |
144 |
145 | | Model | Accuracy [%] |
146 | |-------------------|--------------------------|
147 | | Frog | 57.31 |
148 | | mBERT | **90.94** |
149 | | BERT-NL | 89.7 |
150 | | BERTje | 88.3 |
151 | | RobBERT v2 | 89.08 |
152 |
153 |
154 | ## Pre-Training Procedure Details
155 |
156 | We pre-trained RobBERT using the RoBERTa training regime.
157 | We pre-trained our model on the Dutch section of the [OSCAR corpus](https://oscar-corpus.com/), a large multilingual corpus which was obtained by language classification in the Common Crawl corpus.
158 | This Dutch corpus is 39GB large, with 6.6 billion words spread over 126 million lines of text, where each line could contain multiple sentences, thus using more data than concurrently developed Dutch BERT models.
159 |
160 |
161 | RobBERT shares its architecture with [RoBERTa's base model](https://github.com/pytorch/fairseq/tree/master/examples/roberta), which itself is a replication and improvement over BERT.
162 | Like BERT, it's architecture consists of 12 self-attention layers with 12 heads with 117M trainable parameters.
163 | One difference with the original BERT model is due to the different pre-training task specified by RoBERTa, using only the MLM task and not the NSP task.
164 | During pre-training, it thus only predicts which words are masked in certain positions of given sentences.
165 | The training process uses the Adam optimizer with polynomial decay of the learning rate l_r=10^-6 and a ramp-up period of 1000 iterations, with hyperparameters beta_1=0.9
166 | and RoBERTa's default beta_2=0.98.
167 | Additionally, a weight decay of 0.1 and a small dropout of 0.1 helps prevent the model from overfitting.
168 |
169 |
170 | RobBERT was trained on a computing cluster with 4 Nvidia P100 GPUs per node, where the number of nodes was dynamically adjusted while keeping a fixed batch size of 8192 sentences.
171 | At most 20 nodes were used (i.e. 80 GPUs), and the median was 5 nodes.
172 | By using gradient accumulation, the batch size could be set independently of the number of GPUs available, in order to maximally utilize the cluster.
173 | Using the [Fairseq library](https://github.com/pytorch/fairseq/tree/master/examples/roberta), the model trained for two epochs, which equals over 16k batches in total, which took about three days on the computing cluster.
174 | In between training jobs on the computing cluster, 2 Nvidia 1080 Ti's also covered some parameter updates for RobBERT v2.
175 |
176 |
177 | ## Investigating Limitations and Bias
178 |
179 | In the [RobBERT paper](https://arxiv.org/abs/2001.06286), we also investigated potential sources of bias in RobBERT.
180 |
181 | We found that the zeroshot model estimates the probability of *hij* (he) to be higher than *zij* (she) for most occupations in bleached template sentences, regardless of their actual job gender ratio in reality.
182 |
183 |
184 |
185 |
186 |
187 | By augmenting the DBRB Dutch Book sentiment analysis dataset with the stated gender of the author of the review, we found that highly positive reviews written by women were generally more accurately detected by RobBERT as being positive than those written by men.
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 | ## How to Replicate Our Paper Experiments
196 | Replicating our paper experiments is [described in detail on teh RobBERT repository README](https://github.com/iPieter/RobBERT#how-to-replicate-our-paper-experiments).
197 |
198 | ## Name Origin of RobBERT
199 |
200 | Most BERT-like models have the word *BERT* in their name (e.g. [RoBERTa](https://huggingface.co/transformers/model_doc/roberta.html), [ALBERT](https://arxiv.org/abs/1909.11942), [CamemBERT](https://camembert-model.fr/), and [many, many others](https://huggingface.co/models?search=bert)).
201 | As such, we queried our newly trained model using its masked language model to name itself *\bert* using [all](https://huggingface.co/pdelobelle/robbert-v2-dutch-base?text=Mijn+naam+is+%3Cmask%3Ebert.) [kinds](https://huggingface.co/pdelobelle/robbert-v2-dutch-base?text=Hallo%2C+ik+ben+%3Cmask%3Ebert.) [of](https://huggingface.co/pdelobelle/robbert-v2-dutch-base?text=Leuk+je+te+ontmoeten%2C+ik+heet+%3Cmask%3Ebert.) [prompts](https://huggingface.co/pdelobelle/robbert-v2-dutch-base?text=Niemand+weet%2C+niemand+weet%2C+dat+ik+%3Cmask%3Ebert+heet.), and it consistently called itself RobBERT.
202 | We thought it was really quite fitting, given that RobBERT is a [*very* Dutch name](https://en.wikipedia.org/wiki/Robbert) *(and thus clearly a Dutch language model)*, and additionally has a high similarity to its root architecture, namely [RoBERTa](https://huggingface.co/transformers/model_doc/roberta.html).
203 |
204 | Since *"rob"* is a Dutch words to denote a seal, we decided to draw a seal and dress it up like [Bert from Sesame Street](https://muppet.fandom.com/wiki/Bert) for the [RobBERT logo](https://github.com/iPieter/RobBERT/blob/master/res/robbert_logo.png).
205 |
206 | ## Credits and citation
207 |
208 | This project is created by [Pieter Delobelle](https://people.cs.kuleuven.be/~pieter.delobelle), [Thomas Winters](https://thomaswinters.be) and [Bettina Berendt](https://people.cs.kuleuven.be/~bettina.berendt/).
209 | If you would like to cite our paper or model, you can use the following BibTeX:
210 |
211 | ```
212 | @inproceedings{delobelle2020robbert,
213 | title = "{R}ob{BERT}: a {D}utch {R}o{BERT}a-based {L}anguage {M}odel",
214 | author = "Delobelle, Pieter and
215 | Winters, Thomas and
216 | Berendt, Bettina",
217 | booktitle = "Findings of the Association for Computational Linguistics: EMNLP 2020",
218 | month = nov,
219 | year = "2020",
220 | address = "Online",
221 | publisher = "Association for Computational Linguistics",
222 | url = "https://www.aclweb.org/anthology/2020.findings-emnlp.292",
223 | doi = "10.18653/v1/2020.findings-emnlp.292",
224 | pages = "3255--3265"
225 | }
226 | ```
227 |
--------------------------------------------------------------------------------
/notebooks/demo_RobBERT_for_conll_ner.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "pycharm": {
7 | "name": "#%% md\n"
8 | }
9 | },
10 | "source": [
11 | "# Demo of RobBERT for Dutch named entity recognition\n",
12 | "We use a [RobBERT (Delobelle et al., 2020)](https://arxiv.org/abs/2001.06286) model for NER.\n",
13 | "\n",
14 | "**Dependencies**\n",
15 | "- tokenizers\n",
16 | "- torch\n",
17 | "- transformers"
18 | ]
19 | },
20 | {
21 | "cell_type": "markdown",
22 | "metadata": {},
23 | "source": [
24 | "First we load our RobBERT model that was pretrained on OSCAR and finetuned on Dutch named entity recognition. We also load in RobBERT's tokenizer.\n",
25 | "\n",
26 | "Because we only want to get results, we have to disable dropout etc. So we add `model.eval()`."
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": 2,
32 | "metadata": {
33 | "collapsed": false,
34 | "jupyter": {
35 | "outputs_hidden": false
36 | },
37 | "pycharm": {
38 | "is_executing": false,
39 | "name": "#%%\n"
40 | }
41 | },
42 | "outputs": [
43 | {
44 | "name": "stderr",
45 | "output_type": "stream",
46 | "text": [
47 | "Special tokens have been added in the vocabulary, make sure the associated word embedding are fine-tuned or trained.\n"
48 | ]
49 | },
50 | {
51 | "name": "stdout",
52 | "output_type": "stream",
53 | "text": [
54 | "RobBERT model loaded\n"
55 | ]
56 | }
57 | ],
58 | "source": [
59 | "import torch\n",
60 | "from transformers import RobertaTokenizer, RobertaForTokenClassification\n",
61 | "\n",
62 | "tokenizer = RobertaTokenizer.from_pretrained('pdelobelle/robbert-v2-dutch-ner')\n",
63 | "model = RobertaForTokenClassification.from_pretrained('pdelobelle/robbert-v2-dutch-ner', return_dict=True)\n",
64 | "model.eval()\n",
65 | "print(\"RobBERT model loaded\")"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": 3,
71 | "metadata": {
72 | "collapsed": false,
73 | "jupyter": {
74 | "outputs_hidden": false
75 | },
76 | "pycharm": {
77 | "is_executing": false,
78 | "name": "#%% \n"
79 | }
80 | },
81 | "outputs": [
82 | {
83 | "name": "stdout",
84 | "output_type": "stream",
85 | "text": [
86 | "input_ids:\n",
87 | "\ttensor([[ 0, 6079, 499, 38, 5, 13292, 11, 6422, 8, 7010,\n",
88 | " 9, 2617, 4, 2, 1],\n",
89 | " [ 0, 25907, 129, 1283, 8, 3971, 113, 28, 118, 71,\n",
90 | " 435, 38, 27600, 4, 2],\n",
91 | " [ 0, 9396, 89, 9, 797, 2877, 22, 11, 5, 4290,\n",
92 | " 445, 4, 2, 1, 1],\n",
93 | " [ 0, 7751, 6, 74, 458, 12, 3663, 14334, 342, 4,\n",
94 | " 2, 1, 1, 1, 1]])\n",
95 | "attention_mask:\n",
96 | "\ttensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],\n",
97 | " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
98 | " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
99 | " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0]])\n",
100 | "Tokens:\n",
101 | "\t['', 'Jan', 'Ġging', 'Ġnaar', 'Ġde', 'Ġbakker', 'Ġin', 'ĠLeuven', 'Ġen', 'Ġkocht', 'Ġeen', 'Ġbrood', '.', '', '']\n",
102 | "\t['', 'Bedrijven', 'Ġzoals', 'ĠGoogle', 'Ġen', 'ĠMicrosoft', 'Ġdoen', 'Ġook', 'Ġheel', 'Ġveel', 'Ġonderzoek', 'Ġnaar', 'ĠNLP', '.', '']\n"
103 | ]
104 | }
105 | ],
106 | "source": [
107 | "inputs = tokenizer.batch_encode_plus(\n",
108 | " [\"Jan ging naar de bakker in Leuven en kocht een brood.\",\n",
109 | " \"Bedrijven zoals Google en Microsoft doen ook heel veel onderzoek naar NLP.\",\n",
110 | " \"Men moet een gegeven paard niet in de bek kijken.\",\n",
111 | " \"Hallo, mijn naam is RobBERT.\"],\n",
112 | " return_tensors=\"pt\", padding=True)\n",
113 | "for key, value in inputs.items():\n",
114 | " print(\"{}:\\n\\t{}\".format(key, value))\n",
115 | "print(\"Tokens:\\n\\t{}\".format(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]) ))\n",
116 | "print(\"\\t{}\".format(tokenizer.convert_ids_to_tokens(inputs['input_ids'][1]) ))"
117 | ]
118 | },
119 | {
120 | "cell_type": "markdown",
121 | "metadata": {
122 | "pycharm": {
123 | "is_executing": false,
124 | "name": "#%% md\n"
125 | }
126 | },
127 | "source": [
128 | "In our model config, we stored what labels we use. \n",
129 | "We can load these in and automatically convert our predictions to a human-readable format.\n",
130 | "For reference, we have 4 types of named entities:\n",
131 | "\n",
132 | "- PER\n",
133 | "- LOC\n",
134 | "- ORG\n",
135 | "- MISC\n",
136 | "\n",
137 | "And we mark the first token with `B-` and then we mark a continuation with `I-`. "
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "execution_count": 4,
143 | "metadata": {
144 | "collapsed": false,
145 | "jupyter": {
146 | "outputs_hidden": false
147 | },
148 | "pycharm": {
149 | "is_executing": false,
150 | "name": "#%%\n"
151 | }
152 | },
153 | "outputs": [
154 | {
155 | "name": "stdout",
156 | "output_type": "stream",
157 | "text": [
158 | "{0: 'B-PER', 1: 'B-ORG', 2: 'B-LOC', 3: 'B-MISC', 4: 'I-PER', 5: 'I-ORG', 6: 'I-LOC', 7: 'I-MISC', 8: 'O'}\n"
159 | ]
160 | }
161 | ],
162 | "source": [
163 | "print(model.config.id2label)"
164 | ]
165 | },
166 | {
167 | "cell_type": "markdown",
168 | "metadata": {
169 | "pycharm": {
170 | "name": "#%% md\n"
171 | }
172 | },
173 | "source": [
174 | "Ok, let's do some predictions! Since we have a batch of 4 sentences, we can do this in one batch—as long as it fits on your GPU.\n",
175 | "\n",
176 | "_If the formatting of this fails, you can try to zoom out or make the window wider_"
177 | ]
178 | },
179 | {
180 | "cell_type": "code",
181 | "execution_count": 5,
182 | "metadata": {
183 | "collapsed": false,
184 | "jupyter": {
185 | "outputs_hidden": false
186 | },
187 | "pycharm": {
188 | "is_executing": false,
189 | "name": "#%%\n"
190 | }
191 | },
192 | "outputs": [
193 | {
194 | "name": "stdout",
195 | "output_type": "stream",
196 | "text": [
197 | "Sentence 0\n",
198 | " Jan Ġging Ġnaar Ġde Ġbakker Ġin ĠLeuven Ġen Ġkocht Ġeen Ġbrood . \n",
199 | "\n",
200 | "O B-PER O O O O O B-LOC O O O O O O O \n",
201 | "\n",
202 | "Sentence 1\n",
203 | " Bedrijven Ġzoals ĠGoogle Ġen ĠMicrosoft Ġdoen Ġook Ġheel Ġveel Ġonderzoek Ġnaar ĠNLP . \n",
204 | "\n",
205 | "O O O B-ORG O B-ORG O O O O O O B-MISC O O \n",
206 | "\n",
207 | "Sentence 2\n",
208 | " Men Ġmoet Ġeen Ġgegeven Ġpaard Ġniet Ġin Ġde Ġbek Ġkijken . \n",
209 | "\n",
210 | "O O O O O O O O O O O O O O O \n",
211 | "\n",
212 | "Sentence 3\n",
213 | " Hallo , Ġmijn Ġnaam Ġis ĠRob BER T . \n",
214 | "\n",
215 | "O O O O O O B-PER I-PER I-PER O O O O O I-PER \n",
216 | "\n"
217 | ]
218 | }
219 | ],
220 | "source": [
221 | "with torch.no_grad():\n",
222 | " results = model(**inputs)\n",
223 | " for i, input in enumerate(inputs['input_ids']):\n",
224 | " print(f\"Sentence {i}\")\n",
225 | " [print(\"{:12}\".format(token), end=\"\") for token in tokenizer.convert_ids_to_tokens(input) ]\n",
226 | " print('\\n')\n",
227 | " [print(\"{:12}\".format(model.config.id2label[item.item()]), end=\"\") for item in results.logits[i].argmax(axis=1)]\n",
228 | " print('\\n')"
229 | ]
230 | },
231 | {
232 | "cell_type": "markdown",
233 | "metadata": {
234 | "pycharm": {
235 | "name": "#%% md\n"
236 | }
237 | },
238 | "source": [
239 | "Ok, this works nicely! We have 'Jan', 'Leuven' and companies like 'Google' that are all labeled correctly. \n",
240 | "In addition, RobBERT consists of multiple tokens (perhaps we should have added one with it's name) and that works with the `I-` token as well.\n",
241 | "\n"
242 | ]
243 | }
244 | ],
245 | "metadata": {
246 | "kernelspec": {
247 | "display_name": "Python 3",
248 | "language": "python",
249 | "name": "python3"
250 | },
251 | "language_info": {
252 | "codemirror_mode": {
253 | "name": "ipython",
254 | "version": 3
255 | },
256 | "file_extension": ".py",
257 | "mimetype": "text/x-python",
258 | "name": "python",
259 | "nbconvert_exporter": "python",
260 | "pygments_lexer": "ipython3",
261 | "version": "3.7.3"
262 | },
263 | "pycharm": {
264 | "stem_cell": {
265 | "cell_type": "raw",
266 | "metadata": {
267 | "collapsed": false
268 | },
269 | "source": []
270 | }
271 | }
272 | },
273 | "nbformat": 4,
274 | "nbformat_minor": 4
275 | }
276 |
--------------------------------------------------------------------------------
/notebooks/demo_RobBERT_for_masked_LM.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "collapsed": true,
7 | "pycharm": {
8 | "name": "#%% md\n"
9 | }
10 | },
11 | "source": [
12 | "# Demo of RobBERT for humour detection\n",
13 | "We use a [RobBERT (Delobelle et al., 2020)](https://arxiv.org/abs/2001.06286) model with the original pretraining head for MLM.\n",
14 | "\n",
15 | "**Dependencies**\n",
16 | "- tokenizers\n",
17 | "- torch\n",
18 | "- transformers"
19 | ]
20 | },
21 | {
22 | "cell_type": "markdown",
23 | "metadata": {},
24 | "source": [
25 | "First we load our RobBERT model that was pretrained. We also load in RobBERT's tokenizer.\n",
26 | "\n",
27 | "Because we only want to get results, we have to disable dropout etc. So we add `model.eval()`."
28 | ]
29 | },
30 | {
31 | "cell_type": "markdown",
32 | "metadata": {},
33 | "source": [
34 | "*Note: we pretrained both RobBERT v1 and RobBERT v2 in [Fairseq](https://github.com/pytorch/fairseq) and converted these checkpoints to HuggingFace. The MLM task behaves a bit differently.*"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": 1,
40 | "metadata": {
41 | "pycharm": {
42 | "is_executing": false,
43 | "name": "#%%\n"
44 | }
45 | },
46 | "outputs": [
47 | {
48 | "name": "stderr",
49 | "text": [
50 | "Special tokens have been added in the vocabulary, make sure the associated word emebedding are fine-tuned or trained.\n"
51 | ],
52 | "output_type": "stream"
53 | },
54 | {
55 | "data": {
56 | "text/plain": "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=469740689.0, style=ProgressStyle(descri…",
57 | "application/vnd.jupyter.widget-view+json": {
58 | "version_major": 2,
59 | "version_minor": 0,
60 | "model_id": "4156a63ac55d4844a22db25f1db7ed52"
61 | }
62 | },
63 | "metadata": {},
64 | "output_type": "display_data"
65 | },
66 | {
67 | "name": "stdout",
68 | "text": [
69 | "\n",
70 | "RobBERT model loaded\n"
71 | ],
72 | "output_type": "stream"
73 | }
74 | ],
75 | "source": [
76 | "import torch\n",
77 | "from transformers import RobertaTokenizer, AutoModelForSequenceClassification, AutoConfig\n",
78 | "\n",
79 | "from transformers import RobertaTokenizer, RobertaForMaskedLM\n",
80 | "import torch\n",
81 | "tokenizer = RobertaTokenizer.from_pretrained('pdelobelle/robbert-v2-dutch-base')\n",
82 | "model = RobertaForMaskedLM.from_pretrained('pdelobelle/robbert-v2-dutch-base', return_dict=True)\n",
83 | "model = model.to( 'cuda' if torch.cuda.is_available() else 'cpu' )\n",
84 | "model.eval()\n",
85 | "#model = RobertaForMaskedLM.from_pretrained('pdelobelle/robbert-v2-dutch-base', return_dict=True)\n",
86 | "print(\"RobBERT model loaded\")"
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": 2,
92 | "metadata": {
93 | "pycharm": {
94 | "is_executing": false,
95 | "name": "#%% \n"
96 | }
97 | },
98 | "outputs": [],
99 | "source": [
100 | "sequence = f\"Er staat een {tokenizer.mask_token} in mijn tuin.\""
101 | ]
102 | },
103 | {
104 | "cell_type": "code",
105 | "execution_count": 3,
106 | "metadata": {
107 | "pycharm": {
108 | "is_executing": false,
109 | "name": "#%%\n"
110 | }
111 | },
112 | "outputs": [],
113 | "source": [
114 | "input = tokenizer.encode(sequence, return_tensors=\"pt\").to( 'cuda' if torch.cuda.is_available() else 'cpu' )\n",
115 | "mask_token_index = torch.where(input == tokenizer.mask_token_id)[1]"
116 | ]
117 | },
118 | {
119 | "cell_type": "markdown",
120 | "metadata": {
121 | "pycharm": {
122 | "is_executing": false,
123 | "name": "#%% md\n"
124 | }
125 | },
126 | "source": [
127 | "Now that we have our tokenized input and the position of the masked token, we pass the input through RobBERT. \n",
128 | "\n",
129 | "This will give us a predicting for all tokens, but we're only interested in the `` token. "
130 | ]
131 | },
132 | {
133 | "cell_type": "code",
134 | "execution_count": 4,
135 | "metadata": {
136 | "pycharm": {
137 | "is_executing": false,
138 | "name": "#%%\n"
139 | }
140 | },
141 | "outputs": [],
142 | "source": [
143 | "with torch.no_grad():\n",
144 | " token_logits = model(input).logits"
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": 5,
150 | "metadata": {
151 | "pycharm": {
152 | "is_executing": false,
153 | "name": "#%%\n"
154 | }
155 | },
156 | "outputs": [
157 | {
158 | "name": "stdout",
159 | "text": [
160 | "Ġboom | id = 2600 | p = 0.1416003555059433\n",
161 | "Ġvijver | id = 8217 | p = 0.13144515454769135\n",
162 | "Ġplant | id = 2721 | p = 0.043418534100055695\n",
163 | "Ġhuis | id = 251 | p = 0.01847737282514572\n",
164 | "Ġparkeerplaats | id = 6889 | p = 0.018001794815063477\n",
165 | "Ġbankje | id = 21620 | p = 0.016940612345933914\n",
166 | "Ġmuur | id = 2035 | p = 0.014668751507997513\n",
167 | "Ġmoestuin | id = 17446 | p = 0.0144038125872612\n",
168 | "Ġzonnebloem | id = 30757 | p = 0.014375611208379269\n",
169 | "Ġschutting | id = 15000 | p = 0.013991709798574448\n",
170 | "Ġpaal | id = 8626 | p = 0.01358739286661148\n",
171 | "Ġbloem | id = 3001 | p = 0.01199684850871563\n",
172 | "Ġstal | id = 7416 | p = 0.011224730871617794\n",
173 | "Ġfontein | id = 23425 | p = 0.011203107424080372\n",
174 | "Ġtuin | id = 671 | p = 0.010676783509552479\n"
175 | ],
176 | "output_type": "stream"
177 | }
178 | ],
179 | "source": [
180 | "logits = token_logits[0, mask_token_index, :].squeeze()\n",
181 | "prob = logits.softmax(dim=0)\n",
182 | "values, indeces = prob.topk(k=15, dim=0)\n",
183 | "\n",
184 | "for index, token in enumerate(tokenizer.convert_ids_to_tokens(indeces)):\n",
185 | " print(f\"{token:20} | id = {indeces[index]:4} | p = {values[index]}\")"
186 | ]
187 | },
188 | {
189 | "cell_type": "markdown",
190 | "metadata": {
191 | "pycharm": {
192 | "is_executing": false,
193 | "name": "#%% md\n"
194 | }
195 | },
196 | "source": [
197 | "## RobBERT with pipelines\n",
198 | "We can also use the `fill-mask` pipeline from Huggingface, that does basically the same thing."
199 | ]
200 | },
201 | {
202 | "cell_type": "code",
203 | "execution_count": 6,
204 | "outputs": [
205 | {
206 | "name": "stderr",
207 | "text": [
208 | "Special tokens have been added in the vocabulary, make sure the associated word emebedding are fine-tuned or trained.\n"
209 | ],
210 | "output_type": "stream"
211 | }
212 | ],
213 | "source": [
214 | "from transformers import pipeline\n",
215 | "p = pipeline(\"fill-mask\", model=\"pdelobelle/robbert-v2-dutch-base\")"
216 | ],
217 | "metadata": {
218 | "collapsed": false,
219 | "pycharm": {
220 | "name": "#%%\n",
221 | "is_executing": false
222 | }
223 | }
224 | },
225 | {
226 | "cell_type": "code",
227 | "execution_count": 7,
228 | "metadata": {
229 | "pycharm": {
230 | "is_executing": false,
231 | "name": "#%%\n"
232 | }
233 | },
234 | "outputs": [
235 | {
236 | "data": {
237 | "text/plain": "[{'sequence': 'Er staat een boomin mijn tuin.',\n 'score': 0.1416003555059433,\n 'token': 2600,\n 'token_str': 'Ġboom'},\n {'sequence': 'Er staat een vijverin mijn tuin.',\n 'score': 0.13144515454769135,\n 'token': 8217,\n 'token_str': 'Ġvijver'},\n {'sequence': 'Er staat een plantin mijn tuin.',\n 'score': 0.043418534100055695,\n 'token': 2721,\n 'token_str': 'Ġplant'},\n {'sequence': 'Er staat een huisin mijn tuin.',\n 'score': 0.01847737282514572,\n 'token': 251,\n 'token_str': 'Ġhuis'},\n {'sequence': 'Er staat een parkeerplaatsin mijn tuin.',\n 'score': 0.018001794815063477,\n 'token': 6889,\n 'token_str': 'Ġparkeerplaats'}]"
238 | },
239 | "metadata": {},
240 | "output_type": "execute_result",
241 | "execution_count": 7
242 | }
243 | ],
244 | "source": [
245 | "p(sequence)"
246 | ]
247 | },
248 | {
249 | "cell_type": "markdown",
250 | "source": [
251 | "That's it for this demo of the MLM head. If you use RobBERT in your academic work, you can cite it!\n",
252 | "\n",
253 | "\n",
254 | "```\n",
255 | "@misc{delobelle2020robbert,\n",
256 | " title={{R}ob{BERT}: a {D}utch {R}o{BERT}a-based Language Model},\n",
257 | " author={Pieter Delobelle and Thomas Winters and Bettina Berendt},\n",
258 | " year={2020},\n",
259 | " eprint={2001.06286},\n",
260 | " archivePrefix={arXiv},\n",
261 | " primaryClass={cs.CL}\n",
262 | "}\n",
263 | "```\n"
264 | ],
265 | "metadata": {
266 | "collapsed": false,
267 | "pycharm": {
268 | "name": "#%% md\n"
269 | }
270 | }
271 | }
272 | ],
273 | "metadata": {
274 | "kernelspec": {
275 | "display_name": "Python 3",
276 | "language": "python",
277 | "name": "python3"
278 | },
279 | "language_info": {
280 | "codemirror_mode": {
281 | "name": "ipython",
282 | "version": 3
283 | },
284 | "file_extension": ".py",
285 | "mimetype": "text/x-python",
286 | "name": "python",
287 | "nbconvert_exporter": "python",
288 | "pygments_lexer": "ipython3",
289 | "version": "3.7.4"
290 | },
291 | "pycharm": {
292 | "stem_cell": {
293 | "cell_type": "raw",
294 | "source": [],
295 | "metadata": {
296 | "collapsed": false
297 | }
298 | }
299 | }
300 | },
301 | "nbformat": 4,
302 | "nbformat_minor": 1
303 | }
--------------------------------------------------------------------------------
/notebooks/die_dat_demo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import nltk\n",
10 | "from nltk.tokenize.treebank import TreebankWordDetokenizer\n",
11 | "from fairseq.models.roberta import RobertaModel"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": 2,
17 | "metadata": {},
18 | "outputs": [
19 | {
20 | "name": "stdout",
21 | "output_type": "stream",
22 | "text": [
23 | "loading archive file ../\n",
24 | "| [input] dictionary: 50265 types\n",
25 | "| [label] dictionary: 9 types\n"
26 | ]
27 | },
28 | {
29 | "data": {
30 | "text/plain": [
31 | "RobertaHubInterface(\n",
32 | " (model): RobertaModel(\n",
33 | " (decoder): RobertaEncoder(\n",
34 | " (sentence_encoder): TransformerSentenceEncoder(\n",
35 | " (embed_tokens): Embedding(50265, 768, padding_idx=1)\n",
36 | " (embed_positions): LearnedPositionalEmbedding(514, 768, padding_idx=1)\n",
37 | " (layers): ModuleList(\n",
38 | " (0): TransformerSentenceEncoderLayer(\n",
39 | " (self_attn): MultiheadAttention(\n",
40 | " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n",
41 | " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n",
42 | " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n",
43 | " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
44 | " )\n",
45 | " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
46 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
47 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
48 | " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
49 | " )\n",
50 | " (1): TransformerSentenceEncoderLayer(\n",
51 | " (self_attn): MultiheadAttention(\n",
52 | " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n",
53 | " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n",
54 | " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n",
55 | " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
56 | " )\n",
57 | " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
58 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
59 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
60 | " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
61 | " )\n",
62 | " (2): TransformerSentenceEncoderLayer(\n",
63 | " (self_attn): MultiheadAttention(\n",
64 | " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n",
65 | " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n",
66 | " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n",
67 | " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
68 | " )\n",
69 | " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
70 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
71 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
72 | " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
73 | " )\n",
74 | " (3): TransformerSentenceEncoderLayer(\n",
75 | " (self_attn): MultiheadAttention(\n",
76 | " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n",
77 | " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n",
78 | " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n",
79 | " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
80 | " )\n",
81 | " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
82 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
83 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
84 | " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
85 | " )\n",
86 | " (4): TransformerSentenceEncoderLayer(\n",
87 | " (self_attn): MultiheadAttention(\n",
88 | " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n",
89 | " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n",
90 | " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n",
91 | " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
92 | " )\n",
93 | " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
94 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
95 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
96 | " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
97 | " )\n",
98 | " (5): TransformerSentenceEncoderLayer(\n",
99 | " (self_attn): MultiheadAttention(\n",
100 | " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n",
101 | " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n",
102 | " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n",
103 | " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
104 | " )\n",
105 | " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
106 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
107 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
108 | " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
109 | " )\n",
110 | " (6): TransformerSentenceEncoderLayer(\n",
111 | " (self_attn): MultiheadAttention(\n",
112 | " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n",
113 | " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n",
114 | " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n",
115 | " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
116 | " )\n",
117 | " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
118 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
119 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
120 | " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
121 | " )\n",
122 | " (7): TransformerSentenceEncoderLayer(\n",
123 | " (self_attn): MultiheadAttention(\n",
124 | " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n",
125 | " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n",
126 | " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n",
127 | " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
128 | " )\n",
129 | " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
130 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
131 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
132 | " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
133 | " )\n",
134 | " (8): TransformerSentenceEncoderLayer(\n",
135 | " (self_attn): MultiheadAttention(\n",
136 | " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n",
137 | " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n",
138 | " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n",
139 | " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
140 | " )\n",
141 | " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
142 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
143 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
144 | " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
145 | " )\n",
146 | " (9): TransformerSentenceEncoderLayer(\n",
147 | " (self_attn): MultiheadAttention(\n",
148 | " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n",
149 | " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n",
150 | " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n",
151 | " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
152 | " )\n",
153 | " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
154 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
155 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
156 | " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
157 | " )\n",
158 | " (10): TransformerSentenceEncoderLayer(\n",
159 | " (self_attn): MultiheadAttention(\n",
160 | " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n",
161 | " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n",
162 | " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n",
163 | " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
164 | " )\n",
165 | " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
166 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
167 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
168 | " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
169 | " )\n",
170 | " (11): TransformerSentenceEncoderLayer(\n",
171 | " (self_attn): MultiheadAttention(\n",
172 | " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n",
173 | " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n",
174 | " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n",
175 | " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
176 | " )\n",
177 | " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
178 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
179 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
180 | " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
181 | " )\n",
182 | " )\n",
183 | " (emb_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
184 | " )\n",
185 | " (lm_head): RobertaLMHead(\n",
186 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
187 | " (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
188 | " )\n",
189 | " )\n",
190 | " (classification_heads): ModuleDict(\n",
191 | " (sentence_classification_head): RobertaClassificationHead(\n",
192 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
193 | " (dropout): Dropout(p=0.0, inplace=False)\n",
194 | " (out_proj): Linear(in_features=768, out_features=2, bias=True)\n",
195 | " )\n",
196 | " )\n",
197 | " )\n",
198 | ")"
199 | ]
200 | },
201 | "execution_count": 2,
202 | "metadata": {},
203 | "output_type": "execute_result"
204 | }
205 | ],
206 | "source": [
207 | "roberta = RobertaModel.from_pretrained(\n",
208 | " '../',\n",
209 | " checkpoint_file='checkpoints/checkpoint_best.pt',\n",
210 | " data_name_or_path=\"./data\"\n",
211 | ")\n",
212 | "\n",
213 | "roberta.eval() # disable dropout"
214 | ]
215 | },
216 | {
217 | "cell_type": "code",
218 | "execution_count": 3,
219 | "metadata": {},
220 | "outputs": [],
221 | "source": [
222 | "def replace_query_token(sentence):\n",
223 | " \"Small utility function to replace a sentence with `_die_` or `_dat_` with the proper RobBERT input.\"\n",
224 | " tokens = nltk.word_tokenize(sentence)\n",
225 | " tokens_swapped = nltk.word_tokenize(sentence)\n",
226 | " for i, word in enumerate(tokens):\n",
227 | " if word == \"_die_\":\n",
228 | " tokens[i] = \"die\"\n",
229 | " tokens_swapped[i] = \"dat\"\n",
230 | "\n",
231 | " elif word == \"_dat_\":\n",
232 | " tokens[i] = \"dat\"\n",
233 | " tokens_swapped[i] = \"die\"\n",
234 | "\n",
235 | " elif word == \"_Dat_\":\n",
236 | " tokens[i] = \"Dat\"\n",
237 | " tokens_swapped[i] = \"Die\"\n",
238 | "\n",
239 | " elif word == \"_Die_\":\n",
240 | " tokens[i] = \"Die\"\n",
241 | " tokens_swapped[i] = \"Dat\"\n",
242 | "\n",
243 | "\n",
244 | " if word.lower() == \"_die_\" or word.lower() == \"_dat_\":\n",
245 | " results = TreebankWordDetokenizer().detokenize(tokens)\n",
246 | " results_swapped = TreebankWordDetokenizer().detokenize(tokens_swapped)\n",
247 | "\n",
248 | " return \"{} {}\".format(results, results_swapped) \n"
249 | ]
250 | },
251 | {
252 | "cell_type": "code",
253 | "execution_count": 4,
254 | "metadata": {},
255 | "outputs": [
256 | {
257 | "name": "stdout",
258 | "output_type": "stream",
259 | "text": [
260 | "Correct sentence: True\n"
261 | ]
262 | }
263 | ],
264 | "source": [
265 | "sentence = \"Vervolgens zullen we de gebruikelijke procedure volgen, _dat_ wil zeggen dat we een voorstander en een tegenstander van dit verzoek het woord zullen geven.\"\n",
266 | "\n",
267 | "tokens = roberta.encode(replace_query_token(sentence))\n",
268 | "\n",
269 | "print(\"Correct sentence: \", roberta.predict('sentence_classification_head', tokens).argmax().item() == 1)"
270 | ]
271 | },
272 | {
273 | "cell_type": "code",
274 | "execution_count": 5,
275 | "metadata": {},
276 | "outputs": [
277 | {
278 | "name": "stdout",
279 | "output_type": "stream",
280 | "text": [
281 | "Correct sentence: False\n"
282 | ]
283 | }
284 | ],
285 | "source": [
286 | "sentence = \"Vervolgens zullen we de gebruikelijke procedure volgen, _die_ wil zeggen dat we een voorstander en een tegenstander van dit verzoek het woord zullen geven.\"\n",
287 | "\n",
288 | "tokens = roberta.encode(replace_query_token(sentence))\n",
289 | "\n",
290 | "print(\"Correct sentence: \", roberta.predict('sentence_classification_head', tokens).argmax().item() == 1)"
291 | ]
292 | },
293 | {
294 | "cell_type": "code",
295 | "execution_count": 6,
296 | "metadata": {},
297 | "outputs": [
298 | {
299 | "name": "stdout",
300 | "output_type": "stream",
301 | "text": [
302 | "Correct sentence: False\n"
303 | ]
304 | }
305 | ],
306 | "source": [
307 | "sentence = \"Daar loopt _die_ meisje.\"\n",
308 | "tokens = roberta.encode(replace_query_token(sentence))\n",
309 | "\n",
310 | "#print(roberta.predict('sentence_classification_head', tokens))\n",
311 | "print(\"Correct sentence: \", roberta.predict('sentence_classification_head', tokens).argmax().item() == 1)"
312 | ]
313 | }
314 | ],
315 | "metadata": {
316 | "kernelspec": {
317 | "display_name": "Python 3",
318 | "language": "python",
319 | "name": "python3"
320 | },
321 | "language_info": {
322 | "codemirror_mode": {
323 | "name": "ipython",
324 | "version": 3
325 | },
326 | "file_extension": ".py",
327 | "mimetype": "text/x-python",
328 | "name": "python",
329 | "nbconvert_exporter": "python",
330 | "pygments_lexer": "ipython3",
331 | "version": "3.7.4"
332 | }
333 | },
334 | "nbformat": 4,
335 | "nbformat_minor": 4
336 | }
--------------------------------------------------------------------------------
/notebooks/evaluate_zeroshot_wordlists_v2.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": true,
8 | "jupyter": {
9 | "outputs_hidden": true
10 | }
11 | },
12 | "outputs": [],
13 | "source": [
14 | "from pathlib import Path\n",
15 | "import requests\n",
16 | "from fairseq.models.roberta import RobertaModel\n",
17 | "from src import evaluate_zeroshot_wordlist"
18 | ]
19 | },
20 | {
21 | "cell_type": "code",
22 | "execution_count": 2,
23 | "metadata": {
24 | "collapsed": false,
25 | "jupyter": {
26 | "outputs_hidden": false
27 | },
28 | "pycharm": {
29 | "name": "#%%\n"
30 | }
31 | },
32 | "outputs": [
33 | {
34 | "name": "stdout",
35 | "output_type": "stream",
36 | "text": [
37 | "loading archive file ../models/robbert.v2\n",
38 | "| dictionary: 39984 types\n",
39 | "RobBERT is loaded\n"
40 | ]
41 | }
42 | ],
43 | "source": [
44 | "robbert = RobertaModel.from_pretrained(\n",
45 | " '../models/robbert.v2',\n",
46 | " checkpoint_file='model.pt',\n",
47 | " gpt2_encoder_json='../models/robbert.v2/vocab.json',\n",
48 | " gpt2_vocab_bpe='../models/robbert.v2/merges.txt',\n",
49 | ")\n",
50 | "robbert.eval()\n",
51 | "print(\"RobBERT is loaded\")"
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": 3,
57 | "metadata": {
58 | "collapsed": false,
59 | "jupyter": {
60 | "outputs_hidden": false
61 | },
62 | "pycharm": {
63 | "name": "#%%\n"
64 | }
65 | },
66 | "outputs": [],
67 | "source": [
68 | "path = Path(\"..\", \"data\", \"processed\", \"wordlist\",\"die-dat.test.tsv\")"
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "execution_count": 4,
74 | "metadata": {
75 | "collapsed": false,
76 | "jupyter": {
77 | "outputs_hidden": false
78 | },
79 | "pycharm": {
80 | "name": "#%%\n"
81 | }
82 | },
83 | "outputs": [
84 | {
85 | "name": "stdout",
86 | "output_type": "stream",
87 | "text": [
88 | "98.51% / 9851 / 10000 / 0 errors / die / die / Het is ook wegens de verwoestingen - afhankelijkheid, armoede, gezondheidsproblemen - drugsverslaving bij onze jongeren aanricht, dat het mij verheugt vast te stellen dat het Nederlandse voorzitterschap van de bestrijding van de internationale criminaliteit en de drugshandel een van zijn prioriteiten heeft gemaakt . Praktische samenwerking tussen de politiediensten, douanediensten en justitiële diensten om de drugshandel actief te bestrijden?\n",
89 | "98.61% / 19723 / 20000 / 0 errors / dat / dat / We zullen erop toezien dat wij Noorwegen en IJsland nimmer voor verrassingen of voor voldongen feiten stellen op punt.\n",
90 | "98.62% / 29585 / 30000 / 0 errors / dat / dat / Ook hoop ik dat Europa's industriëlen hiernaar luisteren en Total en Club Med in de voetsporen zullen treden van Heineken, Carlsberg, PepsiCo, Levi, Apple en Thomas Cook.\n",
91 | "98.62% / 39450 / 40000 / 0 errors / dat / dat / Ik wijs er in dit verband op een van de voorstellen die de Commissie aan de Intergouvernementele Conferentie heeft gedaan, de samensmelting betreft van de drie bestaande Gemeenschappen tot één enkele, wat zal leiden tot grotere duidelijkheid op dit gebied.\n",
92 | "98.63% / 49315 / 50000 / 0 errors / die / die / Het is zaak eigen Europese minimumnormen vast te stellen alleen een voldoende bescherming van passagiers, zeelieden en milieu zullen garanderen.\n",
93 | "98.65% / 59188 / 60000 / 0 errors / dat / dat / Zonder deze hervormingen zal de Unie verwateren . Indien de Unie wordt beroofd van sterke en coherente structuren vrees ik ze eens zal bezwijken onder de historische opdracht van de uitbreiding.\n",
94 | "98.61% / 69025 / 70000 / 0 errors / dat / dat / Uit het bovenstaande blijkt de begrotingsperikelen waarmee wij als Parlement kampen grotendeels het gevolg zijn van beslissingen die de Raad zelf heeft genomen.\n",
95 | "98.61% / 78891 / 80000 / 0 errors / dat / dat / Mijnheer de Voorzitter, ik wil hier alleen aangeven mijn naam niet op de presentielijst van gisteren voorkomt, vermoedelijk vanwege nalatigheid mijnerzijds, aangezien ik vergeten heb te tekenen.\n",
96 | "Error with B4-0825/97 van de leden André-Léonard, Fassa en Bertens, namens de ELDR-Fractie, over de terugtrekking van het VN-onderzoeksteam uit Congo; -B4-0832/97 van de leden Aelvoet en Telkämper, namens de V-Fractie, over de VN-onderzoeksreis naar Congo; -B4-0850/97 van de leden Dury en Swoboda, namens de PSE-Fractie, over de weigering om een onderzoekscommissie van de Verenigde Naties tot de Democratische Republiek Congo toe te laten; -B4-0856/97 van de leden Hory, Dell'Alba en Dupuis, namens de ARE-Fractie, over de onderzoekscommissie van de Verenigde Naties inzake de mensenrechten in de Democratische Republiek Congo; -B4-0863/97 van de leden Chanterie, Stasi, Tindemans, Verwaerde, Maij-Weggen en Oomen-Ruijten, namens de PPE-Fractie, over de situatie in de Democratische Republiek Congo; -B4-0877/97 van de leden Pettinari, Carnero González, Ojala en Sjöstedt, namens de GUE/NGL-Fractie, over de internationale onderzoeksdelegatie van de Verenigde Naties voor de schendingen van de rechten van de mens in het voormalige Zaïre; -B4-0890/97 van de leden Pasty en Azzolini, namens de UPE-Fractie, over de toestand in de Democratische Republiek Congo; -B4-0830/97 van de leden Bertens en Larive, namens de ELDR-Fractie, over het standpunt van de Europese Unie over de bevordering van de rechten van de mens in China; -B4-0847/97 van de heer Swoboda, namens de PSE-Fractie, over de bevordering van de mensenrechten in China; -B4-0855/97 van de leden Dupuis, Dell'Alba en Hory, namens de ARE-Fractie, over het standpunt van de Europese Unie inzake de bevordering van de mensenrechten in China; -B4-0862/97 van de leden McMillan-Scott en Habsburg-Lothringen, namens de PPE-Fractie, over het standpunt van de Europese Unie inzake de bevordering van de mensenrechten in China; -B4-0872/97 van de leden Aglietta en Schroedter, namens de V-Fractie, over de bevordering van de rechten van de mens in China; -B4-0828/97 van de leden Cars en La Malfa, namens de ELDR-Fractie, over de situatie in Kosovo; -B4-0837/97 van de leden Aelvoet, Cohn-Bendit, Gahrton, Müller en Tamino, namens de V-Fractie, over de situatie in Kosovo; -B4-0848/97 van de heer Swoboda, namens de PSE-Fractie, over de situatie in Kosovo; -B4-0854/97 van de leden Dupuis en Dell'Alba, namens de ARE-Fractie, over de situatie in Kosovo; -B4-0865/97 van de leden Oostlander, Pack, Habsburg-Lothringen, Maij-Weggen, Posselt en Oomen-Ruijten, namens de PPE-Fractie, over de situatie in Kosovo; -B4-0878/97 van de leden Manisco, Sjöstedt, Sierra González en Mohammed Alí, namens de GUE/NGLFractie, over de schendingen van de mensenrechten in Kosovo; -B4-0858/97 van de heer Pradier, namens de ARE-Fractie, over de omstandigheden in de detentie-inrichting van Khiam; -B4-0864/97 van de leden Soulier en Peijs, namens de PPE-Fractie, over de situatie van Souha Bechara, in hechtenis wordt gehouden in Zuid-Libanon; -B4-0879/97 van de leden Wurtz, Castellina, Marset Campos, Miranda, Ephremidis, Alavanos en Seppänen, namens de GUE/NGL-Fractie, over de vrijlating van Souha Bechara; -B4-0849/97 van de leden Hoff, Wiersma, Bösch en Swoboda, namens de PSE-Fractie, over de politieke situatie in Slowakije; -B4-0827/97 van de leden André-Léonard, Fassa, Bertens en Nordmann, namens de ELDR-Fractie, over Algerije.\t0\n",
97 | "\n",
98 | "98.61% / 88753 / 90000 / 1 errors / dat / dat / Met onze amendementen trachten wij ertoe bij te dragen dat het Fonds zo goed mogelijk wordt gecoördineerd, het additioneel en ter vergroting van het eigen potentieel wordt aangewend en dat de publieke opinie er kennis van neemt en het op zijn waarde kan schatten.\n",
99 | "98.63% / 98628 / 100000 / 1 errors / dat / dat / Ik verwacht van de Commissie zij daartegen optreedt.\n",
100 | "98.61% / 108475 / 110000 / 1 errors / dat / dat / Mijnheer de Voorzitter, dames en heren, het is goed het Eurodac-systeem de mogelijkheid biedt de misbruiken tegen te gaan die om de meest uiteenlopende redenen van het asielrecht worden gemaakt.\n",
101 | "98.62% / 118344 / 120000 / 1 errors / dat / dat / Zoals een van mijn kiezers uit Swindon vroeg:``De ruimte voor gekissebis over wat een``miniem gebrek aan gelijkvormigheid\"voorstelt, is grenzeloos en consumenten zonder scrupules worden er wellicht toe aangemoedigd om de onzekerheid uit te buiten die door deze voorstellen wordt gecreëerd, wat ertoe zal leiden de rechtbanken een overvloed aan zaken te verwerken zullen krijgen.\n",
102 | "98.63% / 128214 / 130000 / 1 errors / dat / dat / Mijnheer de Voorzitter, ik was zeer geïnteresseerd in uw antwoord aan de heer McMahon, waarin u zei , als de regels of de procedures zouden worden gewijzigd, u ze vanzelfsprekend zou verwijzen naar de Commissie Reglement, onderzoek geloofsbrieven en immuniteiten.\n",
103 | "98.62% / 138072 / 140000 / 1 errors / dat / dat / Wat betekent anders de lidstaten zelf kunnen beslissen welke grenswaarden zij vaststellen?\n",
104 | "98.64% / 147953 / 150000 / 1 errors / die / die / Wat betreft de overheidssteun aan de landbouw die nog altijd noodzakelijk is, zijn wij van mening dat op de lange termijn slechts door de Europese belastingbetaler en door onze concurrenten binnen de WHO zal worden geaccepteerd, als rekening wordt gehouden met gerechtvaardigde sociale, ecologische en streekgebonden zorgen.\n",
105 | "98.65% / 157843 / 160000 / 1 errors / die / die / Als we dus Europese verkiezingen gaan steunen de nationaliteitsgrenzen overschrijden, keren we onze rug naar nationale vertegenwoordiging.\n",
106 | "98.65% / 167706 / 170000 / 1 errors / die / die / De EU is via verschillende lidstaten en via het voorzitterschap actief betrokken bij het werk van de contactgroep, een document met opties voor de toekomstige status van Kosovo heeft uitgewerkt en dit heeft overhandigd aan de strijdende partijen.\n",
107 | "98.67% / 177605 / 180000 / 1 errors / dat / dat / Volgens mijn berekening hebt u uw spreektijd met 20% overschreden . U bent echter zo onderhoudend dat wij niet erg vinden.\n",
108 | "98.68% / 187485 / 190000 / 1 errors / dat / dat / Het tweede probleem mij voor de kwijting voor 1996 relevant lijkt, is de vraag hoe de Commissie omgaat met disciplinaire procedures.\n",
109 | "98.69% / 197374 / 200000 / 1 errors / dat / dat / Zij heeft er in dit verband op gewezen amendementen die onverenigbaar zijn met het voorstel van de bevoegde commissie niet ontvankelijk zijn.\n",
110 | "Error with A4-0437/98 van de heer Elchlepp, over het voorstel voor een besluit van de Raad en de Commissie inzake het door de Gemeenschap in te nemen standpunt in de Associatieraad werd ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en de Republiek Litouwen, anderzijds, ten aanzien van de vaststelling van uitvoeringsbepalingen voor artikel 64, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst (4216/98 - COM (98) 0119 - C4-0592/98-98/0075 (CNS) ); -A4-0443/98 van de heer Seppänen, over het voorstel voor een besluit van de Raad en de Commissie inzake het door de Gemeenschap in te nemen standpunt in de Associatieraad die werd ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en de Republiek Letland, anderzijds, ten aanzien van de vaststelling van uitvoeringsbepalingen voor artikel 64, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst (4215/98 - COM (98) 0068 - C4-0593/98-98/0076 (CNS) ); -A4-0472/98 van de heer van Dam, over het voorstel voor een besluit van de Raad en de Commissie inzake het door de Gemeenschap in te nemen standpunt in de Associatieraad die werd ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en de Republiek Estland, anderzijds, ten aanzien van de vaststelling van uitvoeringsbepalingen voor artikel 63, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst (4214/98 - COM (98) 0118 - C4-0594/98-98/0077 (CNS) ) en-A4-0419/98 van de heer Schwaiger, over het voorstel voor een besluit van de Raad en de Commissie betreffende het standpunt dat de Gemeenschap zal innemen in de Associatieraad die is ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en Roemenië, anderzijds, die op 1 februari 1993 in Brussel werd ondertekend, inzake de vaststelling van de voorschriften voor de tenuitvoerlegging van artikel 64, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst en voor de tenuitvoerlegging van artikel 9, lid 1, sub 1) en 2), van Protocol 2 betreffende EGKS-producten bij de Europa-Overeenkomst (COM (98) 0236 - C4-0275/98-98/0139 (CNS) ).\t0\n",
111 | "\n",
112 | "Error with A4-0437/98 van de heer Elchlepp, over het voorstel voor een besluit van de Raad en de Commissie inzake het door de Gemeenschap in te nemen standpunt in de Associatieraad die werd ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en de Republiek Litouwen, anderzijds, ten aanzien van de vaststelling van uitvoeringsbepalingen voor artikel 64, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst (4216/98 - COM (98) 0119 - C4-0592/98-98/0075 (CNS) ); -A4-0443/98 van de heer Seppänen, over het voorstel voor een besluit van de Raad en de Commissie inzake het door de Gemeenschap in te nemen standpunt in de Associatieraad werd ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en de Republiek Letland, anderzijds, ten aanzien van de vaststelling van uitvoeringsbepalingen voor artikel 64, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst (4215/98 - COM (98) 0068 - C4-0593/98-98/0076 (CNS) ); -A4-0472/98 van de heer van Dam, over het voorstel voor een besluit van de Raad en de Commissie inzake het door de Gemeenschap in te nemen standpunt in de Associatieraad die werd ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en de Republiek Estland, anderzijds, ten aanzien van de vaststelling van uitvoeringsbepalingen voor artikel 63, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst (4214/98 - COM (98) 0118 - C4-0594/98-98/0077 (CNS) ) en-A4-0419/98 van de heer Schwaiger, over het voorstel voor een besluit van de Raad en de Commissie betreffende het standpunt dat de Gemeenschap zal innemen in de Associatieraad die is ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en Roemenië, anderzijds, die op 1 februari 1993 in Brussel werd ondertekend, inzake de vaststelling van de voorschriften voor de tenuitvoerlegging van artikel 64, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst en voor de tenuitvoerlegging van artikel 9, lid 1, sub 1) en 2), van Protocol 2 betreffende EGKS-producten bij de Europa-Overeenkomst (COM (98) 0236 - C4-0275/98-98/0139 (CNS) ).\t0\n",
113 | "\n",
114 | "Error with A4-0437/98 van de heer Elchlepp, over het voorstel voor een besluit van de Raad en de Commissie inzake het door de Gemeenschap in te nemen standpunt in de Associatieraad die werd ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en de Republiek Litouwen, anderzijds, ten aanzien van de vaststelling van uitvoeringsbepalingen voor artikel 64, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst (4216/98 - COM (98) 0119 - C4-0592/98-98/0075 (CNS) ); -A4-0443/98 van de heer Seppänen, over het voorstel voor een besluit van de Raad en de Commissie inzake het door de Gemeenschap in te nemen standpunt in de Associatieraad die werd ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en de Republiek Letland, anderzijds, ten aanzien van de vaststelling van uitvoeringsbepalingen voor artikel 64, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst (4215/98 - COM (98) 0068 - C4-0593/98-98/0076 (CNS) ); -A4-0472/98 van de heer van Dam, over het voorstel voor een besluit van de Raad en de Commissie inzake het door de Gemeenschap in te nemen standpunt in de Associatieraad werd ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en de Republiek Estland, anderzijds, ten aanzien van de vaststelling van uitvoeringsbepalingen voor artikel 63, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst (4214/98 - COM (98) 0118 - C4-0594/98-98/0077 (CNS) ) en-A4-0419/98 van de heer Schwaiger, over het voorstel voor een besluit van de Raad en de Commissie betreffende het standpunt dat de Gemeenschap zal innemen in de Associatieraad die is ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en Roemenië, anderzijds, die op 1 februari 1993 in Brussel werd ondertekend, inzake de vaststelling van de voorschriften voor de tenuitvoerlegging van artikel 64, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst en voor de tenuitvoerlegging van artikel 9, lid 1, sub 1) en 2), van Protocol 2 betreffende EGKS-producten bij de Europa-Overeenkomst (COM (98) 0236 - C4-0275/98-98/0139 (CNS) ).\t0\n",
115 | "\n",
116 | "Error with A4-0437/98 van de heer Elchlepp, over het voorstel voor een besluit van de Raad en de Commissie inzake het door de Gemeenschap in te nemen standpunt in de Associatieraad die werd ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en de Republiek Litouwen, anderzijds, ten aanzien van de vaststelling van uitvoeringsbepalingen voor artikel 64, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst (4216/98 - COM (98) 0119 - C4-0592/98-98/0075 (CNS) ); -A4-0443/98 van de heer Seppänen, over het voorstel voor een besluit van de Raad en de Commissie inzake het door de Gemeenschap in te nemen standpunt in de Associatieraad die werd ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en de Republiek Letland, anderzijds, ten aanzien van de vaststelling van uitvoeringsbepalingen voor artikel 64, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst (4215/98 - COM (98) 0068 - C4-0593/98-98/0076 (CNS) ); -A4-0472/98 van de heer van Dam, over het voorstel voor een besluit van de Raad en de Commissie inzake het door de Gemeenschap in te nemen standpunt in de Associatieraad die werd ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en de Republiek Estland, anderzijds, ten aanzien van de vaststelling van uitvoeringsbepalingen voor artikel 63, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst (4214/98 - COM (98) 0118 - C4-0594/98-98/0077 (CNS) ) en-A4-0419/98 van de heer Schwaiger, over het voorstel voor een besluit van de Raad en de Commissie betreffende het standpunt de Gemeenschap zal innemen in de Associatieraad die is ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en Roemenië, anderzijds, die op 1 februari 1993 in Brussel werd ondertekend, inzake de vaststelling van de voorschriften voor de tenuitvoerlegging van artikel 64, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst en voor de tenuitvoerlegging van artikel 9, lid 1, sub 1) en 2), van Protocol 2 betreffende EGKS-producten bij de Europa-Overeenkomst (COM (98) 0236 - C4-0275/98-98/0139 (CNS) ).\t1\n",
117 | "\n",
118 | "Error with A4-0437/98 van de heer Elchlepp, over het voorstel voor een besluit van de Raad en de Commissie inzake het door de Gemeenschap in te nemen standpunt in de Associatieraad die werd ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en de Republiek Litouwen, anderzijds, ten aanzien van de vaststelling van uitvoeringsbepalingen voor artikel 64, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst (4216/98 - COM (98) 0119 - C4-0592/98-98/0075 (CNS) ); -A4-0443/98 van de heer Seppänen, over het voorstel voor een besluit van de Raad en de Commissie inzake het door de Gemeenschap in te nemen standpunt in de Associatieraad die werd ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en de Republiek Letland, anderzijds, ten aanzien van de vaststelling van uitvoeringsbepalingen voor artikel 64, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst (4215/98 - COM (98) 0068 - C4-0593/98-98/0076 (CNS) ); -A4-0472/98 van de heer van Dam, over het voorstel voor een besluit van de Raad en de Commissie inzake het door de Gemeenschap in te nemen standpunt in de Associatieraad die werd ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en de Republiek Estland, anderzijds, ten aanzien van de vaststelling van uitvoeringsbepalingen voor artikel 63, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst (4214/98 - COM (98) 0118 - C4-0594/98-98/0077 (CNS) ) en-A4-0419/98 van de heer Schwaiger, over het voorstel voor een besluit van de Raad en de Commissie betreffende het standpunt dat de Gemeenschap zal innemen in de Associatieraad is ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en Roemenië, anderzijds, die op 1 februari 1993 in Brussel werd ondertekend, inzake de vaststelling van de voorschriften voor de tenuitvoerlegging van artikel 64, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst en voor de tenuitvoerlegging van artikel 9, lid 1, sub 1) en 2), van Protocol 2 betreffende EGKS-producten bij de Europa-Overeenkomst (COM (98) 0236 - C4-0275/98-98/0139 (CNS) ).\t0\n",
119 | "\n",
120 | "Error with A4-0437/98 van de heer Elchlepp, over het voorstel voor een besluit van de Raad en de Commissie inzake het door de Gemeenschap in te nemen standpunt in de Associatieraad die werd ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en de Republiek Litouwen, anderzijds, ten aanzien van de vaststelling van uitvoeringsbepalingen voor artikel 64, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst (4216/98 - COM (98) 0119 - C4-0592/98-98/0075 (CNS) ); -A4-0443/98 van de heer Seppänen, over het voorstel voor een besluit van de Raad en de Commissie inzake het door de Gemeenschap in te nemen standpunt in de Associatieraad die werd ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en de Republiek Letland, anderzijds, ten aanzien van de vaststelling van uitvoeringsbepalingen voor artikel 64, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst (4215/98 - COM (98) 0068 - C4-0593/98-98/0076 (CNS) ); -A4-0472/98 van de heer van Dam, over het voorstel voor een besluit van de Raad en de Commissie inzake het door de Gemeenschap in te nemen standpunt in de Associatieraad die werd ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en de Republiek Estland, anderzijds, ten aanzien van de vaststelling van uitvoeringsbepalingen voor artikel 63, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst (4214/98 - COM (98) 0118 - C4-0594/98-98/0077 (CNS) ) en-A4-0419/98 van de heer Schwaiger, over het voorstel voor een besluit van de Raad en de Commissie betreffende het standpunt dat de Gemeenschap zal innemen in de Associatieraad die is ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en Roemenië, anderzijds, op 1 februari 1993 in Brussel werd ondertekend, inzake de vaststelling van de voorschriften voor de tenuitvoerlegging van artikel 64, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst en voor de tenuitvoerlegging van artikel 9, lid 1, sub 1) en 2), van Protocol 2 betreffende EGKS-producten bij de Europa-Overeenkomst (COM (98) 0236 - C4-0275/98-98/0139 (CNS) ).\t0\n",
121 | "\n",
122 | "98.69% / 207253 / 210000 / 7 errors / dat / dat / Mevrouw de Voorzitter, dames en heren, allereerst wil ik u hartelijk danken voor het vriendelijke onthaal en voor het feit ik er in het Europees Parlement voor het eerst officieel als fungerend voorzitter van de Raad van ministers van Landbouw en Visserij bij mag zijn en het woord tot u mag richten.\n",
123 | "98.70% / 217134 / 220000 / 7 errors / die / die / Ik kan daarom geen amendementen overnemen dat op de helling zetten.\n",
124 | "98.70% / 227010 / 230000 / 7 errors / die / die / Strafheffingen de EU-landen echter verschillend treffen, beïnvloeden de comparatieve kosten- en concurrentiepositie op de interne markt.\n",
125 | "98.70% / 236890 / 240000 / 7 errors / die / die / De doelstellingen zijn: ten eerste het garanderen van het recht op regelmatige informatie en raadpleging van werknemers over de economische en strategische ontwikkelingen van de onderneming en over de voor hen belangrijke beslissingen; ten tweede de versterking van de sociale dialoog en het wederzijds vertrouwen binnen de onderneming, om bij te kunnen dragen aan het voorspellen van eventuele risico's, de ontwikkeling van de flexibiliteit in de arbeidsorganisatie in een zeker kader, de bewustmaking van de werknemers met betrekking tot de noodzaak van aanpassing en aan hun bereidheid om actief mee te werken aan maatregelen en acties ter verhoging van hun vakbekwaamheid en inzetbaarheid; ten derde de opneming van het werkgelegenheidsvraagstuk - de toestand daarvan en de te verwachten ontwikkeling - in de informatie- en raadplegingsprocedures; ten vierde de verzekering van voorafgaande informatie en raadpleging van de werknemers bij beslissingen kunnen leiden tot ingrijpende veranderingen in de arbeidsorganisatie of de arbeidsovereenkomsten en ten vijfde de verbetering van de doeltreffendheid van deze procedure via de vaststelling van specifieke sancties voor ernstige schendingen van de opgelegde verplichtingen.\n",
126 | "98.71% / 246769 / 250000 / 7 errors / die / die / Geen woord van veroordeling of zelfs maar spijt voor de ontoelaatbare militaire agressie tegen een Europese staat, door een NAVO blindelings vrouwen en kinderen doodt.\n",
127 | "98.72% / 256661 / 260000 / 7 errors / die / die / Met veel genoegen heb ik deelgenomen aan deze belangrijke democratische gebeurtenis, in de Europese politiek zijn weerga niet kent.\n",
128 | "98.73% / 266559 / 270000 / 7 errors / die / die / Maar de inwerkingtreding van dit Verdrag in de VS zou regionale kernmachten onder enorme druk hebben gezet om toe te treden tot het Verdrag . Denk maar aan India en Pakistan, bezig zijn hun eigen regionale afschrikkingsevenwicht op te bouwen, of Noord-Korea, Iran, Irak of Israël.\n",
129 | "98.73% / 276458 / 280000 / 7 errors / dat / dat / Maar het geeft hoop dat bad practices op basis van goede gegevens vergeleken kunnen worden en een humaan en effectief drugsbeleid ook in een geïntegreerde Europese Unie mogelijk worden.\n",
130 | "Die/Dat: 0.9874826436379628 285184 288799 7\n"
131 | ]
132 | }
133 | ],
134 | "source": [
135 | "die_dat_correct, die_dat_total, die_dat_errors = evaluate_zeroshot_wordlist.evaluate(\n",
136 | " [\"die\", \"dat\"], path=path, model=robbert, print_step=10000)\n",
137 | "print(\"Die/Dat:\", die_dat_correct/die_dat_total, die_dat_correct, die_dat_total, die_dat_errors)"
138 | ]
139 | }
140 | ],
141 | "metadata": {
142 | "kernelspec": {
143 | "name": "python3",
144 | "language": "python",
145 | "display_name": "Python 3"
146 | },
147 | "language_info": {
148 | "codemirror_mode": {
149 | "name": "ipython",
150 | "version": 3
151 | },
152 | "file_extension": ".py",
153 | "mimetype": "text/x-python",
154 | "name": "python",
155 | "nbconvert_exporter": "python",
156 | "pygments_lexer": "ipython3",
157 | "version": "3.7.4"
158 | }
159 | },
160 | "nbformat": 4,
161 | "nbformat_minor": 4
162 | }
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | certifi==2022.9.14
2 | charset-normalizer==2.1.1
3 | click==8.1.3
4 | filelock==3.8.0
5 | Flask==2.2.2
6 | huggingface-hub==0.9.1
7 | idna==3.4
8 | importlib-metadata==4.12.0
9 | itsdangerous==2.1.2
10 | Jinja2==3.1.2
11 | joblib==1.2.0
12 | MarkupSafe==2.1.1
13 | nltk==3.7
14 | numpy==1.21.6
15 | packaging==21.3
16 | pyparsing==3.0.9
17 | PyYAML==6.0
18 | regex==2022.9.13
19 | requests==2.28.1
20 | tokenizers==0.12.1
21 | torch==1.12.1
22 | tqdm==4.64.1
23 | transformers==4.22.0
24 | typing-extensions==4.3.0
25 | urllib3==1.26.12
26 | Werkzeug==2.2.2
27 | zipp==3.8.1
28 |
--------------------------------------------------------------------------------
/res/dbrd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iPieter/RobBERT/8f562fe3e79ec8c0ea04051277b6ae86e7e382e9/res/dbrd.png
--------------------------------------------------------------------------------
/res/gender_diff.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iPieter/RobBERT/8f562fe3e79ec8c0ea04051277b6ae86e7e382e9/res/gender_diff.png
--------------------------------------------------------------------------------
/res/robbert_2022_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iPieter/RobBERT/8f562fe3e79ec8c0ea04051277b6ae86e7e382e9/res/robbert_2022_logo.png
--------------------------------------------------------------------------------
/res/robbert_2022_logo_with_name.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iPieter/RobBERT/8f562fe3e79ec8c0ea04051277b6ae86e7e382e9/res/robbert_2022_logo_with_name.png
--------------------------------------------------------------------------------
/res/robbert_2023_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iPieter/RobBERT/8f562fe3e79ec8c0ea04051277b6ae86e7e382e9/res/robbert_2023_logo.png
--------------------------------------------------------------------------------
/res/robbert_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iPieter/RobBERT/8f562fe3e79ec8c0ea04051277b6ae86e7e382e9/res/robbert_logo.png
--------------------------------------------------------------------------------
/res/robbert_logo_with_name.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iPieter/RobBERT/8f562fe3e79ec8c0ea04051277b6ae86e7e382e9/res/robbert_logo_with_name.png
--------------------------------------------------------------------------------
/res/robbert_pos_accuracy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iPieter/RobBERT/8f562fe3e79ec8c0ea04051277b6ae86e7e382e9/res/robbert_pos_accuracy.png
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/src/bert_masked_lm_adapter.py:
--------------------------------------------------------------------------------
1 | # Class to mimick the fill_mask functionality of the RobertaModel, but for BERT models.
2 | # Used to evaluate wordlistmasks
3 | import numpy as np
4 | import torch
5 |
6 |
7 | class BertMaskedLMAdapter:
8 | def __init__(self, model_name=None, model=None, tokenizer=None):
9 | # Check imports
10 | from transformers import BertForMaskedLM, BertTokenizer
11 |
12 | self._tokenizer = (
13 | tokenizer if tokenizer else BertTokenizer.from_pretrained(model_name)
14 | )
15 | self._model = model if model else BertForMaskedLM.from_pretrained(model_name)
16 | self._model.eval()
17 |
18 | def fill_mask(self, text, topk=4):
19 | text = text.replace("", "[MASK]")
20 | if not text.startswith("[CLS]"):
21 | text = "[CLS] " + text
22 | if not text.endswith("[SEP]"):
23 | text = text + " [SEP]"
24 |
25 | tokenized_text = self._tokenizer.tokenize(text)
26 | masked_index = tokenized_text.index("[MASK]")
27 | indexed_tokens = self._tokenizer.convert_tokens_to_ids(tokenized_text)
28 | # Create the segments tensors.
29 | segments_ids = [0] * len(tokenized_text)
30 | # Convert inputs to PyTorch tensors
31 | tokens_tensor = torch.tensor([indexed_tokens])
32 | segments_tensors = torch.tensor([segments_ids])
33 |
34 | with torch.no_grad():
35 | predictions = self._model(tokens_tensor, segments_tensors)
36 |
37 | predicted_indices = list(
38 | np.argpartition(predictions[0][0][masked_index], -4)[-topk:]
39 | )
40 | predicted_indices.reverse()
41 | return [
42 | self._convert_token(predicted_index)
43 | for predicted_index in predicted_indices
44 | ]
45 |
46 | def _convert_token(self, predicted_index):
47 | text = self._tokenizer.convert_ids_to_tokens([predicted_index])[0]
48 | return None, None, text if text.startswith("##") else (" " + text)
49 |
50 |
51 | if __name__ == "__main__":
52 | from pathlib import Path
53 |
54 | from src import evaluate_zeroshot_wordlist
55 |
56 | mlm = BertMaskedLMAdapter(model_name="bert-base-multilingual-uncased")
57 | result = mlm.fill_mask("Er is een meisje daar loopt.", 1024)
58 | print(result)
59 |
60 | from src.wordlistfiller import WordListFiller
61 |
62 | wordfiller = WordListFiller(["die", "dat"], model=mlm)
63 | print(wordfiller.find_optimal_word("Er is een meisje daar loopt."))
64 | evaluate_zeroshot_wordlist.evaluate(
65 | ["die", "dat"],
66 | path=Path("..", "data", "processed", "wordlist", "die-dat.test.tsv"),
67 | odel=mlm,
68 | print_step=10,
69 | )
70 |
--------------------------------------------------------------------------------
/src/convert_roberta_dict.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | import argparse
5 | from pathlib import Path
6 | import json
7 |
8 | def create_arg_parser():
9 | parser = argparse.ArgumentParser(
10 | description=main.__doc__)
11 |
12 | parser.add_argument("--path", help="Path to the corpus file.", metavar="path",
13 | default="../data/raw/UD_Dutch-LassySmall/")
14 | parser.add_argument(
15 | "--dict",
16 | help='path to dict.txt',
17 | required=True
18 | )
19 | parser.add_argument(
20 | "--vocab-bpe",
21 | type=str,
22 | help='path to vocab.json',
23 | required=True
24 | )
25 |
26 | parser.add_argument(
27 | "--output-vocab",
28 | type=str,
29 | help='Location for the output vocab.json',
30 | required=True
31 | )
32 |
33 | return parser
34 |
35 | def load_roberta_mapping(file):
36 | "Returns a dict with the position of each word-id in the dict.txt file."
37 |
38 | # This file is basically an ordered count and we're not interested in the count.
39 | lines = {line.rstrip('\n').split()[0]: k for k, line in enumerate(file)}
40 | return lines
41 |
42 |
43 | def map_roberta(mapping, vocab):
44 | "Combine vocab.json and dict.txt contents."
45 | inverse_vocab = {str(v): k for k, v in vocab.items()}
46 |
47 | # We add 4 extra tokens, so they also need to be added to the position id
48 | EXTRA_TOKENS = {'': 0, '': 1, '': 2, '': 3}
49 | offset = len(EXTRA_TOKENS)
50 |
51 | output_vocab = EXTRA_TOKENS
52 | for word_id, position in mapping.items():
53 | if word_id in inverse_vocab:
54 | output_vocab[inverse_vocab[word_id]] = position + offset
55 | else:
56 | print("not found: {}".format(word_id))
57 | output_vocab[word_id] = position + offset
58 |
59 | output_vocab[''] = len(output_vocab)
60 |
61 | for word in [ inverse_vocab[x] for x in (set([str(vocab[k]) for k in vocab])-set(mapping)-set(EXTRA_TOKENS.keys()))]:
62 | output_vocab[word] = len(output_vocab)
63 |
64 | return output_vocab
65 |
66 | def main(args: argparse.Namespace):
67 | "Merge a vocab.json file with a dict.txt created by Fairseq."
68 |
69 | # First we load the dict file created by Fairseq's Roberta
70 | with open(args.dict, encoding="utf-8") as dict_fp:
71 | mapping = load_roberta_mapping(dict_fp)
72 |
73 | # Now we load the vocab file
74 | with open(args.vocab_bpe, encoding="utf-8") as vocab_fp:
75 | vocab = json.load(vocab_fp)
76 |
77 | output_vocab = map_roberta(mapping, vocab)
78 |
79 | with open(args.output_vocab, "w", encoding="utf-8") as output_fp:
80 | json.dump(output_vocab, output_fp, ensure_ascii=False)
81 |
82 |
83 | if __name__ == '__main__':
84 | arg_parser = create_arg_parser()
85 | args = arg_parser.parse_args()
86 |
87 | main(args)
--------------------------------------------------------------------------------
/src/evaluate_zeroshot_wordlist.py:
--------------------------------------------------------------------------------
1 | # Class for evaluating the base RobBERT model on a word list classification task without any fine tuning.
2 |
3 | import argparse
4 | from pathlib import Path
5 | from typing import List
6 |
7 | from fairseq.models.roberta import RobertaModel
8 |
9 | from src.wordlistfiller import WordListFiller
10 |
11 | models_path = Path("..", "data", "processed", "wordlist")
12 |
13 |
14 | def evaluate(words: List[str], path: Path = None, model: RobertaModel = None, print_step: int = 1000):
15 | if not model:
16 | model = RobertaModel.from_pretrained(
17 | '../models/robbert',
18 | checkpoint_file='model.pt'
19 | )
20 | model.eval()
21 |
22 | wordlistfiller = WordListFiller(words, model=model)
23 |
24 | dataset_path = path if path is not None else models_path / ("-".join(words) + ".tsv")
25 |
26 | correct = 0
27 | total = 0
28 | errors = 0
29 |
30 | with open(dataset_path) as input_file:
31 | for line in input_file:
32 | sentence, index = line.split('\t')
33 | expected = words[int(index.strip())]
34 |
35 | try:
36 | predicted = wordlistfiller.find_optimal_word(sentence)
37 | if predicted is None:
38 | errors += 1
39 | elif predicted == expected:
40 | correct += 1
41 | total += 1
42 |
43 | if total % print_step == 0:
44 | print("{0:.2f}%".format(100 * correct / total),
45 | correct, total, str(errors) + " errors", expected, predicted, sentence, sep=' / ')
46 | except Exception:
47 | print("Error with", line)
48 | errors += 1
49 | total += 1
50 |
51 | return correct, total, errors
52 |
53 |
54 | def create_parser():
55 | parser = argparse.ArgumentParser(
56 | description="Preprocess the europarl corpus for the die-dat task."
57 | )
58 | parser.add_argument("--words", help="List of comma-separated words to disambiguate", type=str, default="die,dat")
59 | parser.add_argument("--path", help="Path to the evaluation data", metavar="path", default=None)
60 | return parser
61 |
62 |
63 | if __name__ == "__main__":
64 | parser = create_parser()
65 | args = parser.parse_args()
66 |
67 | evaluate([x.strip() for x in args.words.split(",")], args.path)
68 |
--------------------------------------------------------------------------------
/src/multiprocessing_bpe_encoder.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | import argparse
9 | import contextlib
10 | import sys
11 |
12 | from collections import Counter
13 | from multiprocessing import Pool
14 |
15 | from fairseq.data.encoders.gpt2_bpe import get_encoder
16 |
17 |
18 | def main():
19 | """
20 | Helper script to encode raw text with the GPT-2 BPE using multiple processes.
21 |
22 | The encoder.json and vocab.bpe files can be obtained here:
23 | - https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json
24 | - https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe
25 | """
26 | parser = argparse.ArgumentParser()
27 | parser.add_argument(
28 | "--encoder-json",
29 | help='path to encoder.json',
30 | )
31 | parser.add_argument(
32 | "--vocab-bpe",
33 | type=str,
34 | help='path to vocab.bpe',
35 | )
36 | parser.add_argument(
37 | "--inputs",
38 | nargs="+",
39 | default=['-'],
40 | help="input files to filter/encode",
41 | )
42 | parser.add_argument(
43 | "--outputs",
44 | nargs="+",
45 | default=['-'],
46 | help="path to save encoded outputs",
47 | )
48 | parser.add_argument(
49 | "--keep-empty",
50 | action="store_true",
51 | help="keep empty lines",
52 | )
53 | parser.add_argument("--workers", type=int, default=20)
54 | args = parser.parse_args()
55 |
56 | assert len(args.inputs) == len(args.outputs), \
57 | "number of input and output paths should match"
58 |
59 | with contextlib.ExitStack() as stack:
60 | inputs = [
61 | stack.enter_context(open(input, "r", encoding="utf-8"))
62 | if input != "-" else sys.stdin
63 | for input in args.inputs
64 | ]
65 | outputs = [
66 | stack.enter_context(open(output, "w", encoding="utf-8"))
67 | if output != "-" else sys.stdout
68 | for output in args.outputs
69 | ]
70 |
71 | encoder = MultiprocessingEncoder(args)
72 | pool = Pool(args.workers, initializer=encoder.initializer)
73 | encoded_lines = pool.imap(encoder.encode_lines, zip(*inputs), 100)
74 |
75 | stats = Counter()
76 | for i, (filt, enc_lines) in enumerate(encoded_lines, start=1):
77 | if filt == "PASS":
78 | for enc_line, output_h in zip(enc_lines, outputs):
79 | print(enc_line, file=output_h)
80 | else:
81 | stats["num_filtered_" + filt] += 1
82 | if i % 10000 == 0:
83 | print("processed {} lines".format(i), file=sys.stderr)
84 |
85 | for k, v in stats.most_common():
86 | print("[{}] filtered {} lines".format(k, v), file=sys.stderr)
87 |
88 |
89 | class MultiprocessingEncoder(object):
90 |
91 | def __init__(self, args):
92 | self.args = args
93 |
94 | def initializer(self):
95 | global bpe
96 | bpe = get_encoder(self.args.encoder_json, self.args.vocab_bpe)
97 |
98 | def encode(self, line):
99 | global bpe
100 | ids = bpe.encode(line)
101 | return list(map(str, ids))
102 |
103 | def decode(self, tokens):
104 | global bpe
105 | return bpe.decode(tokens)
106 |
107 | def encode_lines(self, lines):
108 | """
109 | Encode a set of lines. All lines will be encoded together.
110 | """
111 | enc_lines = []
112 | for line in lines:
113 | line = line.strip()
114 | if len(line) == 0 and not self.args.keep_empty:
115 | return ["EMPTY", None]
116 | tokens = self.encode(line)
117 | enc_lines.append(" ".join(tokens))
118 | return ["PASS", enc_lines]
119 |
120 | def decode_lines(self, lines):
121 | dec_lines = []
122 | for line in lines:
123 | tokens = map(int, line.strip().split())
124 | dec_lines.append(self.decode(tokens))
125 | return ["PASS", dec_lines]
126 |
127 |
128 | if __name__ == "__main__":
129 | main()
130 |
--------------------------------------------------------------------------------
/src/preprocess_conll2002_ner.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 |
4 | from src import preprocess_util
5 |
6 | ner_tags = ["O", "B-PER", "I-PER", "B-LOC", "I-LOC", "B-ORG", "I-ORG", "B-MISC", "I-MISC"]
7 |
8 | type_map = {
9 | "dev": "testa",
10 | "test": "testb",
11 | "train": "train",
12 | }
13 |
14 |
15 | def create_arg_parser():
16 | parser = argparse.ArgumentParser(
17 | description="Preprocess the CONLL2002 corpus for the NER task."
18 | )
19 | parser.add_argument("--path", help="Path to the corpus file.", metavar="path",
20 | default="../data/raw/conll2002/")
21 | parser.add_argument(
22 | "--encoder-json",
23 | help='path to encoder.json',
24 | default="../models/robbert/encoder.json"
25 | )
26 | parser.add_argument(
27 | "--vocab-bpe",
28 | type=str,
29 | help='path to vocab.bpe',
30 | default="../models/robbert/vocab.bpe"
31 | )
32 |
33 | return parser
34 |
35 |
36 | def get_dataset_suffix(type):
37 | return type_map.get(type, None)
38 |
39 |
40 | def get_label_index(label_name):
41 | return ner_tags.index(label_name)
42 |
43 |
44 | def process_connl2002_ner(arguments, type, processed_data_path, raw_data_path):
45 | output_sentences_path, output_labels_path, output_tokenized_sentences_path, output_tokenized_labels_path = preprocess_util.get_sequence_file_paths(
46 | processed_data_path, type)
47 | tokenizer = preprocess_util.get_tokenizer(arguments)
48 |
49 | with open(output_sentences_path, mode='w') as output_sentences:
50 | with open(output_labels_path, mode='w') as output_labels:
51 | with open(output_tokenized_sentences_path, mode='w') as output_tokenized_sentences:
52 | with open(output_tokenized_labels_path, mode='w') as output_tokenized_labels:
53 |
54 | dataset_suffix = get_dataset_suffix(type)
55 | if dataset_suffix is None:
56 | raise Exception("Invalid type", type)
57 |
58 | file_path = raw_data_path / ("ned." + dataset_suffix)
59 | with open(file_path) as file_to_read:
60 | # Add new line after seeing comments or new line
61 | has_content = False
62 | for line in file_to_read:
63 | if not line.startswith("-DOCSTART-") and len(line.strip()) > 0:
64 | has_content = True
65 |
66 | word, pos, ner = line.split(" ")
67 |
68 | # Write out normal word & label
69 | word = word.strip()
70 | label = str(get_label_index(ner.strip()))
71 | preprocess_util.write_sequence_word_label(word, label, tokenizer, output_sentences,
72 | output_labels,
73 | output_tokenized_sentences,
74 | output_tokenized_labels)
75 | elif has_content:
76 | output_sentences.write("\n")
77 | output_labels.write("\n")
78 | output_tokenized_sentences.write("\n")
79 | output_tokenized_labels.write("\n")
80 | has_content = False
81 |
82 |
83 | if __name__ == "__main__":
84 | arg_parser = create_arg_parser()
85 | args = arg_parser.parse_args()
86 |
87 | processed_data_path = Path("..", "data", "processed", "conll2002")
88 | processed_data_path.mkdir(parents=True, exist_ok=True)
89 |
90 | process_connl2002_ner(args, 'train', processed_data_path, Path(args.path))
91 | process_connl2002_ner(args, 'dev', processed_data_path, Path(args.path))
92 | process_connl2002_ner(args, 'test', processed_data_path, Path(args.path))
93 |
--------------------------------------------------------------------------------
/src/preprocess_dbrd.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 | from os import listdir
4 | from os.path import isfile, join
5 |
6 |
7 | def create_arg_parser():
8 | parser = argparse.ArgumentParser(
9 | description="Preprocess the Dutch Book Reviews Dataset corpus for the sentiment analysis tagging task."
10 | )
11 | parser.add_argument("--path", help="Path to the corpus folder.", metavar="path", default="data/raw/DBRD/")
12 |
13 | return parser
14 |
15 |
16 | def get_file_id(a):
17 | return int(a.split('_')[0])
18 |
19 |
20 | def get_files_of_folder(folder):
21 | return [f for f in listdir(folder) if isfile(join(folder, f))]
22 |
23 |
24 | def add_content_and_label(file_location, output_sentences, output_labels, label):
25 | with open(file_location) as file:
26 | content = " ".join(file.readlines()).replace('\n', ' ').replace('\r', '')
27 | single_spaced_content = ' '.join(content.split())
28 | output_sentences.write(single_spaced_content + "\n")
29 | output_labels.write(str(label) + "\n")
30 |
31 |
32 | def process_dbrd(raw_data_path, test_or_train):
33 | processed_data_path = Path("..", "data", "processed", "dbrd")
34 | processed_data_path.mkdir(parents=True, exist_ok=True)
35 |
36 | output_sentences_path = processed_data_path / (test_or_train + ".sentences.txt")
37 | output_labels_path = processed_data_path / (test_or_train + ".labels.txt")
38 |
39 | with open(output_sentences_path, mode='w') as output_sentences:
40 | with open(output_labels_path, mode='w') as output_labels:
41 | pos_files_folder = raw_data_path / test_or_train / 'pos'
42 | neg_files_folder = raw_data_path / test_or_train / 'neg'
43 |
44 | pos_files = get_files_of_folder(pos_files_folder)
45 | neg_files = get_files_of_folder(neg_files_folder)
46 |
47 | pos_files.sort(key=get_file_id)
48 | neg_files.sort(key=get_file_id)
49 |
50 |
51 | assert len(pos_files) == len(neg_files)
52 |
53 | # process file by intertwining the files, such that the model can learn better
54 | for i in range(len(pos_files)):
55 | add_content_and_label(pos_files_folder / pos_files[i], output_sentences, output_labels, 1)
56 | add_content_and_label(neg_files_folder / neg_files[i], output_sentences, output_labels, 0)
57 |
58 |
59 | if __name__ == "__main__":
60 | arg_parser = create_arg_parser()
61 | args = arg_parser.parse_args()
62 |
63 | process_dbrd(Path(args.path), 'train')
64 | process_dbrd(Path(args.path), 'test')
65 |
--------------------------------------------------------------------------------
/src/preprocess_diedat.py:
--------------------------------------------------------------------------------
1 | import random
2 | import nltk
3 | from nltk.tokenize.treebank import TreebankWordDetokenizer
4 | import argparse
5 |
6 |
7 | def replace_die_dat_full_sentence(line, flout, fout):
8 | count = 0
9 | for i, word in enumerate(nltk.word_tokenize(line)):
10 | tokens = nltk.word_tokenize(line)
11 | if word == "die":
12 | tokens[i] = "dat"
13 | elif word == "dat":
14 | tokens[i] = "die"
15 | elif word == "Dat":
16 | tokens[i] = "Die"
17 | elif word == "Die":
18 | tokens[i] = "dat"
19 |
20 | if word.lower() == "die" or word.lower() == "dat":
21 | choice = random.getrandbits(1)
22 | results = TreebankWordDetokenizer().detokenize(tokens)
23 |
24 | if choice:
25 | output = "{} {}".format(results, line)
26 | else:
27 | output = "{} {}".format(line, results)
28 |
29 | fout.write(output + "\n")
30 | flout.write(str(choice) + "\n")
31 | count += 1
32 |
33 | return count
34 |
35 |
36 | def create_arg_parser():
37 | parser = argparse.ArgumentParser(
38 | description="Preprocess the europarl corpus for the die-dat task."
39 | )
40 | parser.add_argument("--path", help="Path to the corpus file.", metavar="path",
41 | default="data/raw/europarl/europarl-v7.nl-en.nl")
42 | parser.add_argument("--number", help="Number of examples in the output dataset", type=int, default=10000)
43 |
44 | return parser
45 |
46 |
47 | if __name__ == "__main__":
48 | parser = create_arg_parser()
49 |
50 | args = parser.parse_args()
51 |
52 | with open(args.path + '.labels', mode='a') as labels_output:
53 | with open(args.path + '.sentences', mode='a') as sentences_output:
54 | with open(args.path) as fp:
55 | lines_processed = 0
56 | for line in fp:
57 | line = line.replace('\n', '').replace('\r', '')
58 | count = replace_die_dat_full_sentence(line, labels_output, sentences_output)
59 | lines_processed += count
60 | if lines_processed >= args.number:
61 | break
62 |
--------------------------------------------------------------------------------
/src/preprocess_diedat.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | split -l $[ $(wc -l data/europarl-v7.nl-en.nl.labels|cut -d" " -f1) * 70 / 100 ] data/europarl-v7.nl-en.nl.labels
4 | mv xaa data/processed/diedat/train.labels
5 | mv xab data/processed/diedat/dev.labels
6 |
7 | split -l $[ $(wc -l data/europarl-v7.nl-en.nl.sentences|cut -d" " -f1) * 70 / 100 ] data/europarl-v7.nl-en.nl.sentences
8 | mv xaa data/processed/diedat/train.sentences
9 | mv xab data/processed/diedat/dev.sentences
10 |
11 | rm data/labels/dict.txt
12 |
13 | for SPLIT in train dev; do
14 | python src/multiprocessing_bpe_encoder.py\
15 | --encoder-json data/encoder.json \
16 | --vocab-bpe data/vocab.bpe \
17 | --inputs "data/processed/diedat/$SPLIT.sentences" \
18 | --outputs "data/processed/diedat/$SPLIT.sentences.bpe" \
19 | --workers 24 \
20 | --keep-empty
21 | done
22 |
23 | fairseq-preprocess \
24 | --only-source \
25 | --trainpref "data/processed/diedat/train.sentences.bpe" \
26 | --validpref "data/processed/diedat/dev.sentences.bpe" \
27 | --destdir "data/input0" \
28 | --workers 24 \
29 | --srcdict data/dict.txt
30 |
31 | fairseq-preprocess \
32 | --only-source \
33 | --trainpref "data/processed/diedat/train.labels" \
34 | --validpref "data/processed/diedat/dev.labels" \
35 | --destdir "data/diedat/labels" \
36 | --workers 24
37 |
38 |
--------------------------------------------------------------------------------
/src/preprocess_lassy_ud.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 |
4 | from src import preprocess_util
5 |
6 | universal_pos_tags = ["ADJ", "ADP", "ADV", "AUX", "CCONJ", "DET", "INTJ",
7 | "NOUN", "NUM", "PART", "PRON", "PROPN", "PUNCT", "SCONJ", "SYM", "VERB", "X"]
8 |
9 |
10 | def create_arg_parser():
11 | parser = argparse.ArgumentParser(
12 | description="Preprocess the LASSY corpus for the POS-tagging task."
13 | )
14 | parser.add_argument("--path", help="Path to the corpus file.", metavar="path",
15 | default="../data/raw/UD_Dutch-LassySmall/")
16 | parser.add_argument(
17 | "--encoder-json",
18 | help='path to encoder.json',
19 | default="../models/robbert.large/encoder.json"
20 | )
21 | parser.add_argument(
22 | "--vocab-bpe",
23 | type=str,
24 | help='path to vocab.bpe',
25 | default="../models/robbert/vocab.bpe"
26 | )
27 |
28 | return parser
29 |
30 |
31 | def get_label_index(label_name):
32 | return universal_pos_tags.index(label_name)
33 |
34 |
35 | def process_lassy_ud(arguments, type, processed_data_path, raw_data_path):
36 | output_sentences_path, output_labels_path, output_tokenized_sentences_path, output_tokenized_labels_path = preprocess_util.get_sequence_file_paths(
37 | processed_data_path, type)
38 | tokenizer = preprocess_util.get_tokenizer(arguments)
39 |
40 | with open(output_sentences_path, mode='w') as output_sentences:
41 | with open(output_labels_path, mode='w') as output_labels:
42 | with open(output_tokenized_sentences_path, mode='w') as output_tokenized_sentences:
43 | with open(output_tokenized_labels_path, mode='w') as output_tokenized_labels:
44 |
45 | file_path = raw_data_path / ("nl_lassysmall-ud-" + type + ".conllu")
46 | with open(file_path) as file_to_read:
47 | # For removing first blank line
48 | has_content = False
49 | for line in file_to_read:
50 | if not line.startswith("#") and len(line.strip()) > 0:
51 | has_content = True
52 |
53 | index, word, main_word, universal_pos, detailed_tag, details, number, english_tag, \
54 | number_and_english_tag, space_after = line.split("\t")
55 |
56 | # Write out normal word & label
57 | word = word.strip()
58 | label = str(get_label_index(universal_pos.strip()))
59 | preprocess_util.write_sequence_word_label(word, label, tokenizer, output_sentences,
60 | output_labels,
61 | output_tokenized_sentences,
62 | output_tokenized_labels)
63 | elif has_content:
64 | output_sentences.write("\n")
65 | output_labels.write("\n")
66 | output_tokenized_sentences.write("\n")
67 | output_tokenized_labels.write("\n")
68 | has_content = False
69 |
70 |
71 | if __name__ == "__main__":
72 | arg_parser = create_arg_parser()
73 | args = arg_parser.parse_args()
74 |
75 | processed_data_path = Path("..", "data", "processed", "lassy_ud")
76 | processed_data_path.mkdir(parents=True, exist_ok=True)
77 |
78 | process_lassy_ud(args, 'train', processed_data_path, Path(args.path))
79 | process_lassy_ud(args, 'dev', processed_data_path, Path(args.path))
80 | process_lassy_ud(args, 'test', processed_data_path, Path(args.path))
81 |
--------------------------------------------------------------------------------
/src/preprocess_util.py:
--------------------------------------------------------------------------------
1 | from src.multiprocessing_bpe_encoder import MultiprocessingEncoder
2 |
3 | seperator_token = "\t"
4 |
5 |
6 | def get_sequence_file_paths(processed_data_path, type):
7 | output_sentences_path = processed_data_path / (type + ".sentences.tsv")
8 | output_labels_path = processed_data_path / (type + ".labels.tsv")
9 | output_tokenized_sentences_path = processed_data_path / (type + ".sentences.bpe")
10 | output_tokenized_labels_path = processed_data_path / (type + ".labels.bpe")
11 | return output_sentences_path, output_labels_path, output_tokenized_sentences_path, output_tokenized_labels_path
12 |
13 |
14 | def write_sequence_word_label(word, label, tokenizer, output_sentences, output_labels, output_tokenized_sentences,
15 | output_tokenized_labels):
16 | output_sentences.write(word.strip() + seperator_token)
17 | output_labels.write(label + seperator_token)
18 |
19 | # Write tokenized
20 | tokenized_word = tokenizer.encode(word.strip())
21 | output_tokenized_sentences.write(seperator_token.join(tokenized_word) + seperator_token)
22 | output_tokenized_labels.write(len(tokenized_word) * (label + seperator_token))
23 |
24 |
25 | def get_tokenizer(arguments):
26 | tokenizer = MultiprocessingEncoder(arguments)
27 | tokenizer.initializer()
28 | return tokenizer
29 |
--------------------------------------------------------------------------------
/src/preprocess_wordlist_mask.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 |
4 | from src.wordlistfiller import WordListFiller
5 |
6 |
7 | def create_arg_parser():
8 | parser = argparse.ArgumentParser(
9 | description="Preprocess the europarl corpus for the die-dat task."
10 | )
11 | parser.add_argument("--path", help="Path to the corpus file.", metavar="path",
12 | default="../data/raw/europarl-v7.nl-en.nl")
13 | parser.add_argument("--filename", help="Extra for file name", metavar="path",
14 | default="")
15 | parser.add_argument("--words", help="List of comma-separated words to disambiguate", type=str, default="die,dat")
16 | parser.add_argument("--number", help="Number of examples in the output dataset", type=int, default=10000000)
17 |
18 | return parser
19 |
20 |
21 | if __name__ == "__main__":
22 | parser = create_arg_parser()
23 | args = parser.parse_args()
24 |
25 | models_path = Path("..", "data", "processed", "wordlist")
26 | models_path.mkdir(parents=True, exist_ok=True)
27 |
28 | words = [x.strip() for x in args.words.split(",")]
29 | wordlistfiller = WordListFiller(words)
30 |
31 | output_path = models_path / (args.words.replace(',', '-')
32 | + (('.' + args.filename) if args.filename else '') + ".tsv")
33 |
34 | with open(output_path, mode='w') as output:
35 | with open(args.path) as input_file:
36 | number_of_lines_to_add = args.number
37 | for line in input_file:
38 | line = line.strip()
39 | sentences = wordlistfiller.occlude_target_words_index(line)
40 | number_of_sentences = len(sentences)
41 |
42 | for i in range(min(number_of_lines_to_add, number_of_sentences)):
43 | sentence = sentences[i]
44 | output.write(sentence[0] + "\t" + str(sentence[1]) + '\n')
45 | number_of_lines_to_add -= number_of_sentences
46 |
47 | if number_of_lines_to_add <= 0:
48 | break
49 |
--------------------------------------------------------------------------------
/src/pretrain.pbs:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #PBS -l partition=gpu
3 | #PBS -l walltime=12:00:00
4 | #PBS -l nodes=5:ppn=36:gpus=4
5 | #PBS -l pmem=5gb
6 | #PBS -l excludenodes=r24g39:r22g41
7 | #PBS -N alBERT
8 | #PBS -A lp_dtai1
9 |
10 |
11 | echo "changing dir"
12 | cd $VSC_DATA
13 | source miniconda3/bin/activate torch
14 | cd alBERT
15 |
16 | DATA_DIR=/scratch/leuven/x/x/data-bin/nl_dedup
17 |
18 | MASTER_ADDR=$(cat $PBS_NODEFILE | head -n 1)
19 | MASTER_PORT=6666
20 | WORLD_SIZE=20
21 |
22 | echo "Master node is" $MASTER_ADDR
23 |
24 | pbsdsh -u sh $PBS_O_WORKDIR/run_node.sh $MASTER_ADDR $MASTER_PORT $WORLD_SIZE
25 |
26 | echo "Done training, cleaning up syncfile."
27 | rm /scratch/leuven/x/x/sync_torch/syncfile
28 |
--------------------------------------------------------------------------------
/src/split_dbrd_training.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd data/processed/dbrd
4 |
5 | mv train.sentences.txt traineval.sentences.txt
6 | mv train.labels.txt traineval.labels.txt
7 | echo "renamed files to traineval.*.txt"
8 |
9 | head traineval.sentences.txt -n -500 > train.sentences.txt
10 | head traineval.labels.txt -n -500 > train.labels.txt
11 | head traineval.sentences.txt -n 500 > eval.sentences.txt
12 | head traineval.labels.txt -n 500 > eval.labels.txt
13 |
--------------------------------------------------------------------------------
/src/textdataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import pickle
4 | import torch
5 | from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 | class TextDataset(Dataset):
11 | def __init__(self, tokenizer, model_name_or_path, file_path="train", block_size=512, overwrite_cache=True, mask_padding_with_zero=True):
12 | assert os.path.isfile(file_path + '.sentences.txt')
13 |
14 | assert os.path.isfile(file_path + '.labels.txt')
15 |
16 | directory, filename = os.path.split(file_path)
17 | cached_features_file = os.path.join(
18 | directory, model_name_or_path + "_cached_lm_" + str(block_size) + "_" + filename
19 | )
20 |
21 | if os.path.exists(cached_features_file) and not overwrite_cache:
22 | logger.info("Loading features from cached file %s", cached_features_file)
23 | with open(cached_features_file, "rb") as handle:
24 | self.examples = pickle.load(handle)
25 | else:
26 | logger.info("Creating features from dataset file at %s", directory)
27 |
28 | self.examples = []
29 | with open(file_path + ".labels.txt", encoding="utf-8") as flabel:
30 | with open(file_path + ".sentences.txt", encoding="utf-8") as f:
31 | for sentence in f:
32 | tokenized_text = tokenizer.encode(tokenizer.tokenize(sentence)[-block_size + 3 : -1])
33 |
34 | input_mask = [1 if mask_padding_with_zero else 0] * len(tokenized_text)
35 |
36 | pad_token = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
37 | while len(tokenized_text) < block_size:
38 | tokenized_text.append(pad_token)
39 | input_mask.append(0 if mask_padding_with_zero else 1)
40 | #segment_ids.append(pad_token_segment_id)
41 | #p_mask.append(1)
42 |
43 | #self.examples.append([tokenizer.build_inputs_with_special_tokens(tokenized_text[0 : block_size]), [0], [0]])
44 | label = next(flabel)
45 | self.examples.append([tokenized_text[0 : block_size - 3], input_mask[0 : block_size - 3], [1] if label.startswith("1") else [0]])
46 |
47 | logger.info("Saving features into cached file %s", cached_features_file)
48 | with open(cached_features_file, "wb") as handle:
49 | pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
50 |
51 | def __len__(self):
52 | return len(self.examples)
53 |
54 | def __getitem__(self, item):
55 | return [torch.tensor(self.examples[item][0]), torch.tensor(self.examples[item][1]), torch.tensor([0]), torch.tensor(self.examples[item][2])]
56 |
57 |
58 | def load_and_cache_examples(model_name_or_path, tokenizer, data_file):
59 | dataset = TextDataset(
60 | tokenizer,
61 | model_name_or_path,
62 | file_path=data_file,
63 | block_size=512
64 | )
65 | return dataset
66 |
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | from transformers import AdamW
2 | import torch
3 | from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
4 | import os
5 | from tqdm import tqdm, trange
6 | import pandas as pd
7 | from pycm import ConfusionMatrix
8 |
9 | try:
10 | from torch.utils.tensorboard import SummaryWriter
11 | except ImportError:
12 | from tensorboardX import SummaryWriter
13 |
14 | import json
15 |
16 | from transformers import get_linear_schedule_with_warmup
17 | import logging
18 |
19 | logger = logging.getLogger(__name__)
20 |
21 | class Train:
22 | def train(args, train_dataset, model, tokenizer, evaluate_fn=None):
23 | """ Train the model """
24 |
25 | model.train()
26 |
27 | if args.local_rank in [-1, 0]:
28 | tb_writer = SummaryWriter()
29 |
30 | # Setup CUDA, GPU & distributed training
31 | if args.local_rank == -1 or args.no_cuda:
32 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
33 | args.n_gpu = torch.cuda.device_count()
34 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
35 | torch.cuda.set_device(args.local_rank)
36 | device = torch.device("cuda", args.local_rank)
37 | torch.distributed.init_process_group(backend="nccl")
38 | args.n_gpu = 1
39 | args.device = device
40 |
41 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
42 | train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
43 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
44 |
45 | if args.max_steps > 0:
46 | t_total = args.max_steps
47 | args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
48 | else:
49 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
50 |
51 | # Prepare optimizer and schedule (linear warmup and decay)
52 | no_decay = ["bias", "LayerNorm.weight"]
53 | optimizer_grouped_parameters = [
54 | {
55 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
56 | "weight_decay": args.weight_decay,
57 | },
58 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
59 | ]
60 |
61 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
62 | scheduler = get_linear_schedule_with_warmup(
63 | optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
64 | )
65 |
66 | # Check if saved optimizer or scheduler states exist
67 | if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
68 | os.path.join(args.model_name_or_path, "scheduler.pt")
69 | ):
70 | # Load in optimizer and scheduler states
71 | optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
72 | scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
73 |
74 | if args.fp16:
75 | try:
76 | from apex import amp
77 | except ImportError:
78 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
79 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
80 |
81 | # multi-gpu training (should be after apex fp16 initialization)
82 | if args.n_gpu > 1:
83 | model = torch.nn.DataParallel(model)
84 |
85 | # Distributed training (should be after apex fp16 initialization)
86 | if args.local_rank != -1:
87 | model = torch.nn.parallel.DistributedDataParallel(
88 | model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True,
89 | )
90 |
91 | if args.local_rank == 0:
92 | torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
93 |
94 | model.to(args.device)
95 |
96 | # Train!
97 | logger.info("***** Running training *****")
98 | logger.info(" Num examples = %d", len(train_dataset))
99 | logger.info(" Num Epochs = %d", args.num_train_epochs)
100 | logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
101 | logger.info(
102 | " Total train batch size (w. parallel, distributed & accumulation) = %d",
103 | args.train_batch_size
104 | * args.gradient_accumulation_steps
105 | * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
106 | )
107 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
108 | logger.info(" Total optimization steps = %d", t_total)
109 |
110 | global_step = 0
111 | epochs_trained = 0
112 | steps_trained_in_current_epoch = 0
113 | # Check if continuing training from a checkpoint
114 | if os.path.exists(args.model_name_or_path):
115 | # set global_step to gobal_step of last saved checkpoint from model path
116 | global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
117 | epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
118 | steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
119 |
120 | logger.info(" Continuing training from checkpoint, will skip to saved global_step")
121 | logger.info(" Continuing training from epoch %d", epochs_trained)
122 | logger.info(" Continuing training from global step %d", global_step)
123 | logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
124 |
125 | tr_loss, logging_loss = 0.0, 0.0
126 | model.zero_grad()
127 | train_iterator = trange(
128 | epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0],
129 | )
130 | #set_seed(args) # Added here for reproductibility
131 | for _ in train_iterator:
132 | epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0], position=0, leave=True)
133 | for step, batch in enumerate(epoch_iterator):
134 |
135 | # Skip past any already trained steps if resuming training
136 | if steps_trained_in_current_epoch > 0:
137 | steps_trained_in_current_epoch -= 1
138 | continue
139 |
140 | model.train()
141 | batch = tuple(t.to(args.device) for t in batch)
142 | inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
143 | if args.model_type != "distilbert":
144 | inputs["token_type_ids"] = (
145 | batch[2] if args.model_type in ["bert", "xlnet", "albert"] else None
146 | ) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
147 | outputs = model(**inputs)
148 | loss = outputs[0] # model outputs are always tuple in transformers (see doc)
149 |
150 | if args.n_gpu > 1:
151 | loss = loss.mean() # mean() to average on multi-gpu parallel training
152 | if args.gradient_accumulation_steps > 1:
153 | loss = loss / args.gradient_accumulation_steps
154 |
155 | if args.fp16:
156 | with amp.scale_loss(loss, optimizer) as scaled_loss:
157 | scaled_loss.backward()
158 | else:
159 | loss.backward()
160 |
161 | tr_loss += loss.item()
162 | if (step + 1) % args.gradient_accumulation_steps == 0:
163 | if args.fp16:
164 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
165 | else:
166 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
167 |
168 | optimizer.step()
169 | scheduler.step() # Update learning rate schedule
170 | model.zero_grad()
171 | global_step += 1
172 |
173 | if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
174 | logs = {}
175 | if (
176 | args.local_rank == -1 and args.evaluate_during_training
177 | ): # Only evaluate when single GPU otherwise metrics may not average well
178 | results = evaluate(args, model, tokenizer)
179 | for key, value in results.items():
180 | eval_key = "eval_{}".format(key)
181 | logs[eval_key] = value
182 |
183 | loss_scalar = (tr_loss - logging_loss) / args.logging_steps
184 | learning_rate_scalar = scheduler.get_lr()[0]
185 | logs["learning_rate"] = learning_rate_scalar
186 | logs["loss"] = loss_scalar
187 | logging_loss = tr_loss
188 |
189 | for key, value in logs.items():
190 | tb_writer.add_scalar(key, value, global_step)
191 | epoch_iterator.set_postfix({**logs, **{"step": global_step}})
192 |
193 | if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
194 | # Save model checkpoint
195 | output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
196 | if not os.path.exists(output_dir):
197 | os.makedirs(output_dir)
198 | model_to_save = (
199 | model.module if hasattr(model, "module") else model
200 | ) # Take care of distributed/parallel training
201 | model_to_save.save_pretrained(output_dir)
202 | tokenizer.save_pretrained(output_dir)
203 |
204 | torch.save(args, os.path.join(output_dir, "training_args.bin"))
205 | logger.info("Saving model checkpoint to %s", output_dir)
206 |
207 | torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
208 | torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
209 | logger.info("Saving optimizer and scheduler states to %s", output_dir)
210 |
211 | if args.max_steps > 0 and global_step > args.max_steps:
212 | epoch_iterator.close()
213 | break
214 | if args.max_steps > 0 and global_step > args.max_steps:
215 | train_iterator.close()
216 | break
217 |
218 | if evaluate_fn is not None:
219 | results = pd.DataFrame(evaluate_fn(args.evaluate_dataset, model))
220 | cm = ConfusionMatrix(actual_vector=results['true'].values, predict_vector=results['predicted'].values)
221 | logs = {}
222 | logs["eval_f1_macro"] = cm.F1_Macro
223 | logs["eval_acc_macro"] = cm.ACC_Macro
224 | logs["eval_acc_overall"] = cm.Overall_ACC
225 | logger.info("Results on eval: {}".format(logs))
226 |
227 |
228 | if args.local_rank in [-1, 0]:
229 | tb_writer.close()
230 |
231 | return global_step, tr_loss / global_step
232 |
--------------------------------------------------------------------------------
/src/train_config.py:
--------------------------------------------------------------------------------
1 | class Config:
2 |
3 | def __init__(self):
4 | "Ugly hack to make the args object work without spending to much effort."
5 | self.local_rank = -1
6 | self.per_gpu_train_batch_size = 4
7 | self.gradient_accumulation_steps = 8
8 | self.n_gpu = 1
9 | self.max_steps = 2000
10 | self.weight_decay = 0
11 | self.learning_rate = 5e-5
12 | self.adam_epsilon = 2e-8
13 | self.warmup_steps = 250
14 | self.model_name_or_path = "bert"
15 | self.fp16 = False
16 | self.set_seed = "42"
17 | self.device=0
18 | self.model_type="roberta"
19 | self.no_cuda=False
20 | self.max_grad_norm=1.0
21 | self.logging_steps=1
22 | self.evaluate_during_training=False
23 | self.save_steps=2500
24 | self.output_dir="./"
25 | self.evaluate_dataset=""
26 |
--------------------------------------------------------------------------------
/src/train_diedat.sh:
--------------------------------------------------------------------------------
1 | TOTAL_NUM_UPDATES=7812 # 10 epochs
2 | WARMUP_UPDATES=469 # 6 percent of the number of updates
3 | LR=1e-05 # Peak LR for polynomial LR scheduler.
4 | HEAD_NAME= die-dat-head # Custom name for the classification head.
5 | NUM_CLASSES=2 # Number of classes for the classification task.
6 | MAX_SENTENCES=8 # Batch size.
7 | ROBERTA_PATH=data/checkpoint_best.pt
8 |
9 | fairseq-train data \
10 | --restore-file $ROBERTA_PATH \
11 | --max-positions 512 \
12 | --max-sentences $MAX_SENTENCES \
13 | --max-tokens 4400 \
14 | --task sentence_prediction \
15 | --reset-optimizer --reset-dataloader --reset-meters \
16 | --required-batch-size-multiple 1 \
17 | --init-token 0 --separator-token 2 \
18 | --arch roberta_base \
19 | --criterion sentence_prediction \
20 | --num-classes $NUM_CLASSES \
21 | --dropout 0.1 --attention-dropout 0.1 \
22 | --weight-decay 0.1 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \
23 | --clip-norm 0.0 \
24 | --lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
25 | --fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \
26 | --max-epoch 10 \
27 | --best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \
28 | --truncate-sequence \
29 | --find-unused-parameters \
30 | --update-freq 4 \
31 | --save-interval 1 \
32 | --save-dir checkpoints
33 |
--------------------------------------------------------------------------------
/src/wordlistfiller.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple, Union
2 |
3 | from fairseq.models.roberta import RobertaHubInterface
4 | from nltk import word_tokenize
5 | from nltk.tokenize.treebank import TreebankWordDetokenizer
6 |
7 | from src.bert_masked_lm_adapter import BertMaskedLMAdapter
8 |
9 |
10 | class WordListFiller:
11 | def __init__(self, target_words: List[str],
12 | model: Union[RobertaHubInterface, BertMaskedLMAdapter] = None,
13 | detokenizer: TreebankWordDetokenizer = TreebankWordDetokenizer(),
14 | topk_limit=2048):
15 | self._model = model
16 | self._target_words = [x.lower().strip() for x in target_words]
17 | self._target_words_spaced = set([" " + word for word in self._target_words])
18 | self._detokenizer = detokenizer
19 | self._top_k_limit = topk_limit
20 |
21 | def find_optimal_word(self, text: str) -> str:
22 | if self._model is None:
23 | raise AttributeError("No model given to find the optimal word")
24 | topk = 4
25 | result = None
26 | while result is None and topk <= self._top_k_limit:
27 | filler_words = self._model.fill_mask(text, topk=topk)
28 | result = next((x[2].strip() for x in filler_words if x[2].lower() in self._target_words_spaced), None)
29 | topk *= 2
30 | return result
31 |
32 | # Transforms a sentence into a list of sentences with target words masked if the sentence contains this
33 | def occlude_target_words(self, input_sentence: str) -> List[Tuple[str, str]]:
34 | tokenized = word_tokenize(input_sentence)
35 | result = []
36 | for i in range(len(tokenized)):
37 | if tokenized[i] in self._target_words:
38 | new_sentence_tokens = tokenized[:i] + [""] + tokenized[i + 1:]
39 | new_sentence = self._detokenizer.detokenize(new_sentence_tokens)
40 | result.append((new_sentence, tokenized[i]))
41 | return result
42 |
43 | def occlude_target_words_index(self, input_sentence: str) -> List[Tuple[str, int]]:
44 | return [(s[0], self._target_words.index(s[1].strip().lower())) for s in
45 | self.occlude_target_words(input_sentence)]
46 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iPieter/RobBERT/8f562fe3e79ec8c0ea04051277b6ae86e7e382e9/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_convert_roberta_dict.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from io import StringIO
3 |
4 | from src.convert_roberta_dict import load_roberta_mapping, map_roberta
5 |
6 |
7 | class ConvertRobertaTestCase(unittest.TestCase):
8 | def test_load_roberta_mapping(self):
9 | # Create a mock file
10 | file = StringIO(initial_value="3 12\n1 8\n2 7")
11 | mapping = load_roberta_mapping(file)
12 |
13 | # test our little mapping
14 | self.assertEqual(mapping['3'], 0, msg="First element in dict.txt is 3, should have id = 0")
15 | self.assertEqual(mapping['1'], 1, msg="Second element in dict.txt is 1, should have id = 1")
16 | self.assertEqual(mapping['2'], 2, msg="Third element in dict.txt is 2, should have id = 2")
17 |
18 | def test_map_roberta(self):
19 | file = StringIO(initial_value="3 12\n1 8\n2 7")
20 | mapping = load_roberta_mapping(file)
21 |
22 | vocab = {"Een": 3, "Twee": 2, "Drie": 1}
23 |
24 | output_vocab = map_roberta(mapping, vocab)
25 |
26 | self.assertEqual(output_vocab[''], 0, msg="Extra tokens")
27 | self.assertEqual(output_vocab[''], 1, msg="Extra tokens")
28 | self.assertEqual(output_vocab[''], 2, msg="Extra tokens")
29 | self.assertEqual(output_vocab[''], 3, msg="Extra tokens")
30 | self.assertEqual(output_vocab['Een'], 4, msg="'Een' has vocab_id = 3, which is mapped to 0 (+4)")
31 | self.assertEqual(output_vocab['Twee'], 6, msg="'Twee' has vocab_id = 3, which is mapped to 2 (+4)")
32 | self.assertEqual(output_vocab['Drie'], 5, msg="'Drie' has vocab_id = 1, which is mapped to 1 (+4)")
33 |
34 |
35 | def test_map_roberta_unused_tokens(self):
36 | """
37 | The fast HuggingFace tokenizer requires that all tokens in the merges.txt are also present in the
38 | vocab.json. When converting a Fairseq dict.txt, this is not necessarily the case in a naive implementation.
39 |
40 | More info: https://github.com/huggingface/transformers/issues/9290
41 | """
42 |
43 | file = StringIO(initial_value="3 12\n1 8\n2 7")
44 | mapping = load_roberta_mapping(file)
45 |
46 | # Tokens "Vier" and "Vijf" are not used in the mapping (= dixt.txt)
47 | vocab = {"Een": 3, "Twee": 2, "Drie": 1, "Vier": 5, "Vijf": 4}
48 |
49 | output_vocab = map_roberta(mapping, vocab)
50 |
51 | self.assertEqual(output_vocab[''], 0, msg="Extra tokens")
52 | self.assertEqual(output_vocab[''], 1, msg="Extra tokens")
53 | self.assertEqual(output_vocab[''], 2, msg="Extra tokens")
54 | self.assertEqual(output_vocab[''], 3, msg="Extra tokens")
55 | self.assertEqual(output_vocab['Een'], 4, msg="'Een' has vocab_id = 3, which is mapped to 0 (+4)")
56 | self.assertEqual(output_vocab['Twee'], 6, msg="'Twee' has vocab_id = 3, which is mapped to 2 (+4)")
57 | self.assertEqual(output_vocab['Drie'], 5, msg="'Drie' has vocab_id = 1, which is mapped to 1 (+4)")
58 | self.assertIn(output_vocab['Vier'], [7, 8], msg="'Vier' has vocab_id = 5, which is mapped the next available value")
59 | self.assertIn(output_vocab['Vijf'], [8, 7], msg="'Vijf' has vocab_id = 4, which is mapped the next available value")
60 | self.assertNotEqual(output_vocab['Vijf'], output_vocab['Vier'], msg="Unused tokens must have different values")
61 |
62 | def test_tokenization(self):
63 | sample_input = "De tweede poging: nog een test van de tokenizer met nummers."
64 | expected_output = [0, 62, 488, 5351, 30, 49, 9, 2142, 7, 5, 905, 859, 10605, 15, 3316, 4, 2]
65 |
66 |
67 | if __name__ == '__main__':
68 | unittest.main()
69 |
--------------------------------------------------------------------------------