├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── discussion.md │ └── feature_request.md ├── demo.png ├── logo.png ├── logo.svg ├── loss.png └── stale.yml ├── .gitignore ├── 3rd └── gdown.pl │ ├── LICENSE.txt │ ├── README.md │ └── gdown.pl ├── LICENSE ├── README.md ├── README_CN.md ├── configs ├── base.json ├── large.json └── mega.json ├── dataset ├── README.md ├── prepare_data.py └── prepare_data.sh ├── dockerfiles └── gpu-jupyter.Dockerfile ├── pretrained_model_demo.ipynb ├── requirements-gpu.txt ├── requirements-tpu.txt ├── scripts ├── demo.py └── down_gdrive_file.py ├── tokenization ├── __init__.py ├── bert-base-chinese-vocab.txt ├── bert-large-cased-whole-word-masking-vocab.txt ├── clue-vocab.txt └── tokenization.py └── train ├── __init__.py ├── dataloader.py ├── modeling.py ├── optimization_adafactor.py ├── train_tpu.py ├── train_tpu_adafactor.sh └── utils.py /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a bug report to help us improve GPT2-ML 4 | title: "[Bug] name your bug" 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | > Below is the issue template. You can fill each part then submit your issue. 11 | > Or you can just delete all of these and describe your questions in you-like style. 12 | > But please remember: the more detailed info you offered, the greater possibility your problem will be solved. 😜 13 | 14 | Please write a clear and concise description of what the bug is. 15 | 16 | ## Expected behavior 17 | 18 | Please write a clear and concise description of what you expected to happen. 19 | 20 | ## Environment 21 | 22 | - Python version: 23 | - OS: 24 | - (Optional) Other libraries and their versions: 25 | 26 | ## Error messages, stack traces, or logs 27 | 28 | ``` 29 | # error messages, stack traces, or logs 30 | ``` 31 | 32 | ## Steps to reproduce 33 | 34 | 1. 35 | 2. 36 | 3. 37 | 38 | ## Reproducible examples (optional) 39 | 40 | ```python 41 | # python code 42 | ``` 43 | 44 | ## Additional context (optional) 45 | 46 | Please add any other context or screenshots about the problem here. 47 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/discussion.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Discussion 3 | about: Ideas sharing or theorical question solving 4 | title: "[Discussion] your question" 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | 11 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for GPT2-ML 4 | title: "[Feature] your feature name" 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## Description 11 | 12 | 13 | 14 | ## Additional information 15 | 16 | 17 | -------------------------------------------------------------------------------- /.github/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imcaspar/gpt2-ml/f6286b16cbbee6dddbe1ba557fffb47eaf998cd1/.github/demo.png -------------------------------------------------------------------------------- /.github/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imcaspar/gpt2-ml/f6286b16cbbee6dddbe1ba557fffb47eaf998cd1/.github/logo.png -------------------------------------------------------------------------------- /.github/logo.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.github/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imcaspar/gpt2-ml/f6286b16cbbee6dddbe1ba557fffb47eaf998cd1/.github/loss.png -------------------------------------------------------------------------------- /.github/stale.yml: -------------------------------------------------------------------------------- 1 | # Number of days of inactivity before an issue becomes stale 2 | daysUntilStale: 365 3 | # Number of days of inactivity before a stale issue is closed 4 | daysUntilClose: 30 5 | # Issues with these labels will never be considered stale 6 | exemptLabels: 7 | - pinned 8 | - security 9 | # Label to use when marking an issue as stale 10 | staleLabel: wontfix 11 | # Comment to post when marking an issue as stale. Set to `false` to disable 12 | markComment: > 13 | This issue has been automatically marked as stale because it has not had 14 | recent activity. It will be closed if no further activity occurs. Thank you 15 | for your contributions. 16 | # Comment to post when closing a stale issue. Set to `false` to disable 17 | closeComment: true -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # vscode 132 | .vscode/ 133 | 134 | # dataset 135 | dataset/raw/ 136 | 137 | # models 138 | models/ -------------------------------------------------------------------------------- /3rd/gdown.pl/LICENSE.txt: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /3rd/gdown.pl/README.md: -------------------------------------------------------------------------------- 1 | gdown.pl 2 | ======== 3 | 4 | Google Drive direct download of big files 5 | 6 | Requirements 7 | ============ 8 | 9 | *wget* and *Perl* must be in the PATH. 10 | **Windows** and **linux** compatible. 11 | 12 | Usage 13 | ===== 14 | 15 | Use Google Drive shareable links, viewable by anyone: 16 | 17 | $ ./gdown.pl 'gdrive file url' ['desired file name'] 18 | 19 | Example 20 | ======= 21 | 22 | For example, to download [this video](https://drive.google.com/file/d/0B1L_hFrWJfRhLUJZdXdSdTdfSWs/edit) from my [axolotl project](https://circulosmeos.wordpress.com/2015/03/04/axolotl-a-simple-plain-text-documentation-system/), just copy the url, and give a file name if desired: 23 | 24 | $ ./gdown.pl https://drive.google.com/file/d/0B1L_hFrWJfRhLUJZdXdSdTdfSWs/edit axolotl.mp4 25 | 26 | Resuming a download 27 | =================== 28 | 29 | If you need to resume a download, please, use [**gdown.pl v2.0** here](https://github.com/circulosmeos/gdown.pl/tree/with-resume). 30 | As long as a file name is indicated as second parameter, *gdown.pl v2.0* **will try to resume the partially downloaded file** if a local incomplete file with that name already exists. 31 | 32 | Version 33 | ======= 34 | 35 | This version is **v1.4**. 36 | 37 | ### Warning 38 | 39 | Please, note that v1.2 (available between days 12 to 31 of Jan/2019) **should not be used**, as it contains a bug that could result in unusable downloaded files. Proceed to overwrite with v1.3 in case you have it. 40 | 41 | Docker 42 | ====== 43 | 44 | A simple Docker file is provided, to build a simple Docker image with gdown.pl. 45 | This has been used for pre-pulling data from a Google Drive to Kubernetes persistent volumes. Thanks @anton-khodak 46 | 47 | License 48 | ======= 49 | 50 | Distributed [under GPL 3](http://www.gnu.org/licenses/gpl-3.0.html) 51 | 52 | Disclaimer 53 | ========== 54 | 55 | This software is provided "as is", without warranty of any kind, express or implied. 56 | 57 | More info 58 | ========= 59 | 60 | [https://circulosmeos.wordpress.com/2014/04/12/google-drive-direct-download-of-big-files](https://circulosmeos.wordpress.com/2014/04/12/google-drive-direct-download-of-big-files) 61 | 62 | Contact 63 | ======= 64 | 65 | by [circulosmeos](loopidle@gmail.com) 66 | -------------------------------------------------------------------------------- /3rd/gdown.pl/gdown.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # 3 | # Google Drive direct download of big files 4 | # ./gdown.pl 'gdrive file url' ['desired file name'] 5 | # 6 | # v1.0 by circulosmeos 04-2014. 7 | # v1.1 by circulosmeos 01-2017. 8 | # v1.2, 2.0 by circulosmeos 01-2019. 9 | # v2.1 by circulosmeos 12-2020. 10 | # //circulosmeos.wordpress.com/2014/04/12/google-drive-direct-download-of-big-files 11 | # Distributed under GPL 3 (//www.gnu.org/licenses/gpl-3.0.html) 12 | # 13 | use strict; 14 | use POSIX; 15 | 16 | my $TEMP='gdown.cookie.temp'; 17 | my $COMMAND; 18 | my $confirm; 19 | my $check; 20 | sub execute_command(); 21 | 22 | my $URL=shift; 23 | die "\n./gdown.pl 'gdrive file url' [desired file name]\n\n" if $URL eq ''; 24 | 25 | my $FILENAME=shift; 26 | my $TEMP_FILENAME='gdown.'.strftime("%Y%m%d%H%M%S", localtime).'.'.substr(rand,2); 27 | 28 | if ($URL=~m#^https?://drive.google.com/file/d/([^/]+)#) { 29 | $URL="https://docs.google.com/uc?id=$1&export=download"; 30 | } 31 | elsif ($URL=~m#^https?://drive.google.com/open\?id=([^/]+)#) { 32 | $URL="https://docs.google.com/uc?id=$1&export=download"; 33 | } 34 | 35 | execute_command(); 36 | 37 | while (-s $TEMP_FILENAME < 100000) { # only if the file isn't the download yet 38 | open fFILENAME, '<', $TEMP_FILENAME; 39 | $check=0; 40 | foreach () { 41 | if (/href="(\/uc\?export=download[^"]+)/) { 42 | $URL='https://docs.google.com'.$1; 43 | $URL=~s/&/&/g; 44 | $confirm=''; 45 | $check=1; 46 | last; 47 | } 48 | if (/confirm=([^;&]+)/) { 49 | $confirm=$1; 50 | $check=1; 51 | last; 52 | } 53 | if (/"downloadUrl":"([^"]+)/) { 54 | $URL=$1; 55 | $URL=~s/\\u003d/=/g; 56 | $URL=~s/\\u0026/&/g; 57 | $confirm=''; 58 | $check=1; 59 | last; 60 | } 61 | } 62 | close fFILENAME; 63 | die "Couldn't download the file :-(\n" if ($check==0); 64 | $URL=~s/confirm=([^;&]+)/confirm=$confirm/ if $confirm ne ''; 65 | 66 | execute_command(); 67 | 68 | } 69 | 70 | unlink $TEMP; 71 | 72 | sub execute_command() { 73 | my $OUTPUT_FILENAME = $TEMP_FILENAME; 74 | my $CONTINUE = ''; 75 | 76 | # check contents before download & if a $FILENAME has been indicated resume on content download 77 | # please, note that for this to work, wget must correctly provide --spider with --server-response (-S) 78 | if ( length($FILENAME) > 0 ) { 79 | $COMMAND="wget -q -S --no-check-certificate --spider --load-cookie $TEMP --save-cookie $TEMP \"$URL\" 2>&1"; 80 | my @HEADERS=`$COMMAND`; 81 | foreach my $header (@HEADERS) { 82 | if ( ( $header =~ /Content-Type: (.+)/ && $1 !~ 'text/html' ) || 83 | $header =~ 'HTTP/1.1 405 Method Not Allowed' 84 | ) { 85 | $OUTPUT_FILENAME = $FILENAME; 86 | $CONTINUE = '-c'; 87 | last; 88 | } 89 | } 90 | } 91 | 92 | $COMMAND="wget $CONTINUE --progress=dot:giga --no-check-certificate --load-cookie $TEMP --save-cookie $TEMP \"$URL\""; 93 | $COMMAND.=" -O \"$OUTPUT_FILENAME\""; 94 | my $OUTPUT = system( $COMMAND ); 95 | if ( $OUTPUT == 2 ) { # do a clean exit with Ctrl+C 96 | unlink $TEMP; 97 | die "\nDownloading interrupted by user\n\n"; 98 | } elsif ( $OUTPUT == 0 && length($CONTINUE)>0 ) { # do a clean exit with $FILENAME provided 99 | unlink $TEMP; 100 | die "\nDownloading complete\n\n"; 101 | } 102 | return 1; 103 | } 104 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # **GPT2** for Multiple Languages 4 | 5 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/imcaspar/gpt2-ml/blob/master/pretrained_model_demo.ipynb) 6 | [![GitHub](https://img.shields.io/github/license/imcaspar/gpt2-ml)](https://github.com/imcaspar/gpt2-ml) 7 | [![GitHub All Releases](https://img.shields.io/github/downloads/imcaspar/gpt2-ml/total)](https://github.com/imcaspar/gpt2-ml/releases) 8 | [![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/imcaspar/gpt2-ml/issues) 9 | [![GitHub stars](https://img.shields.io/github/stars/imcaspar/gpt2-ml?style=social)](https://github.com/imcaspar/gpt2-ml) 10 | 11 | [**中文说明**](./README_CN.md) | [**English**](./README.md) 12 | 13 | - [x] Simplifed GPT2 train scripts(based on Grover, supporting TPUs) 14 | - [x] Ported bert tokenizer, multilingual corpus compatible 15 | - [x] 1.5B GPT2 pretrained Chinese model ( ~15G corpus, 10w steps ) 16 | - [x] Batteries-included Colab demo [#](https://github.com/imcaspar/gpt2-ml#google-colab) 17 | - [x] 1.5B GPT2 pretrained Chinese model ( ~30G corpus, 22w steps ) 18 | 19 | 20 | ## Pretrained Model 21 | | Size | Language | Corpus | Vocab | Link1 | Link2 | SHA256 | 22 | | ---- | -------- | ------ | ----- | ----- | ----- | ------ | 23 | | 1.5B Params | Chinese | ~30G | CLUE ( 8021 tokens ) | [**Google Drive**](https://drive.google.com/file/d/1mT_qCQg4AWnAXTwKfsyyRWCRpgPrBJS3) | [**Baidu Pan (ffz6)**](https://pan.baidu.com/s/1yiuTHXUr2DpyBqmFYLJH6A) | e698cc97a7f5f706f84f58bb469d614e
51d3c0ce5f9ab9bf77e01e3fcb41d482 | 24 | | 1.5B Params | Chinese | ~15G | Bert ( 21128 tokens ) | [**Google Drive**](https://drive.google.com/file/d/1IzWpQ6I2IgfV7CldZvFJnZ9byNDZdO4n) | [**Baidu Pan (q9vr)**](https://pan.baidu.com/s/1TA_3e-u2bXg_hcx_NwVbGw) | 4a6e5124df8db7ac2bdd902e6191b807
a6983a7f5d09fb10ce011f9a073b183e | 25 | 26 | Corpus from [THUCNews](http://thuctc.thunlp.org/#%E4%B8%AD%E6%96%87%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB%E6%95%B0%E6%8D%AE%E9%9B%86THUCNews) and [nlp_chinese_corpus](https://github.com/brightmart/nlp_chinese_corpus) 27 | 28 | Using [Cloud TPU Pod v3-256](https://cloud.google.com/tpu/docs/types-zones#types) to train 22w steps 29 | 30 | ![loss](./.github/loss.png) 31 | 32 | 33 | ## Google Colab 34 | With just 2 clicks (not including Colab auth process), the 1.5B pretrained Chinese model demo is ready to go: 35 | 36 | [**[Colab Notebook]**](https://colab.research.google.com/github/imcaspar/gpt2-ml/blob/master/pretrained_model_demo.ipynb) 37 | 38 | 39 | 40 | ## Train 41 | 42 | ## Disclaimer 43 | The contents in this repository are for academic research purpose, and we do not provide any conclusive remarks. 44 | 45 | ## Citation 46 | 47 | ``` 48 | @misc{GPT2-ML, 49 | author = {Zhibo Zhang}, 50 | title = {GPT2-ML: GPT-2 for Multiple Languages}, 51 | year = {2019}, 52 | publisher = {GitHub}, 53 | journal = {GitHub repository}, 54 | howpublished = {\url{https://github.com/imcaspar/gpt2-ml}}, 55 | } 56 | ``` 57 | 58 | ## Reference 59 | https://github.com/google-research/bert 60 | 61 | https://github.com/rowanz/grover 62 | 63 | Research supported with Cloud TPUs from Google's TensorFlow Research Cloud (TFRC) 64 | 65 | ## Press 66 | [[机器之心] 只需单击三次,让中文GPT-2为你生成定制故事](https://mp.weixin.qq.com/s/FpoSNNKZSQOE2diPvJDHog) 67 | 68 | [[科学空间] 现在可以用Keras玩中文GPT2了](https://kexue.fm/archives/7292) 69 | -------------------------------------------------------------------------------- /README_CN.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # **GPT2** for Multiple Languages 4 | 5 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/imcaspar/gpt2-ml/blob/master/pretrained_model_demo.ipynb) 6 | [![GitHub](https://img.shields.io/github/license/imcaspar/gpt2-ml)](https://github.com/imcaspar/gpt2-ml) 7 | [![GitHub All Releases](https://img.shields.io/github/downloads/imcaspar/gpt2-ml/total)](https://github.com/imcaspar/gpt2-ml/releases) 8 | [![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/imcaspar/gpt2-ml/issues) 9 | [![GitHub stars](https://img.shields.io/github/stars/imcaspar/gpt2-ml?style=social)](https://github.com/imcaspar/gpt2-ml) 10 | 11 | [**中文说明**](./README_CN.md) | [**English**](./README.md) 12 | 13 | - [x] 简化整理 GPT2 训练代码(based on Grover, supporting TPUs) 14 | - [x] 移植 bert tokenizer,添加多语言支持 15 | - [x] 15亿参数 GPT2 中文预训练模型( 15G 语料,训练 10w 步 ) 16 | - [x] 开箱即用的模型生成效果 demo [#](https://github.com/imcaspar/gpt2-ml#google-colab) 17 | - [x] 15亿参数 GPT2 中文预训练模型( 30G 语料,训练 22w 步 ) 18 | 19 | 20 | ## 预训练模型 21 | | Size | Language | Corpus | Vocab | Link1 | Link2 | SHA256 | 22 | | ---- | -------- | ------ | ----- | ----- | ----- | ------ | 23 | | 1.5B Params | Chinese | ~30G | CLUE ( 8021 tokens ) | [**Google Drive**](https://drive.google.com/file/d/1mT_qCQg4AWnAXTwKfsyyRWCRpgPrBJS3) | [**Baidu Pan (ffz6)**](https://pan.baidu.com/s/1yiuTHXUr2DpyBqmFYLJH6A) | e698cc97a7f5f706f84f58bb469d614e
51d3c0ce5f9ab9bf77e01e3fcb41d482 | 24 | | 1.5B Params | Chinese | ~15G | Bert ( 21128 tokens ) | [**Google Drive**](https://drive.google.com/file/d/1IzWpQ6I2IgfV7CldZvFJnZ9byNDZdO4n) | [**Baidu Pan (q9vr)**](https://pan.baidu.com/s/1TA_3e-u2bXg_hcx_NwVbGw) | 4a6e5124df8db7ac2bdd902e6191b807
a6983a7f5d09fb10ce011f9a073b183e | 25 | 26 | 训练语料来自 [THUCNews](http://thuctc.thunlp.org/#%E4%B8%AD%E6%96%87%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB%E6%95%B0%E6%8D%AE%E9%9B%86THUCNews) 以及 [nlp_chinese_corpus](https://github.com/brightmart/nlp_chinese_corpus),清洗后总文本量约 15G 27 | 28 | 使用 [Cloud TPU Pod v3-256](https://cloud.google.com/tpu/docs/types-zones#types) 训练 22w 步 29 | 30 | ![loss](./.github/loss.png) 31 | 32 | 33 | ## Google Colab 34 | 只需两次鼠标点击(不包括 Colab 授权流程),体验 15 亿参数中文预训练模型生成效果: 35 | 36 | [**[Colab Notebook]**](https://colab.research.google.com/github/imcaspar/gpt2-ml/blob/master/pretrained_model_demo.ipynb) 37 | 38 | 39 | 40 | ## 训练 41 | 42 | ## 免责声明 43 | 该项目中的内容仅供技术研究参考,不作为任何结论性依据。 44 | 45 | ## Citation 46 | 47 | ``` 48 | @misc{GPT2-ML, 49 | author = {Zhibo Zhang}, 50 | title = {GPT2-ML: GPT-2 for Multiple Languages}, 51 | year = {2019}, 52 | publisher = {GitHub}, 53 | journal = {GitHub repository}, 54 | howpublished = {\url{https://github.com/imcaspar/gpt2-ml}}, 55 | } 56 | ``` 57 | 58 | ## Reference 59 | https://github.com/google-research/bert 60 | 61 | https://github.com/rowanz/grover 62 | 63 | Research supported with Cloud TPUs from Google's TensorFlow Research Cloud (TFRC) 64 | 65 | ## Press 66 | [[机器之心] 只需单击三次,让中文GPT-2为你生成定制故事](https://mp.weixin.qq.com/s/FpoSNNKZSQOE2diPvJDHog) 67 | 68 | [[科学空间] 现在可以用Keras玩中文GPT2了](https://kexue.fm/archives/7292) -------------------------------------------------------------------------------- /configs/base.json: -------------------------------------------------------------------------------- 1 | { 2 | "vocab_size": 50270, 3 | "hidden_size": 768, 4 | "attention_probs_dropout_prob": 0.1, 5 | "hidden_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "initializer_range": 0.02, 8 | "intermediate_size": 3072, 9 | "max_position_embeddings": 1024, 10 | "num_attention_heads": 12, 11 | "num_hidden_layers": 12 12 | } -------------------------------------------------------------------------------- /configs/large.json: -------------------------------------------------------------------------------- 1 | { 2 | "vocab_size": 8021, 3 | "hidden_size": 1024, 4 | "attention_probs_dropout_prob": 0.1, 5 | "hidden_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "initializer_range": 0.02, 8 | "intermediate_size": 4096, 9 | "max_position_embeddings": 1024, 10 | "num_attention_heads": 16, 11 | "num_hidden_layers": 24 12 | } -------------------------------------------------------------------------------- /configs/mega.json: -------------------------------------------------------------------------------- 1 | { 2 | "vocab_size": 8021, 3 | "hidden_size": 1536, 4 | "attention_probs_dropout_prob": 0.1, 5 | "hidden_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "initializer_range": 0.014142135623731, 8 | "intermediate_size": 6144, 9 | "max_position_embeddings": 1024, 10 | "num_attention_heads": 24, 11 | "num_hidden_layers": 48 12 | } 13 | -------------------------------------------------------------------------------- /dataset/README.md: -------------------------------------------------------------------------------- 1 | curl -o ~/miniconda.sh -O https://repo.anaconda.com/miniconda/Miniconda2-latest-Linux-x86_64.sh 2 | chmod +x ~/miniconda.sh 3 | ~/miniconda.sh -b -p ~/conda 4 | rm ~/miniconda.sh 5 | ~/conda/bin/conda install -y python=3.7 6 | ~/conda/bin/conda init // exit shell 7 | 8 | conda create -n gpt3 python=3.7 9 | 10 | 11 | sudo apt install parallel 12 | pip install ujson==2.0.3 13 | 14 | export PYTHONPATH=$(pwd) //project path -------------------------------------------------------------------------------- /dataset/prepare_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Turn a merged corpus into tfrecord files. 3 | 4 | NOTE: You will want to do this using several processes. I did this on an AWS machine with 72 CPUs using GNU parallel 5 | as that's where I had the deduplicated RealNews dataset. 6 | """ 7 | import argparse 8 | import ujson as json 9 | # from sample.encoder import get_encoder, tokenize_for_grover_training, detokenize, sliding_window, create_int_feature 10 | import random 11 | import tensorflow.compat.v1 as tf 12 | import collections 13 | import os 14 | from tempfile import TemporaryDirectory 15 | 16 | from tokenization import tokenization 17 | 18 | parser = argparse.ArgumentParser(description='SCRAPE!') 19 | parser.add_argument( 20 | '-fold', 21 | dest='fold', 22 | default=0, 23 | type=int, 24 | help='which fold we are on' 25 | ) 26 | parser.add_argument( 27 | '-num_folds', 28 | dest='num_folds', 29 | default=1, 30 | type=int, 31 | help='Number of folds (corresponding to both the number of training files and the number of testing files)', 32 | ) 33 | parser.add_argument( 34 | '-seed', 35 | dest='seed', 36 | default=1337, 37 | type=int, 38 | help='which seed to use' 39 | ) 40 | parser.add_argument( 41 | '-base_fn', 42 | dest='base_fn', 43 | default='news2016zh_', 44 | type=str, 45 | help='We will output files that are like {base_fn}_{n}.tfrecord for n in 0, ..., 1023' 46 | ) 47 | 48 | parser.add_argument( 49 | '-input_fn', 50 | dest='input_fn', 51 | default='realnews.jsonl', 52 | type=str, 53 | help='Base filename to use. THIS MUST BE A LOCAL FILE.' 54 | ) 55 | parser.add_argument( 56 | '-max_seq_length', 57 | dest='max_seq_length', 58 | default=1024, 59 | type=int, 60 | help='Max sequence length', 61 | ) 62 | 63 | 64 | args = parser.parse_args() 65 | random.seed(args.seed + args.fold) 66 | 67 | #encoder = get_encoder() 68 | tokenizer = tokenization.FullTokenizer( 69 | vocab_file="clue-vocab.txt", do_lower_case=True) 70 | 71 | 72 | class TFRecordWriter(object): 73 | def __init__(self, fn): 74 | self.fn = fn 75 | if fn.startswith('gs://'): 76 | from google.cloud import storage 77 | self.s3client = None 78 | self.gclient = storage.Client() 79 | self.storage_dir = TemporaryDirectory() 80 | self.writer = tf.python_io.TFRecordWriter( 81 | os.path.join(self.storage_dir.name, 'temp.tfrecord')) 82 | self.bucket_name, self.file_name = self.fn.split( 83 | 'gs://', 1)[1].split('/', 1) 84 | 85 | else: 86 | self.s3client = None 87 | self.gclient = None 88 | self.bucket_name = None 89 | self.file_name = None 90 | self.storage_dir = None 91 | self.writer = tf.python_io.TFRecordWriter(fn) 92 | 93 | def write(self, x): 94 | self.writer.write(x) 95 | 96 | def close(self): 97 | self.writer.close() 98 | 99 | if self.gclient is not None: 100 | bucket = self.gclient.get_bucket(self.bucket_name) 101 | blob = bucket.blob(self.file_name) 102 | blob.upload_from_filename(os.path.join( 103 | self.storage_dir.name, 'temp.tfrecord')) 104 | self.storage_dir.cleanup() 105 | 106 | def __enter__(self): 107 | # Called when entering "with" context. 108 | return self 109 | 110 | def __exit__(self, *_): 111 | # Called when exiting "with" context. 112 | # Upload shit 113 | print("CALLING CLOSE") 114 | self.close() 115 | 116 | 117 | def article_iterator(tokenizer): 118 | """ Iterate through the provided filename + tokenize""" 119 | assert os.path.exists(args.input_fn) 120 | for (dirpath, dirnames, filenames) in os.walk(args.input_fn): 121 | for filename in filenames: 122 | with open(os.path.join(dirpath, filename), 'r') as f: 123 | for l_no, l in enumerate(f): 124 | if l_no % args.num_folds == args.fold: 125 | article = json.loads(l) 126 | 127 | line = tokenization.convert_to_unicode( 128 | article['text']) # for news2016zh text body 129 | tokens = tokenizer.tokenize(line) 130 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 131 | 132 | article['input_ids'] = input_ids 133 | 134 | article['inst_index'] = (l_no // args.num_folds) 135 | if article['inst_index'] < 100: 136 | print('---\nINPUT{}. {}\n---\nTokens: {}\n'.format(article['inst_index'], 137 | tokens, 138 | input_ids 139 | ), flush=True) 140 | if len(article['input_ids']) <= 64: # min size of article 141 | continue 142 | yield article 143 | 144 | 145 | def create_int_feature(values): 146 | feature = tf.train.Feature( 147 | int64_list=tf.train.Int64List(value=list(values))) 148 | return feature 149 | 150 | 151 | def buffered_and_sliding_window_article_iterator(tokenizer, final_desired_size=1025): 152 | """ We apply a sliding window to fix long sequences, and use a buffer that combines short sequences.""" 153 | for article in article_iterator(tokenizer): 154 | if len(article['input_ids']) >= final_desired_size: 155 | article['input_ids'] = article['input_ids'][0:final_desired_size-1] 156 | while len(article['input_ids']) < final_desired_size: 157 | article['input_ids'].append(0) 158 | yield article 159 | 160 | 161 | # OK now write the tfrecord file 162 | total_written = 0 163 | train_file = args.base_fn + 'train_wiki19_{:04d}.tfrecord'.format(args.fold) 164 | with TFRecordWriter(train_file) as train_writer: 165 | for article in buffered_and_sliding_window_article_iterator(tokenizer, 166 | final_desired_size=args.max_seq_length + 1): 167 | writer2use = train_writer 168 | assert len(article['input_ids']) == (args.max_seq_length + 1) 169 | 170 | features = collections.OrderedDict() 171 | features["input_ids"] = create_int_feature(article['input_ids']) 172 | tf_example = tf.train.Example( 173 | features=tf.train.Features(feature=features)) 174 | 175 | writer2use.write(tf_example.SerializeToString()) 176 | total_written += 1 177 | 178 | # DEBUG 179 | if article['inst_index'] < 5: 180 | print("~~~\nIndex {}. ARTICLE: {}\n---\nTokens: {}\n\n".format(article['inst_index'], 181 | tokenizer.convert_ids_to_tokens( 182 | article['input_ids']), 183 | article['input_ids'] 184 | ), flush=True) 185 | if article['inst_index'] % 1000 == 0: 186 | print("{} articles, {} written".format( 187 | article['inst_index'], total_written), flush=True) 188 | print("DONE UPLOADING", flush=True) 189 | -------------------------------------------------------------------------------- /dataset/prepare_data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | NUM_FOLDS=1024 4 | MAX_SEQ_LENGTH=1024 5 | FN=${1} 6 | OUT_BUCKET=${2} 7 | 8 | rm -rf logs_${MAX_SEQ_LENGTH} 9 | mkdir logs_${MAX_SEQ_LENGTH} 10 | parallel -j $(nproc --all) --will-cite "python prepare_data.py -fold {1} -num_folds ${NUM_FOLDS} -base_fn gs://${OUT_BUCKET}/data_${MAX_SEQ_LENGTH}/ -input_fn ${FN} -max_seq_length ${MAX_SEQ_LENGTH} > logs_${MAX_SEQ_LENGTH}/log{1}.txt" ::: $(seq 0 $((${NUM_FOLDS}-1))) 11 | -------------------------------------------------------------------------------- /dockerfiles/gpu-jupyter.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM tensorflow/tensorflow:1.15.0-gpu-py3-jupyter 2 | 3 | RUN apt update && apt install -y --no-install-recommends git 4 | RUN git clone -q https://github.com/imcaspar/gpt2-ml && mkdir -p gpt2-ml/models/mega 5 | 6 | WORKDIR /gpt2-ml 7 | 8 | RUN perl 3rd/gdown.pl/gdown.pl https://drive.google.com/open?id=1n_5-tgPpQ1gqbyLPbP1PwiFi2eo7SWw_ models/mega/model.ckpt-100000.data-00000-of-00001 9 | RUN wget -q --show-progress https://github.com/imcaspar/gpt2-ml/releases/download/v0.5/model.ckpt-100000.index -P models/mega 10 | RUN wget -q --show-progress https://github.com/imcaspar/gpt2-ml/releases/download/v0.5/model.ckpt-100000.meta -P models/mega 11 | 12 | CMD ["bash", "-c", "jupyter notebook --ip 0.0.0.0 --no-browser --allow-root pretrained_model_demo.ipynb"] -------------------------------------------------------------------------------- /pretrained_model_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "[![GitHub stars](https://img.shields.io/github/stars/imcaspar/gpt2-ml?style=social)](https://github.com/imcaspar/gpt2-ml)\n", 8 | "[![GitHub](https://img.shields.io/github/license/imcaspar/gpt2-ml)](https://github.com/imcaspar/gpt2-ml)\n", 9 | "[![Twitter URL](https://img.shields.io/twitter/url?style=social&url=https%3A%2F%2Fgithub.com%2Fimcaspar%2Fgpt2-ml)](https://twitter.com/intent/tweet?text=Wow:&url=https://github.com/imcaspar/gpt2-ml)\n", 10 | "### Instructions for running:\n", 11 | "\n", 12 | "* Press the ▶️button on the left of each of the cells\n", 13 | "* View the code: Double click any of the cells\n", 14 | "* Hide the code: Double click the right side of the cell" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "#@title #Prerequisites\n", 24 | "#%tensorflow_version 1.x\n", 25 | "!pip install -I tensorflow-gpu==1.15.4 &> tmp.log\n", 26 | "!git clone -q https://github.com/imcaspar/gpt2-ml\n", 27 | "%cd /content/gpt2-ml\n", 28 | "!mkdir -p /content/gpt2-ml/models/mega\n", 29 | "\n", 30 | "!perl 3rd/gdown.pl/gdown.pl https://drive.google.com/open?id=1mT_qCQg4AWnAXTwKfsyyRWCRpgPrBJS3 models/mega/model.ckpt-220000.data-00000-of-00001\n", 31 | "!wget -q --show-progress https://github.com/imcaspar/gpt2-ml/releases/download/v1.0/model.ckpt-220000.index -P models/mega\n", 32 | "!wget -q --show-progress https://github.com/imcaspar/gpt2-ml/releases/download/v1.0/model.ckpt-220000.meta -P models/mega\n", 33 | "!echo 'Download finished.🍺'\n", 34 | "# If gdown.pl failed, please uncomment following code & exec\n", 35 | "# !python scripts/down_gdrive_file.py -file_id='1mT_qCQg4AWnAXTwKfsyyRWCRpgPrBJS3' -file_path='models/mega/model.ckpt-220000.data-00000-of-00001'" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "#@title #Inference\n", 45 | "min_len = 150#@param {type:\"number\", min:5, max:1024, step:1}\n", 46 | "sp_num = 5#@param {type:\"number\", min:1, max:50, step:1}\n", 47 | "!PYTHONPATH=$(pwd) python scripts/demo.py -ckpt_fn models/mega/model.ckpt-220000 -min_len $min_len -samples $sp_num" 48 | ] 49 | } 50 | ], 51 | "metadata": { 52 | "colab": { 53 | "name": "15 亿参数 GPT2 中文预训练模型 | 1.5B GPT2 Pretrained Chinese Model", 54 | "provenance": [], 55 | "collapsed_sections": [] 56 | }, 57 | "kernelspec": { 58 | "name": "python3", 59 | "display_name": "Python 3" 60 | }, 61 | "accelerator": "GPU" 62 | }, 63 | "nbformat": 4, 64 | "nbformat_minor": 0 65 | } -------------------------------------------------------------------------------- /requirements-gpu.txt: -------------------------------------------------------------------------------- 1 | pandas==0.24.2 2 | regex==2019.4.14 3 | h5py==2.10.0 4 | numpy==1.18.4 5 | tensorboard==1.15.0 6 | tensorflow-gpu==1.15.4 7 | tensorflow-estimator==1.15.1 8 | tqdm==4.31.1 9 | requests==2.22.0 10 | ujson==2.0.3 -------------------------------------------------------------------------------- /requirements-tpu.txt: -------------------------------------------------------------------------------- 1 | pandas==0.24.2 2 | regex==2019.4.14 3 | h5py==2.10.0 4 | numpy==1.18.4 5 | tensorboard==1.15.0 6 | tensorflow==1.15.4 7 | tensorflow-estimator==1.15.1 8 | tqdm==4.31.1 9 | requests==2.22.0 10 | ujson==2.0.3 -------------------------------------------------------------------------------- /scripts/demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import argparse 4 | import json 5 | import re 6 | 7 | import tensorflow.compat.v1 as tf 8 | import numpy as np 9 | 10 | from train.modeling import GroverModel, GroverConfig, sample 11 | from tokenization import tokenization 12 | 13 | ##### ignore tf deprecated warning temporarily 14 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 15 | tf.logging.set_verbosity(tf.logging.DEBUG) 16 | from tensorflow.python.util import deprecation 17 | deprecation._PRINT_DEPRECATION_WARNINGS = False 18 | try: 19 | from tensorflow.python.util import module_wrapper as deprecation 20 | except ImportError: 21 | from tensorflow.python.util import deprecation_wrapper as deprecation 22 | deprecation._PER_MODULE_WARNING_LIMIT = 0 23 | ##### 24 | 25 | parser = argparse.ArgumentParser(description='Contextual generation (aka given some metadata we will generate articles') 26 | parser.add_argument( 27 | '-metadata_fn', 28 | dest='metadata_fn', 29 | type=str, 30 | help='Path to a JSONL containing metadata', 31 | ) 32 | parser.add_argument( 33 | '-out_fn', 34 | dest='out_fn', 35 | type=str, 36 | help='Out jsonl, which will contain the completed jsons', 37 | ) 38 | parser.add_argument( 39 | '-input', 40 | dest='input', 41 | type=str, 42 | help='Text to complete', 43 | ) 44 | parser.add_argument( 45 | '-config_fn', 46 | dest='config_fn', 47 | default='configs/mega.json', 48 | type=str, 49 | help='Configuration JSON for the model', 50 | ) 51 | parser.add_argument( 52 | '-ckpt_fn', 53 | dest='ckpt_fn', 54 | default='../models/mega/model.ckpt', 55 | type=str, 56 | help='checkpoint file for the model', 57 | ) 58 | parser.add_argument( 59 | '-target', 60 | dest='target', 61 | default='article', 62 | type=str, 63 | help='What to generate for each item in metadata_fn. can be article (body), title, etc.', 64 | ) 65 | parser.add_argument( 66 | '-batch_size', 67 | dest='batch_size', 68 | default=1, 69 | type=int, 70 | help='How many things to generate per context. will split into chunks if need be', 71 | ) 72 | parser.add_argument( 73 | '-num_folds', 74 | dest='num_folds', 75 | default=1, 76 | type=int, 77 | help='Number of folds. useful if we want to split up a big file into multiple jobs.', 78 | ) 79 | parser.add_argument( 80 | '-fold', 81 | dest='fold', 82 | default=0, 83 | type=int, 84 | help='which fold we are on. useful if we want to split up a big file into multiple jobs.' 85 | ) 86 | parser.add_argument( 87 | '-max_batch_size', 88 | dest='max_batch_size', 89 | default=None, 90 | type=int, 91 | help='max batch size. You can leave this out and we will infer one based on the number of hidden layers', 92 | ) 93 | parser.add_argument( 94 | '-top_p', 95 | dest='top_p', 96 | default=0.95, 97 | type=float, 98 | help='p to use for top p sampling. if this isn\'t none, use this for everthing' 99 | ) 100 | parser.add_argument( 101 | '-min_len', 102 | dest='min_len', 103 | default=1024, 104 | type=int, 105 | help='min length of sample', 106 | ) 107 | parser.add_argument( 108 | '-eos_token', 109 | dest='eos_token', 110 | default=102, 111 | type=int, 112 | help='eos token id', 113 | ) 114 | parser.add_argument( 115 | '-samples', 116 | dest='samples', 117 | default=5, 118 | type=int, 119 | help='num_samples', 120 | ) 121 | 122 | def extract_generated_target(output_tokens, tokenizer): 123 | """ 124 | Given some tokens that were generated, extract the target 125 | :param output_tokens: [num_tokens] thing that was generated 126 | :param encoder: how they were encoded 127 | :param target: the piece of metadata we wanted to generate! 128 | :return: 129 | """ 130 | # Filter out first instance of start token 131 | assert output_tokens.ndim == 1 132 | 133 | start_ind = 0 134 | end_ind = output_tokens.shape[0] 135 | 136 | return { 137 | 'extraction': tokenization.printable_text(''.join(tokenizer.convert_ids_to_tokens(output_tokens))), 138 | 'start_ind': start_ind, 139 | 'end_ind': end_ind, 140 | } 141 | 142 | args = parser.parse_args() 143 | proj_root_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 144 | vocab_file_path = os.path.join(proj_root_path, "tokenization/clue-vocab.txt") 145 | tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file_path , do_lower_case=True) 146 | news_config = GroverConfig.from_json_file(args.config_fn) 147 | 148 | # We might have to split the batch into multiple chunks if the batch size is too large 149 | default_mbs = {12: 32, 24: 16, 48: 3} 150 | max_batch_size = args.max_batch_size if args.max_batch_size is not None else default_mbs[news_config.num_hidden_layers] 151 | 152 | # factorize args.batch_size = (num_chunks * batch_size_per_chunk) s.t. batch_size_per_chunk < max_batch_size 153 | num_chunks = int(np.ceil(args.batch_size / max_batch_size)) 154 | batch_size_per_chunk = int(np.ceil(args.batch_size / num_chunks)) 155 | 156 | # This controls the top p for each generation. 157 | top_p = np.ones((num_chunks, batch_size_per_chunk), dtype=np.float32) * args.top_p 158 | 159 | tf_config = tf.ConfigProto(allow_soft_placement=True) 160 | 161 | with tf.Session(config=tf_config, graph=tf.Graph()) as sess: 162 | initial_context = tf.placeholder(tf.int32, [batch_size_per_chunk, None]) 163 | p_for_topp = tf.placeholder(tf.float32, [batch_size_per_chunk]) 164 | eos_token = tf.placeholder(tf.int32, []) 165 | min_len = tf.placeholder(tf.int32, []) 166 | tokens, probs = sample(news_config=news_config, initial_context=initial_context, 167 | eos_token=eos_token, min_len=min_len, ignore_ids=None, p_for_topp=p_for_topp, 168 | do_topk=False) 169 | 170 | saver = tf.train.Saver() 171 | saver.restore(sess, args.ckpt_fn) 172 | print('🍺Model loaded. \nInput something please:⬇️') 173 | text = input() 174 | while text != "": 175 | for i in range(args.samples): 176 | print("Sample,", i + 1, " of ", args.samples) 177 | line = tokenization.convert_to_unicode(text) 178 | bert_tokens = tokenizer.tokenize(line) 179 | encoded = tokenizer.convert_tokens_to_ids(bert_tokens) 180 | context_formatted = [] 181 | context_formatted.extend(encoded) 182 | # Format context end 183 | 184 | gens = [] 185 | gens_raw = [] 186 | gen_probs = [] 187 | 188 | for chunk_i in range(num_chunks): 189 | tokens_out, probs_out = sess.run([tokens, probs], 190 | feed_dict={initial_context: [context_formatted] * batch_size_per_chunk, 191 | eos_token: args.eos_token, min_len: args.min_len, 192 | p_for_topp: top_p[chunk_i]}) 193 | 194 | for t_i, p_i in zip(tokens_out, probs_out): 195 | extraction = extract_generated_target(output_tokens=t_i, tokenizer=tokenizer) 196 | gens.append(extraction['extraction']) 197 | 198 | l = re.findall('.{1,70}', gens[0].replace('[UNK]', '').replace('##', '')) 199 | print("\n".join(l)) 200 | print('Next try:⬇️') 201 | text = input() 202 | -------------------------------------------------------------------------------- /scripts/down_gdrive_file.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from google.colab import auth 4 | from googleapiclient.discovery import build 5 | from apiclient.http import MediaIoBaseDownload 6 | from tqdm import tqdm 7 | 8 | parser = argparse.ArgumentParser(description='Simple file download script for Google Drive') 9 | parser.add_argument( 10 | '-file_id', 11 | dest='file_id', 12 | type=str, 13 | help='File id in Google Drive URL', 14 | ) 15 | parser.add_argument( 16 | '-file_path', 17 | dest='file_path', 18 | type=str, 19 | help='Output file path', 20 | ) 21 | 22 | args = parser.parse_args() 23 | 24 | auth.authenticate_user() 25 | drive_service = build('drive', 'v3') 26 | 27 | # file_id, file_ext = ('1n_5-tgPpQ1gqbyLPbP1PwiFi2eo7SWw_', '.data-00000-of-00001') 28 | # filename = '%s/model.ckpt-%d%s' % (model_dir, 100000, file_ext) 29 | req = drive_service.files().get_media(fileId=args.file_id) 30 | with open(args.file_path, 'wb') as f: 31 | downloader = MediaIoBaseDownload(f, req, chunksize=100*1024*1024) 32 | done = False 33 | pbar = tqdm(total=100, desc='%s' % args.file_path) 34 | progress = 0 35 | while done is False: 36 | status, done = downloader.next_chunk() 37 | new_progress = int(status.progress() * 100) 38 | pbar.update(new_progress - progress) 39 | progress = new_progress 40 | pbar.close() 41 | -------------------------------------------------------------------------------- /tokenization/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imcaspar/gpt2-ml/f6286b16cbbee6dddbe1ba557fffb47eaf998cd1/tokenization/__init__.py -------------------------------------------------------------------------------- /tokenization/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import re 23 | import unicodedata 24 | import six 25 | import tensorflow.compat.v1 as tf 26 | 27 | 28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): 29 | """Checks whether the casing config is consistent with the checkpoint name.""" 30 | 31 | # The casing has to be passed in by the user and there is no explicit check 32 | # as to whether it matches the checkpoint. The casing information probably 33 | # should have been stored in the bert_config.json file, but it's not, so 34 | # we have to heuristically detect it to validate. 35 | 36 | if not init_checkpoint: 37 | return 38 | 39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) 40 | if m is None: 41 | return 42 | 43 | model_name = m.group(1) 44 | 45 | lower_models = [ 46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", 47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" 48 | ] 49 | 50 | cased_models = [ 51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", 52 | "multi_cased_L-12_H-768_A-12" 53 | ] 54 | 55 | is_bad_config = False 56 | if model_name in lower_models and not do_lower_case: 57 | is_bad_config = True 58 | actual_flag = "False" 59 | case_name = "lowercased" 60 | opposite_flag = "True" 61 | 62 | if model_name in cased_models and do_lower_case: 63 | is_bad_config = True 64 | actual_flag = "True" 65 | case_name = "cased" 66 | opposite_flag = "False" 67 | 68 | if is_bad_config: 69 | raise ValueError( 70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " 71 | "However, `%s` seems to be a %s model, so you " 72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches " 73 | "how the model was pre-training. If this error is wrong, please " 74 | "just comment out this check." % (actual_flag, init_checkpoint, 75 | model_name, case_name, opposite_flag)) 76 | 77 | 78 | def convert_to_unicode(text): 79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 80 | if six.PY3: 81 | if isinstance(text, str): 82 | return text 83 | elif isinstance(text, bytes): 84 | return text.decode("utf-8", "ignore") 85 | else: 86 | raise ValueError("Unsupported string type: %s" % (type(text))) 87 | elif six.PY2: 88 | if isinstance(text, str): 89 | return text.decode("utf-8", "ignore") 90 | elif isinstance(text, unicode): 91 | return text 92 | else: 93 | raise ValueError("Unsupported string type: %s" % (type(text))) 94 | else: 95 | raise ValueError("Not running on Python2 or Python 3?") 96 | 97 | 98 | def printable_text(text): 99 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 100 | 101 | # These functions want `str` for both Python2 and Python3, but in one case 102 | # it's a Unicode string and in the other it's a byte string. 103 | if six.PY3: 104 | if isinstance(text, str): 105 | return text 106 | elif isinstance(text, bytes): 107 | return text.decode("utf-8", "ignore") 108 | else: 109 | raise ValueError("Unsupported string type: %s" % (type(text))) 110 | elif six.PY2: 111 | if isinstance(text, str): 112 | return text 113 | elif isinstance(text, unicode): 114 | return text.encode("utf-8") 115 | else: 116 | raise ValueError("Unsupported string type: %s" % (type(text))) 117 | else: 118 | raise ValueError("Not running on Python2 or Python 3?") 119 | 120 | 121 | def load_vocab(vocab_file): 122 | """Loads a vocabulary file into a dictionary.""" 123 | vocab = collections.OrderedDict() 124 | index = 0 125 | with tf.gfile.GFile(vocab_file, "r") as reader: 126 | while True: 127 | token = convert_to_unicode(reader.readline()) 128 | if not token: 129 | break 130 | token = token.strip() 131 | vocab[token] = index 132 | index += 1 133 | return vocab 134 | 135 | 136 | def convert_by_vocab(vocab, items): 137 | """Converts a sequence of [tokens|ids] using the vocab.""" 138 | output = [] 139 | for item in items: 140 | output.append(vocab[item]) 141 | return output 142 | 143 | 144 | def convert_tokens_to_ids(vocab, tokens): 145 | return convert_by_vocab(vocab, tokens) 146 | 147 | 148 | def convert_ids_to_tokens(inv_vocab, ids): 149 | return convert_by_vocab(inv_vocab, ids) 150 | 151 | 152 | def whitespace_tokenize(text): 153 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 154 | text = text.strip() 155 | if not text: 156 | return [] 157 | tokens = text.split() 158 | return tokens 159 | 160 | 161 | class FullTokenizer(object): 162 | """Runs end-to-end tokenziation.""" 163 | 164 | def __init__(self, vocab_file, do_lower_case=True): 165 | self.vocab = load_vocab(vocab_file) 166 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 167 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 168 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 169 | 170 | def tokenize(self, text): 171 | split_tokens = [] 172 | for token in self.basic_tokenizer.tokenize(text): 173 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 174 | split_tokens.append(sub_token) 175 | 176 | return split_tokens 177 | 178 | def convert_tokens_to_ids(self, tokens): 179 | return convert_by_vocab(self.vocab, tokens) 180 | 181 | def convert_ids_to_tokens(self, ids): 182 | return convert_by_vocab(self.inv_vocab, ids) 183 | 184 | 185 | class BasicTokenizer(object): 186 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 187 | 188 | def __init__(self, do_lower_case=True): 189 | """Constructs a BasicTokenizer. 190 | 191 | Args: 192 | do_lower_case: Whether to lower case the input. 193 | """ 194 | self.do_lower_case = do_lower_case 195 | 196 | def tokenize(self, text): 197 | """Tokenizes a piece of text.""" 198 | text = convert_to_unicode(text) 199 | text = self._clean_text(text) 200 | 201 | # This was added on November 1st, 2018 for the multilingual and Chinese 202 | # models. This is also applied to the English models now, but it doesn't 203 | # matter since the English models were not trained on any Chinese data 204 | # and generally don't have any Chinese data in them (there are Chinese 205 | # characters in the vocabulary because Wikipedia does have some Chinese 206 | # words in the English Wikipedia.). 207 | text = self._tokenize_chinese_chars(text) 208 | 209 | orig_tokens = whitespace_tokenize(text) 210 | split_tokens = [] 211 | for token in orig_tokens: 212 | if self.do_lower_case: 213 | token = token.lower() 214 | token = self._run_strip_accents(token) 215 | split_tokens.extend(self._run_split_on_punc(token)) 216 | 217 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 218 | return output_tokens 219 | 220 | def _run_strip_accents(self, text): 221 | """Strips accents from a piece of text.""" 222 | text = unicodedata.normalize("NFD", text) 223 | output = [] 224 | for char in text: 225 | cat = unicodedata.category(char) 226 | if cat == "Mn": 227 | continue 228 | output.append(char) 229 | return "".join(output) 230 | 231 | def _run_split_on_punc(self, text): 232 | """Splits punctuation on a piece of text.""" 233 | chars = list(text) 234 | i = 0 235 | start_new_word = True 236 | output = [] 237 | while i < len(chars): 238 | char = chars[i] 239 | if _is_punctuation(char): 240 | output.append([char]) 241 | start_new_word = True 242 | else: 243 | if start_new_word: 244 | output.append([]) 245 | start_new_word = False 246 | output[-1].append(char) 247 | i += 1 248 | 249 | return ["".join(x) for x in output] 250 | 251 | def _tokenize_chinese_chars(self, text): 252 | """Adds whitespace around any CJK character.""" 253 | output = [] 254 | for char in text: 255 | cp = ord(char) 256 | if self._is_chinese_char(cp): 257 | output.append(" ") 258 | output.append(char) 259 | output.append(" ") 260 | else: 261 | output.append(char) 262 | return "".join(output) 263 | 264 | def _is_chinese_char(self, cp): 265 | """Checks whether CP is the codepoint of a CJK character.""" 266 | # This defines a "chinese character" as anything in the CJK Unicode block: 267 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 268 | # 269 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 270 | # despite its name. The modern Korean Hangul alphabet is a different block, 271 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 272 | # space-separated words, so they are not treated specially and handled 273 | # like the all of the other languages. 274 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 275 | (cp >= 0x3400 and cp <= 0x4DBF) or # 276 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 277 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 278 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 279 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 280 | (cp >= 0xF900 and cp <= 0xFAFF) or # 281 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 282 | return True 283 | 284 | return False 285 | 286 | def _clean_text(self, text): 287 | """Performs invalid character removal and whitespace cleanup on text.""" 288 | output = [] 289 | for char in text: 290 | cp = ord(char) 291 | if cp == 0 or cp == 0xfffd or _is_control(char): 292 | continue 293 | if _is_whitespace(char): 294 | output.append(" ") 295 | else: 296 | output.append(char) 297 | return "".join(output) 298 | 299 | 300 | class WordpieceTokenizer(object): 301 | """Runs WordPiece tokenziation.""" 302 | 303 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 304 | self.vocab = vocab 305 | self.unk_token = unk_token 306 | self.max_input_chars_per_word = max_input_chars_per_word 307 | 308 | def tokenize(self, text): 309 | """Tokenizes a piece of text into its word pieces. 310 | 311 | This uses a greedy longest-match-first algorithm to perform tokenization 312 | using the given vocabulary. 313 | 314 | For example: 315 | input = "unaffable" 316 | output = ["un", "##aff", "##able"] 317 | 318 | Args: 319 | text: A single token or whitespace separated tokens. This should have 320 | already been passed through `BasicTokenizer. 321 | 322 | Returns: 323 | A list of wordpiece tokens. 324 | """ 325 | 326 | text = convert_to_unicode(text) 327 | 328 | output_tokens = [] 329 | for token in whitespace_tokenize(text): 330 | chars = list(token) 331 | if len(chars) > self.max_input_chars_per_word: 332 | output_tokens.append(self.unk_token) 333 | continue 334 | 335 | is_bad = False 336 | start = 0 337 | sub_tokens = [] 338 | while start < len(chars): 339 | end = len(chars) 340 | cur_substr = None 341 | while start < end: 342 | substr = "".join(chars[start:end]) 343 | if start > 0: 344 | substr = "##" + substr 345 | if substr in self.vocab: 346 | cur_substr = substr 347 | break 348 | end -= 1 349 | if cur_substr is None: 350 | is_bad = True 351 | break 352 | sub_tokens.append(cur_substr) 353 | start = end 354 | 355 | if is_bad: 356 | output_tokens.append(self.unk_token) 357 | else: 358 | output_tokens.extend(sub_tokens) 359 | return output_tokens 360 | 361 | 362 | def _is_whitespace(char): 363 | """Checks whether `chars` is a whitespace character.""" 364 | # \t, \n, and \r are technically contorl characters but we treat them 365 | # as whitespace since they are generally considered as such. 366 | if char == " " or char == "\t" or char == "\n" or char == "\r": 367 | return True 368 | cat = unicodedata.category(char) 369 | if cat == "Zs": 370 | return True 371 | return False 372 | 373 | 374 | def _is_control(char): 375 | """Checks whether `chars` is a control character.""" 376 | # These are technically control characters but we count them as whitespace 377 | # characters. 378 | if char == "\t" or char == "\n" or char == "\r": 379 | return False 380 | cat = unicodedata.category(char) 381 | if cat in ("Cc", "Cf"): 382 | return True 383 | return False 384 | 385 | 386 | def _is_punctuation(char): 387 | """Checks whether `chars` is a punctuation character.""" 388 | cp = ord(char) 389 | # We treat all non-letter/number ASCII as punctuation. 390 | # Characters such as "^", "$", and "`" are not in the Unicode 391 | # Punctuation class but we treat them as punctuation anyways, for 392 | # consistency. 393 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 394 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 395 | return True 396 | cat = unicodedata.category(char) 397 | if cat.startswith("P"): 398 | return True 399 | return False 400 | -------------------------------------------------------------------------------- /train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imcaspar/gpt2-ml/f6286b16cbbee6dddbe1ba557fffb47eaf998cd1/train/__init__.py -------------------------------------------------------------------------------- /train/dataloader.py: -------------------------------------------------------------------------------- 1 | # Original work Copyright 2018 The Google AI Language Team Authors. 2 | # Modified work Copyright 2019 Rowan Zellers 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import collections 17 | import tensorflow.compat.v1 as tf 18 | 19 | 20 | def _decode_record(record, name_to_features): 21 | """Decodes a record to a TensorFlow example.""" 22 | example = tf.parse_single_example(record, name_to_features) 23 | 24 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 25 | # So cast all int64 to int32. 26 | for name in list(example.keys()): 27 | t = example[name] 28 | if t.dtype == tf.int64: 29 | t = tf.cast(t, tf.int32) 30 | example[name] = t 31 | return example 32 | 33 | 34 | def input_fn_builder(input_files, 35 | seq_length, 36 | is_training, 37 | num_cpu_threads=4, 38 | evaluate_for_fixed_number_of_steps=True): 39 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 40 | 41 | def input_fn(params): 42 | """The actual input function.""" 43 | batch_size = params["batch_size"] 44 | name_to_features = { 45 | "input_ids": tf.FixedLenFeature([seq_length + 1], tf.int64), 46 | } 47 | 48 | # For training, we want a lot of parallel reading and shuffling. 49 | # For eval, we want no shuffling and parallel reading doesn't matter. 50 | if is_training: 51 | d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files)) 52 | d = d.repeat() 53 | d = d.shuffle(buffer_size=len(input_files)) 54 | 55 | # `cycle_length` is the number of parallel files that get read. 56 | cycle_length = min(num_cpu_threads, len(input_files)) 57 | 58 | # `sloppy` mode means that the interleaving is not exact. This adds 59 | # even more randomness to the training pipeline. 60 | d = d.apply( 61 | tf.data.experimental.parallel_interleave( 62 | tf.data.TFRecordDataset, 63 | sloppy=is_training, 64 | cycle_length=cycle_length)) 65 | d = d.shuffle(buffer_size=100) 66 | else: 67 | d = tf.data.TFRecordDataset(input_files) 68 | # If we evaluate for a fixed number of steps we don't want to encounter 69 | # out-of-range exceptions. 70 | if evaluate_for_fixed_number_of_steps: 71 | d = d.repeat() 72 | 73 | # We must `drop_remainder` on training because the TPU requires fixed 74 | # size dimensions. For eval, we assume we are evaluating on the CPU or GPU 75 | # and we *don't* want to drop the remainder, otherwise we wont cover 76 | # every sample. 77 | d = d.apply( 78 | tf.data.experimental.map_and_batch( 79 | lambda record: _decode_record(record, name_to_features), 80 | batch_size=batch_size, 81 | num_parallel_batches=num_cpu_threads, 82 | drop_remainder=True)) 83 | return d 84 | 85 | return input_fn 86 | 87 | 88 | # ~~~~~~~~~~~~~~ This is for classification / AF ~~~~~~~~~~~~~~~~~~ 89 | def classification_convert_examples_to_features( 90 | examples, max_seq_length, batch_size, encoder, output_file, labels, pad_extra_examples=False, 91 | chop_from_front_if_needed=True): 92 | """Convert a set of `InputExample`s to a TFRecord file.""" 93 | 94 | writer = tf.python_io.TFRecordWriter(output_file) 95 | 96 | label_map = {label: i for i, label in enumerate(labels)} 97 | 98 | for (ex_index, example) in enumerate(examples): 99 | if ex_index % 10000 == 0: 100 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples))) 101 | 102 | # begin_summary is our [CLS] token 103 | tokens = example['ids'] + [encoder.begin_summary] 104 | 105 | if len(tokens) > max_seq_length: 106 | if chop_from_front_if_needed: 107 | tokens = tokens[-max_seq_length:] 108 | else: 109 | tokens = example['ids'][:(max_seq_length-1)] + [encoder.begin_summary] 110 | elif len(tokens) < max_seq_length: 111 | tokens.extend([encoder.padding] * (max_seq_length - len(tokens))) 112 | 113 | features = collections.OrderedDict() 114 | features['input_ids'] = tf.train.Feature(int64_list=tf.train.Int64List(value=tokens)) 115 | features['label_ids'] = tf.train.Feature(int64_list=tf.train.Int64List(value=[label_map[example['label']]])) 116 | features['is_real_example'] = tf.train.Feature(int64_list=tf.train.Int64List(value=[1])) 117 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 118 | writer.write(tf_example.SerializeToString()) 119 | 120 | if pad_extra_examples: 121 | for x in range(len(examples) % batch_size): 122 | features = collections.OrderedDict() 123 | features['input_ids'] = tf.train.Feature(int64_list=tf.train.Int64List(value=[0]*max_seq_length)) 124 | features['label_ids'] = tf.train.Feature(int64_list=tf.train.Int64List(value=[0])) 125 | features['is_real_example'] = tf.train.Feature(int64_list=tf.train.Int64List(value=[0])) 126 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 127 | writer.write(tf_example.SerializeToString()) 128 | writer.close() 129 | 130 | 131 | def classification_input_fn_builder(input_file, seq_length, is_training, 132 | drop_remainder, 133 | buffer_size=100): 134 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 135 | 136 | name_to_features = { 137 | "input_ids": tf.FixedLenFeature([seq_length], tf.int64), 138 | "label_ids": tf.FixedLenFeature([], tf.int64), 139 | "is_real_example": tf.FixedLenFeature([], tf.int64), 140 | } 141 | 142 | def input_fn(params): 143 | """The actual input function.""" 144 | batch_size = params["batch_size"] 145 | 146 | # For training, we want a lot of parallel reading and shuffling. 147 | # For eval, we want no shuffling and parallel reading doesn't matter. 148 | d = tf.data.TFRecordDataset(input_file) 149 | if is_training: 150 | d = d.repeat() 151 | d = d.shuffle(buffer_size=buffer_size) 152 | 153 | d = d.apply( 154 | tf.data.experimental.map_and_batch( 155 | lambda record: _decode_record(record, name_to_features), 156 | batch_size=batch_size, 157 | drop_remainder=drop_remainder)) 158 | 159 | return d 160 | 161 | return input_fn 162 | -------------------------------------------------------------------------------- /train/modeling.py: -------------------------------------------------------------------------------- 1 | # Original work Copyright 2018 The Google AI Language Team Authors. 2 | # Modified work Copyright 2019 Rowan Zellers 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import copy 17 | import json 18 | import math 19 | 20 | import six 21 | import tensorflow.compat.v1 as tf 22 | 23 | from train import optimization_adafactor 24 | from train.utils import get_assignment_map_from_checkpoint, get_shape_list, get_attention_mask, gelu, layer_norm, dropout, \ 25 | construct_scalar_host_call 26 | 27 | class GroverConfig(object): 28 | """Configuration for `GroverModel`""" 29 | 30 | def __init__(self, 31 | vocab_size, 32 | hidden_size=768, 33 | num_hidden_layers=12, 34 | num_attention_heads=12, 35 | intermediate_size=3072, 36 | hidden_act="gelu", 37 | hidden_dropout_prob=0.1, 38 | attention_probs_dropout_prob=0.1, 39 | max_position_embeddings=512, 40 | initializer_range=0.02): 41 | """Constructs NewsConfig. 42 | 43 | Args: 44 | vocab_size: Vocabulary size of `inputs_ids` in `GroverModel`. 45 | hidden_size: Size of the layers 46 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 47 | num_attention_heads: Number of attention heads for each attention layer in 48 | the Transformer encoder. 49 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 50 | layer in the Transformer encoder. 51 | hidden_act: The non-linear activation function (function or string) in the 52 | encoder and pooler. 53 | hidden_dropout_prob: The dropout probability for all fully connected 54 | layers in the embeddings, encoder, and pooler. 55 | attention_probs_dropout_prob: The dropout ratio for the attention 56 | probabilities. 57 | max_position_embeddings: The maximum sequence length that this model might 58 | ever be used with. Typically set this to something large just in case 59 | (e.g., 512 or 1024 or 2048). 60 | initializer_range: The stdev of the truncated_normal_initializer for 61 | initializing all weight matrices. 62 | """ 63 | self.vocab_size = vocab_size 64 | self.hidden_size = hidden_size 65 | self.num_hidden_layers = num_hidden_layers 66 | self.num_attention_heads = num_attention_heads 67 | self.hidden_act = hidden_act 68 | self.intermediate_size = intermediate_size 69 | self.hidden_dropout_prob = hidden_dropout_prob 70 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 71 | self.max_position_embeddings = max_position_embeddings 72 | self.initializer_range = initializer_range 73 | self.pad_token_id = 0 74 | 75 | @classmethod 76 | def from_dict(cls, json_object): 77 | """Constructs a `NewsConfig` from a Python dictionary of parameters.""" 78 | config = GroverConfig(vocab_size=None) 79 | for (key, value) in six.iteritems(json_object): 80 | config.__dict__[key] = value 81 | return config 82 | 83 | @classmethod 84 | def from_json_file(cls, json_file): 85 | """Constructs a `NewsConfig` from a json file of parameters.""" 86 | with tf.gfile.GFile(json_file, "r") as reader: 87 | text = reader.read() 88 | return cls.from_dict(json.loads(text)) 89 | 90 | def to_dict(self): 91 | """Serializes this instance to a Python dictionary.""" 92 | output = copy.deepcopy(self.__dict__) 93 | return output 94 | 95 | def to_json_string(self): 96 | """Serializes this instance to a JSON string.""" 97 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 98 | 99 | 100 | def mask_attention_for_ltr(attention_scores, attention_mask): 101 | """ 102 | Mask attention so that we're only predicting going forward 103 | :param attention_scores: [batch, heads, dst_sequence, src_sequence], where information flows from src to dst. 104 | :param attention_mask [query_length, key_length] 105 | :return: masked attention 106 | """ 107 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 108 | # masked positions, this operation will create a tensor which is 0.0 for 109 | # positions we want to attend and -10000.0 for masked positions. 110 | mask = attention_mask[None, None] 111 | return attention_scores * mask - tf.cast(1e10, attention_scores.dtype) * (1 - mask) 112 | 113 | 114 | def create_initializer(initializer_range=0.02): 115 | """Creates a `truncated_normal_initializer` with the given range.""" 116 | return tf.truncated_normal_initializer(stddev=initializer_range) 117 | 118 | 119 | def _attention_projection_and_transpose(x_flat, batch_size, seq_length, num_attention_heads, size_per_head, 120 | name, initializer_range=0.02): 121 | """ 122 | :param x_flat: [batch_size*seq_length, width] 123 | :return: A fixed up tensor of size [batch_size, num_attention_heads, seq_length, size_per_head] 124 | """ 125 | batch_size_seq_length, dim = get_shape_list(x_flat, expected_rank=2) 126 | 127 | if dim != size_per_head * num_attention_heads: 128 | raise ValueError("passed in a tensor of shape {} when size_per_head={} and num_attention_heads={}".format( 129 | (batch_size_seq_length, dim), size_per_head, num_attention_heads 130 | )) 131 | 132 | projected = tf.layers.dense( 133 | x_flat, 134 | num_attention_heads * size_per_head, 135 | name=name, 136 | kernel_initializer=create_initializer(initializer_range)) 137 | 138 | projected = tf.reshape( 139 | projected, [batch_size, seq_length, num_attention_heads, size_per_head]) 140 | output_tensor = tf.transpose(projected, [0, 2, 1, 3]) 141 | return output_tensor 142 | 143 | 144 | def attention_layer(x_flat, attention_mask, batch_size, seq_length, size_per_head=512, num_attention_heads=1, *, 145 | cache=None, 146 | initializer_range=0.02, hidden_dropout_prob=0.1, 147 | attention_probs_dropout_prob=0.1, do_cache=False): 148 | """ 149 | 150 | :param x_flat: Tensor input, should be [batch_size*seq_length, dim] 151 | :param attention_mask: Attention mask to use of size [seq_length, seq_length+cached_length] 152 | :param size_per_head: dim = size_per_head * num_attention_heads 153 | :param num_attention_heads: dim = size_per_head * num_attention_heads 154 | :param cache: Optionally some past (cached) things of size 155 | [batch, 2, heads, sequence, features], where 2 is [k, v] 156 | :param do_cache: True if we should return cache 157 | :return: A new tensor of shape [batch_size, seq_length, dim] 158 | as well as a new cache "cached_keys_and_values" that will be of size 159 | [batch_size, 2, num_attention_heads, seq_length, dim] 160 | """ 161 | batch_size_seq_length, dim = get_shape_list(x_flat, expected_rank=2) 162 | 163 | if dim != size_per_head * num_attention_heads: 164 | raise ValueError("passed in a tensor of shape {} when size_per_head={} and num_attention_heads={}".format( 165 | (batch_size_seq_length, dim), size_per_head, num_attention_heads 166 | )) 167 | 168 | query = _attention_projection_and_transpose(x_flat, batch_size=batch_size, seq_length=seq_length, 169 | num_attention_heads=num_attention_heads, size_per_head=size_per_head, 170 | name='query_layer', 171 | initializer_range=initializer_range) 172 | key = _attention_projection_and_transpose(x_flat, batch_size=batch_size, seq_length=seq_length, 173 | num_attention_heads=num_attention_heads, size_per_head=size_per_head, 174 | name='key_layer', 175 | initializer_range=initializer_range) 176 | 177 | value = _attention_projection_and_transpose(x_flat, batch_size=batch_size, seq_length=seq_length, 178 | num_attention_heads=num_attention_heads, size_per_head=size_per_head, 179 | name='value_layer', 180 | initializer_range=initializer_range) 181 | 182 | # Add to cache 183 | cached_keys_and_values = tf.stack([key, value], axis=1) if do_cache else None 184 | 185 | # Things that were relevant from the cache 186 | if cache is not None: 187 | pk, pv = tf.unstack(cache, axis=1) 188 | key = tf.concat([pk, key], axis=-2) 189 | value = tf.concat([pv, value], axis=-2) 190 | 191 | # Multiply [batch_size, num_attention_heads, seq_length, size_per_head] with 192 | # [batch_size, num_attention_heads, size_per_head, seq_length+cached_length] -> 193 | # [batch_size, num_attention_heads, seq_length, seq_length+cached_length] 194 | attention_scores = tf.matmul(query, key, transpose_b=True) 195 | attention_scores = tf.multiply(attention_scores, 196 | 1.0 / math.sqrt(float(size_per_head))) 197 | attention_scores = mask_attention_for_ltr(attention_scores, attention_mask) 198 | attention_probs = tf.nn.softmax(attention_scores) 199 | 200 | # This is actually dropping out entire tokens to attend to, which might 201 | # seem a bit unusual, but is taken from the original Transformer paper. 202 | # NOPENOPENOPENOPE 203 | # attention_probs = factoreddropout(attention_probs, attention_probs_dropout_prob) 204 | 205 | # Multiply [batch_size, num_attention_heads, seq_length, seq_length+cached_length] with 206 | # [batch_size, num_attention_heads, seq_length+cached_length, size_per_head] -> 207 | # [batch_size, num_attention_heads, seq_length, size_per_head] -> 208 | context_layer = tf.matmul(attention_probs, value) 209 | 210 | # `context_layer` = [batch_size, seq_length, num_attention_heads, size_per_head] 211 | context_layer = tf.transpose(context_layer, [0, 2, 1, 3]) 212 | context_layer = tf.reshape(context_layer, [batch_size * seq_length, num_attention_heads * size_per_head]) 213 | 214 | context_layer_projected = tf.layers.dense( 215 | context_layer, 216 | num_attention_heads * size_per_head, 217 | kernel_initializer=create_initializer(initializer_range), 218 | name='context_projection_layer' 219 | ) 220 | context_layer_projected = dropout(context_layer_projected, hidden_dropout_prob) 221 | 222 | return context_layer_projected, cached_keys_and_values 223 | 224 | 225 | def residual_mlp_layer(x_flat, intermediate_size, initializer_range=0.02, hidden_dropout_prob=0.1): 226 | """ 227 | :param x: The attention output. It should be [batch_size*seq_length, dim] 228 | :param intermediate_size: the hidden projection. By default this is the input_dim * 4. 229 | 230 | in the original GPT we would return layer_norm(x_norm + h1) rather than layer_norm(x + h1) 231 | 232 | :return: 233 | """ 234 | batch_size_seq_length, hidden_size = get_shape_list(x_flat, expected_rank=2) 235 | x_norm = layer_norm(x_flat, name='mlp_ln0') 236 | 237 | intermediate_output = tf.layers.dense( 238 | x_norm, 239 | intermediate_size, 240 | activation=gelu, 241 | kernel_initializer=create_initializer(initializer_range), 242 | name='intermediate', 243 | ) 244 | 245 | output_for_residual = tf.layers.dense( 246 | intermediate_output, 247 | hidden_size, 248 | name='output', 249 | kernel_initializer=create_initializer(initializer_range)) 250 | output_for_residual = dropout(output_for_residual, hidden_dropout_prob) 251 | 252 | layer_output = layer_norm(x_flat + output_for_residual, name='mlp_ln1') 253 | return layer_output 254 | 255 | 256 | def embed(input_ids, 257 | vocab_size, 258 | embedding_size, 259 | position_offset=0, 260 | initializer_range=0.02, 261 | max_position_embeddings=512, 262 | use_one_hot_embeddings=True): 263 | """reur and position embeddings 264 | :param input_ids: int Tensor of shape [batch_size, seq_length]. 265 | :param vocab_size: number of words in vocab 266 | :param embedding_size: dimensionality of the embedding 267 | :param position_offset: aka number of cached tokens. 268 | :param initializer_range: float. Range of the weight initialization. 269 | :param max_position_embeddings: int. Maximum sequence length. 270 | :param use_one_hot_embeddings: probably want this to be true 271 | :return: [batch_size, seq_length, embedding_size] embedded tensor 272 | """ 273 | (batch_size, seq_length) = get_shape_list(input_ids, expected_rank=2) 274 | 275 | embedding_table = tf.get_variable( 276 | name='word_embed', 277 | shape=[vocab_size, embedding_size], 278 | initializer=create_initializer(initializer_range), 279 | ) 280 | 281 | assert_op = tf.assert_less_equal(tf.reduce_max(input_ids), vocab_size - 1) 282 | with tf.control_dependencies([assert_op]): 283 | if use_one_hot_embeddings: 284 | flat_input_ids = tf.reshape(input_ids, [-1]) 285 | one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size) 286 | output_flat = tf.matmul(one_hot_input_ids, embedding_table) 287 | else: 288 | output_flat = tf.nn.embedding_lookup(embedding_table, input_ids) 289 | 290 | embedded_input = tf.reshape(output_flat, [batch_size, seq_length, embedding_size]) 291 | 292 | assert_op = tf.assert_less_equal(seq_length, max_position_embeddings) 293 | 294 | with tf.control_dependencies([assert_op]): 295 | full_position_embeddings = tf.get_variable( 296 | name='pos_embed', 297 | shape=[max_position_embeddings, embedding_size], 298 | initializer=create_initializer(initializer_range), 299 | ) 300 | # Since the position embedding table is a learned variable, we create it 301 | # using a (long) sequence length `max_position_embeddings`. The actual 302 | # sequence length might be shorter than this, for faster training of 303 | # tasks that do not have long sequences. 304 | # 305 | # So `full_position_embeddings` is effectively an embedding table 306 | # for position [0, 1, 2, ..., max_position_embeddings-1], and the current 307 | # sequence has positions [0, 1, 2, ... seq_length-1], so we can just 308 | # perform a slice. 309 | if position_offset == 0: 310 | embedded_input += tf.slice(full_position_embeddings, [0, 0], [seq_length, -1])[None] 311 | else: 312 | # Tensorflow is too stupid to allow slicing 313 | flat_pos_ids = (tf.range(seq_length, dtype=tf.int32) + position_offset) 314 | one_hot_pos_ids = tf.one_hot(flat_pos_ids, depth=max_position_embeddings) 315 | 316 | # [seq_length, full_position_embeddings], [full_position_embeddings, dim] 317 | seq_embeds = tf.matmul(one_hot_pos_ids, full_position_embeddings) 318 | embedded_input += seq_embeds[None] 319 | 320 | # embedded_input += tf.slice(full_position_embeddings[position_offset:], [0, 0], [seq_length, -1])[None] 321 | 322 | return layer_norm(embedded_input, name='embed_norm'), embedding_table 323 | 324 | 325 | def _top_p_sample(logits, ignore_ids=None, num_samples=1, p=0.9): 326 | """ 327 | Does top-p sampling. if ignore_ids is on, then we will zero out those logits. 328 | :param logits: [batch_size, vocab_size] tensor 329 | :param ignore_ids: [vocab_size] one-hot representation of the indices we'd like to ignore and never predict, 330 | like padding maybe 331 | :param p: topp threshold to use, either a float or a [batch_size] vector 332 | :return: [batch_size, num_samples] samples 333 | 334 | # TODO FIGURE OUT HOW TO DO THIS ON TPUS. IT'S HELLA SLOW RIGHT NOW, DUE TO ARGSORT I THINK 335 | """ 336 | with tf.variable_scope('top_p_sample'): 337 | batch_size, vocab_size = get_shape_list(logits, expected_rank=2) 338 | 339 | probs = tf.nn.softmax(logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10, 340 | axis=-1) 341 | 342 | if isinstance(p, float) and p > 0.999999: 343 | # Don't do top-p sampling in this case 344 | print("Top-p sampling DISABLED", flush=True) 345 | return { 346 | 'probs': probs, 347 | 'sample': tf.random.categorical( 348 | logits=logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10, 349 | num_samples=num_samples, dtype=tf.int32), 350 | } 351 | 352 | # [batch_size, vocab_perm] 353 | indices = tf.argsort(probs, direction='DESCENDING') 354 | cumulative_probabilities = tf.math.cumsum(tf.batch_gather(probs, indices), axis=-1, exclusive=False) 355 | 356 | # find the top pth index to cut off. careful we don't want to cutoff everything! 357 | # result will be [batch_size, vocab_perm] 358 | p_expanded = p if isinstance(p, float) else p[:, None] 359 | exclude_mask = tf.logical_not( 360 | tf.logical_or(cumulative_probabilities < p_expanded, tf.range(vocab_size)[None] < 1)) 361 | 362 | # OPTION A - sample in the sorted space, then unsort. 363 | logits_to_use = tf.batch_gather(logits, indices) - tf.cast(exclude_mask, tf.float32) * 1e10 364 | sample_perm = tf.random.categorical(logits=logits_to_use, num_samples=num_samples) 365 | sample = tf.batch_gather(indices, sample_perm) 366 | 367 | # OPTION B - unsort first - Indices need to go back to 0 -> N-1 -- then sample 368 | # unperm_indices = tf.argsort(indices, direction='ASCENDING') 369 | # include_mask_unperm = tf.batch_gather(include_mask, unperm_indices) 370 | # logits_to_use = logits - (1 - tf.cast(include_mask_unperm, tf.float32)) * 1e10 371 | # sample = tf.random.categorical(logits=logits_to_use, num_samples=num_samples, dtype=tf.int32) 372 | 373 | return { 374 | 'probs': probs, 375 | 'sample': sample, 376 | } 377 | 378 | 379 | def _top_k_sample(logits, ignore_ids=None, num_samples=1, k=10): 380 | """ 381 | Does top-k sampling. if ignore_ids is on, then we will zero out those logits. 382 | :param logits: [batch_size, vocab_size] tensor 383 | :param ignore_ids: [vocab_size] one-hot representation of the indices we'd like to ignore and never predict, 384 | like padding maybe 385 | :param p: topp threshold to use, either a float or a [batch_size] vector 386 | :return: [batch_size, num_samples] samples 387 | 388 | # TODO FIGURE OUT HOW TO DO THIS ON TPUS. IT'S HELLA SLOW RIGHT NOW, DUE TO ARGSORT I THINK 389 | """ 390 | with tf.variable_scope('top_p_sample'): 391 | batch_size, vocab_size = get_shape_list(logits, expected_rank=2) 392 | 393 | probs = tf.nn.softmax(logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10, 394 | axis=-1) 395 | # [batch_size, vocab_perm] 396 | indices = tf.argsort(probs, direction='DESCENDING') 397 | 398 | # find the top pth index to cut off. careful we don't want to cutoff everything! 399 | # result will be [batch_size, vocab_perm] 400 | k_expanded = k if isinstance(k, int) else k[:, None] 401 | exclude_mask = tf.range(vocab_size)[None] >= k_expanded 402 | 403 | # OPTION A - sample in the sorted space, then unsort. 404 | logits_to_use = tf.batch_gather(logits, indices) - tf.cast(exclude_mask, tf.float32) * 1e10 405 | sample_perm = tf.random.categorical(logits=logits_to_use, num_samples=num_samples) 406 | sample = tf.batch_gather(indices, sample_perm) 407 | 408 | return { 409 | 'probs': probs, 410 | 'sample': sample, 411 | } 412 | 413 | 414 | class GroverModel(object): 415 | def __init__(self, 416 | config: GroverConfig, 417 | is_training, 418 | input_ids, 419 | cache=None, 420 | do_cache=False, 421 | pad_token_id=0, 422 | chop_off_last_token=True, 423 | scope=None, 424 | reuse=False): 425 | """ 426 | :param config: 427 | :param is_training: 428 | :param input_ids: Tensor thats of size [batch_size, seq_length] 429 | :param cache: Optionally, a tensor to use that will contain cached information of the size 430 | [batch_size, num_layers, 2, num_heads, cache_length, features] 431 | :param do_cache: Whether to cache again. 432 | :param pad_token_id: Which token will be used for padding (probably 0.) 433 | :param chop_off_last_token: True if we will end up using this for TRAINING only. False if we want to generate. 434 | it means the last token in input_ids will not be processed by the model as input 435 | :param scope: scope to run this on 436 | """ 437 | self.config = copy.deepcopy(config) 438 | self.is_training = is_training 439 | self.pad_token_id = pad_token_id 440 | 441 | if not is_training: 442 | self.config.hidden_dropout_prob = 0.0 443 | self.config.attention_probs_dropout_prob = 0.0 444 | 445 | if chop_off_last_token: 446 | self.target_ids = input_ids[:, 1:] 447 | self.input_ids = input_ids[:, :-1] 448 | else: 449 | self.input_ids = input_ids 450 | self.target_ids = tf.concat((input_ids[:, 1:], 451 | tf.constant(self.pad_token_id, dtype=self.input_ids.dtype, 452 | shape=[get_shape_list(self.input_ids, 2)[0], 1])), 1) 453 | 454 | self.batch_size, self.seq_length = get_shape_list(self.input_ids, 2) 455 | 456 | if cache is None: 457 | caches = [None] * config.num_hidden_layers 458 | self.cache_length = 0 459 | else: 460 | batch_size_, num_layers_, two_, num_heads_, self.cache_length, features_ = get_shape_list( 461 | cache, expected_rank=6) 462 | assert batch_size_ == self.batch_size 463 | assert num_layers_ == config.num_hidden_layers 464 | assert two_ == 2 465 | assert num_heads_ == config.num_attention_heads 466 | assert features_ == (config.hidden_size // config.num_attention_heads) 467 | caches = tf.unstack(cache, axis=1) 468 | 469 | with tf.variable_scope(scope, default_name='newslm', reuse=reuse): 470 | with tf.variable_scope("embeddings"): 471 | embeddings, self.embedding_table = embed(self.input_ids, config.vocab_size, 472 | config.hidden_size, 473 | position_offset=self.cache_length, 474 | initializer_range=config.initializer_range, 475 | max_position_embeddings=config.max_position_embeddings, 476 | use_one_hot_embeddings=True) 477 | 478 | mask = get_attention_mask(self.seq_length, self.seq_length + self.cache_length, dtype=embeddings.dtype) 479 | 480 | # We keep the representation as a 2D tensor to avoid re-shaping it back and 481 | # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on 482 | # the GPU/CPU but may not be free on the TPU, so we want to minimize them to 483 | # help the optimizer. 484 | hidden_state = tf.reshape(embeddings, [self.batch_size * self.seq_length, self.config.hidden_size]) 485 | new_kvs = [] 486 | for layer_idx, layer_cache in enumerate(caches): 487 | with tf.variable_scope('layer{:02d}'.format(layer_idx)): 488 | # [batch_size * seq_length, hidden_size] 489 | attention_output, new_kv = attention_layer( 490 | hidden_state, 491 | mask, 492 | batch_size=self.batch_size, 493 | seq_length=self.seq_length, 494 | size_per_head=config.hidden_size // config.num_attention_heads, 495 | num_attention_heads=config.num_attention_heads, 496 | initializer_range=config.initializer_range, 497 | hidden_dropout_prob=self.config.hidden_dropout_prob, 498 | attention_probs_dropout_prob=self.config.attention_probs_dropout_prob, 499 | do_cache=do_cache, 500 | cache=layer_cache, 501 | ) 502 | new_kvs.append(new_kv) 503 | 504 | # [batch_size * seq_length, hidden_size] 505 | hidden_state = residual_mlp_layer(hidden_state + attention_output, 506 | intermediate_size=config.intermediate_size, 507 | hidden_dropout_prob=self.config.hidden_dropout_prob) 508 | self.hidden_state = hidden_state 509 | 510 | self.new_kvs = tf.stack(new_kvs, axis=1) if do_cache else None 511 | 512 | # Note that the hidden state is still flat (batch_size*hidden_size) 513 | self.logits_flat = tf.matmul(self.hidden_state, self.embedding_table, transpose_b=True) 514 | 515 | # THE OUTPUT BIAS DOES NOT SPARK JOY 516 | # output_bias = tf.get_variable('output_bias', shape=[config.vocab_size], initializer=tf.zeros_initializer()) 517 | # self.logits_flat = tf.nn.bias_add(self.logits_flat, output_bias) 518 | 519 | @property 520 | def log_probs(self): 521 | logprobs_flat = tf.nn.log_softmax(self.logits_flat, axis=-1) 522 | return tf.reshape(logprobs_flat, [self.batch_size, self.seq_length, -1]) 523 | 524 | def lm_loss(self): 525 | """ 526 | :return: stuff 527 | """ 528 | target_ids_flat = tf.reshape(self.target_ids, [-1]) 529 | 530 | # 1 if it's valid and 0 otherwise. 531 | label_weights = tf.cast(tf.not_equal(target_ids_flat, self.pad_token_id), dtype=self.logits_flat.dtype) 532 | 533 | # [batch_size * seq_length, vocab_size] 534 | one_hot_labels = tf.one_hot(target_ids_flat, 535 | depth=self.config.vocab_size, 536 | dtype=self.logits_flat.dtype) 537 | 538 | # [batch_size * seq_length, vocab_size] 539 | logprobs_flat = tf.nn.log_softmax(self.logits_flat, axis=-1) 540 | 541 | per_example_loss = -tf.reduce_sum(logprobs_flat * one_hot_labels, axis=[-1]) 542 | 543 | # per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits_flat, labels=target_ids_flat) 544 | 545 | numerator = tf.reduce_sum(label_weights * per_example_loss) 546 | denominator = tf.reduce_sum(label_weights) + 1e-5 547 | loss = numerator / denominator 548 | return loss 549 | 550 | def pooled_output(self, clf_token): 551 | """ 552 | Extract pooled output given a token that says where we should look 553 | :param clf_token: 554 | :return: 555 | """ 556 | pool_idx = tf.cast(tf.argmax(tf.cast(tf.equal(self.input_ids, clf_token), tf.float32), 1), tf.int32) 557 | return tf.gather(self.hidden_state, tf.range(self.batch_size, dtype=tf.int32) * self.seq_length + pool_idx) 558 | 559 | 560 | def model_fn_builder(config: GroverConfig, init_checkpoint, learning_rate, 561 | num_train_steps, num_warmup_steps, use_tpu): 562 | """Returns `model_fn` closure for TPUEstimator.""" 563 | 564 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 565 | """The `model_fn` for TPUEstimator.""" 566 | 567 | tf.logging.info("*** Features ***") 568 | for name in sorted(features.keys()): 569 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 570 | 571 | input_ids = features["input_ids"] 572 | 573 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 574 | 575 | model = GroverModel( 576 | config=config, 577 | is_training=is_training, 578 | input_ids=input_ids, 579 | pad_token_id=config.pad_token_id, 580 | chop_off_last_token=True, 581 | ) 582 | 583 | total_loss = model.lm_loss() 584 | 585 | if is_training: 586 | train_op, train_metrics = optimization_adafactor.create_optimizer( 587 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) 588 | tvars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) 589 | else: 590 | train_op = None 591 | train_metrics = {} 592 | tvars = tf.trainable_variables() 593 | 594 | initialized_variable_names = {} 595 | scaffold_fn = None 596 | if init_checkpoint: 597 | (assignment_map, initialized_variable_names 598 | ) = get_assignment_map_from_checkpoint(tvars, init_checkpoint) 599 | if use_tpu: 600 | def tpu_scaffold(): 601 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 602 | return tf.train.Scaffold() 603 | 604 | scaffold_fn = tpu_scaffold 605 | else: 606 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 607 | 608 | tf.logging.info("**** Trainable Variables ****") 609 | for var in tvars: 610 | init_string = "" 611 | if var.name in initialized_variable_names: 612 | init_string = ", *INIT_FROM_CKPT*" 613 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 614 | init_string) 615 | 616 | output_spec = None 617 | if mode == tf.estimator.ModeKeys.TRAIN: 618 | if use_tpu: 619 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 620 | mode=mode, 621 | loss=total_loss, 622 | train_op=train_op, 623 | host_call=construct_scalar_host_call(metric_dict=train_metrics, model_dir=params['model_dir'], 624 | prefix='training/'), 625 | scaffold_fn=scaffold_fn) 626 | else: 627 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 628 | mode=mode, 629 | loss=total_loss, 630 | train_op=train_op, 631 | training_hooks=[ 632 | tf.train.LoggingTensorHook({'loss': tf.metrics.mean(total_loss)[1]}, every_n_iter=100)], 633 | scaffold_fn=scaffold_fn) 634 | 635 | elif mode == tf.estimator.ModeKeys.EVAL: 636 | def metric_fn(total_loss): 637 | loss = tf.metrics.mean(values=total_loss) 638 | return { 639 | "eval_loss": loss, 640 | } 641 | 642 | eval_metrics = (metric_fn, 643 | [total_loss]) 644 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 645 | mode=mode, 646 | loss=total_loss, 647 | eval_metrics=eval_metrics, 648 | scaffold_fn=scaffold_fn) 649 | else: 650 | gt_logprobs = tf.squeeze(tf.batch_gather(model.log_probs, model.target_ids[:, :, None]), axis=2) 651 | 652 | # Need top-p required under topp sampling! 653 | better_than_gt = model.log_probs > gt_logprobs[:, :, None] 654 | top_p_required = tf.reduce_sum(tf.cast(better_than_gt, tf.float32) * tf.exp(model.log_probs), axis=2) 655 | 656 | # No top-p sampling for now, since this seems to be too slow on TPUs 657 | if use_tpu: 658 | predictions = tf.reshape( 659 | tf.random.categorical(logits=model.logits_flat, num_samples=1), 660 | get_shape_list(model.target_ids), 661 | ) 662 | else: 663 | # Argmax 664 | # predictions = tf.math.argmax(model.log_probs, axis=-1, output_type=tf.int32) 665 | predictions = tf.reshape( 666 | _top_p_sample(model.logits_flat, num_samples=1, p=0.99)['sample'], 667 | get_shape_list(model.target_ids), 668 | ) 669 | pred_logprobs = tf.squeeze(tf.batch_gather(model.log_probs, predictions[:, :, None]), axis=2) 670 | 671 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 672 | mode=mode, 673 | predictions={'gt_logprobs': gt_logprobs, 674 | 'top_p_required': top_p_required, 675 | 'predictions': predictions, 676 | 'pred_logprobs': pred_logprobs, 677 | 'labels': input_ids}, 678 | scaffold_fn=scaffold_fn) 679 | return output_spec 680 | 681 | return model_fn 682 | 683 | 684 | def sample_step(tokens, ignore_ids, news_config, batch_size=1, p_for_topp=0.95, cache=None, do_topk=False): 685 | """ 686 | Helper function that samples from grover for a single step 687 | :param tokens: [batch_size, n_ctx_b] tokens that we will predict from 688 | :param ignore_ids: [n_vocab] mask of the tokens we don't want to predict 689 | :param news_config: config for the GroverModel 690 | :param batch_size: batch size to use 691 | :param p_for_topp: top-p or top-k threshold 692 | :param cache: [batch_size, news_config.num_hidden_layers, 2, 693 | news_config.num_attention_heads, n_ctx_a, 694 | news_config.hidden_size // news_config.num_attention_heads] OR, None 695 | :return: new_tokens, size [batch_size] 696 | new_probs, also size [batch_size] 697 | new_cache, size [batch_size, news_config.num_hidden_layers, 2, n_ctx_b, 698 | news_config.num_attention_heads, news_config.hidden_size // news_config.num_attention_heads] 699 | """ 700 | model = GroverModel( 701 | config=news_config, 702 | is_training=False, 703 | input_ids=tokens, 704 | reuse=tf.AUTO_REUSE, 705 | scope='newslm', 706 | chop_off_last_token=False, 707 | do_cache=True, 708 | cache=cache, 709 | ) 710 | 711 | # Extract the FINAL SEQ LENGTH 712 | batch_size_times_seq_length, vocab_size = get_shape_list(model.logits_flat, expected_rank=2) 713 | next_logits = tf.reshape(model.logits_flat, [batch_size, -1, vocab_size])[:, -1] 714 | 715 | if do_topk: 716 | sample_info = _top_k_sample(next_logits, num_samples=1, k=tf.cast(p_for_topp, dtype=tf.int32)) 717 | else: 718 | sample_info = _top_p_sample(next_logits, ignore_ids=ignore_ids, num_samples=1, p=p_for_topp) 719 | 720 | new_tokens = tf.squeeze(sample_info['sample'], 1) 721 | new_probs = tf.squeeze(tf.batch_gather(sample_info['probs'], sample_info['sample']), 1) 722 | return { 723 | 'new_tokens': new_tokens, 724 | 'new_probs': new_probs, 725 | 'new_cache': model.new_kvs, 726 | } 727 | 728 | 729 | def initialize_from_context(initial_context, ignore_ids, news_config, p_for_topp=0.95, do_topk=False): 730 | """ same signature as sample_step""" 731 | batch_size, _ = get_shape_list(initial_context, expected_rank=2) 732 | 733 | context_output = sample_step(tokens=initial_context, ignore_ids=ignore_ids, news_config=news_config, 734 | batch_size=batch_size, p_for_topp=p_for_topp, cache=None, do_topk=do_topk) 735 | return { 736 | 'tokens': tf.concat([initial_context, context_output['new_tokens'][:, None]], 1), 737 | 'cache': context_output['new_cache'], 738 | 'probs': context_output['new_probs'][:, None] 739 | } 740 | 741 | 742 | def sample(news_config: GroverConfig, initial_context, eos_token, min_len, ignore_ids=None, p_for_topp=0.95, 743 | do_topk=False): 744 | """ 745 | V1 version of: sample outputs from a model, and do it all at once 746 | :param news_config: Configuration used to construct the model 747 | :param initial_context: [batch_size, seq_length] that we'll start generating with 748 | :param eos_token: Stop generating if you see this (tf scalar) 749 | :param min_len: min length of sample 750 | :param ignore_ids: NEVER GENERATE THESE [vocab_size] 751 | :return: 752 | """ 753 | batch_size, _ = get_shape_list(initial_context, expected_rank=2) 754 | 755 | if ignore_ids is None: 756 | ignore_ids = tf.constant([x == 0 for x in range(news_config.vocab_size)], dtype=tf.bool) 757 | 758 | with tf.name_scope('sample_sequence'): 759 | # Initial call to get cache 760 | context_output = initialize_from_context(initial_context, ignore_ids=ignore_ids, news_config=news_config, 761 | p_for_topp=p_for_topp, 762 | do_topk=do_topk) 763 | ctx = context_output['tokens'] 764 | cache = context_output['cache'] 765 | probs = context_output['probs'] 766 | 767 | def body(ctx, cache, probs): 768 | """ for whatever reason this didn't work when I ran it on more than one at once... ugh.""" 769 | next_outputs = sample_step(ctx[:, -1][:, None], ignore_ids=ignore_ids, news_config=news_config, 770 | batch_size=batch_size, p_for_topp=p_for_topp, cache=cache, 771 | do_topk=do_topk) 772 | 773 | # Update everything 774 | new_cache = tf.concat([cache, next_outputs['new_cache']], axis=-2) 775 | new_ids = tf.concat([ctx, next_outputs['new_tokens'][:, None]], axis=1) 776 | new_probs = tf.concat([probs, next_outputs['new_probs'][:, None]], axis=1) 777 | return [new_ids, new_cache, new_probs] 778 | 779 | def cond(ctx, cache, probs): 780 | # ctx = tf.Print(ctx,[tf.shape(ctx)]) 781 | is_eos = tf.reduce_all(tf.reduce_any(tf.equal(ctx[:,-1:], eos_token), axis=1)) 782 | is_len = tf.greater(get_shape_list(ctx)[1], min_len) 783 | return tf.logical_not(tf.logical_and(is_eos, is_len)) 784 | 785 | tokens, cache, probs = tf.while_loop( 786 | cond=cond, body=body, maximum_iterations=1025 - get_shape_list(ctx)[1], 787 | loop_vars=[ctx, cache, probs], 788 | shape_invariants=[tf.TensorShape([batch_size, None]), 789 | tf.TensorShape( 790 | [batch_size, news_config.num_hidden_layers, 2, 791 | news_config.num_attention_heads, 792 | None, news_config.hidden_size // news_config.num_attention_heads]), 793 | tf.TensorShape([batch_size, None]), 794 | ], 795 | back_prop=False, 796 | ) 797 | return tokens, probs 798 | -------------------------------------------------------------------------------- /train/optimization_adafactor.py: -------------------------------------------------------------------------------- 1 | # Original work Copyright 2018 The Google AI Language Team Authors. 2 | # Modified work Copyright 2019 Rowan Zellers 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import re 16 | import tensorflow as tf 17 | from train.utils import get_shape_list 18 | 19 | 20 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): 21 | """Creates an optimizer training op.""" 22 | global_step = tf.train.get_or_create_global_step() 23 | 24 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 25 | 26 | # Implements linear decay of the learning rate. 27 | learning_rate = tf.train.polynomial_decay( 28 | learning_rate, 29 | global_step, 30 | num_train_steps, 31 | end_learning_rate=0.0, 32 | power=1.0, 33 | cycle=False) 34 | 35 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 36 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 37 | if num_warmup_steps: 38 | global_steps_int = tf.cast(global_step, tf.int32) 39 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 40 | 41 | global_steps_float = tf.cast(global_steps_int, tf.float32) 42 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 43 | 44 | warmup_percent_done = global_steps_float / warmup_steps_float 45 | warmup_learning_rate = init_lr * warmup_percent_done 46 | 47 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 48 | learning_rate = ( 49 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 50 | 51 | # It is recommended that you use this optimizer for fine tuning, since this 52 | # is how the model was trained (note that the Adam m/v variables are NOT 53 | # loaded from init_checkpoint.) 54 | optimizer = AdaFactorOptimizer( 55 | learning_rate=learning_rate, 56 | weight_decay_rate=0.01, 57 | beta_1=0.9, 58 | beta_2=0.999, 59 | epsilon=1e-6, 60 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 61 | 62 | if use_tpu: 63 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 64 | 65 | tvars = tf.trainable_variables() 66 | grads = tf.gradients(loss, tvars) 67 | 68 | # You could do this, but instead we don't because a) it's slow and b) we already did the 'update clipping' 69 | # (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 70 | 71 | train_op = optimizer.apply_gradients( 72 | zip(grads, tvars), global_step=global_step) 73 | 74 | # Normally the global step update is done inside of `apply_gradients`. 75 | # However, `AdaFactorOptimizer` doesn't do this. But if you use 76 | # a different optimizer, you should probably take this line out. 77 | new_global_step = global_step + 1 78 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 79 | 80 | train_metrics = { 81 | 'learning_rate': learning_rate, 82 | 'minibatch_loss': loss, 83 | # 'minibatch_ppl': tf.math.exp(loss), 84 | } 85 | return train_op, train_metrics 86 | 87 | 88 | class AdaFactorOptimizer(tf.compat.v1.train.Optimizer): 89 | """here's the optimizer we'll use""" 90 | 91 | def __init__(self, 92 | learning_rate, 93 | weight_decay_rate=0.0, 94 | beta_1=0.9, 95 | beta_2=0.999, 96 | epsilon=1e-6, 97 | exclude_from_weight_decay=None, 98 | clipping_rate=1.0, 99 | name="AdaFactorOptimizer"): 100 | """Constructs a AdaFactorOptimizer.""" 101 | super(AdaFactorOptimizer, self).__init__(False, name) 102 | 103 | self.learning_rate = learning_rate 104 | self.weight_decay_rate = weight_decay_rate 105 | self.beta_1 = beta_1 106 | self.beta_2 = beta_2 107 | self.epsilon = epsilon 108 | self.epsilon1 = 1e-30 109 | self.epsilon2 = 0.001 110 | self.clipping_rate = clipping_rate 111 | self.exclude_from_weight_decay = exclude_from_weight_decay 112 | self.use_locking = False 113 | 114 | def _use_factored(self, shape): 115 | return len(shape) >= 2 116 | 117 | def _parameter_scale(self, var): 118 | """Estimate the scale of the parameters from the current values. 119 | We include a minimum value of 0.001 to give it a chance to escape 0 120 | if it was zero-initialized. 121 | Instead of using the value, we could impute the scale from the shape, 122 | as initializers do. 123 | Args: 124 | var: a variable or Tensor. 125 | Returns: 126 | a Scalar 127 | """ 128 | return tf.maximum(reduce_rms(var), self.epsilon2) 129 | 130 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 131 | """See base class.""" 132 | assignments = [] 133 | for (grad, param) in grads_and_vars: 134 | if grad is None or param is None: 135 | continue 136 | 137 | param_name = self._get_variable_name(param.name) 138 | shape_list = get_shape_list(param, expected_rank=[1, 2]) 139 | 140 | # decay_rate = 1 - tf.pow(tf.cast(tf.train.get_or_create_global_step(), tf.float32) + 1.0, -0.8) 141 | decay_rate = self.beta_2 142 | grad_squared = tf.square(grad) + self.epsilon1 143 | 144 | update_scale = self.learning_rate 145 | # update_scale = self.learning_rate * tf.cast(self._parameter_scale(param), dtype=tf.float32) 146 | 147 | # HACK: Make things dependent on grad. 148 | # This confounds the XLA rewriter and keeps it from fusing computations 149 | # across different variables. This fusion is a bad for HBM usage, since 150 | # it causes the gradients to persist in memory. 151 | grad_squared_mean = tf.reduce_mean(grad_squared) 152 | decay_rate += grad_squared_mean * 1e-30 153 | update_scale += grad_squared_mean * 1e-30 154 | 155 | # END HACK 156 | 157 | if self._use_factored(shape_list): 158 | num_rows, num_columns = shape_list 159 | 160 | vr = tf.get_variable( 161 | name=param_name + "/adafactor_vr", 162 | shape=[num_rows], 163 | dtype=tf.float32, 164 | trainable=False, 165 | initializer=tf.zeros_initializer()) 166 | vc = tf.get_variable( 167 | name=param_name + "/adafactor_vc", 168 | shape=[num_columns], 169 | dtype=tf.float32, 170 | trainable=False, 171 | initializer=tf.zeros_initializer()) 172 | 173 | next_vr = decay_rate * vr + (1 - decay_rate) * tf.reduce_mean(grad_squared, 1) 174 | next_vc = decay_rate * vc + (1 - decay_rate) * tf.reduce_mean(grad_squared, 0) 175 | 176 | long_term_mean = tf.reduce_mean(next_vr, -1, keepdims=True) 177 | r_factor = tf.rsqrt(next_vr / long_term_mean + self.epsilon1) 178 | c_factor = tf.rsqrt(next_vc + self.epsilon1) 179 | update = grad * tf.expand_dims(r_factor, -1) * tf.expand_dims(c_factor, -2) 180 | 181 | assignments.append(vr.assign(next_vr, use_locking=self.use_locking)) 182 | assignments.append(vc.assign(next_vc, use_locking=self.use_locking)) 183 | else: 184 | v = tf.get_variable( 185 | name=param_name + "/adafactor_v", 186 | shape=shape_list, 187 | dtype=tf.float32, 188 | trainable=False, 189 | initializer=tf.zeros_initializer()) 190 | next_v = decay_rate * v + (1 - decay_rate) * grad_squared 191 | 192 | assignments.append(v.assign(next_v, use_locking=self.use_locking)) 193 | update = grad * tf.rsqrt(next_v + self.epsilon1) 194 | 195 | clipping_denom = tf.maximum(1.0, reduce_rms(update) / self.clipping_rate) 196 | update /= clipping_denom 197 | 198 | # Do weight decay 199 | # Just adding the square of the weights to the loss function is *not* 200 | # the correct way of using L2 regularization/weight decay with Adam, 201 | # since that will interact with the m and v parameters in strange ways. 202 | # 203 | # Instead we want ot decay the weights in a manner that doesn't interact 204 | # with the m/v parameters. This is equivalent to adding the square 205 | # # of the weights to the loss with plain (non-momentum) SGD. 206 | if self._do_use_weight_decay(param_name): 207 | update += self.weight_decay_rate * param 208 | 209 | update_with_lr = update_scale * update 210 | next_param = param - update_with_lr 211 | 212 | assignments.append(param.assign(next_param, use_locking=self.use_locking)) 213 | return tf.group(*assignments, name=name) 214 | 215 | def _do_use_weight_decay(self, param_name): 216 | """Whether to use L2 weight decay for `param_name`.""" 217 | if not self.weight_decay_rate: 218 | return False 219 | if self.exclude_from_weight_decay: 220 | for r in self.exclude_from_weight_decay: 221 | if re.search(r, param_name) is not None: 222 | return False 223 | return True 224 | 225 | def _get_variable_name(self, param_name): 226 | """Get the variable name from the tensor name.""" 227 | m = re.match("^(.*):\\d+$", param_name) 228 | if m is not None: 229 | param_name = m.group(1) 230 | return param_name 231 | 232 | 233 | def reduce_rms(x): 234 | return tf.sqrt(tf.reduce_mean(tf.square(x))) 235 | -------------------------------------------------------------------------------- /train/train_tpu.py: -------------------------------------------------------------------------------- 1 | # Original work Copyright 2018 The Google AI Language Team Authors. 2 | # Modified work Copyright 2019 Rowan Zellers 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """ Training script! """ 17 | 18 | import tensorflow.compat.v1 as tf 19 | 20 | from train.dataloader import input_fn_builder 21 | from train.modeling import model_fn_builder, GroverConfig 22 | 23 | flags = tf.flags 24 | 25 | FLAGS = flags.FLAGS 26 | 27 | ## Required parameters 28 | flags.DEFINE_string( 29 | "config_file", 'configs/base.json', 30 | "The config json file corresponding to the pre-trained news model. " 31 | "This specifies the model architecture.") 32 | 33 | flags.DEFINE_string( 34 | "input_file", None, 35 | "Input TF example files (can be a glob or comma separated).") 36 | 37 | flags.DEFINE_string( 38 | "output_dir", None, 39 | "The output directory where the model checkpoints will be written.") 40 | 41 | ## Other parameters 42 | flags.DEFINE_string( 43 | "init_checkpoint", None, 44 | "Initial checkpoint (usually from a pre-trained model).") 45 | 46 | flags.DEFINE_integer( 47 | "max_seq_length", 1024, 48 | "The maximum total input sequence length after BPE tokenization. " 49 | "Sequences longer than this will be truncated, and sequences shorter " 50 | "than this will be padded. Must match data generation.") 51 | 52 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") 53 | 54 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for adafactor.") 55 | 56 | flags.DEFINE_integer("num_train_steps", 100000, "Number of training steps.") 57 | 58 | flags.DEFINE_integer("num_warmup_steps", 10000, "Number of warmup steps.") 59 | 60 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 61 | "How often to save the model checkpoint.") 62 | 63 | flags.DEFINE_integer("iterations_per_loop", 1000, 64 | "How many steps to make in each estimator call.") 65 | 66 | flags.DEFINE_integer("max_eval_steps", 100, "Maximum number of eval steps.") 67 | 68 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 69 | 70 | flags.DEFINE_string( 71 | "tpu_name", None, 72 | "The Cloud TPU to use for training. This should be either the name " 73 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " 74 | "url.") 75 | 76 | flags.DEFINE_string( 77 | "tpu_zone", None, 78 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 79 | "specified, we will attempt to automatically detect the GCE project from " 80 | "metadata.") 81 | 82 | flags.DEFINE_string( 83 | "gcp_project", None, 84 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 85 | "specified, we will attempt to automatically detect the GCE project from " 86 | "metadata.") 87 | 88 | flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 89 | 90 | flags.DEFINE_integer( 91 | "num_tpu_cores", 8, 92 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 93 | 94 | 95 | def main(_): 96 | tf.logging.set_verbosity(tf.logging.INFO) 97 | 98 | news_config = GroverConfig.from_json_file(FLAGS.config_file) 99 | 100 | tf.gfile.MakeDirs(FLAGS.output_dir) 101 | 102 | input_files = [] 103 | for input_pattern in FLAGS.input_file.split(","): 104 | input_files.extend(tf.gfile.Glob(input_pattern)) 105 | 106 | tf.logging.info("*** Input Files ***") 107 | for input_file in input_files: 108 | tf.logging.info(" %s" % input_file) 109 | 110 | tpu_cluster_resolver = None 111 | if FLAGS.use_tpu and FLAGS.tpu_name: 112 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( 113 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 114 | 115 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 116 | run_config = tf.contrib.tpu.RunConfig( 117 | cluster=tpu_cluster_resolver, 118 | master=FLAGS.master, 119 | model_dir=FLAGS.output_dir, 120 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 121 | keep_checkpoint_max=None, 122 | tpu_config=tf.contrib.tpu.TPUConfig( 123 | iterations_per_loop=FLAGS.iterations_per_loop, 124 | num_shards=FLAGS.num_tpu_cores, 125 | per_host_input_for_training=is_per_host)) 126 | 127 | model_fn = model_fn_builder(news_config, init_checkpoint=FLAGS.init_checkpoint, 128 | learning_rate=FLAGS.learning_rate, 129 | num_train_steps=FLAGS.num_train_steps, 130 | num_warmup_steps=FLAGS.num_warmup_steps, 131 | use_tpu=FLAGS.use_tpu, 132 | ) 133 | 134 | # If TPU is not available, this will fall back to normal Estimator on CPU 135 | # or GPU. 136 | estimator = tf.contrib.tpu.TPUEstimator( 137 | use_tpu=FLAGS.use_tpu, 138 | model_fn=model_fn, 139 | config=run_config, 140 | train_batch_size=FLAGS.train_batch_size, 141 | eval_batch_size=FLAGS.train_batch_size, 142 | params={'model_dir': FLAGS.output_dir} 143 | ) 144 | 145 | tf.logging.info("***** Running training *****") 146 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 147 | train_input_fn = input_fn_builder( 148 | input_files=input_files, 149 | seq_length=FLAGS.max_seq_length, 150 | is_training=True) 151 | 152 | estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps) 153 | 154 | if __name__ == "__main__": 155 | flags.mark_flag_as_required("input_file") 156 | flags.mark_flag_as_required("output_dir") 157 | tf.app.run() 158 | -------------------------------------------------------------------------------- /train/train_tpu_adafactor.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export PYTHONPATH=../ 4 | 5 | learning_rate=1e-4 6 | init_checkpoint="" 7 | max_seq_length=1024 8 | save_checkpoint_steps=1000 9 | 10 | # You can customize the training here 11 | # mega, medium, or base 12 | model_type="mega" 13 | OUTPUT_DIR="gs://" # put your output directory here 14 | input_file="gs://" # put your input files here, it can also be something like "*.tfrecord" 15 | 16 | if [ ${model_type} == "base" ]; then 17 | num_tpu_cores=32 18 | batch_size_per_core=16 19 | elif [ ${model_type} == "medium" ]; then 20 | num_tpu_cores=128 21 | batch_size_per_core=4 22 | elif [ ${model_type} == "mega" ]; then 23 | num_tpu_cores=256 24 | batch_size_per_core=2 25 | fi 26 | 27 | 28 | # there are 20k * 1024 examples so this translates to 20 epochs. seems ok and i can run for more if needed 29 | num_train_steps=800000 30 | 31 | # Make sure batch size scales. 32 | let batch_size="$batch_size_per_core * $num_tpu_cores" 33 | 34 | python train_tpu.py \ 35 | --config_file=configs/${model_type}.json \ 36 | --input_file=${input_file} \ 37 | --output_dir=${OUTPUT_DIR} \ 38 | --max_seq_length=${max_seq_length} \ 39 | --train_batch_size=${batch_size} \ 40 | --learning_rate=${learning_rate} \ 41 | --num_train_steps=${num_train_steps} \ 42 | --num_warmup_steps=10000 \ 43 | --save_checkpoints_steps=${save_checkpoint_steps} \ 44 | --iterations_per_loop=${save_checkpoint_steps} \ 45 | --use_tpu=True \ 46 | --tpu_name=$(hostname) \ 47 | --num_tpu_cores=$num_tpu_cores \ 48 | --init_checkpoint=${init_checkpoint} 49 | -------------------------------------------------------------------------------- /train/utils.py: -------------------------------------------------------------------------------- 1 | # Original work Copyright 2018 The Google AI Language Team Authors. 2 | # Modified work Copyright 2019 Rowan Zellers 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import collections 17 | import re 18 | 19 | import six 20 | import tensorflow.compat.v1 as tf 21 | import numpy as np 22 | from tensorflow.python.lib.io import file_io 23 | 24 | 25 | def _save_np(absolute_fn, array): 26 | if absolute_fn.startswith('gs://'): 27 | with file_io.FileIO(absolute_fn, 'w') as f: 28 | np.save(f, array) 29 | else: 30 | np.save(absolute_fn, array) 31 | 32 | 33 | def assert_rank(tensor, expected_rank, name=None): 34 | """Raises an exception if the tensor rank is not of the expected rank. 35 | 36 | Args: 37 | tensor: A tf.Tensor to check the rank of. 38 | expected_rank: Python integer or list of integers, expected rank. 39 | name: Optional name of the tensor for the error message. 40 | 41 | Raises: 42 | ValueError: If the expected shape doesn't match the actual shape. 43 | """ 44 | if name is None: 45 | name = tensor.name 46 | 47 | expected_rank_dict = {} 48 | if isinstance(expected_rank, six.integer_types): 49 | expected_rank_dict[expected_rank] = True 50 | else: 51 | for x in expected_rank: 52 | expected_rank_dict[x] = True 53 | 54 | actual_rank = tensor.shape.ndims 55 | if actual_rank not in expected_rank_dict: 56 | scope_name = tf.get_variable_scope().name 57 | raise ValueError( 58 | "For the tensor `%s` in scope `%s`, the actual rank " 59 | "`%d` (shape = %s) is not equal to the expected rank `%s`" % 60 | (name, scope_name, actual_rank, str(tensor.shape), str(expected_rank))) 61 | 62 | 63 | def get_shape_list(tensor, expected_rank=None, name=None): 64 | """Returns a list of the shape of tensor, preferring static dimensions. 65 | 66 | Args: 67 | tensor: A tf.Tensor object to find the shape of. 68 | expected_rank: (optional) int. The expected rank of `tensor`. If this is 69 | specified and the `tensor` has a different rank, and exception will be 70 | thrown. 71 | name: Optional name of the tensor for the error message. 72 | 73 | Returns: 74 | A list of dimensions of the shape of tensor. All static dimensions will 75 | be returned as python integers, and dynamic dimensions will be returned 76 | as tf.Tensor scalars. 77 | """ 78 | if name is None: 79 | name = tensor.name 80 | 81 | if expected_rank is not None: 82 | assert_rank(tensor, expected_rank, name) 83 | 84 | shape = tensor.shape.as_list() 85 | 86 | non_static_indexes = [] 87 | for (index, dim) in enumerate(shape): 88 | if dim is None: 89 | non_static_indexes.append(index) 90 | 91 | if not non_static_indexes: 92 | return shape 93 | 94 | dyn_shape = tf.shape(tensor) 95 | for index in non_static_indexes: 96 | shape[index] = dyn_shape[index] 97 | return shape 98 | 99 | 100 | def gelu(input_tensor): 101 | """Gaussian Error Linear Unit. 102 | 103 | This is a smoother version of the RELU. 104 | Original paper: https://arxiv.org/abs/1606.08415 105 | 106 | Args: 107 | input_tensor: float Tensor to perform activation. 108 | 109 | Returns: 110 | `input_tensor` with the GELU activation applied. 111 | """ 112 | cdf = 0.5 * (1.0 + tf.erf(input_tensor / tf.sqrt(2.0))) 113 | return input_tensor * cdf 114 | 115 | 116 | def layer_norm(input_tensor, name=None, epsilon=1e-5): 117 | """Run layer normalization on the last dimension of the tensor.""" 118 | name2use = f'LayerNorm_{name}' if name is not None else name 119 | with tf.variable_scope(name2use, default_name='LayerNorm'): 120 | dim = input_tensor.shape[-1].value 121 | gamma = tf.get_variable('gamma', [dim], initializer=tf.constant_initializer(1)) 122 | beta = tf.get_variable('beta', [dim], initializer=tf.constant_initializer(0)) 123 | mean = tf.reduce_mean(input_tensor, axis=-1, keepdims=True) 124 | std = tf.reduce_mean(tf.square(input_tensor - mean), axis=-1, keepdims=True) 125 | input_tensor = (input_tensor - mean) * tf.rsqrt(std + epsilon) 126 | input_tensor = input_tensor * gamma + beta 127 | return input_tensor 128 | 129 | 130 | def dropout(input_tensor, dropout_prob): 131 | """Perform dropout. 132 | 133 | Args: 134 | input_tensor: float Tensor. 135 | dropout_prob: Python float. The probability of dropping out a value (NOT of 136 | *keeping* a dimension as in `tf.nn.dropout`). 137 | 138 | Returns: 139 | A version of `input_tensor` with dropout applied. 140 | """ 141 | if dropout_prob is None or dropout_prob == 0.0: 142 | return input_tensor 143 | output = tf.nn.dropout(input_tensor, rate=dropout_prob) 144 | return output 145 | 146 | 147 | def get_attention_mask(nd, ns, *, dtype): 148 | """ 149 | this is a TPU compatible version of tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd) 150 | where the lower right triangle contains 1s 151 | """ 152 | i = tf.range(nd)[:, None] 153 | j = tf.range(ns) 154 | m = i >= j - ns + nd 155 | return tf.cast(m, dtype) 156 | 157 | 158 | def get_assignment_map_from_checkpoint(tvars, init_checkpoint): 159 | """Compute the union of the current variables and checkpoint variables.""" 160 | assignment_map = {} 161 | initialized_variable_names = {} 162 | 163 | name_to_variable = collections.OrderedDict() 164 | for var in tvars: 165 | name = var.name 166 | m = re.match("^(.*):\\d+$", name) 167 | if m is not None: 168 | name = m.group(1) 169 | name_to_variable[name] = var 170 | 171 | init_vars = tf.train.list_variables(init_checkpoint) 172 | 173 | assignment_map = collections.OrderedDict() 174 | for x in init_vars: 175 | (name, var) = (x[0], x[1]) 176 | if name not in name_to_variable: 177 | continue 178 | assignment_map[name] = name 179 | initialized_variable_names[name] = 1 180 | initialized_variable_names[name + ":0"] = 1 181 | return (assignment_map, initialized_variable_names) 182 | 183 | 184 | def construct_scalar_host_call(metric_dict, model_dir, prefix=""): 185 | """Construct a host call to log scalars when training on TPU. 186 | 187 | Args: 188 | metric_dict: A dict of the tensors to be logged. 189 | model_dir: The location to write the summary. 190 | prefix: The prefix (if any) to prepend to the metric names. 191 | 192 | Returns: 193 | A tuple of (function, args_to_be_passed_to_said_function) 194 | """ 195 | metric_names = list(metric_dict.keys()) 196 | 197 | def host_call_fn(global_step, *args): 198 | """Training host call. Creates scalar summaries for training metrics. 199 | 200 | This function is executed on the CPU and should not directly reference 201 | any Tensors in the rest of the `model_fn`. To pass Tensors from the 202 | model to the `metric_fn`, provide as part of the `host_call`. See 203 | https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec 204 | for more information. 205 | 206 | Arguments should match the list of `Tensor` objects passed as the second 207 | element in the tuple passed to `host_call`. 208 | 209 | Args: 210 | global_step: `Tensor with shape `[batch]` for the global_step 211 | *args: Remaining tensors to log. 212 | 213 | Returns: 214 | List of summary ops to run on the CPU host. 215 | """ 216 | step = global_step[0] 217 | with tf.contrib.summary.create_file_writer( 218 | logdir=model_dir, filename_suffix=".host_call").as_default(): 219 | with tf.contrib.summary.always_record_summaries(): 220 | for i, name in enumerate(metric_names): 221 | tf.contrib.summary.scalar(prefix + name, args[i][0], step=step) 222 | 223 | return tf.contrib.summary.all_summary_ops() 224 | 225 | # To log the current learning rate, and gradient norm for Tensorboard, the 226 | # summary op needs to be run on the host CPU via host_call. host_call 227 | # expects [batch_size, ...] Tensors, thus reshape to introduce a batch 228 | # dimension. These Tensors are implicitly concatenated to 229 | # [params['batch_size']]. 230 | global_step_tensor = tf.reshape( 231 | tf.compat.v1.train.get_or_create_global_step(), [1]) 232 | other_tensors = [tf.reshape(metric_dict[key], [1]) for key in metric_names] 233 | 234 | return host_call_fn, [global_step_tensor] + other_tensors 235 | --------------------------------------------------------------------------------