├── .gitignore ├── LICENSE ├── Pipfile ├── Pipfile.lock ├── README.md ├── examples └── die_vs_data_rest_api │ ├── .gitignore │ ├── README.md │ ├── app.py │ ├── app │ ├── __init__.py │ ├── __pycache__ │ │ └── __init__.cpython-37.pyc │ └── config.json │ └── requirements.txt ├── model_cards └── README.md ├── notebooks ├── demo_RobBERT_for_conll_ner.ipynb ├── demo_RobBERT_for_masked_LM.ipynb ├── die_dat_demo.ipynb ├── evaluate_zeroshot_wordlists.ipynb ├── evaluate_zeroshot_wordlists_v2.ipynb └── finetune_dbrd.ipynb ├── requirements.txt ├── res ├── dbrd.png ├── gender_diff.png ├── robbert_2022_logo.png ├── robbert_2022_logo_with_name.png ├── robbert_2023_logo.png ├── robbert_logo.png ├── robbert_logo_with_name.png └── robbert_pos_accuracy.png ├── src ├── __init__.py ├── bert_masked_lm_adapter.py ├── convert_roberta_dict.py ├── evaluate_zeroshot_wordlist.py ├── multiprocessing_bpe_encoder.py ├── preprocess_conll2002_ner.py ├── preprocess_dbrd.py ├── preprocess_diedat.py ├── preprocess_diedat.sh ├── preprocess_lassy_ud.py ├── preprocess_util.py ├── preprocess_wordlist_mask.py ├── pretrain.pbs ├── run_lm.py ├── split_dbrd_training.sh ├── textdataset.py ├── train.py ├── train_config.py ├── train_diedat.sh └── wordlistfiller.py └── tests ├── __init__.py └── test_convert_roberta_dict.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | .idea/ 3 | models/ 4 | src/__pycache__/ 5 | 6 | venv/ 7 | .env/ 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Pieter Delobelle, Thomas Winters, Bettina Berendt 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Pipfile: -------------------------------------------------------------------------------- 1 | [[source]] 2 | name = "pypi" 3 | url = "https://pypi.org/simple" 4 | verify_ssl = true 5 | 6 | [dev-packages] 7 | 8 | [packages] 9 | fairseq = "*" 10 | nltk = "*" 11 | transformers = "*" 12 | tensorboardx = "*" 13 | tokenizers = "~=0.4.2" 14 | jupyter = "*" 15 | 16 | [requires] 17 | python_version = "3.7" 18 | -------------------------------------------------------------------------------- /Pipfile.lock: -------------------------------------------------------------------------------- 1 | { 2 | "_meta": { 3 | "hash": { 4 | "sha256": "de30df368a36e74d2c9c8d1d71d6017149e55b4570e2ae1d31539a99ee42feb2" 5 | }, 6 | "pipfile-spec": 6, 7 | "requires": { 8 | "python_version": "3.7" 9 | }, 10 | "sources": [ 11 | { 12 | "name": "pypi", 13 | "url": "https://pypi.org/simple", 14 | "verify_ssl": true 15 | } 16 | ] 17 | }, 18 | "default": { 19 | "boto3": { 20 | "hashes": [ 21 | "sha256:2cb3063d347c0d06f96dcd10d69247d3e3d12eda4b1ed9fdec62ac25c7fee8b6", 22 | "sha256:5bffdfc7d3465f053db38a64821fbeb1f8044ffafb71c90587f1c293d334305f" 23 | ], 24 | "version": "==1.11.16" 25 | }, 26 | "botocore": { 27 | "hashes": [ 28 | "sha256:63dc203046e47573add16c348f00a8e2fe3bb4975323458787e7b2bebeec5923", 29 | "sha256:d31d7e0e12fb64d6bd331c6b853750c71423d560cee01a989f607dd45407ef65" 30 | ], 31 | "version": "==1.14.16" 32 | }, 33 | "certifi": { 34 | "hashes": [ 35 | "sha256:017c25db2a153ce562900032d5bc68e9f191e44e9a0f762f373977de9df1fbb3", 36 | "sha256:25b64c7da4cd7479594d035c08c2d809eb4aab3a26e5a990ea98cc450c320f1f" 37 | ], 38 | "version": "==2019.11.28" 39 | }, 40 | "cffi": { 41 | "hashes": [ 42 | "sha256:001bf3242a1bb04d985d63e138230802c6c8d4db3668fb545fb5005ddf5bb5ff", 43 | "sha256:00789914be39dffba161cfc5be31b55775de5ba2235fe49aa28c148236c4e06b", 44 | "sha256:028a579fc9aed3af38f4892bdcc7390508adabc30c6af4a6e4f611b0c680e6ac", 45 | "sha256:14491a910663bf9f13ddf2bc8f60562d6bc5315c1f09c704937ef17293fb85b0", 46 | "sha256:1cae98a7054b5c9391eb3249b86e0e99ab1e02bb0cc0575da191aedadbdf4384", 47 | "sha256:2089ed025da3919d2e75a4d963d008330c96751127dd6f73c8dc0c65041b4c26", 48 | "sha256:2d384f4a127a15ba701207f7639d94106693b6cd64173d6c8988e2c25f3ac2b6", 49 | "sha256:337d448e5a725bba2d8293c48d9353fc68d0e9e4088d62a9571def317797522b", 50 | "sha256:399aed636c7d3749bbed55bc907c3288cb43c65c4389964ad5ff849b6370603e", 51 | "sha256:3b911c2dbd4f423b4c4fcca138cadde747abdb20d196c4a48708b8a2d32b16dd", 52 | "sha256:3d311bcc4a41408cf5854f06ef2c5cab88f9fded37a3b95936c9879c1640d4c2", 53 | "sha256:62ae9af2d069ea2698bf536dcfe1e4eed9090211dbaafeeedf5cb6c41b352f66", 54 | "sha256:66e41db66b47d0d8672d8ed2708ba91b2f2524ece3dee48b5dfb36be8c2f21dc", 55 | "sha256:675686925a9fb403edba0114db74e741d8181683dcf216be697d208857e04ca8", 56 | "sha256:7e63cbcf2429a8dbfe48dcc2322d5f2220b77b2e17b7ba023d6166d84655da55", 57 | "sha256:8a6c688fefb4e1cd56feb6c511984a6c4f7ec7d2a1ff31a10254f3c817054ae4", 58 | "sha256:8c0ffc886aea5df6a1762d0019e9cb05f825d0eec1f520c51be9d198701daee5", 59 | "sha256:95cd16d3dee553f882540c1ffe331d085c9e629499ceadfbda4d4fde635f4b7d", 60 | "sha256:99f748a7e71ff382613b4e1acc0ac83bf7ad167fb3802e35e90d9763daba4d78", 61 | "sha256:b8c78301cefcf5fd914aad35d3c04c2b21ce8629b5e4f4e45ae6812e461910fa", 62 | "sha256:c420917b188a5582a56d8b93bdd8e0f6eca08c84ff623a4c16e809152cd35793", 63 | "sha256:c43866529f2f06fe0edc6246eb4faa34f03fe88b64a0a9a942561c8e22f4b71f", 64 | "sha256:cab50b8c2250b46fe738c77dbd25ce017d5e6fb35d3407606e7a4180656a5a6a", 65 | "sha256:cef128cb4d5e0b3493f058f10ce32365972c554572ff821e175dbc6f8ff6924f", 66 | "sha256:cf16e3cf6c0a5fdd9bc10c21687e19d29ad1fe863372b5543deaec1039581a30", 67 | "sha256:e56c744aa6ff427a607763346e4170629caf7e48ead6921745986db3692f987f", 68 | "sha256:e577934fc5f8779c554639376beeaa5657d54349096ef24abe8c74c5d9c117c3", 69 | "sha256:f2b0fa0c01d8a0c7483afd9f31d7ecf2d71760ca24499c8697aeb5ca37dc090c" 70 | ], 71 | "version": "==1.14.0" 72 | }, 73 | "chardet": { 74 | "hashes": [ 75 | "sha256:84ab92ed1c4d4f16916e05906b6b75a6c0fb5db821cc65e70cbd64a3e2a5eaae", 76 | "sha256:fc323ffcaeaed0e0a02bf4d117757b98aed530d9ed4531e3e15460124c106691" 77 | ], 78 | "version": "==3.0.4" 79 | }, 80 | "click": { 81 | "hashes": [ 82 | "sha256:2335065e6395b9e67ca716de5f7526736bfa6ceead690adf616d925bdc622b13", 83 | "sha256:5b94b49521f6456670fdb30cd82a4eca9412788a93fa6dd6df72c94d5a8ff2d7" 84 | ], 85 | "version": "==7.0" 86 | }, 87 | "cython": { 88 | "hashes": [ 89 | "sha256:01d566750e7c08e5f094419f8d1ee90e7fa286d8d77c4569748263ed5f05280a", 90 | "sha256:072cb90e2fe4b5cc27d56de12ec5a00311eee781c2d2e3f7c98a82319103c7ed", 91 | "sha256:0e078e793a9882bf48194b8b5c9b40c75769db1859cd90b210a4d7bf33cda2b1", 92 | "sha256:1a3842be21d1e25b7f3440a0c881ef44161937273ea386c30c0e253e30c63740", 93 | "sha256:1dc973bdea03c65f03f41517e4f0fc2b717d71cfbcf4ec34adac7e5bee71303e", 94 | "sha256:214a53257c100e93e7673e95ab448d287a37626a3902e498025993cc633647ae", 95 | "sha256:30462d61e7e290229a64e1c3682b4cc758ffc441e59cc6ce6fae059a05df305b", 96 | "sha256:34004f60b1e79033b0ca29b9ab53a86c12bcaab56648b82fbe21c007cd73d867", 97 | "sha256:34c888a57f419c63bef63bc0911c5bb407b93ed5d6bdeb1587dca2cd1dd56ad1", 98 | "sha256:3dd0cba13b36ff969232930bd6db08d3da0798f1fac376bd1fa4458f4b55d802", 99 | "sha256:4e5acf3b856a50d0aaf385f06a7b56a128a296322a9740f5f279c96619244308", 100 | "sha256:60d859e1efa5cc80436d58aecd3718ff2e74b987db0518376046adedba97ac30", 101 | "sha256:61e505379497b624d6316dd67ef8100aaadca0451f48f8c6fff8d622281cd121", 102 | "sha256:6f6de0bee19c70cb01e519634f0c35770de623006e4876e649ee4a960a147fec", 103 | "sha256:77ac051b7caf02938a32ea0925f558534ab2a99c0c98c681cc905e3e8cba506e", 104 | "sha256:7e4d74515d92c4e2be7201aaef7a51705bd3d5957df4994ddfe1b252195b5e27", 105 | "sha256:993837bbf0849e3b176e1ef6a50e9b8c2225e895501b85d56f4bb65a67f5ea25", 106 | "sha256:9a5f0cf8b95c0c058e413679a650f70dcc97764ccb2a6d5ccc6b08d44c9b334c", 107 | "sha256:9f2839396d21d5537bc9ff53772d44db39b0efb6bf8b6cac709170483df53a5b", 108 | "sha256:b8ba4b4ee3addc26bc595a51b6240b05a80e254b946d624fff6506439bc323d1", 109 | "sha256:bb6d90180eff72fc5a30099c442b8b0b5a620e84bf03ef32a55e3f7bd543f32e", 110 | "sha256:c3d778304209cc39f8287da22f2180f34d2c2ee46cd55abd82e48178841b37b1", 111 | "sha256:c562bc316040097e21357e783286e5eca056a5b2750e89d9d75f9541c156b6dc", 112 | "sha256:d114f9c0164df8fcd2880e4ba96986d7b0e7218f6984acc4989ff384c5d3d512", 113 | "sha256:d282b030ed5c736e4cdb1713a0c4fad7027f4e3959dc4b8fdb7c75042d83ed1b", 114 | "sha256:d8c73fe0ec57a0e4fdf5d2728b5e18b63980f55f1baf51b6bac6a73e8cbb7186", 115 | "sha256:e5c8f4198e25bc4b0e4a884377e0c0e46ca273993679e3bcc212ef96d4211b83", 116 | "sha256:e7f1dcc0e8c3e18fa2fddca4aecdf71c5651555a8dc9a0cd3a1d164cbce6cb35", 117 | "sha256:ea3b61bff995de49b07331d1081e0056ea29901d3e995aa989073fe2b1be0cb7", 118 | "sha256:ea5f987b4da530822fa797cf2f010193be77ea9e232d07454e3194531edd8e58", 119 | "sha256:f91b16e73eca996f86d1943be3b2c2b679b03e068ed8c82a5506c1e65766e4a6" 120 | ], 121 | "version": "==0.29.15" 122 | }, 123 | "docutils": { 124 | "hashes": [ 125 | "sha256:6c4f696463b79f1fb8ba0c594b63840ebd41f059e92b31957c46b74a4599b6d0", 126 | "sha256:9e4d7ecfc600058e07ba661411a2b7de2fd0fafa17d1a7f7361cd47b1175c827", 127 | "sha256:a2aeea129088da402665e92e0b25b04b073c04b2dce4ab65caaa38b7ce2e1a99" 128 | ], 129 | "version": "==0.15.2" 130 | }, 131 | "fairseq": { 132 | "hashes": [ 133 | "sha256:61206358b79f325ea0b46cfd8c95cdb81bfbcfb43cf12b47d1d5124ce7321d3b" 134 | ], 135 | "index": "pypi", 136 | "version": "==0.9.0" 137 | }, 138 | "filelock": { 139 | "hashes": [ 140 | "sha256:18d82244ee114f543149c66a6e0c14e9c4f8a1044b5cdaadd0f82159d6a6ff59", 141 | "sha256:929b7d63ec5b7d6b71b0fa5ac14e030b3f70b75747cef1b10da9b879fef15836" 142 | ], 143 | "version": "==3.0.12" 144 | }, 145 | "idna": { 146 | "hashes": [ 147 | "sha256:c357b3f628cf53ae2c4c05627ecc484553142ca23264e593d327bcde5e9c3407", 148 | "sha256:ea8b7f6188e6fa117537c3df7da9fc686d485087abf6ac197f9c46432f7e4a3c" 149 | ], 150 | "version": "==2.8" 151 | }, 152 | "jmespath": { 153 | "hashes": [ 154 | "sha256:3720a4b1bd659dd2eecad0666459b9788813e032b83e7ba58578e48254e0a0e6", 155 | "sha256:bde2aef6f44302dfb30320115b17d030798de8c4110e28d5cf6cf91a7a31074c" 156 | ], 157 | "version": "==0.9.4" 158 | }, 159 | "joblib": { 160 | "hashes": [ 161 | "sha256:0630eea4f5664c463f23fbf5dcfc54a2bc6168902719fa8e19daf033022786c8", 162 | "sha256:bdb4fd9b72915ffb49fde2229ce482dd7ae79d842ed8c2b4c932441495af1403" 163 | ], 164 | "version": "==0.14.1" 165 | }, 166 | "nltk": { 167 | "hashes": [ 168 | "sha256:bed45551259aa2101381bbdd5df37d44ca2669c5c3dad72439fa459b29137d94" 169 | ], 170 | "index": "pypi", 171 | "version": "==3.4.5" 172 | }, 173 | "numpy": { 174 | "hashes": [ 175 | "sha256:1786a08236f2c92ae0e70423c45e1e62788ed33028f94ca99c4df03f5be6b3c6", 176 | "sha256:17aa7a81fe7599a10f2b7d95856dc5cf84a4eefa45bc96123cbbc3ebc568994e", 177 | "sha256:20b26aaa5b3da029942cdcce719b363dbe58696ad182aff0e5dcb1687ec946dc", 178 | "sha256:2d75908ab3ced4223ccba595b48e538afa5ecc37405923d1fea6906d7c3a50bc", 179 | "sha256:39d2c685af15d3ce682c99ce5925cc66efc824652e10990d2462dfe9b8918c6a", 180 | "sha256:56bc8ded6fcd9adea90f65377438f9fea8c05fcf7c5ba766bef258d0da1554aa", 181 | "sha256:590355aeade1a2eaba17617c19edccb7db8d78760175256e3cf94590a1a964f3", 182 | "sha256:70a840a26f4e61defa7bdf811d7498a284ced303dfbc35acb7be12a39b2aa121", 183 | "sha256:77c3bfe65d8560487052ad55c6998a04b654c2fbc36d546aef2b2e511e760971", 184 | "sha256:9537eecf179f566fd1c160a2e912ca0b8e02d773af0a7a1120ad4f7507cd0d26", 185 | "sha256:9acdf933c1fd263c513a2df3dceecea6f3ff4419d80bf238510976bf9bcb26cd", 186 | "sha256:ae0975f42ab1f28364dcda3dde3cf6c1ddab3e1d4b2909da0cb0191fa9ca0480", 187 | "sha256:b3af02ecc999c8003e538e60c89a2b37646b39b688d4e44d7373e11c2debabec", 188 | "sha256:b6ff59cee96b454516e47e7721098e6ceebef435e3e21ac2d6c3b8b02628eb77", 189 | "sha256:b765ed3930b92812aa698a455847141869ef755a87e099fddd4ccf9d81fffb57", 190 | "sha256:c98c5ffd7d41611407a1103ae11c8b634ad6a43606eca3e2a5a269e5d6e8eb07", 191 | "sha256:cf7eb6b1025d3e169989416b1adcd676624c2dbed9e3bcb7137f51bfc8cc2572", 192 | "sha256:d92350c22b150c1cae7ebb0ee8b5670cc84848f6359cf6b5d8f86617098a9b73", 193 | "sha256:e422c3152921cece8b6a2fb6b0b4d73b6579bd20ae075e7d15143e711f3ca2ca", 194 | "sha256:e840f552a509e3380b0f0ec977e8124d0dc34dc0e68289ca28f4d7c1d0d79474", 195 | "sha256:f3d0a94ad151870978fb93538e95411c83899c9dc63e6fb65542f769568ecfa5" 196 | ], 197 | "version": "==1.18.1" 198 | }, 199 | "portalocker": { 200 | "hashes": [ 201 | "sha256:6f57aabb25ba176462dc7c63b86c42ad6a9b5bd3d679a9d776d0536bfb803d54", 202 | "sha256:dac62e53e5670cb40d2ee4cdc785e6b829665932c3ee75307ad677cf5f7d2e9f" 203 | ], 204 | "version": "==1.5.2" 205 | }, 206 | "protobuf": { 207 | "hashes": [ 208 | "sha256:0bae429443cc4748be2aadfdaf9633297cfaeb24a9a02d0ab15849175ce90fab", 209 | "sha256:24e3b6ad259544d717902777b33966a1a069208c885576254c112663e6a5bb0f", 210 | "sha256:310a7aca6e7f257510d0c750364774034272538d51796ca31d42c3925d12a52a", 211 | "sha256:52e586072612c1eec18e1174f8e3bb19d08f075fc2e3f91d3b16c919078469d0", 212 | "sha256:73152776dc75f335c476d11d52ec6f0f6925774802cd48d6189f4d5d7fe753f4", 213 | "sha256:7774bbbaac81d3ba86de646c39f154afc8156717972bf0450c9dbfa1dc8dbea2", 214 | "sha256:82d7ac987715d8d1eb4068bf997f3053468e0ce0287e2729c30601feb6602fee", 215 | "sha256:8eb9c93798b904f141d9de36a0ba9f9b73cc382869e67c9e642c0aba53b0fc07", 216 | "sha256:adf0e4d57b33881d0c63bb11e7f9038f98ee0c3e334c221f0858f826e8fb0151", 217 | "sha256:c40973a0aee65422d8cb4e7d7cbded95dfeee0199caab54d5ab25b63bce8135a", 218 | "sha256:c77c974d1dadf246d789f6dad1c24426137c9091e930dbf50e0a29c1fcf00b1f", 219 | "sha256:dd9aa4401c36785ea1b6fff0552c674bdd1b641319cb07ed1fe2392388e9b0d7", 220 | "sha256:e11df1ac6905e81b815ab6fd518e79be0a58b5dc427a2cf7208980f30694b956", 221 | "sha256:e2f8a75261c26b2f5f3442b0525d50fd79a71aeca04b5ec270fc123536188306", 222 | "sha256:e512b7f3a4dd780f59f1bf22c302740e27b10b5c97e858a6061772668cd6f961", 223 | "sha256:ef2c2e56aaf9ee914d3dccc3408d42661aaf7d9bb78eaa8f17b2e6282f214481", 224 | "sha256:fac513a9dc2a74b99abd2e17109b53945e364649ca03d9f7a0b96aa8d1807d0a", 225 | "sha256:fdfb6ad138dbbf92b5dbea3576d7c8ba7463173f7d2cb0ca1bd336ec88ddbd80" 226 | ], 227 | "version": "==3.11.3" 228 | }, 229 | "pycparser": { 230 | "hashes": [ 231 | "sha256:a988718abfad80b6b157acce7bf130a30876d27603738ac39f140993246b25b3" 232 | ], 233 | "version": "==2.19" 234 | }, 235 | "python-dateutil": { 236 | "hashes": [ 237 | "sha256:73ebfe9dbf22e832286dafa60473e4cd239f8592f699aa5adaf10050e6e1823c", 238 | "sha256:75bb3f31ea686f1197762692a9ee6a7550b59fc6ca3a1f4b5d7e32fb98e2da2a" 239 | ], 240 | "version": "==2.8.1" 241 | }, 242 | "regex": { 243 | "hashes": [ 244 | "sha256:07b39bf943d3d2fe63d46281d8504f8df0ff3fe4c57e13d1656737950e53e525", 245 | "sha256:0932941cdfb3afcbc26cc3bcf7c3f3d73d5a9b9c56955d432dbf8bbc147d4c5b", 246 | "sha256:0e182d2f097ea8549a249040922fa2b92ae28be4be4895933e369a525ba36576", 247 | "sha256:10671601ee06cf4dc1bc0b4805309040bb34c9af423c12c379c83d7895622bb5", 248 | "sha256:23e2c2c0ff50f44877f64780b815b8fd2e003cda9ce817a7fd00dea5600c84a0", 249 | "sha256:26ff99c980f53b3191d8931b199b29d6787c059f2e029b2b0c694343b1708c35", 250 | "sha256:27429b8d74ba683484a06b260b7bb00f312e7c757792628ea251afdbf1434003", 251 | "sha256:3e77409b678b21a056415da3a56abfd7c3ad03da71f3051bbcdb68cf44d3c34d", 252 | "sha256:4e8f02d3d72ca94efc8396f8036c0d3bcc812aefc28ec70f35bb888c74a25161", 253 | "sha256:4eae742636aec40cf7ab98171ab9400393360b97e8f9da67b1867a9ee0889b26", 254 | "sha256:6a6ae17bf8f2d82d1e8858a47757ce389b880083c4ff2498dba17c56e6c103b9", 255 | "sha256:6a6ba91b94427cd49cd27764679024b14a96874e0dc638ae6bdd4b1a3ce97be1", 256 | "sha256:7bcd322935377abcc79bfe5b63c44abd0b29387f267791d566bbb566edfdd146", 257 | "sha256:98b8ed7bb2155e2cbb8b76f627b2fd12cf4b22ab6e14873e8641f266e0fb6d8f", 258 | "sha256:bd25bb7980917e4e70ccccd7e3b5740614f1c408a642c245019cff9d7d1b6149", 259 | "sha256:d0f424328f9822b0323b3b6f2e4b9c90960b24743d220763c7f07071e0778351", 260 | "sha256:d58e4606da2a41659c84baeb3cfa2e4c87a74cec89a1e7c56bee4b956f9d7461", 261 | "sha256:e3cd21cc2840ca67de0bbe4071f79f031c81418deb544ceda93ad75ca1ee9f7b", 262 | "sha256:e6c02171d62ed6972ca8631f6f34fa3281d51db8b326ee397b9c83093a6b7242", 263 | "sha256:e7c7661f7276507bce416eaae22040fd91ca471b5b33c13f8ff21137ed6f248c", 264 | "sha256:ecc6de77df3ef68fee966bb8cb4e067e84d4d1f397d0ef6fce46913663540d77" 265 | ], 266 | "version": "==2020.1.8" 267 | }, 268 | "requests": { 269 | "hashes": [ 270 | "sha256:11e007a8a2aa0323f5a921e9e6a2d7e4e67d9877e85773fba9ba6419025cbeb4", 271 | "sha256:9cf5292fcd0f598c671cfc1e0d7d1a7f13bb8085e9a590f48c010551dc6c4b31" 272 | ], 273 | "version": "==2.22.0" 274 | }, 275 | "s3transfer": { 276 | "hashes": [ 277 | "sha256:2482b4259524933a022d59da830f51bd746db62f047d6eb213f2f8855dcb8a13", 278 | "sha256:921a37e2aefc64145e7b73d50c71bb4f26f46e4c9f414dc648c6245ff92cf7db" 279 | ], 280 | "version": "==0.3.3" 281 | }, 282 | "sacrebleu": { 283 | "hashes": [ 284 | "sha256:0a4b9e53b742d95fcd2f32e4aaa42aadcf94121d998ca19c66c05e7037d1eeee", 285 | "sha256:6ca2418e4474120537155d77695aa5303b2cf1a4227aec32074b4b4d8c107e2b" 286 | ], 287 | "version": "==1.4.3" 288 | }, 289 | "sacremoses": { 290 | "hashes": [ 291 | "sha256:34dcfaacf9fa34a6353424431f0e4fcc60e8ebb27ffee320d57396690b712a3b" 292 | ], 293 | "version": "==0.0.38" 294 | }, 295 | "sentencepiece": { 296 | "hashes": [ 297 | "sha256:0a98ec863e541304df23a37787033001b62cb089f4ed9307911791d7e210c0b1", 298 | "sha256:0ad221ea7914d65f57d3e3af7ae48852b5035166493312b5025367585b43ac41", 299 | "sha256:0f72c4151791de7242e7184a9b7ef12503cef42e9a5a0c1b3510f2c68874e810", 300 | "sha256:22fe7d92203fadbb6a0dc7d767430d37cdf3a9da4a0f2c5302c7bf294f7bfd8f", 301 | "sha256:2a72d4c3d0dbb1e099ddd2dc6b724376d3d7ff77ba494756b894254485bec4b4", 302 | "sha256:30791ce80a557339e17f1290c68dccd3f661612fdc6b689b4e4f21d805b64952", 303 | "sha256:39904713b81869db10de53fe8b3719f35acf77f49351f28ceaad0d360f2f6305", 304 | "sha256:3d5a2163deea95271ce8e38dfd0c3c924bea92aaf63bdda69b5458628dacc8bd", 305 | "sha256:3f3dee204635c33ca2e450e17ee9e0e92f114a47f853c2e44e7f0f0ab444d8d0", 306 | "sha256:4dcea889af53f669dc39d1ca870c37c52bb3110fcd96a2e7330d288400958281", 307 | "sha256:4e36a92558ad9e2f91b311c5bcea90b7a63c567c0e7e20da44d6a6f01031b57e", 308 | "sha256:576bf820eb963e6f275d4005ed5334fbed59eb54bed508e5cae6d16c7179710f", 309 | "sha256:6d2bbdbf296d96304c6345675749981bb17dcf2a7163d2fec38f70a704b75669", 310 | "sha256:76fdce3e7e614e24b35167c22c9c388e0c843be53d99afb5e1f25f6bfe04e228", 311 | "sha256:97b8ee26892d236b2620af8ddae11713fbbb2dae9adf4ad5e988e5a82ce50a90", 312 | "sha256:b3b6fe02af7ea4823c19e0d8efddc10ff59b8449bc1ae9921f9dd8ad33802c33", 313 | "sha256:b416f514fff8785a1113e6c07f696e52967fc979d6cd946e454a8660cca72ef8", 314 | "sha256:bf0bad6ba01ace3e938ffdf05c42b24d8fd3740487ba865504795a0bb9b1f2b3", 315 | "sha256:c00387970360ec0369b5e7c75f3977fb14330df75465200c13bafb7a632d2e6b", 316 | "sha256:c23fb7bb949934998375d41dbe54d4df1778a3b9dcb24bc2ddaaa595819ed1da", 317 | "sha256:dfdcf48678656592b11d11e2102c52c38122e309f7a1a5272305d397cfe21ce0", 318 | "sha256:fb69c5ba325b900cf2b91f517b46eec8ce3c50995955e293b46681d832021c0e", 319 | "sha256:fba83bef6c7a7899cd811d9b1195e748722eb2a9737c3f3890160f0e01e3ad08", 320 | "sha256:fe115aee209197839b2a357e34523e23768d553e8a69eac2b558499ccda56f80", 321 | "sha256:ffdf51218a3d7e0dad79bdffd21ad15a23cbb9c572d2300c3295c6efc6c2357e" 322 | ], 323 | "version": "==0.1.85" 324 | }, 325 | "six": { 326 | "hashes": [ 327 | "sha256:236bdbdce46e6e6a3d61a337c0f8b763ca1e8717c03b369e87a7ec7ce1319c0a", 328 | "sha256:8f3cd2e254d8f793e7f3d6d9df77b92252b52637291d0f0da013c76ea2724b6c" 329 | ], 330 | "version": "==1.14.0" 331 | }, 332 | "tensorboardx": { 333 | "hashes": [ 334 | "sha256:835d85db0aef2c6768f07c35e69a74e3dcb122d6afceaf2b8504d7d16c7209a5", 335 | "sha256:8e336103a66b1b97a663057cc13d1db4419f7a12f332b8364386dbf8635031d9" 336 | ], 337 | "index": "pypi", 338 | "version": "==2.0" 339 | }, 340 | "tokenizers": { 341 | "hashes": [ 342 | "sha256:08e08027564194e16aa647d180837d292b2c9c5ef772fed15badcc88e2474a8f", 343 | "sha256:1385deb90ec76cbee59b50298c8d2dc5909cda080a706d263e4f81c8474ba53d", 344 | "sha256:3ebe7f0bff9e30ab15dec4846c54c9085e02e47711eb7253d36a6777eadc2948", 345 | "sha256:4b7c42b644a1c5705a59b14c53c84b50b8f0b9c0f5f952a8a91a350403e7615f", 346 | "sha256:503418d5195ae1a483ced0257a0d2f4583456aa49bdfe0014c8605babf244ac5", 347 | "sha256:5ba2c6eaac2e8e0a2d839c0420d16707496b5e93b1454029d19487c5dd8c9b62", 348 | "sha256:7de28e0bebd0904b990560a1f14c3c5600da29be287e544bdf19e6970ea11d54", 349 | "sha256:82e8c3b13a66410358753b7e48776749935851cdb49a3d0c139a046178ec4f49", 350 | "sha256:a4d1ef6ee9221e7f9c1a4c122a15e93f0961977aaae2813b7b405c778728dcee", 351 | "sha256:a66ff87c32a221a126904d7ec972e7c8e0033486b24f8777c0f056aedbc09011", 352 | "sha256:a7f5e43674dd5b012ad29b79a32f0652ecfff3a3ed1c04f9073038c4bf63829d", 353 | "sha256:bb44fa1b268d1bbdf2bb14cd82da6ffb93d19638157c77f9e17e246928f0233f", 354 | "sha256:ce75c75430a3dfc33a10c90c1607d44b172c6d2ea19d586692b6cc9ba6ec5e14" 355 | ], 356 | "index": "pypi", 357 | "version": "==0.0.11" 358 | }, 359 | "torch": { 360 | "hashes": [ 361 | "sha256:271d4d1e44df6ed57c530f8849b028447c62b8a19b8e8740dd9baa56e7f682c1", 362 | "sha256:30ce089475b287a37d6fbb8d71853e672edaf66699e3dd2eb19be6ce6296732a", 363 | "sha256:405b9eb40e44037d2525b3ddb5bc4c66b519cd742bff249d4207d23f83e88ea5", 364 | "sha256:504915c6bc6051ba6a4c2a43c446463dff04411e352f1e26fe13debeae431778", 365 | "sha256:54d06a0e8ee85e5a437c24f4af9f4196c819294c23ffb5914e177756f55f1829", 366 | "sha256:6f2fd9eb8c7eaf38a982ab266dbbfba0f29fb643bc74e677d045d6f2595e4692", 367 | "sha256:8856f334aa9ecb742c1504bd2563d0ffb8dceb97149c8d72a04afa357f667dbc", 368 | "sha256:8fff03bf7b474c16e4b50da65ea14200cc64553b67b9b2307f9dc7e8c69b9d28", 369 | "sha256:9a1b1db73d8dcfd94b2eee24b939368742aa85f1217c55b8f5681e76c581e99a", 370 | "sha256:bb1e87063661414e1149bef2e3a2499ce0b5060290799d7e26bc5578037075ba", 371 | "sha256:d7b34a78f021935ad727a3bede56a8a8d4fda0b0272314a04c5f6890bbe7bb29" 372 | ], 373 | "version": "==1.4.0" 374 | }, 375 | "tqdm": { 376 | "hashes": [ 377 | "sha256:251ee8440dbda126b8dfa8a7c028eb3f13704898caaef7caa699b35e119301e2", 378 | "sha256:fe231261cfcbc6f4a99165455f8f6b9ef4e1032a6e29bccf168b4bf42012f09c" 379 | ], 380 | "version": "==4.42.1" 381 | }, 382 | "transformers": { 383 | "hashes": [ 384 | "sha256:995393b9ce764044287847792476101cd4e8377c756874df1116b221980749ad", 385 | "sha256:c5e765d3fd1a654e27b0f675e1cc1e210139429cc00d1b816c3c21502ace941e" 386 | ], 387 | "index": "pypi", 388 | "version": "==2.4.1" 389 | }, 390 | "typing": { 391 | "hashes": [ 392 | "sha256:91dfe6f3f706ee8cc32d38edbbf304e9b7583fb37108fef38229617f8b3eba23", 393 | "sha256:c8cabb5ab8945cd2f54917be357d134db9cc1eb039e59d1606dc1e60cb1d9d36", 394 | "sha256:f38d83c5a7a7086543a0f649564d661859c5146a85775ab90c0d2f93ffaa9714" 395 | ], 396 | "version": "==3.7.4.1" 397 | }, 398 | "urllib3": { 399 | "hashes": [ 400 | "sha256:2f3db8b19923a873b3e5256dc9c2dedfa883e33d87c690d9c7913e1f40673cdc", 401 | "sha256:87716c2d2a7121198ebcb7ce7cccf6ce5e9ba539041cfbaeecfb641dc0bf6acc" 402 | ], 403 | "markers": "python_version != '3.4'", 404 | "version": "==1.25.8" 405 | } 406 | }, 407 | "develop": {} 408 | } 409 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | RobBERT: A Dutch RoBERTa-based Language Model 3 |

4 | 5 | 6 | ![Python](https://img.shields.io/badge/python-v3.6+-blue.svg) 7 | ![Contributions welcome](https://img.shields.io/badge/contributions-welcome-orange.svg) 8 | ![GitHub](https://img.shields.io/github/license/ipieter/RobBERT) 9 | [![🤗 HuggingFace](https://img.shields.io/badge/🤗%20HuggingFace-%20pdelobelle%2Frobbert--v2--dutch--base-orange)](https://huggingface.co/pdelobelle/robbert-v2-dutch-base) 10 | 11 | 12 | # RobBERT: Dutch RoBERTa-based Language Model. 13 | 14 | RobBERT is the state-of-the-art Dutch BERT model. 15 | It is a large pre-trained general Dutch language model that can be fine-tuned on a given dataset to perform any text classification, regression or token-tagging task. 16 | As such, it has been successfully used by many [researchers](https://scholar.google.com/scholar?oi=bibs&hl=en&cites=7180110604335112086) and [practitioners](https://huggingface.co/models?search=robbert) for achieving state-of-the-art performance for a wide range of Dutch natural language processing tasks, including: 17 | 18 | - [Emotion detection](https://www.aclweb.org/anthology/2021.wassa-1.27/) 19 | - Sentiment analysis ([book reviews](https://arxiv.org/pdf/2001.06286.pdf), [news articles](https://biblio.ugent.be/publication/8704637/file/8704638.pdf)*) 20 | - [Coreference resolution](https://arxiv.org/pdf/2001.06286.pdf) 21 | - Named entity recognition ([CoNLL](https://arxiv.org/pdf/2001.06286.pdf), [job titles](https://arxiv.org/pdf/2004.02814.pdf)*, [SoNaR](https://github.com/proycon/deepfrog)) 22 | - Part-of-speech tagging ([Small UD Lassy](https://arxiv.org/pdf/2001.06286.pdf), [CGN](https://github.com/proycon/deepfrog)) 23 | - [Zero-shot word prediction](https://arxiv.org/pdf/2001.06286.pdf) 24 | - [Humor detection](https://arxiv.org/pdf/2010.13652.pdf) 25 | - [Cyberbullying detection](https://www.cambridge.org/core/journals/natural-language-engineering/article/abs/automatic-classification-of-participant-roles-in-cyberbullying-can-we-detect-victims-bullies-and-bystanders-in-social-media-text/A2079C2C738C29428E666810B8903342) 26 | - [Correcting dt-spelling mistakes](https://gitlab.com/spelfouten/dutch-simpletransformers/)* 27 | 28 | and also achieved outstanding, near-sota results for: 29 | 30 | - [Natural language inference](https://arxiv.org/pdf/2101.05716.pdf)* 31 | - [Review classification](https://medium.com/broadhorizon-cmotions/nlp-with-r-part-5-state-of-the-art-in-nlp-transformers-bert-3449e3cd7494)* 32 | 33 | \* *Note that several evaluations use RobBERT-v1, and that the second and improved RobBERT-v2 outperforms this first model on everything we tested* 34 | 35 | *(Also note that this list is not exhaustive. If you used RobBERT for your application, we are happy to know about it! Send us a mail, or add it yourself to this list by sending a pull request with the edit!)* 36 | 37 | To use the RobBERT model using [HuggingFace transformers](https://huggingface.co/transformers/), use the name [`pdelobelle/robbert-v2-dutch-base`](https://huggingface.co/pdelobelle/robbert-v2-dutch-base). 38 | 39 | More in-depth information about RobBERT can be found in our [blog post](https://pieter.ai/robbert/) and in [our paper](https://arxiv.org/abs/2001.06286). 40 | 41 | ## Table of contents 42 | * [How To Use](#how-to-use) 43 | + [Using Huggingface Transformers (easiest)](#using-huggingface-transformers-easiest) 44 | + [Using Fairseq (harder)](#using-fairseq-harder) 45 | * [Technical Details From The Paper](#technical-details-from-the-paper) 46 | + [Our Performance Evaluation Results](#our-performance-evaluation-results) 47 | + [Sentiment analysis](#sentiment-analysis) 48 | + [Die/Dat (coreference resolution)](#diedat-coreference-resolution) 49 | - [Finetuning on whole dataset](#finetuning-on-whole-dataset) 50 | - [Finetuning on 10K examples](#finetuning-on-10k-examples) 51 | - [Using zero-shot word masking task](#using-zero-shot-word-masking-task) 52 | + [Part-of-Speech Tagging.](#part-of-speech-tagging) 53 | + [Named Entity Recognition](#named-entity-recognition) 54 | * [Pre-Training Procedure Details](#pre-training-procedure-details) 55 | * [Investigating Limitations and Bias](#investigating-limitations-and-bias) 56 | * [How to Replicate Our Paper Experiments](#how-to-replicate-our-paper-experiments) 57 | + [Classification](#classification) 58 | - [Sentiment analysis using the Dutch Book Review Dataset](#sentiment-analysis-using-the-dutch-book-review-dataset) 59 | - [Predicting the Dutch pronouns _die_ and _dat_](#predicting-the-dutch-pronouns-die-and-dat) 60 | * [Name Origin of RobBERT](#name-origin-of-robbert) 61 | * [Credits and citation](#credits-and-citation) 62 | 63 | 64 | 65 | ## How To Use 66 | 67 | RobBERT uses the [RoBERTa](https://arxiv.org/abs/1907.11692) architecture and pre-training but with a Dutch tokenizer and training data. RoBERTa is the robustly optimized English BERT model, making it even more powerful than the original BERT model. Given this same architecture, RobBERT can easily be finetuned and inferenced using [code to finetune RoBERTa](https://huggingface.co/transformers/model_doc/roberta.html) models and most code used for BERT models, e.g. as provided by [HuggingFace Transformers](https://huggingface.co/transformers/) library. 68 | 69 | RobBERT can easily be used in two different ways, namely either using [Fairseq RoBERTa](https://github.com/pytorch/fairseq/tree/master/examples/roberta) code or using [HuggingFace Transformers](https://github.com/huggingface/transformers) 70 | 71 | By default, RobBERT has the masked language model head used in training. This can be used as a zero-shot way to fill masks in sentences. It can be tested out for free on [RobBERT's Hosted infererence API of Huggingface](https://huggingface.co/pdelobelle/robbert-v2-dutch-base?text=De+hoofdstad+van+Belgi%C3%AB+is+%3Cmask%3E.). You can also create a new prediction head for your own task by using any of HuggingFace's [RoBERTa-runners](https://huggingface.co/transformers/v2.7.0/examples.html#language-model-training), [their fine-tuning notebooks](https://huggingface.co/transformers/v4.1.1/notebooks.html) by changing the model name to `pdelobelle/robbert-v2-dutch-base`, or use the original fairseq [RoBERTa](https://github.com/pytorch/fairseq/tree/master/examples/roberta) training regimes. 72 | 73 | ### Using Huggingface Transformers (easiest) 74 | 75 | You can easily [download RobBERT v2](https://huggingface.co/pdelobelle/robbert-v2-dutch-base) using [🤗 Transformers](https://github.com/huggingface/transformers). 76 | Use the following code to download the base model and finetune it yourself, or use one of our finetuned models (documented on [our project site](https://pieter.ai/robbert/)). 77 | 78 | ```python 79 | from transformers import RobertaTokenizer, RobertaForSequenceClassification 80 | tokenizer = RobertaTokenizer.from_pretrained("pdelobelle/robbert-v2-dutch-base") 81 | model = RobertaForSequenceClassification.from_pretrained("pdelobelle/robbert-v2-dutch-base") 82 | ``` 83 | 84 | Starting with `transformers v2.4.0` (or installing from source), you can use AutoTokenizer and AutoModel. 85 | You can then use most of [HuggingFace's BERT-based notebooks](https://huggingface.co/transformers/v4.1.1/notebooks.html) for finetuning RobBERT on your type of Dutch language dataset. 86 | 87 | ### Using Fairseq (harder) 88 | 89 | Alternatively, you can also use RobBERT using the [RoBERTa architecture code]((https://github.com/pytorch/fairseq/tree/master/examples/roberta)). 90 | You can download RobBERT v2's Fairseq model here: [(RobBERT-base, 1.5 GB)](https://github.com/iPieter/BERDT/releases/download/v1.0/RobBERT-base.pt). 91 | Using RobBERT's `model.pt`, this method allows you to use all other functionalities of [RoBERTa](https://github.com/pytorch/fairseq/tree/master/examples/roberta). 92 | 93 | 94 | ## Technical Details From The Paper 95 | 96 | 97 | ### Our Performance Evaluation Results 98 | 99 | All experiments are described in more detail in our [paper](https://arxiv.org/abs/2001.06286), with the code in [our GitHub repository](https://github.com/iPieter/RobBERT). 100 | 101 | ### Sentiment analysis 102 | Predicting whether a review is positive or negative using the [Dutch Book Reviews Dataset](https://github.com/benjaminvdb/DBRD). 103 | 104 | | Model | Accuracy [%] | 105 | |-------------------|--------------------------| 106 | | ULMFiT | 93.8 | 107 | | BERTje | 93.0 | 108 | | RobBERT v2 | **95.1** | 109 | 110 | ### Die/Dat (coreference resolution) 111 | 112 | We measured how well the models are able to do coreference resolution by predicting whether "die" or "dat" should be filled into a sentence. 113 | For this, we used the [EuroParl corpus](https://www.statmt.org/europarl/). 114 | 115 | #### Finetuning on whole dataset 116 | 117 | | Model | Accuracy [%] | F1 [%] | 118 | |-------------------|--------------------------|--------------| 119 | | [Baseline](https://arxiv.org/abs/2001.02943) (LSTM) | | 75.03 | 120 | | mBERT | 98.285 | 98.033 | 121 | | BERTje | 98.268 | 98.014 | 122 | | RobBERT v2 | **99.232** | **99.121** | 123 | 124 | #### Finetuning on 10K examples 125 | 126 | We also measured the performance using only 10K training examples. 127 | This experiment clearly illustrates that RobBERT outperforms other models when there is little data available. 128 | 129 | | Model | Accuracy [%] | F1 [%] | 130 | |-------------------|--------------------------|--------------| 131 | | mBERT | 92.157 | 90.898 | 132 | | BERTje | 93.096 | 91.279 | 133 | | RobBERT v2 | **97.816** | **97.514** | 134 | 135 | #### Using zero-shot word masking task 136 | 137 | Since BERT models are pre-trained using the word masking task, we can use this to predict whether "die" or "dat" is more likely. 138 | This experiment shows that RobBERT has internalised more information about Dutch than other models. 139 | 140 | | Model | Accuracy [%] | 141 | |-------------------|--------------------------| 142 | | ZeroR | 66.70 | 143 | | mBERT | 90.21 | 144 | | BERTje | 94.94 | 145 | | RobBERT v2 | **98.75** | 146 | 147 | ### Part-of-Speech Tagging. 148 | 149 | Using the [Lassy UD dataset](https://universaldependencies.org/treebanks/nl_lassysmall/index.html). 150 | 151 | 152 | | Model | Accuracy [%] | 153 | |-------------------|--------------------------| 154 | | Frog | 91.7 | 155 | | mBERT | **96.5** | 156 | | BERTje | 96.3 | 157 | | RobBERT v2 | 96.4 | 158 | 159 | Interestingly, we found that when dealing with **small data sets**, RobBERT v2 **significantly outperforms** other models. 160 | 161 |

162 | RobBERT's performance on smaller datasets 163 |

164 | 165 | ### Named Entity Recognition 166 | 167 | Using the [CoNLL 2002 evaluation script](https://www.clips.uantwerpen.be/conll2002/ner/). 168 | 169 | 170 | | Model | Accuracy [%] | 171 | |-------------------|--------------------------| 172 | | Frog | 57.31 | 173 | | mBERT | **90.94** | 174 | | BERT-NL | 89.7 | 175 | | BERTje | 88.3 | 176 | | RobBERT v2 | 89.08 | 177 | 178 | 179 | ## Pre-Training Procedure Details 180 | 181 | We pre-trained RobBERT using the RoBERTa training regime. 182 | We pre-trained our model on the Dutch section of the [OSCAR corpus](https://oscar-corpus.com/), a large multilingual corpus which was obtained by language classification in the Common Crawl corpus. 183 | This Dutch corpus is 39GB large, with 6.6 billion words spread over 126 million lines of text, where each line could contain multiple sentences, thus using more data than concurrently developed Dutch BERT models. 184 | 185 | 186 | RobBERT shares its architecture with [RoBERTa's base model](https://github.com/pytorch/fairseq/tree/master/examples/roberta), which itself is a replication and improvement over BERT. 187 | Like BERT, it's architecture consists of 12 self-attention layers with 12 heads with 117M trainable parameters. 188 | One difference with the original BERT model is due to the different pre-training task specified by RoBERTa, using only the MLM task and not the NSP task. 189 | During pre-training, it thus only predicts which words are masked in certain positions of given sentences. 190 | The training process uses the Adam optimizer with polynomial decay of the learning rate l_r=10^-6 and a ramp-up period of 1000 iterations, with hyperparameters beta_1=0.9 191 | and RoBERTa's default beta_2=0.98. 192 | Additionally, a weight decay of 0.1 and a small dropout of 0.1 helps prevent the model from overfitting. 193 | 194 | 195 | RobBERT was trained on a computing cluster with 4 Nvidia P100 GPUs per node, where the number of nodes was dynamically adjusted while keeping a fixed batch size of 8192 sentences. 196 | At most 20 nodes were used (i.e. 80 GPUs), and the median was 5 nodes. 197 | By using gradient accumulation, the batch size could be set independently of the number of GPUs available, in order to maximally utilize the cluster. 198 | Using the [Fairseq library](https://github.com/pytorch/fairseq/tree/master/examples/roberta), the model trained for two epochs, which equals over 16k batches in total, which took about three days on the computing cluster. 199 | In between training jobs on the computing cluster, 2 Nvidia 1080 Ti's also covered some parameter updates for RobBERT v2. 200 | 201 | 202 | ## Investigating Limitations and Bias 203 | 204 | In the [RobBERT paper](https://arxiv.org/abs/2001.06286), we also investigated potential sources of bias in RobBERT. 205 | 206 | We found that the zeroshot model estimates the probability of *hij* (he) to be higher than *zij* (she) for most occupations in bleached template sentences, regardless of their actual job gender ratio in reality. 207 | 208 |

209 | RobBERT's performance on smaller datasets 210 |

211 | 212 | By augmenting the DBRB Dutch Book sentiment analysis dataset with the stated gender of the author of the review, we found that highly positive reviews written by women were generally more accurately detected by RobBERT as being positive than those written by men. 213 | 214 |

215 | RobBERT's performance on smaller datasets 216 |

217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | ## How to Replicate Our Paper Experiments 225 | 226 | You can replicate the experiments done in our paper by following the following steps. 227 | You can install the required dependencies either the requirements.txt or pipenv: 228 | - Installing the dependencies from the requirements.txt file using `pip install -r requirements.txt` 229 | - OR install using [Pipenv](https://pipenv.readthedocs.io/en/latest/) *(install by running `pip install pipenv` in your terminal)* by running `pipenv install`. 230 | 231 | 232 | ### Classification 233 | In this section we describe how to use the scripts we provide to fine-tune models, which should be general enough to reuse for other desired textual classification tasks. 234 | 235 | #### Sentiment analysis using the Dutch Book Review Dataset 236 | 237 | - Download the Dutch book review dataset from [https://github.com/benjaminvdb/DBRD](https://github.com/benjaminvdb/DBRD), and save it to `data/raw/DBRD` 238 | - Run `src/preprocess_dbrd.py` to prepare the dataset. 239 | - To not be blind during training, we recommend to keep aside a small evaluation set from the training set. For this run `src/split_dbrd_training.sh`. 240 | - Follow the notebook `notebooks/finetune_dbrd.ipynb` to finetune the model. 241 | 242 | #### Predicting the Dutch pronouns _die_ and _dat_ 243 | We fine-tune our model on the Dutch [Europarl corpus](http://www.statmt.org/europarl/). You can download it first with: 244 | 245 | ``` 246 | cd data\raw\europarl\ 247 | wget -N 'http://www.statmt.org/europarl/v7/nl-en.tgz' 248 | tar zxvf nl-en.tgz 249 | ``` 250 | As a sanity check, now you should have the following files in your `data/raw/europarl` folder: 251 | 252 | ``` 253 | europarl-v7.nl-en.en 254 | europarl-v7.nl-en.nl 255 | nl-en.tgz 256 | ``` 257 | 258 | Then you can run the preprocessing with the following script, which fill first process the Europarl corpus to remove sentences without any _die_ or _dat_. 259 | Afterwards, it will flip the pronoun and join both sentences together with a `` token. 260 | 261 | ``` 262 | python src/preprocess_diedat.py 263 | . src/preprocess_diedat.sh 264 | ``` 265 | 266 | note: You can monitor the progress of the first preprocessing step with `watch -n 2 wc -l data/europarl-v7.nl-en.nl.sentences`. This will take a while, but it's certainly not needed to use all inputs. This is after all why you want to use a pre-trained language model. You can terminate the python script at any time and the second step will only use those._ 267 | 268 | ## Name Origin of RobBERT 269 | 270 | Most BERT-like models have the word *BERT* in their name (e.g. [RoBERTa](https://huggingface.co/transformers/model_doc/roberta.html), [ALBERT](https://arxiv.org/abs/1909.11942), [CamemBERT](https://camembert-model.fr/), and [many, many others](https://huggingface.co/models?search=bert)). 271 | As such, we queried our newly trained model using its masked language model to name itself *\bert* using [all](https://huggingface.co/pdelobelle/robbert-v2-dutch-base?text=Mijn+naam+is+%3Cmask%3Ebert.) [kinds](https://huggingface.co/pdelobelle/robbert-v2-dutch-base?text=Hallo%2C+ik+ben+%3Cmask%3Ebert.) [of](https://huggingface.co/pdelobelle/robbert-v2-dutch-base?text=Leuk+je+te+ontmoeten%2C+ik+heet+%3Cmask%3Ebert.) [prompts](https://huggingface.co/pdelobelle/robbert-v2-dutch-base?text=Niemand+weet%2C+niemand+weet%2C+dat+ik+%3Cmask%3Ebert+heet.), and it consistently called itself RobBERT. 272 | We thought it was really quite fitting, given that RobBERT is a [*very* Dutch name](https://en.wikipedia.org/wiki/Robbert) *(and thus clearly a Dutch language model)*, and additionally has a high similarity to its root architecture, namely [RoBERTa](https://huggingface.co/transformers/model_doc/roberta.html). 273 | 274 | Since *"rob"* is a Dutch words to denote a seal, we decided to draw a seal and dress it up like [Bert from Sesame Street](https://muppet.fandom.com/wiki/Bert) for the [RobBERT logo](https://github.com/iPieter/RobBERT/blob/master/res/robbert_logo.png). 275 | 276 | ## Credits and citation 277 | 278 | This project is created by [Pieter Delobelle](https://github.com/iPieter), [Thomas Winters](https://github.com/twinters) and Bettina Berendt. 279 | 280 | We are grateful to Liesbeth Allein, for her work on die-dat disambiguation, Huggingface for their transformer package, Facebook for their Fairseq package and all other people whose work we could use. 281 | 282 | We release our models and this code under MIT. 283 | 284 | If you would like to cite our paper or model, you can use the following BibTeX code: 285 | 286 | ``` 287 | @inproceedings{delobelle2020robbert, 288 | title = "{R}ob{BERT}: a {D}utch {R}o{BERT}a-based {L}anguage {M}odel", 289 | author = "Delobelle, Pieter and 290 | Winters, Thomas and 291 | Berendt, Bettina", 292 | booktitle = "Findings of the Association for Computational Linguistics: EMNLP 2020", 293 | month = nov, 294 | year = "2020", 295 | address = "Online", 296 | publisher = "Association for Computational Linguistics", 297 | url = "https://www.aclweb.org/anthology/2020.findings-emnlp.292", 298 | doi = "10.18653/v1/2020.findings-emnlp.292", 299 | pages = "3255--3265" 300 | } 301 | ``` 302 | -------------------------------------------------------------------------------- /examples/die_vs_data_rest_api/.gitignore: -------------------------------------------------------------------------------- 1 | venv 2 | app/__pycache__ 3 | -------------------------------------------------------------------------------- /examples/die_vs_data_rest_api/README.md: -------------------------------------------------------------------------------- 1 | # Demo: 'die' vs 'dat' as a rest endpoint with Flask 2 | 3 | As a demo, we release a small Flask example for a rest endpoint to analyse sentences. 4 | It will return whether a sentence contains the correct—according to RobBERT—use of 'die' or 'dat'. 5 | 6 | By default, A Flask server will listen to port 5000. The endpoint is `/`. 7 | 8 | ## Get started 9 | First install the dependencies from the requirements.txt file using `pip install -r requirements.txt` 10 | 11 | ```shell script 12 | $ python app.py --model-path DTAI-KULeuven/robbertje-shuffled-dutch-die-vs-dat --fast-model-path pdelobelle/robbert-v2-dutch-base 13 | ``` 14 | 15 | ## Classification model 16 | Simply make a http POST request to `/` with the parameter `sentence` filled in: 17 | 18 | ```shell script 19 | $ curl --data "sentence=Daar loopt _die_ meisje." localhost:5000 20 | ``` 21 | 22 | This should give you the following response: 23 | 24 | ```json 25 | { 26 | "rating": 1, 27 | "interpretation": "incorrect", 28 | "confidence": 5.222124099731445, 29 | "sentence": "Daar loopt _die_ meisje." 30 | } 31 | ``` 32 | 33 | ## Zero-shot model 34 | We also have a faster zero-shot model (using RobBERT base), which might be faster and easier to use. There is a small drop in accuracy, but that should be quite limited. 35 | 36 | To use the faster zero-shot model, just make a http POST request to `/fast` with the same parameter `sentence` filled in: 37 | 38 | ```shell script 39 | $ curl --data "sentence=Daar loopt _die_ meisje." localhost:5000/fast 40 | ``` 41 | 42 | This should give you the following response, which is similar, but also provides `die`, `dat`, `Die` or `Dat` as the rating value: 43 | 44 | ```json 45 | { 46 | "rating": "dat", 47 | "interpretation": "incorrect", 48 | "confidence": 2.567270278930664, 49 | "sentence": "Daar loopt _die_ meisie." 50 | } 51 | ``` -------------------------------------------------------------------------------- /examples/die_vs_data_rest_api/app.py: -------------------------------------------------------------------------------- 1 | from app import create_app 2 | import argparse 3 | 4 | def create_parser(): 5 | "Utility function to create the CLI argument parser" 6 | 7 | parser = argparse.ArgumentParser( 8 | description="Create a REST endpoint for for 'die' vs 'dat' disambiguation." 9 | ) 10 | 11 | parser.add_argument("--model-path", help="Path to the finetuned RobBERT identifier.", required=False) 12 | parser.add_argument("--fast-model-path", help="Path to the mlm RobBERT identifier.", required=False) 13 | 14 | return parser 15 | 16 | if __name__ == "__main__": 17 | arg_parser = create_parser() 18 | args = arg_parser.parse_args() 19 | 20 | create_parser() 21 | create_app(args.model_path, args.fast_model_path).run() -------------------------------------------------------------------------------- /examples/die_vs_data_rest_api/app/__init__.py: -------------------------------------------------------------------------------- 1 | from flask import Flask, request 2 | import os 3 | from transformers import ( 4 | RobertaForSequenceClassification, 5 | RobertaForMaskedLM, 6 | RobertaTokenizer, 7 | ) 8 | import torch 9 | import nltk 10 | from nltk.tokenize.treebank import TreebankWordDetokenizer 11 | import json 12 | import re 13 | 14 | 15 | def replace_query_token(sentence): 16 | "Small utility function to replace a sentence with `_die_` or `_dat_` with the proper RobBERT input." 17 | tokens = nltk.word_tokenize(sentence) 18 | tokens_swapped = nltk.word_tokenize(sentence) 19 | for i, word in enumerate(tokens): 20 | if word == "_die_": 21 | tokens[i] = "die" 22 | tokens_swapped[i] = "dat" 23 | 24 | elif word == "_dat_": 25 | tokens[i] = "dat" 26 | tokens_swapped[i] = "die" 27 | 28 | elif word == "_Dat_": 29 | tokens[i] = "Dat" 30 | tokens_swapped[i] = "Die" 31 | 32 | elif word == "_Die_": 33 | tokens[i] = "Die" 34 | 35 | tokens_swapped[i] = "Dat" 36 | 37 | if word.lower() == "_die_" or word.lower() == "_dat_": 38 | results = TreebankWordDetokenizer().detokenize(tokens) 39 | results_swapped = TreebankWordDetokenizer().detokenize(tokens_swapped) 40 | 41 | return "{} {}".format(results, results_swapped) 42 | 43 | # If we reach the end of the for loop, it means no query token was present 44 | raise ValueError("'die' or 'dat' should be surrounded by underscores.") 45 | 46 | 47 | def create_app(model_path: str, fast_model_path: str, device="cpu"): 48 | """ 49 | Create the flask app. 50 | 51 | :param model_path: Path to the finetuned model. 52 | :param device: Pytorch device, default CPU (switch to 'cuda' if a GPU is present) 53 | :return: the flask app 54 | """ 55 | app = Flask(__name__, instance_relative_config=True) 56 | 57 | print("initializing tokenizer and RobBERT.") 58 | if model_path: 59 | tokenizer: RobertaTokenizer = RobertaTokenizer.from_pretrained( 60 | model_path, use_auth_token=True 61 | ) 62 | robbert = RobertaForSequenceClassification.from_pretrained( 63 | model_path, use_auth_token=True 64 | ) 65 | robbert.eval() 66 | print("Loaded finetuned model") 67 | 68 | if fast_model_path: 69 | fast_tokenizer: RobertaTokenizer = RobertaTokenizer.from_pretrained( 70 | fast_model_path, use_auth_token=True 71 | ) 72 | fast_robbert = RobertaForMaskedLM.from_pretrained( 73 | fast_model_path, use_auth_token=True 74 | ) 75 | fast_robbert.eval() 76 | 77 | print("Loaded MLM model") 78 | 79 | possible_tokens = ["die", "dat", "Die", "Dat"] 80 | 81 | ids = fast_tokenizer.convert_tokens_to_ids(possible_tokens) 82 | 83 | mask_padding_with_zero = True 84 | block_size = 512 85 | 86 | # Disable dropout 87 | 88 | nltk.download("punkt") 89 | 90 | if fast_model_path: 91 | 92 | @app.route("/disambiguation/mlm/all", methods=["POST"]) 93 | def split(): 94 | sentence = request.form["sentence"] 95 | 96 | response = [] 97 | old_pos = 0 98 | for match in re.finditer(r"(die|dat|Die|Dat)+", sentence): 99 | print( 100 | "match", 101 | match.group(), 102 | "start index", 103 | match.start(), 104 | "End index", 105 | match.end(), 106 | ) 107 | with torch.no_grad(): 108 | query = ( 109 | sentence[: match.start()] + "" + sentence[match.end() :] 110 | ) 111 | print(query) 112 | if match.start() > 0: 113 | response.append({"part": sentence[old_pos: match.start()]}) 114 | 115 | old_pos = match.end() 116 | inputs = fast_tokenizer.encode_plus(query, return_tensors="pt") 117 | 118 | outputs = fast_robbert(**inputs) 119 | masked_position = torch.where( 120 | inputs["input_ids"] == fast_tokenizer.mask_token_id 121 | )[1] 122 | if len(masked_position) > 1: 123 | return "No two queries allowed in one sentence.", 400 124 | 125 | print(outputs.logits[0, masked_position, ids]) 126 | token = outputs.logits[0, masked_position, ids].argmax() 127 | 128 | confidence = float(outputs.logits[0, masked_position, ids].max()) 129 | 130 | response.append({ 131 | "predicted": possible_tokens[token], 132 | "input": match.group(), 133 | "interpretation": "correct" 134 | if possible_tokens[token] == match.group() 135 | else "incorrect", 136 | "confidence": confidence, 137 | "sentence": sentence, 138 | }) 139 | 140 | # This would be a good place for logging/storing queries + results 141 | print(response) 142 | 143 | # inputs = fast_tokenizer.encode_plus(query, return_tensors="pt") 144 | response.append({"part": sentence[match.end():]}) 145 | return json.dumps(response) 146 | 147 | @app.route("/disambiguation/mlm", methods=["POST"]) 148 | def fast(): 149 | sentence = request.form["sentence"] 150 | for i, x in enumerate(possible_tokens): 151 | if f"_{x}_" in sentence: 152 | masked_id = i 153 | query = sentence.replace(f"_{x}_", fast_tokenizer.mask_token) 154 | 155 | inputs = fast_tokenizer.encode_plus(query, return_tensors="pt") 156 | 157 | masked_position = torch.where( 158 | inputs["input_ids"] == fast_tokenizer.mask_token_id 159 | )[1] 160 | if len(masked_position) > 1: 161 | return "No two queries allowed in one sentence.", 400 162 | 163 | # self.examples.append([tokenizer.build_inputs_with_special_tokens(tokenized_text[0 : block_size]), [0], [0]]) 164 | with torch.no_grad(): 165 | outputs = fast_robbert(**inputs) 166 | 167 | print(outputs.logits[0, masked_position, ids]) 168 | token = outputs.logits[0, masked_position, ids].argmax() 169 | 170 | confidence = float(outputs.logits[0, masked_position, ids].max()) 171 | 172 | response = { 173 | "rating": possible_tokens[token], 174 | "interpretation": "correct" if token == masked_id else "incorrect", 175 | "confidence": confidence, 176 | "sentence": sentence, 177 | } 178 | 179 | # This would be a good place for logging/storing queries + results 180 | print(response) 181 | 182 | return json.dumps(response) 183 | 184 | if model_path: 185 | 186 | @app.route("/disambiguation/classifier", methods=["POST"]) 187 | def main(): 188 | sentence = request.form["sentence"] 189 | query = replace_query_token(sentence) 190 | 191 | tokenized_text = tokenizer.encode( 192 | tokenizer.tokenize(query)[-block_size + 3 : -1] 193 | ) 194 | 195 | input_mask = [1 if mask_padding_with_zero else 0] * len(tokenized_text) 196 | 197 | pad_token = tokenizer.convert_tokens_to_ids(tokenizer.pad_token) 198 | while len(tokenized_text) < block_size: 199 | tokenized_text.append(pad_token) 200 | input_mask.append(0 if mask_padding_with_zero else 1) 201 | # segment_ids.append(pad_token_segment_id) 202 | # p_mask.append(1) 203 | 204 | # self.examples.append([tokenizer.build_inputs_with_special_tokens(tokenized_text[0 : block_size]), [0], [0]]) 205 | batch = tuple( 206 | torch.tensor(t).to(torch.device(device)) 207 | for t in [ 208 | tokenized_text[0 : block_size - 3], 209 | input_mask[0 : block_size - 3], 210 | [0], 211 | [1][0], 212 | ] 213 | ) 214 | inputs = { 215 | "input_ids": batch[0].unsqueeze(0), 216 | "attention_mask": batch[1].unsqueeze(0), 217 | "labels": batch[3].unsqueeze(0), 218 | } 219 | with torch.no_grad(): 220 | outputs = robbert(**inputs) 221 | 222 | rating = outputs[1].argmax().item() 223 | confidence = outputs[1][0, rating].item() 224 | 225 | response = { 226 | "rating": rating, 227 | "interpretation": "incorrect" if rating == 1 else "correct", 228 | "confidence": confidence, 229 | "sentence": sentence, 230 | } 231 | 232 | # This would be a good place for logging/storing queries + results 233 | print(response) 234 | 235 | return json.dumps(response) 236 | 237 | return app 238 | -------------------------------------------------------------------------------- /examples/die_vs_data_rest_api/app/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iPieter/RobBERT/8f562fe3e79ec8c0ea04051277b6ae86e7e382e9/examples/die_vs_data_rest_api/app/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /examples/die_vs_data_rest_api/app/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "port": 5000 3 | } -------------------------------------------------------------------------------- /examples/die_vs_data_rest_api/requirements.txt: -------------------------------------------------------------------------------- 1 | boto3==1.12.36 2 | botocore==1.15.36 3 | certifi==2020.4.5.1 4 | chardet==3.0.4 5 | click==7.1.1 6 | docutils==0.15.2 7 | filelock==3.0.12 8 | Flask==1.1.2 9 | idna==2.9 10 | itsdangerous==1.1.0 11 | Jinja2==2.11.1 12 | jmespath==0.9.5 13 | joblib==0.14.1 14 | MarkupSafe==1.1.1 15 | nltk==3.4.5 16 | numpy==1.18.2 17 | python-dateutil==2.8.1 18 | regex==2020.4.4 19 | requests==2.23.0 20 | s3transfer==0.3.3 21 | sacremoses==0.0.38 22 | sentencepiece==0.1.85 23 | six==1.14.0 24 | tokenizers==0.5.2 25 | torch==1.4.0 26 | tqdm==4.45.0 27 | transformers==2.7.0 28 | urllib3==1.25.8 29 | Werkzeug==1.0.1 30 | -------------------------------------------------------------------------------- /model_cards/README.md: -------------------------------------------------------------------------------- 1 | --- 2 | language: "nl" 3 | thumbnail: "https://github.com/iPieter/RobBERT/raw/master/res/robbert_logo.png" 4 | tags: 5 | - Dutch 6 | - Flemish 7 | - RoBERTa 8 | - RobBERT 9 | license: mit 10 | datasets: 11 | - oscar 12 | - oscar (NL) 13 | - dbrd 14 | - lassy-ud 15 | - europarl-mono 16 | - conll2002 17 | widget: 18 | - text: "Hallo, ik ben RobBERT, een taalmodel van de KU Leuven." 19 | --- 20 | 21 |

22 | RobBERT: A Dutch RoBERTa-based Language Model 23 |

24 | 25 | # RobBERT: Dutch RoBERTa-based Language Model. 26 | 27 | [RobBERT](https://github.com/iPieter/RobBERT) is the state-of-the-art Dutch BERT model. It is a large pre-trained general Dutch language model that can be fine-tuned on a given dataset to perform any text classification, regression or token-tagging task. As such, it has been successfully used by many [researchers](https://scholar.google.com/scholar?oi=bibs&hl=en&cites=7180110604335112086) and [practitioners](https://huggingface.co/models?search=robbert) for achieving state-of-the-art performance for a wide range of Dutch natural language processing tasks, including: 28 | 29 | - [Emotion detection](https://www.aclweb.org/anthology/2021.wassa-1.27/) 30 | - Sentiment analysis ([book reviews](https://arxiv.org/pdf/2001.06286.pdf), [news articles](https://biblio.ugent.be/publication/8704637/file/8704638.pdf)*) 31 | - [Coreference resolution](https://arxiv.org/pdf/2001.06286.pdf) 32 | - Named entity recognition ([CoNLL](https://arxiv.org/pdf/2001.06286.pdf), [job titles](https://arxiv.org/pdf/2004.02814.pdf)*, [SoNaR](https://github.com/proycon/deepfrog)) 33 | - Part-of-speech tagging ([Small UD Lassy](https://arxiv.org/pdf/2001.06286.pdf), [CGN](https://github.com/proycon/deepfrog)) 34 | - [Zero-shot word prediction](https://arxiv.org/pdf/2001.06286.pdf) 35 | - [Humor detection](https://arxiv.org/pdf/2010.13652.pdf) 36 | - [Cyberbulling detection](https://www.cambridge.org/core/journals/natural-language-engineering/article/abs/automatic-classification-of-participant-roles-in-cyberbullying-can-we-detect-victims-bullies-and-bystanders-in-social-media-text/A2079C2C738C29428E666810B8903342) 37 | - [Correcting dt-spelling mistakes](https://gitlab.com/spelfouten/dutch-simpletransformers/)* 38 | 39 | and also achieved outstanding, near-sota results for: 40 | 41 | - [Natural language inference](https://arxiv.org/pdf/2101.05716.pdf)* 42 | - [Review classification](https://medium.com/broadhorizon-cmotions/nlp-with-r-part-5-state-of-the-art-in-nlp-transformers-bert-3449e3cd7494)* 43 | 44 | \* *Note that several evaluations use RobBERT-v1, and that the second and improved RobBERT-v2 outperforms this first model on everything we tested* 45 | 46 | *(Also note that this list is not exhaustive. If you used RobBERT for your application, we are happy to know about it! Send us a mail, or add it yourself to this list by sending a pull request with the edit!)* 47 | 48 | More in-depth information about RobBERT can be found in our [blog post](https://people.cs.kuleuven.be/~pieter.delobelle/robbert/), [our paper](https://arxiv.org/abs/2001.06286) and [the RobBERT Github repository](https://github.com/iPieter/RobBERT) 49 | 50 | 51 | ## How to use 52 | 53 | RobBERT uses the [RoBERTa](https://arxiv.org/abs/1907.11692) architecture and pre-training but with a Dutch tokenizer and training data. RoBERTa is the robustly optimized English BERT model, making it even more powerful than the original BERT model. Given this same architecture, RobBERT can easily be finetuned and inferenced using [code to finetune RoBERTa](https://huggingface.co/transformers/model_doc/roberta.html) models and most code used for BERT models, e.g. as provided by [HuggingFace Transformers](https://huggingface.co/transformers/) library. 54 | 55 | By default, RobBERT has the masked language model head used in training. This can be used as a zero-shot way to fill masks in sentences. It can be tested out for free on [RobBERT's Hosted infererence API of Huggingface](https://huggingface.co/pdelobelle/robbert-v2-dutch-base?text=De+hoofdstad+van+Belgi%C3%AB+is+%3Cmask%3E.). You can also create a new prediction head for your own task by using any of HuggingFace's [RoBERTa-runners](https://huggingface.co/transformers/v2.7.0/examples.html#language-model-training), [their fine-tuning notebooks](https://huggingface.co/transformers/v4.1.1/notebooks.html) by changing the model name to `pdelobelle/robbert-v2-dutch-base`, or use the original fairseq [RoBERTa](https://github.com/pytorch/fairseq/tree/master/examples/roberta) training regimes. 56 | 57 | Use the following code to download the base model and finetune it yourself, or use one of our finetuned models (documented on [our project site](https://people.cs.kuleuven.be/~pieter.delobelle/robbert/)). 58 | 59 | ```python 60 | from transformers import RobertaTokenizer, RobertaForSequenceClassification 61 | tokenizer = RobertaTokenizer.from_pretrained("pdelobelle/robbert-v2-dutch-base") 62 | model = RobertaForSequenceClassification.from_pretrained("pdelobelle/robbert-v2-dutch-base") 63 | ``` 64 | 65 | Starting with `transformers v2.4.0` (or installing from source), you can use AutoTokenizer and AutoModel. 66 | You can then use most of [HuggingFace's BERT-based notebooks](https://huggingface.co/transformers/v4.1.1/notebooks.html) for finetuning RobBERT on your type of Dutch language dataset. 67 | 68 | 69 | ## Technical Details From The Paper 70 | 71 | 72 | ### Our Performance Evaluation Results 73 | 74 | All experiments are described in more detail in our [paper](https://arxiv.org/abs/2001.06286), with the code in [our GitHub repository](https://github.com/iPieter/RobBERT). 75 | 76 | ### Sentiment analysis 77 | Predicting whether a review is positive or negative using the [Dutch Book Reviews Dataset](https://github.com/benjaminvdb/DBRD). 78 | 79 | | Model | Accuracy [%] | 80 | |-------------------|--------------------------| 81 | | ULMFiT | 93.8 | 82 | | BERTje | 93.0 | 83 | | RobBERT v2 | **95.1** | 84 | 85 | ### Die/Dat (coreference resolution) 86 | 87 | We measured how well the models are able to do coreference resolution by predicting whether "die" or "dat" should be filled into a sentence. 88 | For this, we used the [EuroParl corpus](https://www.statmt.org/europarl/). 89 | 90 | #### Finetuning on whole dataset 91 | 92 | | Model | Accuracy [%] | F1 [%] | 93 | |-------------------|--------------------------|--------------| 94 | | [Baseline](https://arxiv.org/abs/2001.02943) (LSTM) | | 75.03 | 95 | | mBERT | 98.285 | 98.033 | 96 | | BERTje | 98.268 | 98.014 | 97 | | RobBERT v2 | **99.232** | **99.121** | 98 | 99 | #### Finetuning on 10K examples 100 | 101 | We also measured the performance using only 10K training examples. 102 | This experiment clearly illustrates that RobBERT outperforms other models when there is little data available. 103 | 104 | | Model | Accuracy [%] | F1 [%] | 105 | |-------------------|--------------------------|--------------| 106 | | mBERT | 92.157 | 90.898 | 107 | | BERTje | 93.096 | 91.279 | 108 | | RobBERT v2 | **97.816** | **97.514** | 109 | 110 | #### Using zero-shot word masking task 111 | 112 | Since BERT models are pre-trained using the word masking task, we can use this to predict whether "die" or "dat" is more likely. 113 | This experiment shows that RobBERT has internalised more information about Dutch than other models. 114 | 115 | | Model | Accuracy [%] | 116 | |-------------------|--------------------------| 117 | | ZeroR | 66.70 | 118 | | mBERT | 90.21 | 119 | | BERTje | 94.94 | 120 | | RobBERT v2 | **98.75** | 121 | 122 | ### Part-of-Speech Tagging. 123 | 124 | Using the [Lassy UD dataset](https://universaldependencies.org/treebanks/nl_lassysmall/index.html). 125 | 126 | 127 | | Model | Accuracy [%] | 128 | |-------------------|--------------------------| 129 | | Frog | 91.7 | 130 | | mBERT | **96.5** | 131 | | BERTje | 96.3 | 132 | | RobBERT v2 | 96.4 | 133 | 134 | Interestingly, we found that when dealing with **small data sets**, RobBERT v2 **significantly outperforms** other models. 135 | 136 |

137 | RobBERT's performance on smaller datasets 138 |

139 | 140 | ### Named Entity Recognition 141 | 142 | Using the [CoNLL 2002 evaluation script](https://www.clips.uantwerpen.be/conll2002/ner/). 143 | 144 | 145 | | Model | Accuracy [%] | 146 | |-------------------|--------------------------| 147 | | Frog | 57.31 | 148 | | mBERT | **90.94** | 149 | | BERT-NL | 89.7 | 150 | | BERTje | 88.3 | 151 | | RobBERT v2 | 89.08 | 152 | 153 | 154 | ## Pre-Training Procedure Details 155 | 156 | We pre-trained RobBERT using the RoBERTa training regime. 157 | We pre-trained our model on the Dutch section of the [OSCAR corpus](https://oscar-corpus.com/), a large multilingual corpus which was obtained by language classification in the Common Crawl corpus. 158 | This Dutch corpus is 39GB large, with 6.6 billion words spread over 126 million lines of text, where each line could contain multiple sentences, thus using more data than concurrently developed Dutch BERT models. 159 | 160 | 161 | RobBERT shares its architecture with [RoBERTa's base model](https://github.com/pytorch/fairseq/tree/master/examples/roberta), which itself is a replication and improvement over BERT. 162 | Like BERT, it's architecture consists of 12 self-attention layers with 12 heads with 117M trainable parameters. 163 | One difference with the original BERT model is due to the different pre-training task specified by RoBERTa, using only the MLM task and not the NSP task. 164 | During pre-training, it thus only predicts which words are masked in certain positions of given sentences. 165 | The training process uses the Adam optimizer with polynomial decay of the learning rate l_r=10^-6 and a ramp-up period of 1000 iterations, with hyperparameters beta_1=0.9 166 | and RoBERTa's default beta_2=0.98. 167 | Additionally, a weight decay of 0.1 and a small dropout of 0.1 helps prevent the model from overfitting. 168 | 169 | 170 | RobBERT was trained on a computing cluster with 4 Nvidia P100 GPUs per node, where the number of nodes was dynamically adjusted while keeping a fixed batch size of 8192 sentences. 171 | At most 20 nodes were used (i.e. 80 GPUs), and the median was 5 nodes. 172 | By using gradient accumulation, the batch size could be set independently of the number of GPUs available, in order to maximally utilize the cluster. 173 | Using the [Fairseq library](https://github.com/pytorch/fairseq/tree/master/examples/roberta), the model trained for two epochs, which equals over 16k batches in total, which took about three days on the computing cluster. 174 | In between training jobs on the computing cluster, 2 Nvidia 1080 Ti's also covered some parameter updates for RobBERT v2. 175 | 176 | 177 | ## Investigating Limitations and Bias 178 | 179 | In the [RobBERT paper](https://arxiv.org/abs/2001.06286), we also investigated potential sources of bias in RobBERT. 180 | 181 | We found that the zeroshot model estimates the probability of *hij* (he) to be higher than *zij* (she) for most occupations in bleached template sentences, regardless of their actual job gender ratio in reality. 182 | 183 |

184 | RobBERT's performance on smaller datasets 185 |

186 | 187 | By augmenting the DBRB Dutch Book sentiment analysis dataset with the stated gender of the author of the review, we found that highly positive reviews written by women were generally more accurately detected by RobBERT as being positive than those written by men. 188 | 189 |

190 | RobBERT's performance on smaller datasets 191 |

192 | 193 | 194 | 195 | ## How to Replicate Our Paper Experiments 196 | Replicating our paper experiments is [described in detail on teh RobBERT repository README](https://github.com/iPieter/RobBERT#how-to-replicate-our-paper-experiments). 197 | 198 | ## Name Origin of RobBERT 199 | 200 | Most BERT-like models have the word *BERT* in their name (e.g. [RoBERTa](https://huggingface.co/transformers/model_doc/roberta.html), [ALBERT](https://arxiv.org/abs/1909.11942), [CamemBERT](https://camembert-model.fr/), and [many, many others](https://huggingface.co/models?search=bert)). 201 | As such, we queried our newly trained model using its masked language model to name itself *\bert* using [all](https://huggingface.co/pdelobelle/robbert-v2-dutch-base?text=Mijn+naam+is+%3Cmask%3Ebert.) [kinds](https://huggingface.co/pdelobelle/robbert-v2-dutch-base?text=Hallo%2C+ik+ben+%3Cmask%3Ebert.) [of](https://huggingface.co/pdelobelle/robbert-v2-dutch-base?text=Leuk+je+te+ontmoeten%2C+ik+heet+%3Cmask%3Ebert.) [prompts](https://huggingface.co/pdelobelle/robbert-v2-dutch-base?text=Niemand+weet%2C+niemand+weet%2C+dat+ik+%3Cmask%3Ebert+heet.), and it consistently called itself RobBERT. 202 | We thought it was really quite fitting, given that RobBERT is a [*very* Dutch name](https://en.wikipedia.org/wiki/Robbert) *(and thus clearly a Dutch language model)*, and additionally has a high similarity to its root architecture, namely [RoBERTa](https://huggingface.co/transformers/model_doc/roberta.html). 203 | 204 | Since *"rob"* is a Dutch words to denote a seal, we decided to draw a seal and dress it up like [Bert from Sesame Street](https://muppet.fandom.com/wiki/Bert) for the [RobBERT logo](https://github.com/iPieter/RobBERT/blob/master/res/robbert_logo.png). 205 | 206 | ## Credits and citation 207 | 208 | This project is created by [Pieter Delobelle](https://people.cs.kuleuven.be/~pieter.delobelle), [Thomas Winters](https://thomaswinters.be) and [Bettina Berendt](https://people.cs.kuleuven.be/~bettina.berendt/). 209 | If you would like to cite our paper or model, you can use the following BibTeX: 210 | 211 | ``` 212 | @inproceedings{delobelle2020robbert, 213 | title = "{R}ob{BERT}: a {D}utch {R}o{BERT}a-based {L}anguage {M}odel", 214 | author = "Delobelle, Pieter and 215 | Winters, Thomas and 216 | Berendt, Bettina", 217 | booktitle = "Findings of the Association for Computational Linguistics: EMNLP 2020", 218 | month = nov, 219 | year = "2020", 220 | address = "Online", 221 | publisher = "Association for Computational Linguistics", 222 | url = "https://www.aclweb.org/anthology/2020.findings-emnlp.292", 223 | doi = "10.18653/v1/2020.findings-emnlp.292", 224 | pages = "3255--3265" 225 | } 226 | ``` 227 | -------------------------------------------------------------------------------- /notebooks/demo_RobBERT_for_conll_ner.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "pycharm": { 7 | "name": "#%% md\n" 8 | } 9 | }, 10 | "source": [ 11 | "# Demo of RobBERT for Dutch named entity recognition\n", 12 | "We use a [RobBERT (Delobelle et al., 2020)](https://arxiv.org/abs/2001.06286) model for NER.\n", 13 | "\n", 14 | "**Dependencies**\n", 15 | "- tokenizers\n", 16 | "- torch\n", 17 | "- transformers" 18 | ] 19 | }, 20 | { 21 | "cell_type": "markdown", 22 | "metadata": {}, 23 | "source": [ 24 | "First we load our RobBERT model that was pretrained on OSCAR and finetuned on Dutch named entity recognition. We also load in RobBERT's tokenizer.\n", 25 | "\n", 26 | "Because we only want to get results, we have to disable dropout etc. So we add `model.eval()`." 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 2, 32 | "metadata": { 33 | "collapsed": false, 34 | "jupyter": { 35 | "outputs_hidden": false 36 | }, 37 | "pycharm": { 38 | "is_executing": false, 39 | "name": "#%%\n" 40 | } 41 | }, 42 | "outputs": [ 43 | { 44 | "name": "stderr", 45 | "output_type": "stream", 46 | "text": [ 47 | "Special tokens have been added in the vocabulary, make sure the associated word embedding are fine-tuned or trained.\n" 48 | ] 49 | }, 50 | { 51 | "name": "stdout", 52 | "output_type": "stream", 53 | "text": [ 54 | "RobBERT model loaded\n" 55 | ] 56 | } 57 | ], 58 | "source": [ 59 | "import torch\n", 60 | "from transformers import RobertaTokenizer, RobertaForTokenClassification\n", 61 | "\n", 62 | "tokenizer = RobertaTokenizer.from_pretrained('pdelobelle/robbert-v2-dutch-ner')\n", 63 | "model = RobertaForTokenClassification.from_pretrained('pdelobelle/robbert-v2-dutch-ner', return_dict=True)\n", 64 | "model.eval()\n", 65 | "print(\"RobBERT model loaded\")" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 3, 71 | "metadata": { 72 | "collapsed": false, 73 | "jupyter": { 74 | "outputs_hidden": false 75 | }, 76 | "pycharm": { 77 | "is_executing": false, 78 | "name": "#%% \n" 79 | } 80 | }, 81 | "outputs": [ 82 | { 83 | "name": "stdout", 84 | "output_type": "stream", 85 | "text": [ 86 | "input_ids:\n", 87 | "\ttensor([[ 0, 6079, 499, 38, 5, 13292, 11, 6422, 8, 7010,\n", 88 | " 9, 2617, 4, 2, 1],\n", 89 | " [ 0, 25907, 129, 1283, 8, 3971, 113, 28, 118, 71,\n", 90 | " 435, 38, 27600, 4, 2],\n", 91 | " [ 0, 9396, 89, 9, 797, 2877, 22, 11, 5, 4290,\n", 92 | " 445, 4, 2, 1, 1],\n", 93 | " [ 0, 7751, 6, 74, 458, 12, 3663, 14334, 342, 4,\n", 94 | " 2, 1, 1, 1, 1]])\n", 95 | "attention_mask:\n", 96 | "\ttensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],\n", 97 | " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", 98 | " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n", 99 | " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0]])\n", 100 | "Tokens:\n", 101 | "\t['', 'Jan', 'Ġging', 'Ġnaar', 'Ġde', 'Ġbakker', 'Ġin', 'ĠLeuven', 'Ġen', 'Ġkocht', 'Ġeen', 'Ġbrood', '.', '', '']\n", 102 | "\t['', 'Bedrijven', 'Ġzoals', 'ĠGoogle', 'Ġen', 'ĠMicrosoft', 'Ġdoen', 'Ġook', 'Ġheel', 'Ġveel', 'Ġonderzoek', 'Ġnaar', 'ĠNLP', '.', '']\n" 103 | ] 104 | } 105 | ], 106 | "source": [ 107 | "inputs = tokenizer.batch_encode_plus(\n", 108 | " [\"Jan ging naar de bakker in Leuven en kocht een brood.\",\n", 109 | " \"Bedrijven zoals Google en Microsoft doen ook heel veel onderzoek naar NLP.\",\n", 110 | " \"Men moet een gegeven paard niet in de bek kijken.\",\n", 111 | " \"Hallo, mijn naam is RobBERT.\"],\n", 112 | " return_tensors=\"pt\", padding=True)\n", 113 | "for key, value in inputs.items():\n", 114 | " print(\"{}:\\n\\t{}\".format(key, value))\n", 115 | "print(\"Tokens:\\n\\t{}\".format(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]) ))\n", 116 | "print(\"\\t{}\".format(tokenizer.convert_ids_to_tokens(inputs['input_ids'][1]) ))" 117 | ] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "metadata": { 122 | "pycharm": { 123 | "is_executing": false, 124 | "name": "#%% md\n" 125 | } 126 | }, 127 | "source": [ 128 | "In our model config, we stored what labels we use. \n", 129 | "We can load these in and automatically convert our predictions to a human-readable format.\n", 130 | "For reference, we have 4 types of named entities:\n", 131 | "\n", 132 | "- PER\n", 133 | "- LOC\n", 134 | "- ORG\n", 135 | "- MISC\n", 136 | "\n", 137 | "And we mark the first token with `B-` and then we mark a continuation with `I-`. " 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 4, 143 | "metadata": { 144 | "collapsed": false, 145 | "jupyter": { 146 | "outputs_hidden": false 147 | }, 148 | "pycharm": { 149 | "is_executing": false, 150 | "name": "#%%\n" 151 | } 152 | }, 153 | "outputs": [ 154 | { 155 | "name": "stdout", 156 | "output_type": "stream", 157 | "text": [ 158 | "{0: 'B-PER', 1: 'B-ORG', 2: 'B-LOC', 3: 'B-MISC', 4: 'I-PER', 5: 'I-ORG', 6: 'I-LOC', 7: 'I-MISC', 8: 'O'}\n" 159 | ] 160 | } 161 | ], 162 | "source": [ 163 | "print(model.config.id2label)" 164 | ] 165 | }, 166 | { 167 | "cell_type": "markdown", 168 | "metadata": { 169 | "pycharm": { 170 | "name": "#%% md\n" 171 | } 172 | }, 173 | "source": [ 174 | "Ok, let's do some predictions! Since we have a batch of 4 sentences, we can do this in one batch—as long as it fits on your GPU.\n", 175 | "\n", 176 | "_If the formatting of this fails, you can try to zoom out or make the window wider_" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": 5, 182 | "metadata": { 183 | "collapsed": false, 184 | "jupyter": { 185 | "outputs_hidden": false 186 | }, 187 | "pycharm": { 188 | "is_executing": false, 189 | "name": "#%%\n" 190 | } 191 | }, 192 | "outputs": [ 193 | { 194 | "name": "stdout", 195 | "output_type": "stream", 196 | "text": [ 197 | "Sentence 0\n", 198 | " Jan Ġging Ġnaar Ġde Ġbakker Ġin ĠLeuven Ġen Ġkocht Ġeen Ġbrood . \n", 199 | "\n", 200 | "O B-PER O O O O O B-LOC O O O O O O O \n", 201 | "\n", 202 | "Sentence 1\n", 203 | " Bedrijven Ġzoals ĠGoogle Ġen ĠMicrosoft Ġdoen Ġook Ġheel Ġveel Ġonderzoek Ġnaar ĠNLP . \n", 204 | "\n", 205 | "O O O B-ORG O B-ORG O O O O O O B-MISC O O \n", 206 | "\n", 207 | "Sentence 2\n", 208 | " Men Ġmoet Ġeen Ġgegeven Ġpaard Ġniet Ġin Ġde Ġbek Ġkijken . \n", 209 | "\n", 210 | "O O O O O O O O O O O O O O O \n", 211 | "\n", 212 | "Sentence 3\n", 213 | " Hallo , Ġmijn Ġnaam Ġis ĠRob BER T . \n", 214 | "\n", 215 | "O O O O O O B-PER I-PER I-PER O O O O O I-PER \n", 216 | "\n" 217 | ] 218 | } 219 | ], 220 | "source": [ 221 | "with torch.no_grad():\n", 222 | " results = model(**inputs)\n", 223 | " for i, input in enumerate(inputs['input_ids']):\n", 224 | " print(f\"Sentence {i}\")\n", 225 | " [print(\"{:12}\".format(token), end=\"\") for token in tokenizer.convert_ids_to_tokens(input) ]\n", 226 | " print('\\n')\n", 227 | " [print(\"{:12}\".format(model.config.id2label[item.item()]), end=\"\") for item in results.logits[i].argmax(axis=1)]\n", 228 | " print('\\n')" 229 | ] 230 | }, 231 | { 232 | "cell_type": "markdown", 233 | "metadata": { 234 | "pycharm": { 235 | "name": "#%% md\n" 236 | } 237 | }, 238 | "source": [ 239 | "Ok, this works nicely! We have 'Jan', 'Leuven' and companies like 'Google' that are all labeled correctly. \n", 240 | "In addition, RobBERT consists of multiple tokens (perhaps we should have added one with it's name) and that works with the `I-` token as well.\n", 241 | "\n" 242 | ] 243 | } 244 | ], 245 | "metadata": { 246 | "kernelspec": { 247 | "display_name": "Python 3", 248 | "language": "python", 249 | "name": "python3" 250 | }, 251 | "language_info": { 252 | "codemirror_mode": { 253 | "name": "ipython", 254 | "version": 3 255 | }, 256 | "file_extension": ".py", 257 | "mimetype": "text/x-python", 258 | "name": "python", 259 | "nbconvert_exporter": "python", 260 | "pygments_lexer": "ipython3", 261 | "version": "3.7.3" 262 | }, 263 | "pycharm": { 264 | "stem_cell": { 265 | "cell_type": "raw", 266 | "metadata": { 267 | "collapsed": false 268 | }, 269 | "source": [] 270 | } 271 | } 272 | }, 273 | "nbformat": 4, 274 | "nbformat_minor": 4 275 | } 276 | -------------------------------------------------------------------------------- /notebooks/demo_RobBERT_for_masked_LM.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "collapsed": true, 7 | "pycharm": { 8 | "name": "#%% md\n" 9 | } 10 | }, 11 | "source": [ 12 | "# Demo of RobBERT for humour detection\n", 13 | "We use a [RobBERT (Delobelle et al., 2020)](https://arxiv.org/abs/2001.06286) model with the original pretraining head for MLM.\n", 14 | "\n", 15 | "**Dependencies**\n", 16 | "- tokenizers\n", 17 | "- torch\n", 18 | "- transformers" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "First we load our RobBERT model that was pretrained. We also load in RobBERT's tokenizer.\n", 26 | "\n", 27 | "Because we only want to get results, we have to disable dropout etc. So we add `model.eval()`." 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "*Note: we pretrained both RobBERT v1 and RobBERT v2 in [Fairseq](https://github.com/pytorch/fairseq) and converted these checkpoints to HuggingFace. The MLM task behaves a bit differently.*" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 1, 40 | "metadata": { 41 | "pycharm": { 42 | "is_executing": false, 43 | "name": "#%%\n" 44 | } 45 | }, 46 | "outputs": [ 47 | { 48 | "name": "stderr", 49 | "text": [ 50 | "Special tokens have been added in the vocabulary, make sure the associated word emebedding are fine-tuned or trained.\n" 51 | ], 52 | "output_type": "stream" 53 | }, 54 | { 55 | "data": { 56 | "text/plain": "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=469740689.0, style=ProgressStyle(descri…", 57 | "application/vnd.jupyter.widget-view+json": { 58 | "version_major": 2, 59 | "version_minor": 0, 60 | "model_id": "4156a63ac55d4844a22db25f1db7ed52" 61 | } 62 | }, 63 | "metadata": {}, 64 | "output_type": "display_data" 65 | }, 66 | { 67 | "name": "stdout", 68 | "text": [ 69 | "\n", 70 | "RobBERT model loaded\n" 71 | ], 72 | "output_type": "stream" 73 | } 74 | ], 75 | "source": [ 76 | "import torch\n", 77 | "from transformers import RobertaTokenizer, AutoModelForSequenceClassification, AutoConfig\n", 78 | "\n", 79 | "from transformers import RobertaTokenizer, RobertaForMaskedLM\n", 80 | "import torch\n", 81 | "tokenizer = RobertaTokenizer.from_pretrained('pdelobelle/robbert-v2-dutch-base')\n", 82 | "model = RobertaForMaskedLM.from_pretrained('pdelobelle/robbert-v2-dutch-base', return_dict=True)\n", 83 | "model = model.to( 'cuda' if torch.cuda.is_available() else 'cpu' )\n", 84 | "model.eval()\n", 85 | "#model = RobertaForMaskedLM.from_pretrained('pdelobelle/robbert-v2-dutch-base', return_dict=True)\n", 86 | "print(\"RobBERT model loaded\")" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 2, 92 | "metadata": { 93 | "pycharm": { 94 | "is_executing": false, 95 | "name": "#%% \n" 96 | } 97 | }, 98 | "outputs": [], 99 | "source": [ 100 | "sequence = f\"Er staat een {tokenizer.mask_token} in mijn tuin.\"" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 3, 106 | "metadata": { 107 | "pycharm": { 108 | "is_executing": false, 109 | "name": "#%%\n" 110 | } 111 | }, 112 | "outputs": [], 113 | "source": [ 114 | "input = tokenizer.encode(sequence, return_tensors=\"pt\").to( 'cuda' if torch.cuda.is_available() else 'cpu' )\n", 115 | "mask_token_index = torch.where(input == tokenizer.mask_token_id)[1]" 116 | ] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "metadata": { 121 | "pycharm": { 122 | "is_executing": false, 123 | "name": "#%% md\n" 124 | } 125 | }, 126 | "source": [ 127 | "Now that we have our tokenized input and the position of the masked token, we pass the input through RobBERT. \n", 128 | "\n", 129 | "This will give us a predicting for all tokens, but we're only interested in the `` token. " 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 4, 135 | "metadata": { 136 | "pycharm": { 137 | "is_executing": false, 138 | "name": "#%%\n" 139 | } 140 | }, 141 | "outputs": [], 142 | "source": [ 143 | "with torch.no_grad():\n", 144 | " token_logits = model(input).logits" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 5, 150 | "metadata": { 151 | "pycharm": { 152 | "is_executing": false, 153 | "name": "#%%\n" 154 | } 155 | }, 156 | "outputs": [ 157 | { 158 | "name": "stdout", 159 | "text": [ 160 | "Ġboom | id = 2600 | p = 0.1416003555059433\n", 161 | "Ġvijver | id = 8217 | p = 0.13144515454769135\n", 162 | "Ġplant | id = 2721 | p = 0.043418534100055695\n", 163 | "Ġhuis | id = 251 | p = 0.01847737282514572\n", 164 | "Ġparkeerplaats | id = 6889 | p = 0.018001794815063477\n", 165 | "Ġbankje | id = 21620 | p = 0.016940612345933914\n", 166 | "Ġmuur | id = 2035 | p = 0.014668751507997513\n", 167 | "Ġmoestuin | id = 17446 | p = 0.0144038125872612\n", 168 | "Ġzonnebloem | id = 30757 | p = 0.014375611208379269\n", 169 | "Ġschutting | id = 15000 | p = 0.013991709798574448\n", 170 | "Ġpaal | id = 8626 | p = 0.01358739286661148\n", 171 | "Ġbloem | id = 3001 | p = 0.01199684850871563\n", 172 | "Ġstal | id = 7416 | p = 0.011224730871617794\n", 173 | "Ġfontein | id = 23425 | p = 0.011203107424080372\n", 174 | "Ġtuin | id = 671 | p = 0.010676783509552479\n" 175 | ], 176 | "output_type": "stream" 177 | } 178 | ], 179 | "source": [ 180 | "logits = token_logits[0, mask_token_index, :].squeeze()\n", 181 | "prob = logits.softmax(dim=0)\n", 182 | "values, indeces = prob.topk(k=15, dim=0)\n", 183 | "\n", 184 | "for index, token in enumerate(tokenizer.convert_ids_to_tokens(indeces)):\n", 185 | " print(f\"{token:20} | id = {indeces[index]:4} | p = {values[index]}\")" 186 | ] 187 | }, 188 | { 189 | "cell_type": "markdown", 190 | "metadata": { 191 | "pycharm": { 192 | "is_executing": false, 193 | "name": "#%% md\n" 194 | } 195 | }, 196 | "source": [ 197 | "## RobBERT with pipelines\n", 198 | "We can also use the `fill-mask` pipeline from Huggingface, that does basically the same thing." 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": 6, 204 | "outputs": [ 205 | { 206 | "name": "stderr", 207 | "text": [ 208 | "Special tokens have been added in the vocabulary, make sure the associated word emebedding are fine-tuned or trained.\n" 209 | ], 210 | "output_type": "stream" 211 | } 212 | ], 213 | "source": [ 214 | "from transformers import pipeline\n", 215 | "p = pipeline(\"fill-mask\", model=\"pdelobelle/robbert-v2-dutch-base\")" 216 | ], 217 | "metadata": { 218 | "collapsed": false, 219 | "pycharm": { 220 | "name": "#%%\n", 221 | "is_executing": false 222 | } 223 | } 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": 7, 228 | "metadata": { 229 | "pycharm": { 230 | "is_executing": false, 231 | "name": "#%%\n" 232 | } 233 | }, 234 | "outputs": [ 235 | { 236 | "data": { 237 | "text/plain": "[{'sequence': 'Er staat een boomin mijn tuin.',\n 'score': 0.1416003555059433,\n 'token': 2600,\n 'token_str': 'Ġboom'},\n {'sequence': 'Er staat een vijverin mijn tuin.',\n 'score': 0.13144515454769135,\n 'token': 8217,\n 'token_str': 'Ġvijver'},\n {'sequence': 'Er staat een plantin mijn tuin.',\n 'score': 0.043418534100055695,\n 'token': 2721,\n 'token_str': 'Ġplant'},\n {'sequence': 'Er staat een huisin mijn tuin.',\n 'score': 0.01847737282514572,\n 'token': 251,\n 'token_str': 'Ġhuis'},\n {'sequence': 'Er staat een parkeerplaatsin mijn tuin.',\n 'score': 0.018001794815063477,\n 'token': 6889,\n 'token_str': 'Ġparkeerplaats'}]" 238 | }, 239 | "metadata": {}, 240 | "output_type": "execute_result", 241 | "execution_count": 7 242 | } 243 | ], 244 | "source": [ 245 | "p(sequence)" 246 | ] 247 | }, 248 | { 249 | "cell_type": "markdown", 250 | "source": [ 251 | "That's it for this demo of the MLM head. If you use RobBERT in your academic work, you can cite it!\n", 252 | "\n", 253 | "\n", 254 | "```\n", 255 | "@misc{delobelle2020robbert,\n", 256 | " title={{R}ob{BERT}: a {D}utch {R}o{BERT}a-based Language Model},\n", 257 | " author={Pieter Delobelle and Thomas Winters and Bettina Berendt},\n", 258 | " year={2020},\n", 259 | " eprint={2001.06286},\n", 260 | " archivePrefix={arXiv},\n", 261 | " primaryClass={cs.CL}\n", 262 | "}\n", 263 | "```\n" 264 | ], 265 | "metadata": { 266 | "collapsed": false, 267 | "pycharm": { 268 | "name": "#%% md\n" 269 | } 270 | } 271 | } 272 | ], 273 | "metadata": { 274 | "kernelspec": { 275 | "display_name": "Python 3", 276 | "language": "python", 277 | "name": "python3" 278 | }, 279 | "language_info": { 280 | "codemirror_mode": { 281 | "name": "ipython", 282 | "version": 3 283 | }, 284 | "file_extension": ".py", 285 | "mimetype": "text/x-python", 286 | "name": "python", 287 | "nbconvert_exporter": "python", 288 | "pygments_lexer": "ipython3", 289 | "version": "3.7.4" 290 | }, 291 | "pycharm": { 292 | "stem_cell": { 293 | "cell_type": "raw", 294 | "source": [], 295 | "metadata": { 296 | "collapsed": false 297 | } 298 | } 299 | } 300 | }, 301 | "nbformat": 4, 302 | "nbformat_minor": 1 303 | } -------------------------------------------------------------------------------- /notebooks/die_dat_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import nltk\n", 10 | "from nltk.tokenize.treebank import TreebankWordDetokenizer\n", 11 | "from fairseq.models.roberta import RobertaModel" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "metadata": {}, 18 | "outputs": [ 19 | { 20 | "name": "stdout", 21 | "output_type": "stream", 22 | "text": [ 23 | "loading archive file ../\n", 24 | "| [input] dictionary: 50265 types\n", 25 | "| [label] dictionary: 9 types\n" 26 | ] 27 | }, 28 | { 29 | "data": { 30 | "text/plain": [ 31 | "RobertaHubInterface(\n", 32 | " (model): RobertaModel(\n", 33 | " (decoder): RobertaEncoder(\n", 34 | " (sentence_encoder): TransformerSentenceEncoder(\n", 35 | " (embed_tokens): Embedding(50265, 768, padding_idx=1)\n", 36 | " (embed_positions): LearnedPositionalEmbedding(514, 768, padding_idx=1)\n", 37 | " (layers): ModuleList(\n", 38 | " (0): TransformerSentenceEncoderLayer(\n", 39 | " (self_attn): MultiheadAttention(\n", 40 | " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", 41 | " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", 42 | " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", 43 | " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", 44 | " )\n", 45 | " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 46 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 47 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 48 | " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 49 | " )\n", 50 | " (1): TransformerSentenceEncoderLayer(\n", 51 | " (self_attn): MultiheadAttention(\n", 52 | " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", 53 | " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", 54 | " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", 55 | " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", 56 | " )\n", 57 | " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 58 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 59 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 60 | " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 61 | " )\n", 62 | " (2): TransformerSentenceEncoderLayer(\n", 63 | " (self_attn): MultiheadAttention(\n", 64 | " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", 65 | " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", 66 | " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", 67 | " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", 68 | " )\n", 69 | " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 70 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 71 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 72 | " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 73 | " )\n", 74 | " (3): TransformerSentenceEncoderLayer(\n", 75 | " (self_attn): MultiheadAttention(\n", 76 | " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", 77 | " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", 78 | " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", 79 | " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", 80 | " )\n", 81 | " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 82 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 83 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 84 | " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 85 | " )\n", 86 | " (4): TransformerSentenceEncoderLayer(\n", 87 | " (self_attn): MultiheadAttention(\n", 88 | " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", 89 | " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", 90 | " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", 91 | " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", 92 | " )\n", 93 | " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 94 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 95 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 96 | " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 97 | " )\n", 98 | " (5): TransformerSentenceEncoderLayer(\n", 99 | " (self_attn): MultiheadAttention(\n", 100 | " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", 101 | " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", 102 | " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", 103 | " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", 104 | " )\n", 105 | " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 106 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 107 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 108 | " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 109 | " )\n", 110 | " (6): TransformerSentenceEncoderLayer(\n", 111 | " (self_attn): MultiheadAttention(\n", 112 | " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", 113 | " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", 114 | " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", 115 | " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", 116 | " )\n", 117 | " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 118 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 119 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 120 | " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 121 | " )\n", 122 | " (7): TransformerSentenceEncoderLayer(\n", 123 | " (self_attn): MultiheadAttention(\n", 124 | " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", 125 | " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", 126 | " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", 127 | " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", 128 | " )\n", 129 | " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 130 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 131 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 132 | " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 133 | " )\n", 134 | " (8): TransformerSentenceEncoderLayer(\n", 135 | " (self_attn): MultiheadAttention(\n", 136 | " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", 137 | " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", 138 | " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", 139 | " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", 140 | " )\n", 141 | " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 142 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 143 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 144 | " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 145 | " )\n", 146 | " (9): TransformerSentenceEncoderLayer(\n", 147 | " (self_attn): MultiheadAttention(\n", 148 | " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", 149 | " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", 150 | " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", 151 | " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", 152 | " )\n", 153 | " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 154 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 155 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 156 | " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 157 | " )\n", 158 | " (10): TransformerSentenceEncoderLayer(\n", 159 | " (self_attn): MultiheadAttention(\n", 160 | " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", 161 | " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", 162 | " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", 163 | " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", 164 | " )\n", 165 | " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 166 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 167 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 168 | " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 169 | " )\n", 170 | " (11): TransformerSentenceEncoderLayer(\n", 171 | " (self_attn): MultiheadAttention(\n", 172 | " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", 173 | " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", 174 | " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", 175 | " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", 176 | " )\n", 177 | " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 178 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 179 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 180 | " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 181 | " )\n", 182 | " )\n", 183 | " (emb_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 184 | " )\n", 185 | " (lm_head): RobertaLMHead(\n", 186 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 187 | " (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 188 | " )\n", 189 | " )\n", 190 | " (classification_heads): ModuleDict(\n", 191 | " (sentence_classification_head): RobertaClassificationHead(\n", 192 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 193 | " (dropout): Dropout(p=0.0, inplace=False)\n", 194 | " (out_proj): Linear(in_features=768, out_features=2, bias=True)\n", 195 | " )\n", 196 | " )\n", 197 | " )\n", 198 | ")" 199 | ] 200 | }, 201 | "execution_count": 2, 202 | "metadata": {}, 203 | "output_type": "execute_result" 204 | } 205 | ], 206 | "source": [ 207 | "roberta = RobertaModel.from_pretrained(\n", 208 | " '../',\n", 209 | " checkpoint_file='checkpoints/checkpoint_best.pt',\n", 210 | " data_name_or_path=\"./data\"\n", 211 | ")\n", 212 | "\n", 213 | "roberta.eval() # disable dropout" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": 3, 219 | "metadata": {}, 220 | "outputs": [], 221 | "source": [ 222 | "def replace_query_token(sentence):\n", 223 | " \"Small utility function to replace a sentence with `_die_` or `_dat_` with the proper RobBERT input.\"\n", 224 | " tokens = nltk.word_tokenize(sentence)\n", 225 | " tokens_swapped = nltk.word_tokenize(sentence)\n", 226 | " for i, word in enumerate(tokens):\n", 227 | " if word == \"_die_\":\n", 228 | " tokens[i] = \"die\"\n", 229 | " tokens_swapped[i] = \"dat\"\n", 230 | "\n", 231 | " elif word == \"_dat_\":\n", 232 | " tokens[i] = \"dat\"\n", 233 | " tokens_swapped[i] = \"die\"\n", 234 | "\n", 235 | " elif word == \"_Dat_\":\n", 236 | " tokens[i] = \"Dat\"\n", 237 | " tokens_swapped[i] = \"Die\"\n", 238 | "\n", 239 | " elif word == \"_Die_\":\n", 240 | " tokens[i] = \"Die\"\n", 241 | " tokens_swapped[i] = \"Dat\"\n", 242 | "\n", 243 | "\n", 244 | " if word.lower() == \"_die_\" or word.lower() == \"_dat_\":\n", 245 | " results = TreebankWordDetokenizer().detokenize(tokens)\n", 246 | " results_swapped = TreebankWordDetokenizer().detokenize(tokens_swapped)\n", 247 | "\n", 248 | " return \"{} {}\".format(results, results_swapped) \n" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": 4, 254 | "metadata": {}, 255 | "outputs": [ 256 | { 257 | "name": "stdout", 258 | "output_type": "stream", 259 | "text": [ 260 | "Correct sentence: True\n" 261 | ] 262 | } 263 | ], 264 | "source": [ 265 | "sentence = \"Vervolgens zullen we de gebruikelijke procedure volgen, _dat_ wil zeggen dat we een voorstander en een tegenstander van dit verzoek het woord zullen geven.\"\n", 266 | "\n", 267 | "tokens = roberta.encode(replace_query_token(sentence))\n", 268 | "\n", 269 | "print(\"Correct sentence: \", roberta.predict('sentence_classification_head', tokens).argmax().item() == 1)" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": 5, 275 | "metadata": {}, 276 | "outputs": [ 277 | { 278 | "name": "stdout", 279 | "output_type": "stream", 280 | "text": [ 281 | "Correct sentence: False\n" 282 | ] 283 | } 284 | ], 285 | "source": [ 286 | "sentence = \"Vervolgens zullen we de gebruikelijke procedure volgen, _die_ wil zeggen dat we een voorstander en een tegenstander van dit verzoek het woord zullen geven.\"\n", 287 | "\n", 288 | "tokens = roberta.encode(replace_query_token(sentence))\n", 289 | "\n", 290 | "print(\"Correct sentence: \", roberta.predict('sentence_classification_head', tokens).argmax().item() == 1)" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": 6, 296 | "metadata": {}, 297 | "outputs": [ 298 | { 299 | "name": "stdout", 300 | "output_type": "stream", 301 | "text": [ 302 | "Correct sentence: False\n" 303 | ] 304 | } 305 | ], 306 | "source": [ 307 | "sentence = \"Daar loopt _die_ meisje.\"\n", 308 | "tokens = roberta.encode(replace_query_token(sentence))\n", 309 | "\n", 310 | "#print(roberta.predict('sentence_classification_head', tokens))\n", 311 | "print(\"Correct sentence: \", roberta.predict('sentence_classification_head', tokens).argmax().item() == 1)" 312 | ] 313 | } 314 | ], 315 | "metadata": { 316 | "kernelspec": { 317 | "display_name": "Python 3", 318 | "language": "python", 319 | "name": "python3" 320 | }, 321 | "language_info": { 322 | "codemirror_mode": { 323 | "name": "ipython", 324 | "version": 3 325 | }, 326 | "file_extension": ".py", 327 | "mimetype": "text/x-python", 328 | "name": "python", 329 | "nbconvert_exporter": "python", 330 | "pygments_lexer": "ipython3", 331 | "version": "3.7.4" 332 | } 333 | }, 334 | "nbformat": 4, 335 | "nbformat_minor": 4 336 | } -------------------------------------------------------------------------------- /notebooks/evaluate_zeroshot_wordlists_v2.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true, 8 | "jupyter": { 9 | "outputs_hidden": true 10 | } 11 | }, 12 | "outputs": [], 13 | "source": [ 14 | "from pathlib import Path\n", 15 | "import requests\n", 16 | "from fairseq.models.roberta import RobertaModel\n", 17 | "from src import evaluate_zeroshot_wordlist" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 2, 23 | "metadata": { 24 | "collapsed": false, 25 | "jupyter": { 26 | "outputs_hidden": false 27 | }, 28 | "pycharm": { 29 | "name": "#%%\n" 30 | } 31 | }, 32 | "outputs": [ 33 | { 34 | "name": "stdout", 35 | "output_type": "stream", 36 | "text": [ 37 | "loading archive file ../models/robbert.v2\n", 38 | "| dictionary: 39984 types\n", 39 | "RobBERT is loaded\n" 40 | ] 41 | } 42 | ], 43 | "source": [ 44 | "robbert = RobertaModel.from_pretrained(\n", 45 | " '../models/robbert.v2',\n", 46 | " checkpoint_file='model.pt',\n", 47 | " gpt2_encoder_json='../models/robbert.v2/vocab.json',\n", 48 | " gpt2_vocab_bpe='../models/robbert.v2/merges.txt',\n", 49 | ")\n", 50 | "robbert.eval()\n", 51 | "print(\"RobBERT is loaded\")" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 3, 57 | "metadata": { 58 | "collapsed": false, 59 | "jupyter": { 60 | "outputs_hidden": false 61 | }, 62 | "pycharm": { 63 | "name": "#%%\n" 64 | } 65 | }, 66 | "outputs": [], 67 | "source": [ 68 | "path = Path(\"..\", \"data\", \"processed\", \"wordlist\",\"die-dat.test.tsv\")" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 4, 74 | "metadata": { 75 | "collapsed": false, 76 | "jupyter": { 77 | "outputs_hidden": false 78 | }, 79 | "pycharm": { 80 | "name": "#%%\n" 81 | } 82 | }, 83 | "outputs": [ 84 | { 85 | "name": "stdout", 86 | "output_type": "stream", 87 | "text": [ 88 | "98.51% / 9851 / 10000 / 0 errors / die / die / Het is ook wegens de verwoestingen - afhankelijkheid, armoede, gezondheidsproblemen - drugsverslaving bij onze jongeren aanricht, dat het mij verheugt vast te stellen dat het Nederlandse voorzitterschap van de bestrijding van de internationale criminaliteit en de drugshandel een van zijn prioriteiten heeft gemaakt . Praktische samenwerking tussen de politiediensten, douanediensten en justitiële diensten om de drugshandel actief te bestrijden?\n", 89 | "98.61% / 19723 / 20000 / 0 errors / dat / dat / We zullen erop toezien dat wij Noorwegen en IJsland nimmer voor verrassingen of voor voldongen feiten stellen op punt.\n", 90 | "98.62% / 29585 / 30000 / 0 errors / dat / dat / Ook hoop ik dat Europa's industriëlen hiernaar luisteren en Total en Club Med in de voetsporen zullen treden van Heineken, Carlsberg, PepsiCo, Levi, Apple en Thomas Cook.\n", 91 | "98.62% / 39450 / 40000 / 0 errors / dat / dat / Ik wijs er in dit verband op een van de voorstellen die de Commissie aan de Intergouvernementele Conferentie heeft gedaan, de samensmelting betreft van de drie bestaande Gemeenschappen tot één enkele, wat zal leiden tot grotere duidelijkheid op dit gebied.\n", 92 | "98.63% / 49315 / 50000 / 0 errors / die / die / Het is zaak eigen Europese minimumnormen vast te stellen alleen een voldoende bescherming van passagiers, zeelieden en milieu zullen garanderen.\n", 93 | "98.65% / 59188 / 60000 / 0 errors / dat / dat / Zonder deze hervormingen zal de Unie verwateren . Indien de Unie wordt beroofd van sterke en coherente structuren vrees ik ze eens zal bezwijken onder de historische opdracht van de uitbreiding.\n", 94 | "98.61% / 69025 / 70000 / 0 errors / dat / dat / Uit het bovenstaande blijkt de begrotingsperikelen waarmee wij als Parlement kampen grotendeels het gevolg zijn van beslissingen die de Raad zelf heeft genomen.\n", 95 | "98.61% / 78891 / 80000 / 0 errors / dat / dat / Mijnheer de Voorzitter, ik wil hier alleen aangeven mijn naam niet op de presentielijst van gisteren voorkomt, vermoedelijk vanwege nalatigheid mijnerzijds, aangezien ik vergeten heb te tekenen.\n", 96 | "Error with B4-0825/97 van de leden André-Léonard, Fassa en Bertens, namens de ELDR-Fractie, over de terugtrekking van het VN-onderzoeksteam uit Congo; -B4-0832/97 van de leden Aelvoet en Telkämper, namens de V-Fractie, over de VN-onderzoeksreis naar Congo; -B4-0850/97 van de leden Dury en Swoboda, namens de PSE-Fractie, over de weigering om een onderzoekscommissie van de Verenigde Naties tot de Democratische Republiek Congo toe te laten; -B4-0856/97 van de leden Hory, Dell'Alba en Dupuis, namens de ARE-Fractie, over de onderzoekscommissie van de Verenigde Naties inzake de mensenrechten in de Democratische Republiek Congo; -B4-0863/97 van de leden Chanterie, Stasi, Tindemans, Verwaerde, Maij-Weggen en Oomen-Ruijten, namens de PPE-Fractie, over de situatie in de Democratische Republiek Congo; -B4-0877/97 van de leden Pettinari, Carnero González, Ojala en Sjöstedt, namens de GUE/NGL-Fractie, over de internationale onderzoeksdelegatie van de Verenigde Naties voor de schendingen van de rechten van de mens in het voormalige Zaïre; -B4-0890/97 van de leden Pasty en Azzolini, namens de UPE-Fractie, over de toestand in de Democratische Republiek Congo; -B4-0830/97 van de leden Bertens en Larive, namens de ELDR-Fractie, over het standpunt van de Europese Unie over de bevordering van de rechten van de mens in China; -B4-0847/97 van de heer Swoboda, namens de PSE-Fractie, over de bevordering van de mensenrechten in China; -B4-0855/97 van de leden Dupuis, Dell'Alba en Hory, namens de ARE-Fractie, over het standpunt van de Europese Unie inzake de bevordering van de mensenrechten in China; -B4-0862/97 van de leden McMillan-Scott en Habsburg-Lothringen, namens de PPE-Fractie, over het standpunt van de Europese Unie inzake de bevordering van de mensenrechten in China; -B4-0872/97 van de leden Aglietta en Schroedter, namens de V-Fractie, over de bevordering van de rechten van de mens in China; -B4-0828/97 van de leden Cars en La Malfa, namens de ELDR-Fractie, over de situatie in Kosovo; -B4-0837/97 van de leden Aelvoet, Cohn-Bendit, Gahrton, Müller en Tamino, namens de V-Fractie, over de situatie in Kosovo; -B4-0848/97 van de heer Swoboda, namens de PSE-Fractie, over de situatie in Kosovo; -B4-0854/97 van de leden Dupuis en Dell'Alba, namens de ARE-Fractie, over de situatie in Kosovo; -B4-0865/97 van de leden Oostlander, Pack, Habsburg-Lothringen, Maij-Weggen, Posselt en Oomen-Ruijten, namens de PPE-Fractie, over de situatie in Kosovo; -B4-0878/97 van de leden Manisco, Sjöstedt, Sierra González en Mohammed Alí, namens de GUE/NGLFractie, over de schendingen van de mensenrechten in Kosovo; -B4-0858/97 van de heer Pradier, namens de ARE-Fractie, over de omstandigheden in de detentie-inrichting van Khiam; -B4-0864/97 van de leden Soulier en Peijs, namens de PPE-Fractie, over de situatie van Souha Bechara, in hechtenis wordt gehouden in Zuid-Libanon; -B4-0879/97 van de leden Wurtz, Castellina, Marset Campos, Miranda, Ephremidis, Alavanos en Seppänen, namens de GUE/NGL-Fractie, over de vrijlating van Souha Bechara; -B4-0849/97 van de leden Hoff, Wiersma, Bösch en Swoboda, namens de PSE-Fractie, over de politieke situatie in Slowakije; -B4-0827/97 van de leden André-Léonard, Fassa, Bertens en Nordmann, namens de ELDR-Fractie, over Algerije.\t0\n", 97 | "\n", 98 | "98.61% / 88753 / 90000 / 1 errors / dat / dat / Met onze amendementen trachten wij ertoe bij te dragen dat het Fonds zo goed mogelijk wordt gecoördineerd, het additioneel en ter vergroting van het eigen potentieel wordt aangewend en dat de publieke opinie er kennis van neemt en het op zijn waarde kan schatten.\n", 99 | "98.63% / 98628 / 100000 / 1 errors / dat / dat / Ik verwacht van de Commissie zij daartegen optreedt.\n", 100 | "98.61% / 108475 / 110000 / 1 errors / dat / dat / Mijnheer de Voorzitter, dames en heren, het is goed het Eurodac-systeem de mogelijkheid biedt de misbruiken tegen te gaan die om de meest uiteenlopende redenen van het asielrecht worden gemaakt.\n", 101 | "98.62% / 118344 / 120000 / 1 errors / dat / dat / Zoals een van mijn kiezers uit Swindon vroeg:``De ruimte voor gekissebis over wat een``miniem gebrek aan gelijkvormigheid\"voorstelt, is grenzeloos en consumenten zonder scrupules worden er wellicht toe aangemoedigd om de onzekerheid uit te buiten die door deze voorstellen wordt gecreëerd, wat ertoe zal leiden de rechtbanken een overvloed aan zaken te verwerken zullen krijgen.\n", 102 | "98.63% / 128214 / 130000 / 1 errors / dat / dat / Mijnheer de Voorzitter, ik was zeer geïnteresseerd in uw antwoord aan de heer McMahon, waarin u zei , als de regels of de procedures zouden worden gewijzigd, u ze vanzelfsprekend zou verwijzen naar de Commissie Reglement, onderzoek geloofsbrieven en immuniteiten.\n", 103 | "98.62% / 138072 / 140000 / 1 errors / dat / dat / Wat betekent anders de lidstaten zelf kunnen beslissen welke grenswaarden zij vaststellen?\n", 104 | "98.64% / 147953 / 150000 / 1 errors / die / die / Wat betreft de overheidssteun aan de landbouw die nog altijd noodzakelijk is, zijn wij van mening dat op de lange termijn slechts door de Europese belastingbetaler en door onze concurrenten binnen de WHO zal worden geaccepteerd, als rekening wordt gehouden met gerechtvaardigde sociale, ecologische en streekgebonden zorgen.\n", 105 | "98.65% / 157843 / 160000 / 1 errors / die / die / Als we dus Europese verkiezingen gaan steunen de nationaliteitsgrenzen overschrijden, keren we onze rug naar nationale vertegenwoordiging.\n", 106 | "98.65% / 167706 / 170000 / 1 errors / die / die / De EU is via verschillende lidstaten en via het voorzitterschap actief betrokken bij het werk van de contactgroep, een document met opties voor de toekomstige status van Kosovo heeft uitgewerkt en dit heeft overhandigd aan de strijdende partijen.\n", 107 | "98.67% / 177605 / 180000 / 1 errors / dat / dat / Volgens mijn berekening hebt u uw spreektijd met 20% overschreden . U bent echter zo onderhoudend dat wij niet erg vinden.\n", 108 | "98.68% / 187485 / 190000 / 1 errors / dat / dat / Het tweede probleem mij voor de kwijting voor 1996 relevant lijkt, is de vraag hoe de Commissie omgaat met disciplinaire procedures.\n", 109 | "98.69% / 197374 / 200000 / 1 errors / dat / dat / Zij heeft er in dit verband op gewezen amendementen die onverenigbaar zijn met het voorstel van de bevoegde commissie niet ontvankelijk zijn.\n", 110 | "Error with A4-0437/98 van de heer Elchlepp, over het voorstel voor een besluit van de Raad en de Commissie inzake het door de Gemeenschap in te nemen standpunt in de Associatieraad werd ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en de Republiek Litouwen, anderzijds, ten aanzien van de vaststelling van uitvoeringsbepalingen voor artikel 64, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst (4216/98 - COM (98) 0119 - C4-0592/98-98/0075 (CNS) ); -A4-0443/98 van de heer Seppänen, over het voorstel voor een besluit van de Raad en de Commissie inzake het door de Gemeenschap in te nemen standpunt in de Associatieraad die werd ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en de Republiek Letland, anderzijds, ten aanzien van de vaststelling van uitvoeringsbepalingen voor artikel 64, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst (4215/98 - COM (98) 0068 - C4-0593/98-98/0076 (CNS) ); -A4-0472/98 van de heer van Dam, over het voorstel voor een besluit van de Raad en de Commissie inzake het door de Gemeenschap in te nemen standpunt in de Associatieraad die werd ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en de Republiek Estland, anderzijds, ten aanzien van de vaststelling van uitvoeringsbepalingen voor artikel 63, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst (4214/98 - COM (98) 0118 - C4-0594/98-98/0077 (CNS) ) en-A4-0419/98 van de heer Schwaiger, over het voorstel voor een besluit van de Raad en de Commissie betreffende het standpunt dat de Gemeenschap zal innemen in de Associatieraad die is ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en Roemenië, anderzijds, die op 1 februari 1993 in Brussel werd ondertekend, inzake de vaststelling van de voorschriften voor de tenuitvoerlegging van artikel 64, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst en voor de tenuitvoerlegging van artikel 9, lid 1, sub 1) en 2), van Protocol 2 betreffende EGKS-producten bij de Europa-Overeenkomst (COM (98) 0236 - C4-0275/98-98/0139 (CNS) ).\t0\n", 111 | "\n", 112 | "Error with A4-0437/98 van de heer Elchlepp, over het voorstel voor een besluit van de Raad en de Commissie inzake het door de Gemeenschap in te nemen standpunt in de Associatieraad die werd ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en de Republiek Litouwen, anderzijds, ten aanzien van de vaststelling van uitvoeringsbepalingen voor artikel 64, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst (4216/98 - COM (98) 0119 - C4-0592/98-98/0075 (CNS) ); -A4-0443/98 van de heer Seppänen, over het voorstel voor een besluit van de Raad en de Commissie inzake het door de Gemeenschap in te nemen standpunt in de Associatieraad werd ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en de Republiek Letland, anderzijds, ten aanzien van de vaststelling van uitvoeringsbepalingen voor artikel 64, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst (4215/98 - COM (98) 0068 - C4-0593/98-98/0076 (CNS) ); -A4-0472/98 van de heer van Dam, over het voorstel voor een besluit van de Raad en de Commissie inzake het door de Gemeenschap in te nemen standpunt in de Associatieraad die werd ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en de Republiek Estland, anderzijds, ten aanzien van de vaststelling van uitvoeringsbepalingen voor artikel 63, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst (4214/98 - COM (98) 0118 - C4-0594/98-98/0077 (CNS) ) en-A4-0419/98 van de heer Schwaiger, over het voorstel voor een besluit van de Raad en de Commissie betreffende het standpunt dat de Gemeenschap zal innemen in de Associatieraad die is ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en Roemenië, anderzijds, die op 1 februari 1993 in Brussel werd ondertekend, inzake de vaststelling van de voorschriften voor de tenuitvoerlegging van artikel 64, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst en voor de tenuitvoerlegging van artikel 9, lid 1, sub 1) en 2), van Protocol 2 betreffende EGKS-producten bij de Europa-Overeenkomst (COM (98) 0236 - C4-0275/98-98/0139 (CNS) ).\t0\n", 113 | "\n", 114 | "Error with A4-0437/98 van de heer Elchlepp, over het voorstel voor een besluit van de Raad en de Commissie inzake het door de Gemeenschap in te nemen standpunt in de Associatieraad die werd ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en de Republiek Litouwen, anderzijds, ten aanzien van de vaststelling van uitvoeringsbepalingen voor artikel 64, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst (4216/98 - COM (98) 0119 - C4-0592/98-98/0075 (CNS) ); -A4-0443/98 van de heer Seppänen, over het voorstel voor een besluit van de Raad en de Commissie inzake het door de Gemeenschap in te nemen standpunt in de Associatieraad die werd ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en de Republiek Letland, anderzijds, ten aanzien van de vaststelling van uitvoeringsbepalingen voor artikel 64, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst (4215/98 - COM (98) 0068 - C4-0593/98-98/0076 (CNS) ); -A4-0472/98 van de heer van Dam, over het voorstel voor een besluit van de Raad en de Commissie inzake het door de Gemeenschap in te nemen standpunt in de Associatieraad werd ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en de Republiek Estland, anderzijds, ten aanzien van de vaststelling van uitvoeringsbepalingen voor artikel 63, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst (4214/98 - COM (98) 0118 - C4-0594/98-98/0077 (CNS) ) en-A4-0419/98 van de heer Schwaiger, over het voorstel voor een besluit van de Raad en de Commissie betreffende het standpunt dat de Gemeenschap zal innemen in de Associatieraad die is ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en Roemenië, anderzijds, die op 1 februari 1993 in Brussel werd ondertekend, inzake de vaststelling van de voorschriften voor de tenuitvoerlegging van artikel 64, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst en voor de tenuitvoerlegging van artikel 9, lid 1, sub 1) en 2), van Protocol 2 betreffende EGKS-producten bij de Europa-Overeenkomst (COM (98) 0236 - C4-0275/98-98/0139 (CNS) ).\t0\n", 115 | "\n", 116 | "Error with A4-0437/98 van de heer Elchlepp, over het voorstel voor een besluit van de Raad en de Commissie inzake het door de Gemeenschap in te nemen standpunt in de Associatieraad die werd ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en de Republiek Litouwen, anderzijds, ten aanzien van de vaststelling van uitvoeringsbepalingen voor artikel 64, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst (4216/98 - COM (98) 0119 - C4-0592/98-98/0075 (CNS) ); -A4-0443/98 van de heer Seppänen, over het voorstel voor een besluit van de Raad en de Commissie inzake het door de Gemeenschap in te nemen standpunt in de Associatieraad die werd ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en de Republiek Letland, anderzijds, ten aanzien van de vaststelling van uitvoeringsbepalingen voor artikel 64, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst (4215/98 - COM (98) 0068 - C4-0593/98-98/0076 (CNS) ); -A4-0472/98 van de heer van Dam, over het voorstel voor een besluit van de Raad en de Commissie inzake het door de Gemeenschap in te nemen standpunt in de Associatieraad die werd ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en de Republiek Estland, anderzijds, ten aanzien van de vaststelling van uitvoeringsbepalingen voor artikel 63, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst (4214/98 - COM (98) 0118 - C4-0594/98-98/0077 (CNS) ) en-A4-0419/98 van de heer Schwaiger, over het voorstel voor een besluit van de Raad en de Commissie betreffende het standpunt de Gemeenschap zal innemen in de Associatieraad die is ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en Roemenië, anderzijds, die op 1 februari 1993 in Brussel werd ondertekend, inzake de vaststelling van de voorschriften voor de tenuitvoerlegging van artikel 64, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst en voor de tenuitvoerlegging van artikel 9, lid 1, sub 1) en 2), van Protocol 2 betreffende EGKS-producten bij de Europa-Overeenkomst (COM (98) 0236 - C4-0275/98-98/0139 (CNS) ).\t1\n", 117 | "\n", 118 | "Error with A4-0437/98 van de heer Elchlepp, over het voorstel voor een besluit van de Raad en de Commissie inzake het door de Gemeenschap in te nemen standpunt in de Associatieraad die werd ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en de Republiek Litouwen, anderzijds, ten aanzien van de vaststelling van uitvoeringsbepalingen voor artikel 64, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst (4216/98 - COM (98) 0119 - C4-0592/98-98/0075 (CNS) ); -A4-0443/98 van de heer Seppänen, over het voorstel voor een besluit van de Raad en de Commissie inzake het door de Gemeenschap in te nemen standpunt in de Associatieraad die werd ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en de Republiek Letland, anderzijds, ten aanzien van de vaststelling van uitvoeringsbepalingen voor artikel 64, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst (4215/98 - COM (98) 0068 - C4-0593/98-98/0076 (CNS) ); -A4-0472/98 van de heer van Dam, over het voorstel voor een besluit van de Raad en de Commissie inzake het door de Gemeenschap in te nemen standpunt in de Associatieraad die werd ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en de Republiek Estland, anderzijds, ten aanzien van de vaststelling van uitvoeringsbepalingen voor artikel 63, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst (4214/98 - COM (98) 0118 - C4-0594/98-98/0077 (CNS) ) en-A4-0419/98 van de heer Schwaiger, over het voorstel voor een besluit van de Raad en de Commissie betreffende het standpunt dat de Gemeenschap zal innemen in de Associatieraad is ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en Roemenië, anderzijds, die op 1 februari 1993 in Brussel werd ondertekend, inzake de vaststelling van de voorschriften voor de tenuitvoerlegging van artikel 64, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst en voor de tenuitvoerlegging van artikel 9, lid 1, sub 1) en 2), van Protocol 2 betreffende EGKS-producten bij de Europa-Overeenkomst (COM (98) 0236 - C4-0275/98-98/0139 (CNS) ).\t0\n", 119 | "\n", 120 | "Error with A4-0437/98 van de heer Elchlepp, over het voorstel voor een besluit van de Raad en de Commissie inzake het door de Gemeenschap in te nemen standpunt in de Associatieraad die werd ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en de Republiek Litouwen, anderzijds, ten aanzien van de vaststelling van uitvoeringsbepalingen voor artikel 64, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst (4216/98 - COM (98) 0119 - C4-0592/98-98/0075 (CNS) ); -A4-0443/98 van de heer Seppänen, over het voorstel voor een besluit van de Raad en de Commissie inzake het door de Gemeenschap in te nemen standpunt in de Associatieraad die werd ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en de Republiek Letland, anderzijds, ten aanzien van de vaststelling van uitvoeringsbepalingen voor artikel 64, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst (4215/98 - COM (98) 0068 - C4-0593/98-98/0076 (CNS) ); -A4-0472/98 van de heer van Dam, over het voorstel voor een besluit van de Raad en de Commissie inzake het door de Gemeenschap in te nemen standpunt in de Associatieraad die werd ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en de Republiek Estland, anderzijds, ten aanzien van de vaststelling van uitvoeringsbepalingen voor artikel 63, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst (4214/98 - COM (98) 0118 - C4-0594/98-98/0077 (CNS) ) en-A4-0419/98 van de heer Schwaiger, over het voorstel voor een besluit van de Raad en de Commissie betreffende het standpunt dat de Gemeenschap zal innemen in de Associatieraad die is ingesteld bij de Europa-Overeenkomst tussen de Europese Gemeenschappen en hun lidstaten, enerzijds, en Roemenië, anderzijds, op 1 februari 1993 in Brussel werd ondertekend, inzake de vaststelling van de voorschriften voor de tenuitvoerlegging van artikel 64, lid 1, sub i) en ii), en lid 2, van de Europa-Overeenkomst en voor de tenuitvoerlegging van artikel 9, lid 1, sub 1) en 2), van Protocol 2 betreffende EGKS-producten bij de Europa-Overeenkomst (COM (98) 0236 - C4-0275/98-98/0139 (CNS) ).\t0\n", 121 | "\n", 122 | "98.69% / 207253 / 210000 / 7 errors / dat / dat / Mevrouw de Voorzitter, dames en heren, allereerst wil ik u hartelijk danken voor het vriendelijke onthaal en voor het feit ik er in het Europees Parlement voor het eerst officieel als fungerend voorzitter van de Raad van ministers van Landbouw en Visserij bij mag zijn en het woord tot u mag richten.\n", 123 | "98.70% / 217134 / 220000 / 7 errors / die / die / Ik kan daarom geen amendementen overnemen dat op de helling zetten.\n", 124 | "98.70% / 227010 / 230000 / 7 errors / die / die / Strafheffingen de EU-landen echter verschillend treffen, beïnvloeden de comparatieve kosten- en concurrentiepositie op de interne markt.\n", 125 | "98.70% / 236890 / 240000 / 7 errors / die / die / De doelstellingen zijn: ten eerste het garanderen van het recht op regelmatige informatie en raadpleging van werknemers over de economische en strategische ontwikkelingen van de onderneming en over de voor hen belangrijke beslissingen; ten tweede de versterking van de sociale dialoog en het wederzijds vertrouwen binnen de onderneming, om bij te kunnen dragen aan het voorspellen van eventuele risico's, de ontwikkeling van de flexibiliteit in de arbeidsorganisatie in een zeker kader, de bewustmaking van de werknemers met betrekking tot de noodzaak van aanpassing en aan hun bereidheid om actief mee te werken aan maatregelen en acties ter verhoging van hun vakbekwaamheid en inzetbaarheid; ten derde de opneming van het werkgelegenheidsvraagstuk - de toestand daarvan en de te verwachten ontwikkeling - in de informatie- en raadplegingsprocedures; ten vierde de verzekering van voorafgaande informatie en raadpleging van de werknemers bij beslissingen kunnen leiden tot ingrijpende veranderingen in de arbeidsorganisatie of de arbeidsovereenkomsten en ten vijfde de verbetering van de doeltreffendheid van deze procedure via de vaststelling van specifieke sancties voor ernstige schendingen van de opgelegde verplichtingen.\n", 126 | "98.71% / 246769 / 250000 / 7 errors / die / die / Geen woord van veroordeling of zelfs maar spijt voor de ontoelaatbare militaire agressie tegen een Europese staat, door een NAVO blindelings vrouwen en kinderen doodt.\n", 127 | "98.72% / 256661 / 260000 / 7 errors / die / die / Met veel genoegen heb ik deelgenomen aan deze belangrijke democratische gebeurtenis, in de Europese politiek zijn weerga niet kent.\n", 128 | "98.73% / 266559 / 270000 / 7 errors / die / die / Maar de inwerkingtreding van dit Verdrag in de VS zou regionale kernmachten onder enorme druk hebben gezet om toe te treden tot het Verdrag . Denk maar aan India en Pakistan, bezig zijn hun eigen regionale afschrikkingsevenwicht op te bouwen, of Noord-Korea, Iran, Irak of Israël.\n", 129 | "98.73% / 276458 / 280000 / 7 errors / dat / dat / Maar het geeft hoop dat bad practices op basis van goede gegevens vergeleken kunnen worden en een humaan en effectief drugsbeleid ook in een geïntegreerde Europese Unie mogelijk worden.\n", 130 | "Die/Dat: 0.9874826436379628 285184 288799 7\n" 131 | ] 132 | } 133 | ], 134 | "source": [ 135 | "die_dat_correct, die_dat_total, die_dat_errors = evaluate_zeroshot_wordlist.evaluate(\n", 136 | " [\"die\", \"dat\"], path=path, model=robbert, print_step=10000)\n", 137 | "print(\"Die/Dat:\", die_dat_correct/die_dat_total, die_dat_correct, die_dat_total, die_dat_errors)" 138 | ] 139 | } 140 | ], 141 | "metadata": { 142 | "kernelspec": { 143 | "name": "python3", 144 | "language": "python", 145 | "display_name": "Python 3" 146 | }, 147 | "language_info": { 148 | "codemirror_mode": { 149 | "name": "ipython", 150 | "version": 3 151 | }, 152 | "file_extension": ".py", 153 | "mimetype": "text/x-python", 154 | "name": "python", 155 | "nbconvert_exporter": "python", 156 | "pygments_lexer": "ipython3", 157 | "version": "3.7.4" 158 | } 159 | }, 160 | "nbformat": 4, 161 | "nbformat_minor": 4 162 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | certifi==2022.9.14 2 | charset-normalizer==2.1.1 3 | click==8.1.3 4 | filelock==3.8.0 5 | Flask==2.2.2 6 | huggingface-hub==0.9.1 7 | idna==3.4 8 | importlib-metadata==4.12.0 9 | itsdangerous==2.1.2 10 | Jinja2==3.1.2 11 | joblib==1.2.0 12 | MarkupSafe==2.1.1 13 | nltk==3.7 14 | numpy==1.21.6 15 | packaging==21.3 16 | pyparsing==3.0.9 17 | PyYAML==6.0 18 | regex==2022.9.13 19 | requests==2.28.1 20 | tokenizers==0.12.1 21 | torch==1.12.1 22 | tqdm==4.64.1 23 | transformers==4.22.0 24 | typing-extensions==4.3.0 25 | urllib3==1.26.12 26 | Werkzeug==2.2.2 27 | zipp==3.8.1 28 | -------------------------------------------------------------------------------- /res/dbrd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iPieter/RobBERT/8f562fe3e79ec8c0ea04051277b6ae86e7e382e9/res/dbrd.png -------------------------------------------------------------------------------- /res/gender_diff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iPieter/RobBERT/8f562fe3e79ec8c0ea04051277b6ae86e7e382e9/res/gender_diff.png -------------------------------------------------------------------------------- /res/robbert_2022_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iPieter/RobBERT/8f562fe3e79ec8c0ea04051277b6ae86e7e382e9/res/robbert_2022_logo.png -------------------------------------------------------------------------------- /res/robbert_2022_logo_with_name.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iPieter/RobBERT/8f562fe3e79ec8c0ea04051277b6ae86e7e382e9/res/robbert_2022_logo_with_name.png -------------------------------------------------------------------------------- /res/robbert_2023_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iPieter/RobBERT/8f562fe3e79ec8c0ea04051277b6ae86e7e382e9/res/robbert_2023_logo.png -------------------------------------------------------------------------------- /res/robbert_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iPieter/RobBERT/8f562fe3e79ec8c0ea04051277b6ae86e7e382e9/res/robbert_logo.png -------------------------------------------------------------------------------- /res/robbert_logo_with_name.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iPieter/RobBERT/8f562fe3e79ec8c0ea04051277b6ae86e7e382e9/res/robbert_logo_with_name.png -------------------------------------------------------------------------------- /res/robbert_pos_accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iPieter/RobBERT/8f562fe3e79ec8c0ea04051277b6ae86e7e382e9/res/robbert_pos_accuracy.png -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/bert_masked_lm_adapter.py: -------------------------------------------------------------------------------- 1 | # Class to mimick the fill_mask functionality of the RobertaModel, but for BERT models. 2 | # Used to evaluate wordlistmasks 3 | import numpy as np 4 | import torch 5 | 6 | 7 | class BertMaskedLMAdapter: 8 | def __init__(self, model_name=None, model=None, tokenizer=None): 9 | # Check imports 10 | from transformers import BertForMaskedLM, BertTokenizer 11 | 12 | self._tokenizer = ( 13 | tokenizer if tokenizer else BertTokenizer.from_pretrained(model_name) 14 | ) 15 | self._model = model if model else BertForMaskedLM.from_pretrained(model_name) 16 | self._model.eval() 17 | 18 | def fill_mask(self, text, topk=4): 19 | text = text.replace("", "[MASK]") 20 | if not text.startswith("[CLS]"): 21 | text = "[CLS] " + text 22 | if not text.endswith("[SEP]"): 23 | text = text + " [SEP]" 24 | 25 | tokenized_text = self._tokenizer.tokenize(text) 26 | masked_index = tokenized_text.index("[MASK]") 27 | indexed_tokens = self._tokenizer.convert_tokens_to_ids(tokenized_text) 28 | # Create the segments tensors. 29 | segments_ids = [0] * len(tokenized_text) 30 | # Convert inputs to PyTorch tensors 31 | tokens_tensor = torch.tensor([indexed_tokens]) 32 | segments_tensors = torch.tensor([segments_ids]) 33 | 34 | with torch.no_grad(): 35 | predictions = self._model(tokens_tensor, segments_tensors) 36 | 37 | predicted_indices = list( 38 | np.argpartition(predictions[0][0][masked_index], -4)[-topk:] 39 | ) 40 | predicted_indices.reverse() 41 | return [ 42 | self._convert_token(predicted_index) 43 | for predicted_index in predicted_indices 44 | ] 45 | 46 | def _convert_token(self, predicted_index): 47 | text = self._tokenizer.convert_ids_to_tokens([predicted_index])[0] 48 | return None, None, text if text.startswith("##") else (" " + text) 49 | 50 | 51 | if __name__ == "__main__": 52 | from pathlib import Path 53 | 54 | from src import evaluate_zeroshot_wordlist 55 | 56 | mlm = BertMaskedLMAdapter(model_name="bert-base-multilingual-uncased") 57 | result = mlm.fill_mask("Er is een meisje daar loopt.", 1024) 58 | print(result) 59 | 60 | from src.wordlistfiller import WordListFiller 61 | 62 | wordfiller = WordListFiller(["die", "dat"], model=mlm) 63 | print(wordfiller.find_optimal_word("Er is een meisje daar loopt.")) 64 | evaluate_zeroshot_wordlist.evaluate( 65 | ["die", "dat"], 66 | path=Path("..", "data", "processed", "wordlist", "die-dat.test.tsv"), 67 | odel=mlm, 68 | print_step=10, 69 | ) 70 | -------------------------------------------------------------------------------- /src/convert_roberta_dict.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import argparse 5 | from pathlib import Path 6 | import json 7 | 8 | def create_arg_parser(): 9 | parser = argparse.ArgumentParser( 10 | description=main.__doc__) 11 | 12 | parser.add_argument("--path", help="Path to the corpus file.", metavar="path", 13 | default="../data/raw/UD_Dutch-LassySmall/") 14 | parser.add_argument( 15 | "--dict", 16 | help='path to dict.txt', 17 | required=True 18 | ) 19 | parser.add_argument( 20 | "--vocab-bpe", 21 | type=str, 22 | help='path to vocab.json', 23 | required=True 24 | ) 25 | 26 | parser.add_argument( 27 | "--output-vocab", 28 | type=str, 29 | help='Location for the output vocab.json', 30 | required=True 31 | ) 32 | 33 | return parser 34 | 35 | def load_roberta_mapping(file): 36 | "Returns a dict with the position of each word-id in the dict.txt file." 37 | 38 | # This file is basically an ordered count and we're not interested in the count. 39 | lines = {line.rstrip('\n').split()[0]: k for k, line in enumerate(file)} 40 | return lines 41 | 42 | 43 | def map_roberta(mapping, vocab): 44 | "Combine vocab.json and dict.txt contents." 45 | inverse_vocab = {str(v): k for k, v in vocab.items()} 46 | 47 | # We add 4 extra tokens, so they also need to be added to the position id 48 | EXTRA_TOKENS = {'': 0, '': 1, '': 2, '': 3} 49 | offset = len(EXTRA_TOKENS) 50 | 51 | output_vocab = EXTRA_TOKENS 52 | for word_id, position in mapping.items(): 53 | if word_id in inverse_vocab: 54 | output_vocab[inverse_vocab[word_id]] = position + offset 55 | else: 56 | print("not found: {}".format(word_id)) 57 | output_vocab[word_id] = position + offset 58 | 59 | output_vocab[''] = len(output_vocab) 60 | 61 | for word in [ inverse_vocab[x] for x in (set([str(vocab[k]) for k in vocab])-set(mapping)-set(EXTRA_TOKENS.keys()))]: 62 | output_vocab[word] = len(output_vocab) 63 | 64 | return output_vocab 65 | 66 | def main(args: argparse.Namespace): 67 | "Merge a vocab.json file with a dict.txt created by Fairseq." 68 | 69 | # First we load the dict file created by Fairseq's Roberta 70 | with open(args.dict, encoding="utf-8") as dict_fp: 71 | mapping = load_roberta_mapping(dict_fp) 72 | 73 | # Now we load the vocab file 74 | with open(args.vocab_bpe, encoding="utf-8") as vocab_fp: 75 | vocab = json.load(vocab_fp) 76 | 77 | output_vocab = map_roberta(mapping, vocab) 78 | 79 | with open(args.output_vocab, "w", encoding="utf-8") as output_fp: 80 | json.dump(output_vocab, output_fp, ensure_ascii=False) 81 | 82 | 83 | if __name__ == '__main__': 84 | arg_parser = create_arg_parser() 85 | args = arg_parser.parse_args() 86 | 87 | main(args) -------------------------------------------------------------------------------- /src/evaluate_zeroshot_wordlist.py: -------------------------------------------------------------------------------- 1 | # Class for evaluating the base RobBERT model on a word list classification task without any fine tuning. 2 | 3 | import argparse 4 | from pathlib import Path 5 | from typing import List 6 | 7 | from fairseq.models.roberta import RobertaModel 8 | 9 | from src.wordlistfiller import WordListFiller 10 | 11 | models_path = Path("..", "data", "processed", "wordlist") 12 | 13 | 14 | def evaluate(words: List[str], path: Path = None, model: RobertaModel = None, print_step: int = 1000): 15 | if not model: 16 | model = RobertaModel.from_pretrained( 17 | '../models/robbert', 18 | checkpoint_file='model.pt' 19 | ) 20 | model.eval() 21 | 22 | wordlistfiller = WordListFiller(words, model=model) 23 | 24 | dataset_path = path if path is not None else models_path / ("-".join(words) + ".tsv") 25 | 26 | correct = 0 27 | total = 0 28 | errors = 0 29 | 30 | with open(dataset_path) as input_file: 31 | for line in input_file: 32 | sentence, index = line.split('\t') 33 | expected = words[int(index.strip())] 34 | 35 | try: 36 | predicted = wordlistfiller.find_optimal_word(sentence) 37 | if predicted is None: 38 | errors += 1 39 | elif predicted == expected: 40 | correct += 1 41 | total += 1 42 | 43 | if total % print_step == 0: 44 | print("{0:.2f}%".format(100 * correct / total), 45 | correct, total, str(errors) + " errors", expected, predicted, sentence, sep=' / ') 46 | except Exception: 47 | print("Error with", line) 48 | errors += 1 49 | total += 1 50 | 51 | return correct, total, errors 52 | 53 | 54 | def create_parser(): 55 | parser = argparse.ArgumentParser( 56 | description="Preprocess the europarl corpus for the die-dat task." 57 | ) 58 | parser.add_argument("--words", help="List of comma-separated words to disambiguate", type=str, default="die,dat") 59 | parser.add_argument("--path", help="Path to the evaluation data", metavar="path", default=None) 60 | return parser 61 | 62 | 63 | if __name__ == "__main__": 64 | parser = create_parser() 65 | args = parser.parse_args() 66 | 67 | evaluate([x.strip() for x in args.words.split(",")], args.path) 68 | -------------------------------------------------------------------------------- /src/multiprocessing_bpe_encoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import argparse 9 | import contextlib 10 | import sys 11 | 12 | from collections import Counter 13 | from multiprocessing import Pool 14 | 15 | from fairseq.data.encoders.gpt2_bpe import get_encoder 16 | 17 | 18 | def main(): 19 | """ 20 | Helper script to encode raw text with the GPT-2 BPE using multiple processes. 21 | 22 | The encoder.json and vocab.bpe files can be obtained here: 23 | - https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json 24 | - https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe 25 | """ 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument( 28 | "--encoder-json", 29 | help='path to encoder.json', 30 | ) 31 | parser.add_argument( 32 | "--vocab-bpe", 33 | type=str, 34 | help='path to vocab.bpe', 35 | ) 36 | parser.add_argument( 37 | "--inputs", 38 | nargs="+", 39 | default=['-'], 40 | help="input files to filter/encode", 41 | ) 42 | parser.add_argument( 43 | "--outputs", 44 | nargs="+", 45 | default=['-'], 46 | help="path to save encoded outputs", 47 | ) 48 | parser.add_argument( 49 | "--keep-empty", 50 | action="store_true", 51 | help="keep empty lines", 52 | ) 53 | parser.add_argument("--workers", type=int, default=20) 54 | args = parser.parse_args() 55 | 56 | assert len(args.inputs) == len(args.outputs), \ 57 | "number of input and output paths should match" 58 | 59 | with contextlib.ExitStack() as stack: 60 | inputs = [ 61 | stack.enter_context(open(input, "r", encoding="utf-8")) 62 | if input != "-" else sys.stdin 63 | for input in args.inputs 64 | ] 65 | outputs = [ 66 | stack.enter_context(open(output, "w", encoding="utf-8")) 67 | if output != "-" else sys.stdout 68 | for output in args.outputs 69 | ] 70 | 71 | encoder = MultiprocessingEncoder(args) 72 | pool = Pool(args.workers, initializer=encoder.initializer) 73 | encoded_lines = pool.imap(encoder.encode_lines, zip(*inputs), 100) 74 | 75 | stats = Counter() 76 | for i, (filt, enc_lines) in enumerate(encoded_lines, start=1): 77 | if filt == "PASS": 78 | for enc_line, output_h in zip(enc_lines, outputs): 79 | print(enc_line, file=output_h) 80 | else: 81 | stats["num_filtered_" + filt] += 1 82 | if i % 10000 == 0: 83 | print("processed {} lines".format(i), file=sys.stderr) 84 | 85 | for k, v in stats.most_common(): 86 | print("[{}] filtered {} lines".format(k, v), file=sys.stderr) 87 | 88 | 89 | class MultiprocessingEncoder(object): 90 | 91 | def __init__(self, args): 92 | self.args = args 93 | 94 | def initializer(self): 95 | global bpe 96 | bpe = get_encoder(self.args.encoder_json, self.args.vocab_bpe) 97 | 98 | def encode(self, line): 99 | global bpe 100 | ids = bpe.encode(line) 101 | return list(map(str, ids)) 102 | 103 | def decode(self, tokens): 104 | global bpe 105 | return bpe.decode(tokens) 106 | 107 | def encode_lines(self, lines): 108 | """ 109 | Encode a set of lines. All lines will be encoded together. 110 | """ 111 | enc_lines = [] 112 | for line in lines: 113 | line = line.strip() 114 | if len(line) == 0 and not self.args.keep_empty: 115 | return ["EMPTY", None] 116 | tokens = self.encode(line) 117 | enc_lines.append(" ".join(tokens)) 118 | return ["PASS", enc_lines] 119 | 120 | def decode_lines(self, lines): 121 | dec_lines = [] 122 | for line in lines: 123 | tokens = map(int, line.strip().split()) 124 | dec_lines.append(self.decode(tokens)) 125 | return ["PASS", dec_lines] 126 | 127 | 128 | if __name__ == "__main__": 129 | main() 130 | -------------------------------------------------------------------------------- /src/preprocess_conll2002_ner.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | from src import preprocess_util 5 | 6 | ner_tags = ["O", "B-PER", "I-PER", "B-LOC", "I-LOC", "B-ORG", "I-ORG", "B-MISC", "I-MISC"] 7 | 8 | type_map = { 9 | "dev": "testa", 10 | "test": "testb", 11 | "train": "train", 12 | } 13 | 14 | 15 | def create_arg_parser(): 16 | parser = argparse.ArgumentParser( 17 | description="Preprocess the CONLL2002 corpus for the NER task." 18 | ) 19 | parser.add_argument("--path", help="Path to the corpus file.", metavar="path", 20 | default="../data/raw/conll2002/") 21 | parser.add_argument( 22 | "--encoder-json", 23 | help='path to encoder.json', 24 | default="../models/robbert/encoder.json" 25 | ) 26 | parser.add_argument( 27 | "--vocab-bpe", 28 | type=str, 29 | help='path to vocab.bpe', 30 | default="../models/robbert/vocab.bpe" 31 | ) 32 | 33 | return parser 34 | 35 | 36 | def get_dataset_suffix(type): 37 | return type_map.get(type, None) 38 | 39 | 40 | def get_label_index(label_name): 41 | return ner_tags.index(label_name) 42 | 43 | 44 | def process_connl2002_ner(arguments, type, processed_data_path, raw_data_path): 45 | output_sentences_path, output_labels_path, output_tokenized_sentences_path, output_tokenized_labels_path = preprocess_util.get_sequence_file_paths( 46 | processed_data_path, type) 47 | tokenizer = preprocess_util.get_tokenizer(arguments) 48 | 49 | with open(output_sentences_path, mode='w') as output_sentences: 50 | with open(output_labels_path, mode='w') as output_labels: 51 | with open(output_tokenized_sentences_path, mode='w') as output_tokenized_sentences: 52 | with open(output_tokenized_labels_path, mode='w') as output_tokenized_labels: 53 | 54 | dataset_suffix = get_dataset_suffix(type) 55 | if dataset_suffix is None: 56 | raise Exception("Invalid type", type) 57 | 58 | file_path = raw_data_path / ("ned." + dataset_suffix) 59 | with open(file_path) as file_to_read: 60 | # Add new line after seeing comments or new line 61 | has_content = False 62 | for line in file_to_read: 63 | if not line.startswith("-DOCSTART-") and len(line.strip()) > 0: 64 | has_content = True 65 | 66 | word, pos, ner = line.split(" ") 67 | 68 | # Write out normal word & label 69 | word = word.strip() 70 | label = str(get_label_index(ner.strip())) 71 | preprocess_util.write_sequence_word_label(word, label, tokenizer, output_sentences, 72 | output_labels, 73 | output_tokenized_sentences, 74 | output_tokenized_labels) 75 | elif has_content: 76 | output_sentences.write("\n") 77 | output_labels.write("\n") 78 | output_tokenized_sentences.write("\n") 79 | output_tokenized_labels.write("\n") 80 | has_content = False 81 | 82 | 83 | if __name__ == "__main__": 84 | arg_parser = create_arg_parser() 85 | args = arg_parser.parse_args() 86 | 87 | processed_data_path = Path("..", "data", "processed", "conll2002") 88 | processed_data_path.mkdir(parents=True, exist_ok=True) 89 | 90 | process_connl2002_ner(args, 'train', processed_data_path, Path(args.path)) 91 | process_connl2002_ner(args, 'dev', processed_data_path, Path(args.path)) 92 | process_connl2002_ner(args, 'test', processed_data_path, Path(args.path)) 93 | -------------------------------------------------------------------------------- /src/preprocess_dbrd.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | from os import listdir 4 | from os.path import isfile, join 5 | 6 | 7 | def create_arg_parser(): 8 | parser = argparse.ArgumentParser( 9 | description="Preprocess the Dutch Book Reviews Dataset corpus for the sentiment analysis tagging task." 10 | ) 11 | parser.add_argument("--path", help="Path to the corpus folder.", metavar="path", default="data/raw/DBRD/") 12 | 13 | return parser 14 | 15 | 16 | def get_file_id(a): 17 | return int(a.split('_')[0]) 18 | 19 | 20 | def get_files_of_folder(folder): 21 | return [f for f in listdir(folder) if isfile(join(folder, f))] 22 | 23 | 24 | def add_content_and_label(file_location, output_sentences, output_labels, label): 25 | with open(file_location) as file: 26 | content = " ".join(file.readlines()).replace('\n', ' ').replace('\r', '') 27 | single_spaced_content = ' '.join(content.split()) 28 | output_sentences.write(single_spaced_content + "\n") 29 | output_labels.write(str(label) + "\n") 30 | 31 | 32 | def process_dbrd(raw_data_path, test_or_train): 33 | processed_data_path = Path("..", "data", "processed", "dbrd") 34 | processed_data_path.mkdir(parents=True, exist_ok=True) 35 | 36 | output_sentences_path = processed_data_path / (test_or_train + ".sentences.txt") 37 | output_labels_path = processed_data_path / (test_or_train + ".labels.txt") 38 | 39 | with open(output_sentences_path, mode='w') as output_sentences: 40 | with open(output_labels_path, mode='w') as output_labels: 41 | pos_files_folder = raw_data_path / test_or_train / 'pos' 42 | neg_files_folder = raw_data_path / test_or_train / 'neg' 43 | 44 | pos_files = get_files_of_folder(pos_files_folder) 45 | neg_files = get_files_of_folder(neg_files_folder) 46 | 47 | pos_files.sort(key=get_file_id) 48 | neg_files.sort(key=get_file_id) 49 | 50 | 51 | assert len(pos_files) == len(neg_files) 52 | 53 | # process file by intertwining the files, such that the model can learn better 54 | for i in range(len(pos_files)): 55 | add_content_and_label(pos_files_folder / pos_files[i], output_sentences, output_labels, 1) 56 | add_content_and_label(neg_files_folder / neg_files[i], output_sentences, output_labels, 0) 57 | 58 | 59 | if __name__ == "__main__": 60 | arg_parser = create_arg_parser() 61 | args = arg_parser.parse_args() 62 | 63 | process_dbrd(Path(args.path), 'train') 64 | process_dbrd(Path(args.path), 'test') 65 | -------------------------------------------------------------------------------- /src/preprocess_diedat.py: -------------------------------------------------------------------------------- 1 | import random 2 | import nltk 3 | from nltk.tokenize.treebank import TreebankWordDetokenizer 4 | import argparse 5 | 6 | 7 | def replace_die_dat_full_sentence(line, flout, fout): 8 | count = 0 9 | for i, word in enumerate(nltk.word_tokenize(line)): 10 | tokens = nltk.word_tokenize(line) 11 | if word == "die": 12 | tokens[i] = "dat" 13 | elif word == "dat": 14 | tokens[i] = "die" 15 | elif word == "Dat": 16 | tokens[i] = "Die" 17 | elif word == "Die": 18 | tokens[i] = "dat" 19 | 20 | if word.lower() == "die" or word.lower() == "dat": 21 | choice = random.getrandbits(1) 22 | results = TreebankWordDetokenizer().detokenize(tokens) 23 | 24 | if choice: 25 | output = "{} {}".format(results, line) 26 | else: 27 | output = "{} {}".format(line, results) 28 | 29 | fout.write(output + "\n") 30 | flout.write(str(choice) + "\n") 31 | count += 1 32 | 33 | return count 34 | 35 | 36 | def create_arg_parser(): 37 | parser = argparse.ArgumentParser( 38 | description="Preprocess the europarl corpus for the die-dat task." 39 | ) 40 | parser.add_argument("--path", help="Path to the corpus file.", metavar="path", 41 | default="data/raw/europarl/europarl-v7.nl-en.nl") 42 | parser.add_argument("--number", help="Number of examples in the output dataset", type=int, default=10000) 43 | 44 | return parser 45 | 46 | 47 | if __name__ == "__main__": 48 | parser = create_arg_parser() 49 | 50 | args = parser.parse_args() 51 | 52 | with open(args.path + '.labels', mode='a') as labels_output: 53 | with open(args.path + '.sentences', mode='a') as sentences_output: 54 | with open(args.path) as fp: 55 | lines_processed = 0 56 | for line in fp: 57 | line = line.replace('\n', '').replace('\r', '') 58 | count = replace_die_dat_full_sentence(line, labels_output, sentences_output) 59 | lines_processed += count 60 | if lines_processed >= args.number: 61 | break 62 | -------------------------------------------------------------------------------- /src/preprocess_diedat.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | split -l $[ $(wc -l data/europarl-v7.nl-en.nl.labels|cut -d" " -f1) * 70 / 100 ] data/europarl-v7.nl-en.nl.labels 4 | mv xaa data/processed/diedat/train.labels 5 | mv xab data/processed/diedat/dev.labels 6 | 7 | split -l $[ $(wc -l data/europarl-v7.nl-en.nl.sentences|cut -d" " -f1) * 70 / 100 ] data/europarl-v7.nl-en.nl.sentences 8 | mv xaa data/processed/diedat/train.sentences 9 | mv xab data/processed/diedat/dev.sentences 10 | 11 | rm data/labels/dict.txt 12 | 13 | for SPLIT in train dev; do 14 | python src/multiprocessing_bpe_encoder.py\ 15 | --encoder-json data/encoder.json \ 16 | --vocab-bpe data/vocab.bpe \ 17 | --inputs "data/processed/diedat/$SPLIT.sentences" \ 18 | --outputs "data/processed/diedat/$SPLIT.sentences.bpe" \ 19 | --workers 24 \ 20 | --keep-empty 21 | done 22 | 23 | fairseq-preprocess \ 24 | --only-source \ 25 | --trainpref "data/processed/diedat/train.sentences.bpe" \ 26 | --validpref "data/processed/diedat/dev.sentences.bpe" \ 27 | --destdir "data/input0" \ 28 | --workers 24 \ 29 | --srcdict data/dict.txt 30 | 31 | fairseq-preprocess \ 32 | --only-source \ 33 | --trainpref "data/processed/diedat/train.labels" \ 34 | --validpref "data/processed/diedat/dev.labels" \ 35 | --destdir "data/diedat/labels" \ 36 | --workers 24 37 | 38 | -------------------------------------------------------------------------------- /src/preprocess_lassy_ud.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | from src import preprocess_util 5 | 6 | universal_pos_tags = ["ADJ", "ADP", "ADV", "AUX", "CCONJ", "DET", "INTJ", 7 | "NOUN", "NUM", "PART", "PRON", "PROPN", "PUNCT", "SCONJ", "SYM", "VERB", "X"] 8 | 9 | 10 | def create_arg_parser(): 11 | parser = argparse.ArgumentParser( 12 | description="Preprocess the LASSY corpus for the POS-tagging task." 13 | ) 14 | parser.add_argument("--path", help="Path to the corpus file.", metavar="path", 15 | default="../data/raw/UD_Dutch-LassySmall/") 16 | parser.add_argument( 17 | "--encoder-json", 18 | help='path to encoder.json', 19 | default="../models/robbert.large/encoder.json" 20 | ) 21 | parser.add_argument( 22 | "--vocab-bpe", 23 | type=str, 24 | help='path to vocab.bpe', 25 | default="../models/robbert/vocab.bpe" 26 | ) 27 | 28 | return parser 29 | 30 | 31 | def get_label_index(label_name): 32 | return universal_pos_tags.index(label_name) 33 | 34 | 35 | def process_lassy_ud(arguments, type, processed_data_path, raw_data_path): 36 | output_sentences_path, output_labels_path, output_tokenized_sentences_path, output_tokenized_labels_path = preprocess_util.get_sequence_file_paths( 37 | processed_data_path, type) 38 | tokenizer = preprocess_util.get_tokenizer(arguments) 39 | 40 | with open(output_sentences_path, mode='w') as output_sentences: 41 | with open(output_labels_path, mode='w') as output_labels: 42 | with open(output_tokenized_sentences_path, mode='w') as output_tokenized_sentences: 43 | with open(output_tokenized_labels_path, mode='w') as output_tokenized_labels: 44 | 45 | file_path = raw_data_path / ("nl_lassysmall-ud-" + type + ".conllu") 46 | with open(file_path) as file_to_read: 47 | # For removing first blank line 48 | has_content = False 49 | for line in file_to_read: 50 | if not line.startswith("#") and len(line.strip()) > 0: 51 | has_content = True 52 | 53 | index, word, main_word, universal_pos, detailed_tag, details, number, english_tag, \ 54 | number_and_english_tag, space_after = line.split("\t") 55 | 56 | # Write out normal word & label 57 | word = word.strip() 58 | label = str(get_label_index(universal_pos.strip())) 59 | preprocess_util.write_sequence_word_label(word, label, tokenizer, output_sentences, 60 | output_labels, 61 | output_tokenized_sentences, 62 | output_tokenized_labels) 63 | elif has_content: 64 | output_sentences.write("\n") 65 | output_labels.write("\n") 66 | output_tokenized_sentences.write("\n") 67 | output_tokenized_labels.write("\n") 68 | has_content = False 69 | 70 | 71 | if __name__ == "__main__": 72 | arg_parser = create_arg_parser() 73 | args = arg_parser.parse_args() 74 | 75 | processed_data_path = Path("..", "data", "processed", "lassy_ud") 76 | processed_data_path.mkdir(parents=True, exist_ok=True) 77 | 78 | process_lassy_ud(args, 'train', processed_data_path, Path(args.path)) 79 | process_lassy_ud(args, 'dev', processed_data_path, Path(args.path)) 80 | process_lassy_ud(args, 'test', processed_data_path, Path(args.path)) 81 | -------------------------------------------------------------------------------- /src/preprocess_util.py: -------------------------------------------------------------------------------- 1 | from src.multiprocessing_bpe_encoder import MultiprocessingEncoder 2 | 3 | seperator_token = "\t" 4 | 5 | 6 | def get_sequence_file_paths(processed_data_path, type): 7 | output_sentences_path = processed_data_path / (type + ".sentences.tsv") 8 | output_labels_path = processed_data_path / (type + ".labels.tsv") 9 | output_tokenized_sentences_path = processed_data_path / (type + ".sentences.bpe") 10 | output_tokenized_labels_path = processed_data_path / (type + ".labels.bpe") 11 | return output_sentences_path, output_labels_path, output_tokenized_sentences_path, output_tokenized_labels_path 12 | 13 | 14 | def write_sequence_word_label(word, label, tokenizer, output_sentences, output_labels, output_tokenized_sentences, 15 | output_tokenized_labels): 16 | output_sentences.write(word.strip() + seperator_token) 17 | output_labels.write(label + seperator_token) 18 | 19 | # Write tokenized 20 | tokenized_word = tokenizer.encode(word.strip()) 21 | output_tokenized_sentences.write(seperator_token.join(tokenized_word) + seperator_token) 22 | output_tokenized_labels.write(len(tokenized_word) * (label + seperator_token)) 23 | 24 | 25 | def get_tokenizer(arguments): 26 | tokenizer = MultiprocessingEncoder(arguments) 27 | tokenizer.initializer() 28 | return tokenizer 29 | -------------------------------------------------------------------------------- /src/preprocess_wordlist_mask.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | from src.wordlistfiller import WordListFiller 5 | 6 | 7 | def create_arg_parser(): 8 | parser = argparse.ArgumentParser( 9 | description="Preprocess the europarl corpus for the die-dat task." 10 | ) 11 | parser.add_argument("--path", help="Path to the corpus file.", metavar="path", 12 | default="../data/raw/europarl-v7.nl-en.nl") 13 | parser.add_argument("--filename", help="Extra for file name", metavar="path", 14 | default="") 15 | parser.add_argument("--words", help="List of comma-separated words to disambiguate", type=str, default="die,dat") 16 | parser.add_argument("--number", help="Number of examples in the output dataset", type=int, default=10000000) 17 | 18 | return parser 19 | 20 | 21 | if __name__ == "__main__": 22 | parser = create_arg_parser() 23 | args = parser.parse_args() 24 | 25 | models_path = Path("..", "data", "processed", "wordlist") 26 | models_path.mkdir(parents=True, exist_ok=True) 27 | 28 | words = [x.strip() for x in args.words.split(",")] 29 | wordlistfiller = WordListFiller(words) 30 | 31 | output_path = models_path / (args.words.replace(',', '-') 32 | + (('.' + args.filename) if args.filename else '') + ".tsv") 33 | 34 | with open(output_path, mode='w') as output: 35 | with open(args.path) as input_file: 36 | number_of_lines_to_add = args.number 37 | for line in input_file: 38 | line = line.strip() 39 | sentences = wordlistfiller.occlude_target_words_index(line) 40 | number_of_sentences = len(sentences) 41 | 42 | for i in range(min(number_of_lines_to_add, number_of_sentences)): 43 | sentence = sentences[i] 44 | output.write(sentence[0] + "\t" + str(sentence[1]) + '\n') 45 | number_of_lines_to_add -= number_of_sentences 46 | 47 | if number_of_lines_to_add <= 0: 48 | break 49 | -------------------------------------------------------------------------------- /src/pretrain.pbs: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #PBS -l partition=gpu 3 | #PBS -l walltime=12:00:00 4 | #PBS -l nodes=5:ppn=36:gpus=4 5 | #PBS -l pmem=5gb 6 | #PBS -l excludenodes=r24g39:r22g41 7 | #PBS -N alBERT 8 | #PBS -A lp_dtai1 9 | 10 | 11 | echo "changing dir" 12 | cd $VSC_DATA 13 | source miniconda3/bin/activate torch 14 | cd alBERT 15 | 16 | DATA_DIR=/scratch/leuven/x/x/data-bin/nl_dedup 17 | 18 | MASTER_ADDR=$(cat $PBS_NODEFILE | head -n 1) 19 | MASTER_PORT=6666 20 | WORLD_SIZE=20 21 | 22 | echo "Master node is" $MASTER_ADDR 23 | 24 | pbsdsh -u sh $PBS_O_WORKDIR/run_node.sh $MASTER_ADDR $MASTER_PORT $WORLD_SIZE 25 | 26 | echo "Done training, cleaning up syncfile." 27 | rm /scratch/leuven/x/x/sync_torch/syncfile 28 | -------------------------------------------------------------------------------- /src/split_dbrd_training.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd data/processed/dbrd 4 | 5 | mv train.sentences.txt traineval.sentences.txt 6 | mv train.labels.txt traineval.labels.txt 7 | echo "renamed files to traineval.*.txt" 8 | 9 | head traineval.sentences.txt -n -500 > train.sentences.txt 10 | head traineval.labels.txt -n -500 > train.labels.txt 11 | head traineval.sentences.txt -n 500 > eval.sentences.txt 12 | head traineval.labels.txt -n 500 > eval.labels.txt 13 | -------------------------------------------------------------------------------- /src/textdataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import pickle 4 | import torch 5 | from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class TextDataset(Dataset): 11 | def __init__(self, tokenizer, model_name_or_path, file_path="train", block_size=512, overwrite_cache=True, mask_padding_with_zero=True): 12 | assert os.path.isfile(file_path + '.sentences.txt') 13 | 14 | assert os.path.isfile(file_path + '.labels.txt') 15 | 16 | directory, filename = os.path.split(file_path) 17 | cached_features_file = os.path.join( 18 | directory, model_name_or_path + "_cached_lm_" + str(block_size) + "_" + filename 19 | ) 20 | 21 | if os.path.exists(cached_features_file) and not overwrite_cache: 22 | logger.info("Loading features from cached file %s", cached_features_file) 23 | with open(cached_features_file, "rb") as handle: 24 | self.examples = pickle.load(handle) 25 | else: 26 | logger.info("Creating features from dataset file at %s", directory) 27 | 28 | self.examples = [] 29 | with open(file_path + ".labels.txt", encoding="utf-8") as flabel: 30 | with open(file_path + ".sentences.txt", encoding="utf-8") as f: 31 | for sentence in f: 32 | tokenized_text = tokenizer.encode(tokenizer.tokenize(sentence)[-block_size + 3 : -1]) 33 | 34 | input_mask = [1 if mask_padding_with_zero else 0] * len(tokenized_text) 35 | 36 | pad_token = tokenizer.convert_tokens_to_ids(tokenizer.pad_token) 37 | while len(tokenized_text) < block_size: 38 | tokenized_text.append(pad_token) 39 | input_mask.append(0 if mask_padding_with_zero else 1) 40 | #segment_ids.append(pad_token_segment_id) 41 | #p_mask.append(1) 42 | 43 | #self.examples.append([tokenizer.build_inputs_with_special_tokens(tokenized_text[0 : block_size]), [0], [0]]) 44 | label = next(flabel) 45 | self.examples.append([tokenized_text[0 : block_size - 3], input_mask[0 : block_size - 3], [1] if label.startswith("1") else [0]]) 46 | 47 | logger.info("Saving features into cached file %s", cached_features_file) 48 | with open(cached_features_file, "wb") as handle: 49 | pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL) 50 | 51 | def __len__(self): 52 | return len(self.examples) 53 | 54 | def __getitem__(self, item): 55 | return [torch.tensor(self.examples[item][0]), torch.tensor(self.examples[item][1]), torch.tensor([0]), torch.tensor(self.examples[item][2])] 56 | 57 | 58 | def load_and_cache_examples(model_name_or_path, tokenizer, data_file): 59 | dataset = TextDataset( 60 | tokenizer, 61 | model_name_or_path, 62 | file_path=data_file, 63 | block_size=512 64 | ) 65 | return dataset 66 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | from transformers import AdamW 2 | import torch 3 | from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler 4 | import os 5 | from tqdm import tqdm, trange 6 | import pandas as pd 7 | from pycm import ConfusionMatrix 8 | 9 | try: 10 | from torch.utils.tensorboard import SummaryWriter 11 | except ImportError: 12 | from tensorboardX import SummaryWriter 13 | 14 | import json 15 | 16 | from transformers import get_linear_schedule_with_warmup 17 | import logging 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | class Train: 22 | def train(args, train_dataset, model, tokenizer, evaluate_fn=None): 23 | """ Train the model """ 24 | 25 | model.train() 26 | 27 | if args.local_rank in [-1, 0]: 28 | tb_writer = SummaryWriter() 29 | 30 | # Setup CUDA, GPU & distributed training 31 | if args.local_rank == -1 or args.no_cuda: 32 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 33 | args.n_gpu = torch.cuda.device_count() 34 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 35 | torch.cuda.set_device(args.local_rank) 36 | device = torch.device("cuda", args.local_rank) 37 | torch.distributed.init_process_group(backend="nccl") 38 | args.n_gpu = 1 39 | args.device = device 40 | 41 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) 42 | train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) 43 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) 44 | 45 | if args.max_steps > 0: 46 | t_total = args.max_steps 47 | args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 48 | else: 49 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs 50 | 51 | # Prepare optimizer and schedule (linear warmup and decay) 52 | no_decay = ["bias", "LayerNorm.weight"] 53 | optimizer_grouped_parameters = [ 54 | { 55 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 56 | "weight_decay": args.weight_decay, 57 | }, 58 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, 59 | ] 60 | 61 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 62 | scheduler = get_linear_schedule_with_warmup( 63 | optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total 64 | ) 65 | 66 | # Check if saved optimizer or scheduler states exist 67 | if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile( 68 | os.path.join(args.model_name_or_path, "scheduler.pt") 69 | ): 70 | # Load in optimizer and scheduler states 71 | optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))) 72 | scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"))) 73 | 74 | if args.fp16: 75 | try: 76 | from apex import amp 77 | except ImportError: 78 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 79 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) 80 | 81 | # multi-gpu training (should be after apex fp16 initialization) 82 | if args.n_gpu > 1: 83 | model = torch.nn.DataParallel(model) 84 | 85 | # Distributed training (should be after apex fp16 initialization) 86 | if args.local_rank != -1: 87 | model = torch.nn.parallel.DistributedDataParallel( 88 | model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True, 89 | ) 90 | 91 | if args.local_rank == 0: 92 | torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab 93 | 94 | model.to(args.device) 95 | 96 | # Train! 97 | logger.info("***** Running training *****") 98 | logger.info(" Num examples = %d", len(train_dataset)) 99 | logger.info(" Num Epochs = %d", args.num_train_epochs) 100 | logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) 101 | logger.info( 102 | " Total train batch size (w. parallel, distributed & accumulation) = %d", 103 | args.train_batch_size 104 | * args.gradient_accumulation_steps 105 | * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), 106 | ) 107 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 108 | logger.info(" Total optimization steps = %d", t_total) 109 | 110 | global_step = 0 111 | epochs_trained = 0 112 | steps_trained_in_current_epoch = 0 113 | # Check if continuing training from a checkpoint 114 | if os.path.exists(args.model_name_or_path): 115 | # set global_step to gobal_step of last saved checkpoint from model path 116 | global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0]) 117 | epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) 118 | steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps) 119 | 120 | logger.info(" Continuing training from checkpoint, will skip to saved global_step") 121 | logger.info(" Continuing training from epoch %d", epochs_trained) 122 | logger.info(" Continuing training from global step %d", global_step) 123 | logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) 124 | 125 | tr_loss, logging_loss = 0.0, 0.0 126 | model.zero_grad() 127 | train_iterator = trange( 128 | epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0], 129 | ) 130 | #set_seed(args) # Added here for reproductibility 131 | for _ in train_iterator: 132 | epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0], position=0, leave=True) 133 | for step, batch in enumerate(epoch_iterator): 134 | 135 | # Skip past any already trained steps if resuming training 136 | if steps_trained_in_current_epoch > 0: 137 | steps_trained_in_current_epoch -= 1 138 | continue 139 | 140 | model.train() 141 | batch = tuple(t.to(args.device) for t in batch) 142 | inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]} 143 | if args.model_type != "distilbert": 144 | inputs["token_type_ids"] = ( 145 | batch[2] if args.model_type in ["bert", "xlnet", "albert"] else None 146 | ) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids 147 | outputs = model(**inputs) 148 | loss = outputs[0] # model outputs are always tuple in transformers (see doc) 149 | 150 | if args.n_gpu > 1: 151 | loss = loss.mean() # mean() to average on multi-gpu parallel training 152 | if args.gradient_accumulation_steps > 1: 153 | loss = loss / args.gradient_accumulation_steps 154 | 155 | if args.fp16: 156 | with amp.scale_loss(loss, optimizer) as scaled_loss: 157 | scaled_loss.backward() 158 | else: 159 | loss.backward() 160 | 161 | tr_loss += loss.item() 162 | if (step + 1) % args.gradient_accumulation_steps == 0: 163 | if args.fp16: 164 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) 165 | else: 166 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 167 | 168 | optimizer.step() 169 | scheduler.step() # Update learning rate schedule 170 | model.zero_grad() 171 | global_step += 1 172 | 173 | if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: 174 | logs = {} 175 | if ( 176 | args.local_rank == -1 and args.evaluate_during_training 177 | ): # Only evaluate when single GPU otherwise metrics may not average well 178 | results = evaluate(args, model, tokenizer) 179 | for key, value in results.items(): 180 | eval_key = "eval_{}".format(key) 181 | logs[eval_key] = value 182 | 183 | loss_scalar = (tr_loss - logging_loss) / args.logging_steps 184 | learning_rate_scalar = scheduler.get_lr()[0] 185 | logs["learning_rate"] = learning_rate_scalar 186 | logs["loss"] = loss_scalar 187 | logging_loss = tr_loss 188 | 189 | for key, value in logs.items(): 190 | tb_writer.add_scalar(key, value, global_step) 191 | epoch_iterator.set_postfix({**logs, **{"step": global_step}}) 192 | 193 | if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: 194 | # Save model checkpoint 195 | output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step)) 196 | if not os.path.exists(output_dir): 197 | os.makedirs(output_dir) 198 | model_to_save = ( 199 | model.module if hasattr(model, "module") else model 200 | ) # Take care of distributed/parallel training 201 | model_to_save.save_pretrained(output_dir) 202 | tokenizer.save_pretrained(output_dir) 203 | 204 | torch.save(args, os.path.join(output_dir, "training_args.bin")) 205 | logger.info("Saving model checkpoint to %s", output_dir) 206 | 207 | torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) 208 | torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) 209 | logger.info("Saving optimizer and scheduler states to %s", output_dir) 210 | 211 | if args.max_steps > 0 and global_step > args.max_steps: 212 | epoch_iterator.close() 213 | break 214 | if args.max_steps > 0 and global_step > args.max_steps: 215 | train_iterator.close() 216 | break 217 | 218 | if evaluate_fn is not None: 219 | results = pd.DataFrame(evaluate_fn(args.evaluate_dataset, model)) 220 | cm = ConfusionMatrix(actual_vector=results['true'].values, predict_vector=results['predicted'].values) 221 | logs = {} 222 | logs["eval_f1_macro"] = cm.F1_Macro 223 | logs["eval_acc_macro"] = cm.ACC_Macro 224 | logs["eval_acc_overall"] = cm.Overall_ACC 225 | logger.info("Results on eval: {}".format(logs)) 226 | 227 | 228 | if args.local_rank in [-1, 0]: 229 | tb_writer.close() 230 | 231 | return global_step, tr_loss / global_step 232 | -------------------------------------------------------------------------------- /src/train_config.py: -------------------------------------------------------------------------------- 1 | class Config: 2 | 3 | def __init__(self): 4 | "Ugly hack to make the args object work without spending to much effort." 5 | self.local_rank = -1 6 | self.per_gpu_train_batch_size = 4 7 | self.gradient_accumulation_steps = 8 8 | self.n_gpu = 1 9 | self.max_steps = 2000 10 | self.weight_decay = 0 11 | self.learning_rate = 5e-5 12 | self.adam_epsilon = 2e-8 13 | self.warmup_steps = 250 14 | self.model_name_or_path = "bert" 15 | self.fp16 = False 16 | self.set_seed = "42" 17 | self.device=0 18 | self.model_type="roberta" 19 | self.no_cuda=False 20 | self.max_grad_norm=1.0 21 | self.logging_steps=1 22 | self.evaluate_during_training=False 23 | self.save_steps=2500 24 | self.output_dir="./" 25 | self.evaluate_dataset="" 26 | -------------------------------------------------------------------------------- /src/train_diedat.sh: -------------------------------------------------------------------------------- 1 | TOTAL_NUM_UPDATES=7812 # 10 epochs 2 | WARMUP_UPDATES=469 # 6 percent of the number of updates 3 | LR=1e-05 # Peak LR for polynomial LR scheduler. 4 | HEAD_NAME= die-dat-head # Custom name for the classification head. 5 | NUM_CLASSES=2 # Number of classes for the classification task. 6 | MAX_SENTENCES=8 # Batch size. 7 | ROBERTA_PATH=data/checkpoint_best.pt 8 | 9 | fairseq-train data \ 10 | --restore-file $ROBERTA_PATH \ 11 | --max-positions 512 \ 12 | --max-sentences $MAX_SENTENCES \ 13 | --max-tokens 4400 \ 14 | --task sentence_prediction \ 15 | --reset-optimizer --reset-dataloader --reset-meters \ 16 | --required-batch-size-multiple 1 \ 17 | --init-token 0 --separator-token 2 \ 18 | --arch roberta_base \ 19 | --criterion sentence_prediction \ 20 | --num-classes $NUM_CLASSES \ 21 | --dropout 0.1 --attention-dropout 0.1 \ 22 | --weight-decay 0.1 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \ 23 | --clip-norm 0.0 \ 24 | --lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \ 25 | --fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \ 26 | --max-epoch 10 \ 27 | --best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \ 28 | --truncate-sequence \ 29 | --find-unused-parameters \ 30 | --update-freq 4 \ 31 | --save-interval 1 \ 32 | --save-dir checkpoints 33 | -------------------------------------------------------------------------------- /src/wordlistfiller.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union 2 | 3 | from fairseq.models.roberta import RobertaHubInterface 4 | from nltk import word_tokenize 5 | from nltk.tokenize.treebank import TreebankWordDetokenizer 6 | 7 | from src.bert_masked_lm_adapter import BertMaskedLMAdapter 8 | 9 | 10 | class WordListFiller: 11 | def __init__(self, target_words: List[str], 12 | model: Union[RobertaHubInterface, BertMaskedLMAdapter] = None, 13 | detokenizer: TreebankWordDetokenizer = TreebankWordDetokenizer(), 14 | topk_limit=2048): 15 | self._model = model 16 | self._target_words = [x.lower().strip() for x in target_words] 17 | self._target_words_spaced = set([" " + word for word in self._target_words]) 18 | self._detokenizer = detokenizer 19 | self._top_k_limit = topk_limit 20 | 21 | def find_optimal_word(self, text: str) -> str: 22 | if self._model is None: 23 | raise AttributeError("No model given to find the optimal word") 24 | topk = 4 25 | result = None 26 | while result is None and topk <= self._top_k_limit: 27 | filler_words = self._model.fill_mask(text, topk=topk) 28 | result = next((x[2].strip() for x in filler_words if x[2].lower() in self._target_words_spaced), None) 29 | topk *= 2 30 | return result 31 | 32 | # Transforms a sentence into a list of sentences with target words masked if the sentence contains this 33 | def occlude_target_words(self, input_sentence: str) -> List[Tuple[str, str]]: 34 | tokenized = word_tokenize(input_sentence) 35 | result = [] 36 | for i in range(len(tokenized)): 37 | if tokenized[i] in self._target_words: 38 | new_sentence_tokens = tokenized[:i] + [""] + tokenized[i + 1:] 39 | new_sentence = self._detokenizer.detokenize(new_sentence_tokens) 40 | result.append((new_sentence, tokenized[i])) 41 | return result 42 | 43 | def occlude_target_words_index(self, input_sentence: str) -> List[Tuple[str, int]]: 44 | return [(s[0], self._target_words.index(s[1].strip().lower())) for s in 45 | self.occlude_target_words(input_sentence)] 46 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iPieter/RobBERT/8f562fe3e79ec8c0ea04051277b6ae86e7e382e9/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_convert_roberta_dict.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from io import StringIO 3 | 4 | from src.convert_roberta_dict import load_roberta_mapping, map_roberta 5 | 6 | 7 | class ConvertRobertaTestCase(unittest.TestCase): 8 | def test_load_roberta_mapping(self): 9 | # Create a mock file 10 | file = StringIO(initial_value="3 12\n1 8\n2 7") 11 | mapping = load_roberta_mapping(file) 12 | 13 | # test our little mapping 14 | self.assertEqual(mapping['3'], 0, msg="First element in dict.txt is 3, should have id = 0") 15 | self.assertEqual(mapping['1'], 1, msg="Second element in dict.txt is 1, should have id = 1") 16 | self.assertEqual(mapping['2'], 2, msg="Third element in dict.txt is 2, should have id = 2") 17 | 18 | def test_map_roberta(self): 19 | file = StringIO(initial_value="3 12\n1 8\n2 7") 20 | mapping = load_roberta_mapping(file) 21 | 22 | vocab = {"Een": 3, "Twee": 2, "Drie": 1} 23 | 24 | output_vocab = map_roberta(mapping, vocab) 25 | 26 | self.assertEqual(output_vocab[''], 0, msg="Extra tokens") 27 | self.assertEqual(output_vocab[''], 1, msg="Extra tokens") 28 | self.assertEqual(output_vocab[''], 2, msg="Extra tokens") 29 | self.assertEqual(output_vocab[''], 3, msg="Extra tokens") 30 | self.assertEqual(output_vocab['Een'], 4, msg="'Een' has vocab_id = 3, which is mapped to 0 (+4)") 31 | self.assertEqual(output_vocab['Twee'], 6, msg="'Twee' has vocab_id = 3, which is mapped to 2 (+4)") 32 | self.assertEqual(output_vocab['Drie'], 5, msg="'Drie' has vocab_id = 1, which is mapped to 1 (+4)") 33 | 34 | 35 | def test_map_roberta_unused_tokens(self): 36 | """ 37 | The fast HuggingFace tokenizer requires that all tokens in the merges.txt are also present in the 38 | vocab.json. When converting a Fairseq dict.txt, this is not necessarily the case in a naive implementation. 39 | 40 | More info: https://github.com/huggingface/transformers/issues/9290 41 | """ 42 | 43 | file = StringIO(initial_value="3 12\n1 8\n2 7") 44 | mapping = load_roberta_mapping(file) 45 | 46 | # Tokens "Vier" and "Vijf" are not used in the mapping (= dixt.txt) 47 | vocab = {"Een": 3, "Twee": 2, "Drie": 1, "Vier": 5, "Vijf": 4} 48 | 49 | output_vocab = map_roberta(mapping, vocab) 50 | 51 | self.assertEqual(output_vocab[''], 0, msg="Extra tokens") 52 | self.assertEqual(output_vocab[''], 1, msg="Extra tokens") 53 | self.assertEqual(output_vocab[''], 2, msg="Extra tokens") 54 | self.assertEqual(output_vocab[''], 3, msg="Extra tokens") 55 | self.assertEqual(output_vocab['Een'], 4, msg="'Een' has vocab_id = 3, which is mapped to 0 (+4)") 56 | self.assertEqual(output_vocab['Twee'], 6, msg="'Twee' has vocab_id = 3, which is mapped to 2 (+4)") 57 | self.assertEqual(output_vocab['Drie'], 5, msg="'Drie' has vocab_id = 1, which is mapped to 1 (+4)") 58 | self.assertIn(output_vocab['Vier'], [7, 8], msg="'Vier' has vocab_id = 5, which is mapped the next available value") 59 | self.assertIn(output_vocab['Vijf'], [8, 7], msg="'Vijf' has vocab_id = 4, which is mapped the next available value") 60 | self.assertNotEqual(output_vocab['Vijf'], output_vocab['Vier'], msg="Unused tokens must have different values") 61 | 62 | def test_tokenization(self): 63 | sample_input = "De tweede poging: nog een test van de tokenizer met nummers." 64 | expected_output = [0, 62, 488, 5351, 30, 49, 9, 2142, 7, 5, 905, 859, 10605, 15, 3316, 4, 2] 65 | 66 | 67 | if __name__ == '__main__': 68 | unittest.main() 69 | --------------------------------------------------------------------------------