├── .github
├── ISSUE_TEMPLATE
│ ├── bug_report.md
│ ├── discussion.md
│ └── feature_request.md
├── demo.png
├── logo.png
├── logo.svg
├── loss.png
└── stale.yml
├── .gitignore
├── 3rd
└── gdown.pl
│ ├── LICENSE.txt
│ ├── README.md
│ └── gdown.pl
├── LICENSE
├── README.md
├── README_CN.md
├── configs
├── base.json
├── large.json
└── mega.json
├── dataset
├── README.md
├── prepare_data.py
└── prepare_data.sh
├── dockerfiles
└── gpu-jupyter.Dockerfile
├── pretrained_model_demo.ipynb
├── requirements-gpu.txt
├── requirements-tpu.txt
├── scripts
├── demo.py
└── down_gdrive_file.py
├── tokenization
├── __init__.py
├── bert-base-chinese-vocab.txt
├── bert-large-cased-whole-word-masking-vocab.txt
├── clue-vocab.txt
└── tokenization.py
└── train
├── __init__.py
├── dataloader.py
├── modeling.py
├── optimization_adafactor.py
├── train_tpu.py
├── train_tpu_adafactor.sh
└── utils.py
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a bug report to help us improve GPT2-ML
4 | title: "[Bug] name your bug"
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | > Below is the issue template. You can fill each part then submit your issue.
11 | > Or you can just delete all of these and describe your questions in you-like style.
12 | > But please remember: the more detailed info you offered, the greater possibility your problem will be solved. 😜
13 |
14 | Please write a clear and concise description of what the bug is.
15 |
16 | ## Expected behavior
17 |
18 | Please write a clear and concise description of what you expected to happen.
19 |
20 | ## Environment
21 |
22 | - Python version:
23 | - OS:
24 | - (Optional) Other libraries and their versions:
25 |
26 | ## Error messages, stack traces, or logs
27 |
28 | ```
29 | # error messages, stack traces, or logs
30 | ```
31 |
32 | ## Steps to reproduce
33 |
34 | 1.
35 | 2.
36 | 3.
37 |
38 | ## Reproducible examples (optional)
39 |
40 | ```python
41 | # python code
42 | ```
43 |
44 | ## Additional context (optional)
45 |
46 | Please add any other context or screenshots about the problem here.
47 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/discussion.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Discussion
3 | about: Ideas sharing or theorical question solving
4 | title: "[Discussion] your question"
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for GPT2-ML
4 | title: "[Feature] your feature name"
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | ## Description
11 |
12 |
13 |
14 | ## Additional information
15 |
16 |
17 |
--------------------------------------------------------------------------------
/.github/demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/imcaspar/gpt2-ml/f6286b16cbbee6dddbe1ba557fffb47eaf998cd1/.github/demo.png
--------------------------------------------------------------------------------
/.github/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/imcaspar/gpt2-ml/f6286b16cbbee6dddbe1ba557fffb47eaf998cd1/.github/logo.png
--------------------------------------------------------------------------------
/.github/logo.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.github/loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/imcaspar/gpt2-ml/f6286b16cbbee6dddbe1ba557fffb47eaf998cd1/.github/loss.png
--------------------------------------------------------------------------------
/.github/stale.yml:
--------------------------------------------------------------------------------
1 | # Number of days of inactivity before an issue becomes stale
2 | daysUntilStale: 365
3 | # Number of days of inactivity before a stale issue is closed
4 | daysUntilClose: 30
5 | # Issues with these labels will never be considered stale
6 | exemptLabels:
7 | - pinned
8 | - security
9 | # Label to use when marking an issue as stale
10 | staleLabel: wontfix
11 | # Comment to post when marking an issue as stale. Set to `false` to disable
12 | markComment: >
13 | This issue has been automatically marked as stale because it has not had
14 | recent activity. It will be closed if no further activity occurs. Thank you
15 | for your contributions.
16 | # Comment to post when closing a stale issue. Set to `false` to disable
17 | closeComment: true
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # vscode
132 | .vscode/
133 |
134 | # dataset
135 | dataset/raw/
136 |
137 | # models
138 | models/
--------------------------------------------------------------------------------
/3rd/gdown.pl/LICENSE.txt:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 |
635 | Copyright (C)
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | Copyright (C)
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
675 |
--------------------------------------------------------------------------------
/3rd/gdown.pl/README.md:
--------------------------------------------------------------------------------
1 | gdown.pl
2 | ========
3 |
4 | Google Drive direct download of big files
5 |
6 | Requirements
7 | ============
8 |
9 | *wget* and *Perl* must be in the PATH.
10 | **Windows** and **linux** compatible.
11 |
12 | Usage
13 | =====
14 |
15 | Use Google Drive shareable links, viewable by anyone:
16 |
17 | $ ./gdown.pl 'gdrive file url' ['desired file name']
18 |
19 | Example
20 | =======
21 |
22 | For example, to download [this video](https://drive.google.com/file/d/0B1L_hFrWJfRhLUJZdXdSdTdfSWs/edit) from my [axolotl project](https://circulosmeos.wordpress.com/2015/03/04/axolotl-a-simple-plain-text-documentation-system/), just copy the url, and give a file name if desired:
23 |
24 | $ ./gdown.pl https://drive.google.com/file/d/0B1L_hFrWJfRhLUJZdXdSdTdfSWs/edit axolotl.mp4
25 |
26 | Resuming a download
27 | ===================
28 |
29 | If you need to resume a download, please, use [**gdown.pl v2.0** here](https://github.com/circulosmeos/gdown.pl/tree/with-resume).
30 | As long as a file name is indicated as second parameter, *gdown.pl v2.0* **will try to resume the partially downloaded file** if a local incomplete file with that name already exists.
31 |
32 | Version
33 | =======
34 |
35 | This version is **v1.4**.
36 |
37 | ### Warning
38 |
39 | Please, note that v1.2 (available between days 12 to 31 of Jan/2019) **should not be used**, as it contains a bug that could result in unusable downloaded files. Proceed to overwrite with v1.3 in case you have it.
40 |
41 | Docker
42 | ======
43 |
44 | A simple Docker file is provided, to build a simple Docker image with gdown.pl.
45 | This has been used for pre-pulling data from a Google Drive to Kubernetes persistent volumes. Thanks @anton-khodak
46 |
47 | License
48 | =======
49 |
50 | Distributed [under GPL 3](http://www.gnu.org/licenses/gpl-3.0.html)
51 |
52 | Disclaimer
53 | ==========
54 |
55 | This software is provided "as is", without warranty of any kind, express or implied.
56 |
57 | More info
58 | =========
59 |
60 | [https://circulosmeos.wordpress.com/2014/04/12/google-drive-direct-download-of-big-files](https://circulosmeos.wordpress.com/2014/04/12/google-drive-direct-download-of-big-files)
61 |
62 | Contact
63 | =======
64 |
65 | by [circulosmeos](loopidle@gmail.com)
66 |
--------------------------------------------------------------------------------
/3rd/gdown.pl/gdown.pl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env perl
2 | #
3 | # Google Drive direct download of big files
4 | # ./gdown.pl 'gdrive file url' ['desired file name']
5 | #
6 | # v1.0 by circulosmeos 04-2014.
7 | # v1.1 by circulosmeos 01-2017.
8 | # v1.2, 2.0 by circulosmeos 01-2019.
9 | # v2.1 by circulosmeos 12-2020.
10 | # //circulosmeos.wordpress.com/2014/04/12/google-drive-direct-download-of-big-files
11 | # Distributed under GPL 3 (//www.gnu.org/licenses/gpl-3.0.html)
12 | #
13 | use strict;
14 | use POSIX;
15 |
16 | my $TEMP='gdown.cookie.temp';
17 | my $COMMAND;
18 | my $confirm;
19 | my $check;
20 | sub execute_command();
21 |
22 | my $URL=shift;
23 | die "\n./gdown.pl 'gdrive file url' [desired file name]\n\n" if $URL eq '';
24 |
25 | my $FILENAME=shift;
26 | my $TEMP_FILENAME='gdown.'.strftime("%Y%m%d%H%M%S", localtime).'.'.substr(rand,2);
27 |
28 | if ($URL=~m#^https?://drive.google.com/file/d/([^/]+)#) {
29 | $URL="https://docs.google.com/uc?id=$1&export=download";
30 | }
31 | elsif ($URL=~m#^https?://drive.google.com/open\?id=([^/]+)#) {
32 | $URL="https://docs.google.com/uc?id=$1&export=download";
33 | }
34 |
35 | execute_command();
36 |
37 | while (-s $TEMP_FILENAME < 100000) { # only if the file isn't the download yet
38 | open fFILENAME, '<', $TEMP_FILENAME;
39 | $check=0;
40 | foreach () {
41 | if (/href="(\/uc\?export=download[^"]+)/) {
42 | $URL='https://docs.google.com'.$1;
43 | $URL=~s/&/&/g;
44 | $confirm='';
45 | $check=1;
46 | last;
47 | }
48 | if (/confirm=([^;&]+)/) {
49 | $confirm=$1;
50 | $check=1;
51 | last;
52 | }
53 | if (/"downloadUrl":"([^"]+)/) {
54 | $URL=$1;
55 | $URL=~s/\\u003d/=/g;
56 | $URL=~s/\\u0026/&/g;
57 | $confirm='';
58 | $check=1;
59 | last;
60 | }
61 | }
62 | close fFILENAME;
63 | die "Couldn't download the file :-(\n" if ($check==0);
64 | $URL=~s/confirm=([^;&]+)/confirm=$confirm/ if $confirm ne '';
65 |
66 | execute_command();
67 |
68 | }
69 |
70 | unlink $TEMP;
71 |
72 | sub execute_command() {
73 | my $OUTPUT_FILENAME = $TEMP_FILENAME;
74 | my $CONTINUE = '';
75 |
76 | # check contents before download & if a $FILENAME has been indicated resume on content download
77 | # please, note that for this to work, wget must correctly provide --spider with --server-response (-S)
78 | if ( length($FILENAME) > 0 ) {
79 | $COMMAND="wget -q -S --no-check-certificate --spider --load-cookie $TEMP --save-cookie $TEMP \"$URL\" 2>&1";
80 | my @HEADERS=`$COMMAND`;
81 | foreach my $header (@HEADERS) {
82 | if ( ( $header =~ /Content-Type: (.+)/ && $1 !~ 'text/html' ) ||
83 | $header =~ 'HTTP/1.1 405 Method Not Allowed'
84 | ) {
85 | $OUTPUT_FILENAME = $FILENAME;
86 | $CONTINUE = '-c';
87 | last;
88 | }
89 | }
90 | }
91 |
92 | $COMMAND="wget $CONTINUE --progress=dot:giga --no-check-certificate --load-cookie $TEMP --save-cookie $TEMP \"$URL\"";
93 | $COMMAND.=" -O \"$OUTPUT_FILENAME\"";
94 | my $OUTPUT = system( $COMMAND );
95 | if ( $OUTPUT == 2 ) { # do a clean exit with Ctrl+C
96 | unlink $TEMP;
97 | die "\nDownloading interrupted by user\n\n";
98 | } elsif ( $OUTPUT == 0 && length($CONTINUE)>0 ) { # do a clean exit with $FILENAME provided
99 | unlink $TEMP;
100 | die "\nDownloading complete\n\n";
101 | }
102 | return 1;
103 | }
104 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # **GPT2** for Multiple Languages
4 |
5 | [](https://colab.research.google.com/github/imcaspar/gpt2-ml/blob/master/pretrained_model_demo.ipynb)
6 | [](https://github.com/imcaspar/gpt2-ml)
7 | [](https://github.com/imcaspar/gpt2-ml/releases)
8 | [](https://github.com/imcaspar/gpt2-ml/issues)
9 | [](https://github.com/imcaspar/gpt2-ml)
10 |
11 | [**中文说明**](./README_CN.md) | [**English**](./README.md)
12 |
13 | - [x] Simplifed GPT2 train scripts(based on Grover, supporting TPUs)
14 | - [x] Ported bert tokenizer, multilingual corpus compatible
15 | - [x] 1.5B GPT2 pretrained Chinese model ( ~15G corpus, 10w steps )
16 | - [x] Batteries-included Colab demo [#](https://github.com/imcaspar/gpt2-ml#google-colab)
17 | - [x] 1.5B GPT2 pretrained Chinese model ( ~30G corpus, 22w steps )
18 |
19 |
20 | ## Pretrained Model
21 | | Size | Language | Corpus | Vocab | Link1 | Link2 | SHA256 |
22 | | ---- | -------- | ------ | ----- | ----- | ----- | ------ |
23 | | 1.5B Params | Chinese | ~30G | CLUE ( 8021 tokens ) | [**Google Drive**](https://drive.google.com/file/d/1mT_qCQg4AWnAXTwKfsyyRWCRpgPrBJS3) | [**Baidu Pan (ffz6)**](https://pan.baidu.com/s/1yiuTHXUr2DpyBqmFYLJH6A) | e698cc97a7f5f706f84f58bb469d614e 51d3c0ce5f9ab9bf77e01e3fcb41d482 |
24 | | 1.5B Params | Chinese | ~15G | Bert ( 21128 tokens ) | [**Google Drive**](https://drive.google.com/file/d/1IzWpQ6I2IgfV7CldZvFJnZ9byNDZdO4n) | [**Baidu Pan (q9vr)**](https://pan.baidu.com/s/1TA_3e-u2bXg_hcx_NwVbGw) | 4a6e5124df8db7ac2bdd902e6191b807 a6983a7f5d09fb10ce011f9a073b183e |
25 |
26 | Corpus from [THUCNews](http://thuctc.thunlp.org/#%E4%B8%AD%E6%96%87%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB%E6%95%B0%E6%8D%AE%E9%9B%86THUCNews) and [nlp_chinese_corpus](https://github.com/brightmart/nlp_chinese_corpus)
27 |
28 | Using [Cloud TPU Pod v3-256](https://cloud.google.com/tpu/docs/types-zones#types) to train 22w steps
29 |
30 | 
31 |
32 |
33 | ## Google Colab
34 | With just 2 clicks (not including Colab auth process), the 1.5B pretrained Chinese model demo is ready to go:
35 |
36 | [**[Colab Notebook]**](https://colab.research.google.com/github/imcaspar/gpt2-ml/blob/master/pretrained_model_demo.ipynb)
37 |
38 |
39 |
40 | ## Train
41 |
42 | ## Disclaimer
43 | The contents in this repository are for academic research purpose, and we do not provide any conclusive remarks.
44 |
45 | ## Citation
46 |
47 | ```
48 | @misc{GPT2-ML,
49 | author = {Zhibo Zhang},
50 | title = {GPT2-ML: GPT-2 for Multiple Languages},
51 | year = {2019},
52 | publisher = {GitHub},
53 | journal = {GitHub repository},
54 | howpublished = {\url{https://github.com/imcaspar/gpt2-ml}},
55 | }
56 | ```
57 |
58 | ## Reference
59 | https://github.com/google-research/bert
60 |
61 | https://github.com/rowanz/grover
62 |
63 | Research supported with Cloud TPUs from Google's TensorFlow Research Cloud (TFRC)
64 |
65 | ## Press
66 | [[机器之心] 只需单击三次,让中文GPT-2为你生成定制故事](https://mp.weixin.qq.com/s/FpoSNNKZSQOE2diPvJDHog)
67 |
68 | [[科学空间] 现在可以用Keras玩中文GPT2了](https://kexue.fm/archives/7292)
69 |
--------------------------------------------------------------------------------
/README_CN.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # **GPT2** for Multiple Languages
4 |
5 | [](https://colab.research.google.com/github/imcaspar/gpt2-ml/blob/master/pretrained_model_demo.ipynb)
6 | [](https://github.com/imcaspar/gpt2-ml)
7 | [](https://github.com/imcaspar/gpt2-ml/releases)
8 | [](https://github.com/imcaspar/gpt2-ml/issues)
9 | [](https://github.com/imcaspar/gpt2-ml)
10 |
11 | [**中文说明**](./README_CN.md) | [**English**](./README.md)
12 |
13 | - [x] 简化整理 GPT2 训练代码(based on Grover, supporting TPUs)
14 | - [x] 移植 bert tokenizer,添加多语言支持
15 | - [x] 15亿参数 GPT2 中文预训练模型( 15G 语料,训练 10w 步 )
16 | - [x] 开箱即用的模型生成效果 demo [#](https://github.com/imcaspar/gpt2-ml#google-colab)
17 | - [x] 15亿参数 GPT2 中文预训练模型( 30G 语料,训练 22w 步 )
18 |
19 |
20 | ## 预训练模型
21 | | Size | Language | Corpus | Vocab | Link1 | Link2 | SHA256 |
22 | | ---- | -------- | ------ | ----- | ----- | ----- | ------ |
23 | | 1.5B Params | Chinese | ~30G | CLUE ( 8021 tokens ) | [**Google Drive**](https://drive.google.com/file/d/1mT_qCQg4AWnAXTwKfsyyRWCRpgPrBJS3) | [**Baidu Pan (ffz6)**](https://pan.baidu.com/s/1yiuTHXUr2DpyBqmFYLJH6A) | e698cc97a7f5f706f84f58bb469d614e 51d3c0ce5f9ab9bf77e01e3fcb41d482 |
24 | | 1.5B Params | Chinese | ~15G | Bert ( 21128 tokens ) | [**Google Drive**](https://drive.google.com/file/d/1IzWpQ6I2IgfV7CldZvFJnZ9byNDZdO4n) | [**Baidu Pan (q9vr)**](https://pan.baidu.com/s/1TA_3e-u2bXg_hcx_NwVbGw) | 4a6e5124df8db7ac2bdd902e6191b807 a6983a7f5d09fb10ce011f9a073b183e |
25 |
26 | 训练语料来自 [THUCNews](http://thuctc.thunlp.org/#%E4%B8%AD%E6%96%87%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB%E6%95%B0%E6%8D%AE%E9%9B%86THUCNews) 以及 [nlp_chinese_corpus](https://github.com/brightmart/nlp_chinese_corpus),清洗后总文本量约 15G
27 |
28 | 使用 [Cloud TPU Pod v3-256](https://cloud.google.com/tpu/docs/types-zones#types) 训练 22w 步
29 |
30 | 
31 |
32 |
33 | ## Google Colab
34 | 只需两次鼠标点击(不包括 Colab 授权流程),体验 15 亿参数中文预训练模型生成效果:
35 |
36 | [**[Colab Notebook]**](https://colab.research.google.com/github/imcaspar/gpt2-ml/blob/master/pretrained_model_demo.ipynb)
37 |
38 |
39 |
40 | ## 训练
41 |
42 | ## 免责声明
43 | 该项目中的内容仅供技术研究参考,不作为任何结论性依据。
44 |
45 | ## Citation
46 |
47 | ```
48 | @misc{GPT2-ML,
49 | author = {Zhibo Zhang},
50 | title = {GPT2-ML: GPT-2 for Multiple Languages},
51 | year = {2019},
52 | publisher = {GitHub},
53 | journal = {GitHub repository},
54 | howpublished = {\url{https://github.com/imcaspar/gpt2-ml}},
55 | }
56 | ```
57 |
58 | ## Reference
59 | https://github.com/google-research/bert
60 |
61 | https://github.com/rowanz/grover
62 |
63 | Research supported with Cloud TPUs from Google's TensorFlow Research Cloud (TFRC)
64 |
65 | ## Press
66 | [[机器之心] 只需单击三次,让中文GPT-2为你生成定制故事](https://mp.weixin.qq.com/s/FpoSNNKZSQOE2diPvJDHog)
67 |
68 | [[科学空间] 现在可以用Keras玩中文GPT2了](https://kexue.fm/archives/7292)
--------------------------------------------------------------------------------
/configs/base.json:
--------------------------------------------------------------------------------
1 | {
2 | "vocab_size": 50270,
3 | "hidden_size": 768,
4 | "attention_probs_dropout_prob": 0.1,
5 | "hidden_dropout_prob": 0.1,
6 | "hidden_act": "gelu",
7 | "initializer_range": 0.02,
8 | "intermediate_size": 3072,
9 | "max_position_embeddings": 1024,
10 | "num_attention_heads": 12,
11 | "num_hidden_layers": 12
12 | }
--------------------------------------------------------------------------------
/configs/large.json:
--------------------------------------------------------------------------------
1 | {
2 | "vocab_size": 8021,
3 | "hidden_size": 1024,
4 | "attention_probs_dropout_prob": 0.1,
5 | "hidden_dropout_prob": 0.1,
6 | "hidden_act": "gelu",
7 | "initializer_range": 0.02,
8 | "intermediate_size": 4096,
9 | "max_position_embeddings": 1024,
10 | "num_attention_heads": 16,
11 | "num_hidden_layers": 24
12 | }
--------------------------------------------------------------------------------
/configs/mega.json:
--------------------------------------------------------------------------------
1 | {
2 | "vocab_size": 8021,
3 | "hidden_size": 1536,
4 | "attention_probs_dropout_prob": 0.1,
5 | "hidden_dropout_prob": 0.1,
6 | "hidden_act": "gelu",
7 | "initializer_range": 0.014142135623731,
8 | "intermediate_size": 6144,
9 | "max_position_embeddings": 1024,
10 | "num_attention_heads": 24,
11 | "num_hidden_layers": 48
12 | }
13 |
--------------------------------------------------------------------------------
/dataset/README.md:
--------------------------------------------------------------------------------
1 | curl -o ~/miniconda.sh -O https://repo.anaconda.com/miniconda/Miniconda2-latest-Linux-x86_64.sh
2 | chmod +x ~/miniconda.sh
3 | ~/miniconda.sh -b -p ~/conda
4 | rm ~/miniconda.sh
5 | ~/conda/bin/conda install -y python=3.7
6 | ~/conda/bin/conda init // exit shell
7 |
8 | conda create -n gpt3 python=3.7
9 |
10 |
11 | sudo apt install parallel
12 | pip install ujson==2.0.3
13 |
14 | export PYTHONPATH=$(pwd) //project path
--------------------------------------------------------------------------------
/dataset/prepare_data.py:
--------------------------------------------------------------------------------
1 | """
2 | Turn a merged corpus into tfrecord files.
3 |
4 | NOTE: You will want to do this using several processes. I did this on an AWS machine with 72 CPUs using GNU parallel
5 | as that's where I had the deduplicated RealNews dataset.
6 | """
7 | import argparse
8 | import ujson as json
9 | # from sample.encoder import get_encoder, tokenize_for_grover_training, detokenize, sliding_window, create_int_feature
10 | import random
11 | import tensorflow.compat.v1 as tf
12 | import collections
13 | import os
14 | from tempfile import TemporaryDirectory
15 |
16 | from tokenization import tokenization
17 |
18 | parser = argparse.ArgumentParser(description='SCRAPE!')
19 | parser.add_argument(
20 | '-fold',
21 | dest='fold',
22 | default=0,
23 | type=int,
24 | help='which fold we are on'
25 | )
26 | parser.add_argument(
27 | '-num_folds',
28 | dest='num_folds',
29 | default=1,
30 | type=int,
31 | help='Number of folds (corresponding to both the number of training files and the number of testing files)',
32 | )
33 | parser.add_argument(
34 | '-seed',
35 | dest='seed',
36 | default=1337,
37 | type=int,
38 | help='which seed to use'
39 | )
40 | parser.add_argument(
41 | '-base_fn',
42 | dest='base_fn',
43 | default='news2016zh_',
44 | type=str,
45 | help='We will output files that are like {base_fn}_{n}.tfrecord for n in 0, ..., 1023'
46 | )
47 |
48 | parser.add_argument(
49 | '-input_fn',
50 | dest='input_fn',
51 | default='realnews.jsonl',
52 | type=str,
53 | help='Base filename to use. THIS MUST BE A LOCAL FILE.'
54 | )
55 | parser.add_argument(
56 | '-max_seq_length',
57 | dest='max_seq_length',
58 | default=1024,
59 | type=int,
60 | help='Max sequence length',
61 | )
62 |
63 |
64 | args = parser.parse_args()
65 | random.seed(args.seed + args.fold)
66 |
67 | #encoder = get_encoder()
68 | tokenizer = tokenization.FullTokenizer(
69 | vocab_file="clue-vocab.txt", do_lower_case=True)
70 |
71 |
72 | class TFRecordWriter(object):
73 | def __init__(self, fn):
74 | self.fn = fn
75 | if fn.startswith('gs://'):
76 | from google.cloud import storage
77 | self.s3client = None
78 | self.gclient = storage.Client()
79 | self.storage_dir = TemporaryDirectory()
80 | self.writer = tf.python_io.TFRecordWriter(
81 | os.path.join(self.storage_dir.name, 'temp.tfrecord'))
82 | self.bucket_name, self.file_name = self.fn.split(
83 | 'gs://', 1)[1].split('/', 1)
84 |
85 | else:
86 | self.s3client = None
87 | self.gclient = None
88 | self.bucket_name = None
89 | self.file_name = None
90 | self.storage_dir = None
91 | self.writer = tf.python_io.TFRecordWriter(fn)
92 |
93 | def write(self, x):
94 | self.writer.write(x)
95 |
96 | def close(self):
97 | self.writer.close()
98 |
99 | if self.gclient is not None:
100 | bucket = self.gclient.get_bucket(self.bucket_name)
101 | blob = bucket.blob(self.file_name)
102 | blob.upload_from_filename(os.path.join(
103 | self.storage_dir.name, 'temp.tfrecord'))
104 | self.storage_dir.cleanup()
105 |
106 | def __enter__(self):
107 | # Called when entering "with" context.
108 | return self
109 |
110 | def __exit__(self, *_):
111 | # Called when exiting "with" context.
112 | # Upload shit
113 | print("CALLING CLOSE")
114 | self.close()
115 |
116 |
117 | def article_iterator(tokenizer):
118 | """ Iterate through the provided filename + tokenize"""
119 | assert os.path.exists(args.input_fn)
120 | for (dirpath, dirnames, filenames) in os.walk(args.input_fn):
121 | for filename in filenames:
122 | with open(os.path.join(dirpath, filename), 'r') as f:
123 | for l_no, l in enumerate(f):
124 | if l_no % args.num_folds == args.fold:
125 | article = json.loads(l)
126 |
127 | line = tokenization.convert_to_unicode(
128 | article['text']) # for news2016zh text body
129 | tokens = tokenizer.tokenize(line)
130 | input_ids = tokenizer.convert_tokens_to_ids(tokens)
131 |
132 | article['input_ids'] = input_ids
133 |
134 | article['inst_index'] = (l_no // args.num_folds)
135 | if article['inst_index'] < 100:
136 | print('---\nINPUT{}. {}\n---\nTokens: {}\n'.format(article['inst_index'],
137 | tokens,
138 | input_ids
139 | ), flush=True)
140 | if len(article['input_ids']) <= 64: # min size of article
141 | continue
142 | yield article
143 |
144 |
145 | def create_int_feature(values):
146 | feature = tf.train.Feature(
147 | int64_list=tf.train.Int64List(value=list(values)))
148 | return feature
149 |
150 |
151 | def buffered_and_sliding_window_article_iterator(tokenizer, final_desired_size=1025):
152 | """ We apply a sliding window to fix long sequences, and use a buffer that combines short sequences."""
153 | for article in article_iterator(tokenizer):
154 | if len(article['input_ids']) >= final_desired_size:
155 | article['input_ids'] = article['input_ids'][0:final_desired_size-1]
156 | while len(article['input_ids']) < final_desired_size:
157 | article['input_ids'].append(0)
158 | yield article
159 |
160 |
161 | # OK now write the tfrecord file
162 | total_written = 0
163 | train_file = args.base_fn + 'train_wiki19_{:04d}.tfrecord'.format(args.fold)
164 | with TFRecordWriter(train_file) as train_writer:
165 | for article in buffered_and_sliding_window_article_iterator(tokenizer,
166 | final_desired_size=args.max_seq_length + 1):
167 | writer2use = train_writer
168 | assert len(article['input_ids']) == (args.max_seq_length + 1)
169 |
170 | features = collections.OrderedDict()
171 | features["input_ids"] = create_int_feature(article['input_ids'])
172 | tf_example = tf.train.Example(
173 | features=tf.train.Features(feature=features))
174 |
175 | writer2use.write(tf_example.SerializeToString())
176 | total_written += 1
177 |
178 | # DEBUG
179 | if article['inst_index'] < 5:
180 | print("~~~\nIndex {}. ARTICLE: {}\n---\nTokens: {}\n\n".format(article['inst_index'],
181 | tokenizer.convert_ids_to_tokens(
182 | article['input_ids']),
183 | article['input_ids']
184 | ), flush=True)
185 | if article['inst_index'] % 1000 == 0:
186 | print("{} articles, {} written".format(
187 | article['inst_index'], total_written), flush=True)
188 | print("DONE UPLOADING", flush=True)
189 |
--------------------------------------------------------------------------------
/dataset/prepare_data.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | NUM_FOLDS=1024
4 | MAX_SEQ_LENGTH=1024
5 | FN=${1}
6 | OUT_BUCKET=${2}
7 |
8 | rm -rf logs_${MAX_SEQ_LENGTH}
9 | mkdir logs_${MAX_SEQ_LENGTH}
10 | parallel -j $(nproc --all) --will-cite "python prepare_data.py -fold {1} -num_folds ${NUM_FOLDS} -base_fn gs://${OUT_BUCKET}/data_${MAX_SEQ_LENGTH}/ -input_fn ${FN} -max_seq_length ${MAX_SEQ_LENGTH} > logs_${MAX_SEQ_LENGTH}/log{1}.txt" ::: $(seq 0 $((${NUM_FOLDS}-1)))
11 |
--------------------------------------------------------------------------------
/dockerfiles/gpu-jupyter.Dockerfile:
--------------------------------------------------------------------------------
1 | FROM tensorflow/tensorflow:1.15.0-gpu-py3-jupyter
2 |
3 | RUN apt update && apt install -y --no-install-recommends git
4 | RUN git clone -q https://github.com/imcaspar/gpt2-ml && mkdir -p gpt2-ml/models/mega
5 |
6 | WORKDIR /gpt2-ml
7 |
8 | RUN perl 3rd/gdown.pl/gdown.pl https://drive.google.com/open?id=1n_5-tgPpQ1gqbyLPbP1PwiFi2eo7SWw_ models/mega/model.ckpt-100000.data-00000-of-00001
9 | RUN wget -q --show-progress https://github.com/imcaspar/gpt2-ml/releases/download/v0.5/model.ckpt-100000.index -P models/mega
10 | RUN wget -q --show-progress https://github.com/imcaspar/gpt2-ml/releases/download/v0.5/model.ckpt-100000.meta -P models/mega
11 |
12 | CMD ["bash", "-c", "jupyter notebook --ip 0.0.0.0 --no-browser --allow-root pretrained_model_demo.ipynb"]
--------------------------------------------------------------------------------
/pretrained_model_demo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "[](https://github.com/imcaspar/gpt2-ml)\n",
8 | "[](https://github.com/imcaspar/gpt2-ml)\n",
9 | "[](https://twitter.com/intent/tweet?text=Wow:&url=https://github.com/imcaspar/gpt2-ml)\n",
10 | "### Instructions for running:\n",
11 | "\n",
12 | "* Press the ▶️button on the left of each of the cells\n",
13 | "* View the code: Double click any of the cells\n",
14 | "* Hide the code: Double click the right side of the cell"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": null,
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "#@title #Prerequisites\n",
24 | "#%tensorflow_version 1.x\n",
25 | "!pip install -I tensorflow-gpu==1.15.4 &> tmp.log\n",
26 | "!git clone -q https://github.com/imcaspar/gpt2-ml\n",
27 | "%cd /content/gpt2-ml\n",
28 | "!mkdir -p /content/gpt2-ml/models/mega\n",
29 | "\n",
30 | "!perl 3rd/gdown.pl/gdown.pl https://drive.google.com/open?id=1mT_qCQg4AWnAXTwKfsyyRWCRpgPrBJS3 models/mega/model.ckpt-220000.data-00000-of-00001\n",
31 | "!wget -q --show-progress https://github.com/imcaspar/gpt2-ml/releases/download/v1.0/model.ckpt-220000.index -P models/mega\n",
32 | "!wget -q --show-progress https://github.com/imcaspar/gpt2-ml/releases/download/v1.0/model.ckpt-220000.meta -P models/mega\n",
33 | "!echo 'Download finished.🍺'\n",
34 | "# If gdown.pl failed, please uncomment following code & exec\n",
35 | "# !python scripts/down_gdrive_file.py -file_id='1mT_qCQg4AWnAXTwKfsyyRWCRpgPrBJS3' -file_path='models/mega/model.ckpt-220000.data-00000-of-00001'"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": null,
41 | "metadata": {},
42 | "outputs": [],
43 | "source": [
44 | "#@title #Inference\n",
45 | "min_len = 150#@param {type:\"number\", min:5, max:1024, step:1}\n",
46 | "sp_num = 5#@param {type:\"number\", min:1, max:50, step:1}\n",
47 | "!PYTHONPATH=$(pwd) python scripts/demo.py -ckpt_fn models/mega/model.ckpt-220000 -min_len $min_len -samples $sp_num"
48 | ]
49 | }
50 | ],
51 | "metadata": {
52 | "colab": {
53 | "name": "15 亿参数 GPT2 中文预训练模型 | 1.5B GPT2 Pretrained Chinese Model",
54 | "provenance": [],
55 | "collapsed_sections": []
56 | },
57 | "kernelspec": {
58 | "name": "python3",
59 | "display_name": "Python 3"
60 | },
61 | "accelerator": "GPU"
62 | },
63 | "nbformat": 4,
64 | "nbformat_minor": 0
65 | }
--------------------------------------------------------------------------------
/requirements-gpu.txt:
--------------------------------------------------------------------------------
1 | pandas==0.24.2
2 | regex==2019.4.14
3 | h5py==2.10.0
4 | numpy==1.18.4
5 | tensorboard==1.15.0
6 | tensorflow-gpu==1.15.4
7 | tensorflow-estimator==1.15.1
8 | tqdm==4.31.1
9 | requests==2.22.0
10 | ujson==2.0.3
--------------------------------------------------------------------------------
/requirements-tpu.txt:
--------------------------------------------------------------------------------
1 | pandas==0.24.2
2 | regex==2019.4.14
3 | h5py==2.10.0
4 | numpy==1.18.4
5 | tensorboard==1.15.0
6 | tensorflow==1.15.4
7 | tensorflow-estimator==1.15.1
8 | tqdm==4.31.1
9 | requests==2.22.0
10 | ujson==2.0.3
--------------------------------------------------------------------------------
/scripts/demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import argparse
4 | import json
5 | import re
6 |
7 | import tensorflow.compat.v1 as tf
8 | import numpy as np
9 |
10 | from train.modeling import GroverModel, GroverConfig, sample
11 | from tokenization import tokenization
12 |
13 | ##### ignore tf deprecated warning temporarily
14 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
15 | tf.logging.set_verbosity(tf.logging.DEBUG)
16 | from tensorflow.python.util import deprecation
17 | deprecation._PRINT_DEPRECATION_WARNINGS = False
18 | try:
19 | from tensorflow.python.util import module_wrapper as deprecation
20 | except ImportError:
21 | from tensorflow.python.util import deprecation_wrapper as deprecation
22 | deprecation._PER_MODULE_WARNING_LIMIT = 0
23 | #####
24 |
25 | parser = argparse.ArgumentParser(description='Contextual generation (aka given some metadata we will generate articles')
26 | parser.add_argument(
27 | '-metadata_fn',
28 | dest='metadata_fn',
29 | type=str,
30 | help='Path to a JSONL containing metadata',
31 | )
32 | parser.add_argument(
33 | '-out_fn',
34 | dest='out_fn',
35 | type=str,
36 | help='Out jsonl, which will contain the completed jsons',
37 | )
38 | parser.add_argument(
39 | '-input',
40 | dest='input',
41 | type=str,
42 | help='Text to complete',
43 | )
44 | parser.add_argument(
45 | '-config_fn',
46 | dest='config_fn',
47 | default='configs/mega.json',
48 | type=str,
49 | help='Configuration JSON for the model',
50 | )
51 | parser.add_argument(
52 | '-ckpt_fn',
53 | dest='ckpt_fn',
54 | default='../models/mega/model.ckpt',
55 | type=str,
56 | help='checkpoint file for the model',
57 | )
58 | parser.add_argument(
59 | '-target',
60 | dest='target',
61 | default='article',
62 | type=str,
63 | help='What to generate for each item in metadata_fn. can be article (body), title, etc.',
64 | )
65 | parser.add_argument(
66 | '-batch_size',
67 | dest='batch_size',
68 | default=1,
69 | type=int,
70 | help='How many things to generate per context. will split into chunks if need be',
71 | )
72 | parser.add_argument(
73 | '-num_folds',
74 | dest='num_folds',
75 | default=1,
76 | type=int,
77 | help='Number of folds. useful if we want to split up a big file into multiple jobs.',
78 | )
79 | parser.add_argument(
80 | '-fold',
81 | dest='fold',
82 | default=0,
83 | type=int,
84 | help='which fold we are on. useful if we want to split up a big file into multiple jobs.'
85 | )
86 | parser.add_argument(
87 | '-max_batch_size',
88 | dest='max_batch_size',
89 | default=None,
90 | type=int,
91 | help='max batch size. You can leave this out and we will infer one based on the number of hidden layers',
92 | )
93 | parser.add_argument(
94 | '-top_p',
95 | dest='top_p',
96 | default=0.95,
97 | type=float,
98 | help='p to use for top p sampling. if this isn\'t none, use this for everthing'
99 | )
100 | parser.add_argument(
101 | '-min_len',
102 | dest='min_len',
103 | default=1024,
104 | type=int,
105 | help='min length of sample',
106 | )
107 | parser.add_argument(
108 | '-eos_token',
109 | dest='eos_token',
110 | default=102,
111 | type=int,
112 | help='eos token id',
113 | )
114 | parser.add_argument(
115 | '-samples',
116 | dest='samples',
117 | default=5,
118 | type=int,
119 | help='num_samples',
120 | )
121 |
122 | def extract_generated_target(output_tokens, tokenizer):
123 | """
124 | Given some tokens that were generated, extract the target
125 | :param output_tokens: [num_tokens] thing that was generated
126 | :param encoder: how they were encoded
127 | :param target: the piece of metadata we wanted to generate!
128 | :return:
129 | """
130 | # Filter out first instance of start token
131 | assert output_tokens.ndim == 1
132 |
133 | start_ind = 0
134 | end_ind = output_tokens.shape[0]
135 |
136 | return {
137 | 'extraction': tokenization.printable_text(''.join(tokenizer.convert_ids_to_tokens(output_tokens))),
138 | 'start_ind': start_ind,
139 | 'end_ind': end_ind,
140 | }
141 |
142 | args = parser.parse_args()
143 | proj_root_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
144 | vocab_file_path = os.path.join(proj_root_path, "tokenization/clue-vocab.txt")
145 | tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file_path , do_lower_case=True)
146 | news_config = GroverConfig.from_json_file(args.config_fn)
147 |
148 | # We might have to split the batch into multiple chunks if the batch size is too large
149 | default_mbs = {12: 32, 24: 16, 48: 3}
150 | max_batch_size = args.max_batch_size if args.max_batch_size is not None else default_mbs[news_config.num_hidden_layers]
151 |
152 | # factorize args.batch_size = (num_chunks * batch_size_per_chunk) s.t. batch_size_per_chunk < max_batch_size
153 | num_chunks = int(np.ceil(args.batch_size / max_batch_size))
154 | batch_size_per_chunk = int(np.ceil(args.batch_size / num_chunks))
155 |
156 | # This controls the top p for each generation.
157 | top_p = np.ones((num_chunks, batch_size_per_chunk), dtype=np.float32) * args.top_p
158 |
159 | tf_config = tf.ConfigProto(allow_soft_placement=True)
160 |
161 | with tf.Session(config=tf_config, graph=tf.Graph()) as sess:
162 | initial_context = tf.placeholder(tf.int32, [batch_size_per_chunk, None])
163 | p_for_topp = tf.placeholder(tf.float32, [batch_size_per_chunk])
164 | eos_token = tf.placeholder(tf.int32, [])
165 | min_len = tf.placeholder(tf.int32, [])
166 | tokens, probs = sample(news_config=news_config, initial_context=initial_context,
167 | eos_token=eos_token, min_len=min_len, ignore_ids=None, p_for_topp=p_for_topp,
168 | do_topk=False)
169 |
170 | saver = tf.train.Saver()
171 | saver.restore(sess, args.ckpt_fn)
172 | print('🍺Model loaded. \nInput something please:⬇️')
173 | text = input()
174 | while text != "":
175 | for i in range(args.samples):
176 | print("Sample,", i + 1, " of ", args.samples)
177 | line = tokenization.convert_to_unicode(text)
178 | bert_tokens = tokenizer.tokenize(line)
179 | encoded = tokenizer.convert_tokens_to_ids(bert_tokens)
180 | context_formatted = []
181 | context_formatted.extend(encoded)
182 | # Format context end
183 |
184 | gens = []
185 | gens_raw = []
186 | gen_probs = []
187 |
188 | for chunk_i in range(num_chunks):
189 | tokens_out, probs_out = sess.run([tokens, probs],
190 | feed_dict={initial_context: [context_formatted] * batch_size_per_chunk,
191 | eos_token: args.eos_token, min_len: args.min_len,
192 | p_for_topp: top_p[chunk_i]})
193 |
194 | for t_i, p_i in zip(tokens_out, probs_out):
195 | extraction = extract_generated_target(output_tokens=t_i, tokenizer=tokenizer)
196 | gens.append(extraction['extraction'])
197 |
198 | l = re.findall('.{1,70}', gens[0].replace('[UNK]', '').replace('##', ''))
199 | print("\n".join(l))
200 | print('Next try:⬇️')
201 | text = input()
202 |
--------------------------------------------------------------------------------
/scripts/down_gdrive_file.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from google.colab import auth
4 | from googleapiclient.discovery import build
5 | from apiclient.http import MediaIoBaseDownload
6 | from tqdm import tqdm
7 |
8 | parser = argparse.ArgumentParser(description='Simple file download script for Google Drive')
9 | parser.add_argument(
10 | '-file_id',
11 | dest='file_id',
12 | type=str,
13 | help='File id in Google Drive URL',
14 | )
15 | parser.add_argument(
16 | '-file_path',
17 | dest='file_path',
18 | type=str,
19 | help='Output file path',
20 | )
21 |
22 | args = parser.parse_args()
23 |
24 | auth.authenticate_user()
25 | drive_service = build('drive', 'v3')
26 |
27 | # file_id, file_ext = ('1n_5-tgPpQ1gqbyLPbP1PwiFi2eo7SWw_', '.data-00000-of-00001')
28 | # filename = '%s/model.ckpt-%d%s' % (model_dir, 100000, file_ext)
29 | req = drive_service.files().get_media(fileId=args.file_id)
30 | with open(args.file_path, 'wb') as f:
31 | downloader = MediaIoBaseDownload(f, req, chunksize=100*1024*1024)
32 | done = False
33 | pbar = tqdm(total=100, desc='%s' % args.file_path)
34 | progress = 0
35 | while done is False:
36 | status, done = downloader.next_chunk()
37 | new_progress = int(status.progress() * 100)
38 | pbar.update(new_progress - progress)
39 | progress = new_progress
40 | pbar.close()
41 |
--------------------------------------------------------------------------------
/tokenization/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/imcaspar/gpt2-ml/f6286b16cbbee6dddbe1ba557fffb47eaf998cd1/tokenization/__init__.py
--------------------------------------------------------------------------------
/tokenization/tokenization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import re
23 | import unicodedata
24 | import six
25 | import tensorflow.compat.v1 as tf
26 |
27 |
28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
29 | """Checks whether the casing config is consistent with the checkpoint name."""
30 |
31 | # The casing has to be passed in by the user and there is no explicit check
32 | # as to whether it matches the checkpoint. The casing information probably
33 | # should have been stored in the bert_config.json file, but it's not, so
34 | # we have to heuristically detect it to validate.
35 |
36 | if not init_checkpoint:
37 | return
38 |
39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
40 | if m is None:
41 | return
42 |
43 | model_name = m.group(1)
44 |
45 | lower_models = [
46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
48 | ]
49 |
50 | cased_models = [
51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
52 | "multi_cased_L-12_H-768_A-12"
53 | ]
54 |
55 | is_bad_config = False
56 | if model_name in lower_models and not do_lower_case:
57 | is_bad_config = True
58 | actual_flag = "False"
59 | case_name = "lowercased"
60 | opposite_flag = "True"
61 |
62 | if model_name in cased_models and do_lower_case:
63 | is_bad_config = True
64 | actual_flag = "True"
65 | case_name = "cased"
66 | opposite_flag = "False"
67 |
68 | if is_bad_config:
69 | raise ValueError(
70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
71 | "However, `%s` seems to be a %s model, so you "
72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches "
73 | "how the model was pre-training. If this error is wrong, please "
74 | "just comment out this check." % (actual_flag, init_checkpoint,
75 | model_name, case_name, opposite_flag))
76 |
77 |
78 | def convert_to_unicode(text):
79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
80 | if six.PY3:
81 | if isinstance(text, str):
82 | return text
83 | elif isinstance(text, bytes):
84 | return text.decode("utf-8", "ignore")
85 | else:
86 | raise ValueError("Unsupported string type: %s" % (type(text)))
87 | elif six.PY2:
88 | if isinstance(text, str):
89 | return text.decode("utf-8", "ignore")
90 | elif isinstance(text, unicode):
91 | return text
92 | else:
93 | raise ValueError("Unsupported string type: %s" % (type(text)))
94 | else:
95 | raise ValueError("Not running on Python2 or Python 3?")
96 |
97 |
98 | def printable_text(text):
99 | """Returns text encoded in a way suitable for print or `tf.logging`."""
100 |
101 | # These functions want `str` for both Python2 and Python3, but in one case
102 | # it's a Unicode string and in the other it's a byte string.
103 | if six.PY3:
104 | if isinstance(text, str):
105 | return text
106 | elif isinstance(text, bytes):
107 | return text.decode("utf-8", "ignore")
108 | else:
109 | raise ValueError("Unsupported string type: %s" % (type(text)))
110 | elif six.PY2:
111 | if isinstance(text, str):
112 | return text
113 | elif isinstance(text, unicode):
114 | return text.encode("utf-8")
115 | else:
116 | raise ValueError("Unsupported string type: %s" % (type(text)))
117 | else:
118 | raise ValueError("Not running on Python2 or Python 3?")
119 |
120 |
121 | def load_vocab(vocab_file):
122 | """Loads a vocabulary file into a dictionary."""
123 | vocab = collections.OrderedDict()
124 | index = 0
125 | with tf.gfile.GFile(vocab_file, "r") as reader:
126 | while True:
127 | token = convert_to_unicode(reader.readline())
128 | if not token:
129 | break
130 | token = token.strip()
131 | vocab[token] = index
132 | index += 1
133 | return vocab
134 |
135 |
136 | def convert_by_vocab(vocab, items):
137 | """Converts a sequence of [tokens|ids] using the vocab."""
138 | output = []
139 | for item in items:
140 | output.append(vocab[item])
141 | return output
142 |
143 |
144 | def convert_tokens_to_ids(vocab, tokens):
145 | return convert_by_vocab(vocab, tokens)
146 |
147 |
148 | def convert_ids_to_tokens(inv_vocab, ids):
149 | return convert_by_vocab(inv_vocab, ids)
150 |
151 |
152 | def whitespace_tokenize(text):
153 | """Runs basic whitespace cleaning and splitting on a piece of text."""
154 | text = text.strip()
155 | if not text:
156 | return []
157 | tokens = text.split()
158 | return tokens
159 |
160 |
161 | class FullTokenizer(object):
162 | """Runs end-to-end tokenziation."""
163 |
164 | def __init__(self, vocab_file, do_lower_case=True):
165 | self.vocab = load_vocab(vocab_file)
166 | self.inv_vocab = {v: k for k, v in self.vocab.items()}
167 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
168 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
169 |
170 | def tokenize(self, text):
171 | split_tokens = []
172 | for token in self.basic_tokenizer.tokenize(text):
173 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
174 | split_tokens.append(sub_token)
175 |
176 | return split_tokens
177 |
178 | def convert_tokens_to_ids(self, tokens):
179 | return convert_by_vocab(self.vocab, tokens)
180 |
181 | def convert_ids_to_tokens(self, ids):
182 | return convert_by_vocab(self.inv_vocab, ids)
183 |
184 |
185 | class BasicTokenizer(object):
186 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
187 |
188 | def __init__(self, do_lower_case=True):
189 | """Constructs a BasicTokenizer.
190 |
191 | Args:
192 | do_lower_case: Whether to lower case the input.
193 | """
194 | self.do_lower_case = do_lower_case
195 |
196 | def tokenize(self, text):
197 | """Tokenizes a piece of text."""
198 | text = convert_to_unicode(text)
199 | text = self._clean_text(text)
200 |
201 | # This was added on November 1st, 2018 for the multilingual and Chinese
202 | # models. This is also applied to the English models now, but it doesn't
203 | # matter since the English models were not trained on any Chinese data
204 | # and generally don't have any Chinese data in them (there are Chinese
205 | # characters in the vocabulary because Wikipedia does have some Chinese
206 | # words in the English Wikipedia.).
207 | text = self._tokenize_chinese_chars(text)
208 |
209 | orig_tokens = whitespace_tokenize(text)
210 | split_tokens = []
211 | for token in orig_tokens:
212 | if self.do_lower_case:
213 | token = token.lower()
214 | token = self._run_strip_accents(token)
215 | split_tokens.extend(self._run_split_on_punc(token))
216 |
217 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
218 | return output_tokens
219 |
220 | def _run_strip_accents(self, text):
221 | """Strips accents from a piece of text."""
222 | text = unicodedata.normalize("NFD", text)
223 | output = []
224 | for char in text:
225 | cat = unicodedata.category(char)
226 | if cat == "Mn":
227 | continue
228 | output.append(char)
229 | return "".join(output)
230 |
231 | def _run_split_on_punc(self, text):
232 | """Splits punctuation on a piece of text."""
233 | chars = list(text)
234 | i = 0
235 | start_new_word = True
236 | output = []
237 | while i < len(chars):
238 | char = chars[i]
239 | if _is_punctuation(char):
240 | output.append([char])
241 | start_new_word = True
242 | else:
243 | if start_new_word:
244 | output.append([])
245 | start_new_word = False
246 | output[-1].append(char)
247 | i += 1
248 |
249 | return ["".join(x) for x in output]
250 |
251 | def _tokenize_chinese_chars(self, text):
252 | """Adds whitespace around any CJK character."""
253 | output = []
254 | for char in text:
255 | cp = ord(char)
256 | if self._is_chinese_char(cp):
257 | output.append(" ")
258 | output.append(char)
259 | output.append(" ")
260 | else:
261 | output.append(char)
262 | return "".join(output)
263 |
264 | def _is_chinese_char(self, cp):
265 | """Checks whether CP is the codepoint of a CJK character."""
266 | # This defines a "chinese character" as anything in the CJK Unicode block:
267 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
268 | #
269 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
270 | # despite its name. The modern Korean Hangul alphabet is a different block,
271 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
272 | # space-separated words, so they are not treated specially and handled
273 | # like the all of the other languages.
274 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
275 | (cp >= 0x3400 and cp <= 0x4DBF) or #
276 | (cp >= 0x20000 and cp <= 0x2A6DF) or #
277 | (cp >= 0x2A700 and cp <= 0x2B73F) or #
278 | (cp >= 0x2B740 and cp <= 0x2B81F) or #
279 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
280 | (cp >= 0xF900 and cp <= 0xFAFF) or #
281 | (cp >= 0x2F800 and cp <= 0x2FA1F)): #
282 | return True
283 |
284 | return False
285 |
286 | def _clean_text(self, text):
287 | """Performs invalid character removal and whitespace cleanup on text."""
288 | output = []
289 | for char in text:
290 | cp = ord(char)
291 | if cp == 0 or cp == 0xfffd or _is_control(char):
292 | continue
293 | if _is_whitespace(char):
294 | output.append(" ")
295 | else:
296 | output.append(char)
297 | return "".join(output)
298 |
299 |
300 | class WordpieceTokenizer(object):
301 | """Runs WordPiece tokenziation."""
302 |
303 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
304 | self.vocab = vocab
305 | self.unk_token = unk_token
306 | self.max_input_chars_per_word = max_input_chars_per_word
307 |
308 | def tokenize(self, text):
309 | """Tokenizes a piece of text into its word pieces.
310 |
311 | This uses a greedy longest-match-first algorithm to perform tokenization
312 | using the given vocabulary.
313 |
314 | For example:
315 | input = "unaffable"
316 | output = ["un", "##aff", "##able"]
317 |
318 | Args:
319 | text: A single token or whitespace separated tokens. This should have
320 | already been passed through `BasicTokenizer.
321 |
322 | Returns:
323 | A list of wordpiece tokens.
324 | """
325 |
326 | text = convert_to_unicode(text)
327 |
328 | output_tokens = []
329 | for token in whitespace_tokenize(text):
330 | chars = list(token)
331 | if len(chars) > self.max_input_chars_per_word:
332 | output_tokens.append(self.unk_token)
333 | continue
334 |
335 | is_bad = False
336 | start = 0
337 | sub_tokens = []
338 | while start < len(chars):
339 | end = len(chars)
340 | cur_substr = None
341 | while start < end:
342 | substr = "".join(chars[start:end])
343 | if start > 0:
344 | substr = "##" + substr
345 | if substr in self.vocab:
346 | cur_substr = substr
347 | break
348 | end -= 1
349 | if cur_substr is None:
350 | is_bad = True
351 | break
352 | sub_tokens.append(cur_substr)
353 | start = end
354 |
355 | if is_bad:
356 | output_tokens.append(self.unk_token)
357 | else:
358 | output_tokens.extend(sub_tokens)
359 | return output_tokens
360 |
361 |
362 | def _is_whitespace(char):
363 | """Checks whether `chars` is a whitespace character."""
364 | # \t, \n, and \r are technically contorl characters but we treat them
365 | # as whitespace since they are generally considered as such.
366 | if char == " " or char == "\t" or char == "\n" or char == "\r":
367 | return True
368 | cat = unicodedata.category(char)
369 | if cat == "Zs":
370 | return True
371 | return False
372 |
373 |
374 | def _is_control(char):
375 | """Checks whether `chars` is a control character."""
376 | # These are technically control characters but we count them as whitespace
377 | # characters.
378 | if char == "\t" or char == "\n" or char == "\r":
379 | return False
380 | cat = unicodedata.category(char)
381 | if cat in ("Cc", "Cf"):
382 | return True
383 | return False
384 |
385 |
386 | def _is_punctuation(char):
387 | """Checks whether `chars` is a punctuation character."""
388 | cp = ord(char)
389 | # We treat all non-letter/number ASCII as punctuation.
390 | # Characters such as "^", "$", and "`" are not in the Unicode
391 | # Punctuation class but we treat them as punctuation anyways, for
392 | # consistency.
393 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
394 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
395 | return True
396 | cat = unicodedata.category(char)
397 | if cat.startswith("P"):
398 | return True
399 | return False
400 |
--------------------------------------------------------------------------------
/train/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/imcaspar/gpt2-ml/f6286b16cbbee6dddbe1ba557fffb47eaf998cd1/train/__init__.py
--------------------------------------------------------------------------------
/train/dataloader.py:
--------------------------------------------------------------------------------
1 | # Original work Copyright 2018 The Google AI Language Team Authors.
2 | # Modified work Copyright 2019 Rowan Zellers
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import collections
17 | import tensorflow.compat.v1 as tf
18 |
19 |
20 | def _decode_record(record, name_to_features):
21 | """Decodes a record to a TensorFlow example."""
22 | example = tf.parse_single_example(record, name_to_features)
23 |
24 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
25 | # So cast all int64 to int32.
26 | for name in list(example.keys()):
27 | t = example[name]
28 | if t.dtype == tf.int64:
29 | t = tf.cast(t, tf.int32)
30 | example[name] = t
31 | return example
32 |
33 |
34 | def input_fn_builder(input_files,
35 | seq_length,
36 | is_training,
37 | num_cpu_threads=4,
38 | evaluate_for_fixed_number_of_steps=True):
39 | """Creates an `input_fn` closure to be passed to TPUEstimator."""
40 |
41 | def input_fn(params):
42 | """The actual input function."""
43 | batch_size = params["batch_size"]
44 | name_to_features = {
45 | "input_ids": tf.FixedLenFeature([seq_length + 1], tf.int64),
46 | }
47 |
48 | # For training, we want a lot of parallel reading and shuffling.
49 | # For eval, we want no shuffling and parallel reading doesn't matter.
50 | if is_training:
51 | d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files))
52 | d = d.repeat()
53 | d = d.shuffle(buffer_size=len(input_files))
54 |
55 | # `cycle_length` is the number of parallel files that get read.
56 | cycle_length = min(num_cpu_threads, len(input_files))
57 |
58 | # `sloppy` mode means that the interleaving is not exact. This adds
59 | # even more randomness to the training pipeline.
60 | d = d.apply(
61 | tf.data.experimental.parallel_interleave(
62 | tf.data.TFRecordDataset,
63 | sloppy=is_training,
64 | cycle_length=cycle_length))
65 | d = d.shuffle(buffer_size=100)
66 | else:
67 | d = tf.data.TFRecordDataset(input_files)
68 | # If we evaluate for a fixed number of steps we don't want to encounter
69 | # out-of-range exceptions.
70 | if evaluate_for_fixed_number_of_steps:
71 | d = d.repeat()
72 |
73 | # We must `drop_remainder` on training because the TPU requires fixed
74 | # size dimensions. For eval, we assume we are evaluating on the CPU or GPU
75 | # and we *don't* want to drop the remainder, otherwise we wont cover
76 | # every sample.
77 | d = d.apply(
78 | tf.data.experimental.map_and_batch(
79 | lambda record: _decode_record(record, name_to_features),
80 | batch_size=batch_size,
81 | num_parallel_batches=num_cpu_threads,
82 | drop_remainder=True))
83 | return d
84 |
85 | return input_fn
86 |
87 |
88 | # ~~~~~~~~~~~~~~ This is for classification / AF ~~~~~~~~~~~~~~~~~~
89 | def classification_convert_examples_to_features(
90 | examples, max_seq_length, batch_size, encoder, output_file, labels, pad_extra_examples=False,
91 | chop_from_front_if_needed=True):
92 | """Convert a set of `InputExample`s to a TFRecord file."""
93 |
94 | writer = tf.python_io.TFRecordWriter(output_file)
95 |
96 | label_map = {label: i for i, label in enumerate(labels)}
97 |
98 | for (ex_index, example) in enumerate(examples):
99 | if ex_index % 10000 == 0:
100 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples)))
101 |
102 | # begin_summary is our [CLS] token
103 | tokens = example['ids'] + [encoder.begin_summary]
104 |
105 | if len(tokens) > max_seq_length:
106 | if chop_from_front_if_needed:
107 | tokens = tokens[-max_seq_length:]
108 | else:
109 | tokens = example['ids'][:(max_seq_length-1)] + [encoder.begin_summary]
110 | elif len(tokens) < max_seq_length:
111 | tokens.extend([encoder.padding] * (max_seq_length - len(tokens)))
112 |
113 | features = collections.OrderedDict()
114 | features['input_ids'] = tf.train.Feature(int64_list=tf.train.Int64List(value=tokens))
115 | features['label_ids'] = tf.train.Feature(int64_list=tf.train.Int64List(value=[label_map[example['label']]]))
116 | features['is_real_example'] = tf.train.Feature(int64_list=tf.train.Int64List(value=[1]))
117 | tf_example = tf.train.Example(features=tf.train.Features(feature=features))
118 | writer.write(tf_example.SerializeToString())
119 |
120 | if pad_extra_examples:
121 | for x in range(len(examples) % batch_size):
122 | features = collections.OrderedDict()
123 | features['input_ids'] = tf.train.Feature(int64_list=tf.train.Int64List(value=[0]*max_seq_length))
124 | features['label_ids'] = tf.train.Feature(int64_list=tf.train.Int64List(value=[0]))
125 | features['is_real_example'] = tf.train.Feature(int64_list=tf.train.Int64List(value=[0]))
126 | tf_example = tf.train.Example(features=tf.train.Features(feature=features))
127 | writer.write(tf_example.SerializeToString())
128 | writer.close()
129 |
130 |
131 | def classification_input_fn_builder(input_file, seq_length, is_training,
132 | drop_remainder,
133 | buffer_size=100):
134 | """Creates an `input_fn` closure to be passed to TPUEstimator."""
135 |
136 | name_to_features = {
137 | "input_ids": tf.FixedLenFeature([seq_length], tf.int64),
138 | "label_ids": tf.FixedLenFeature([], tf.int64),
139 | "is_real_example": tf.FixedLenFeature([], tf.int64),
140 | }
141 |
142 | def input_fn(params):
143 | """The actual input function."""
144 | batch_size = params["batch_size"]
145 |
146 | # For training, we want a lot of parallel reading and shuffling.
147 | # For eval, we want no shuffling and parallel reading doesn't matter.
148 | d = tf.data.TFRecordDataset(input_file)
149 | if is_training:
150 | d = d.repeat()
151 | d = d.shuffle(buffer_size=buffer_size)
152 |
153 | d = d.apply(
154 | tf.data.experimental.map_and_batch(
155 | lambda record: _decode_record(record, name_to_features),
156 | batch_size=batch_size,
157 | drop_remainder=drop_remainder))
158 |
159 | return d
160 |
161 | return input_fn
162 |
--------------------------------------------------------------------------------
/train/modeling.py:
--------------------------------------------------------------------------------
1 | # Original work Copyright 2018 The Google AI Language Team Authors.
2 | # Modified work Copyright 2019 Rowan Zellers
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import copy
17 | import json
18 | import math
19 |
20 | import six
21 | import tensorflow.compat.v1 as tf
22 |
23 | from train import optimization_adafactor
24 | from train.utils import get_assignment_map_from_checkpoint, get_shape_list, get_attention_mask, gelu, layer_norm, dropout, \
25 | construct_scalar_host_call
26 |
27 | class GroverConfig(object):
28 | """Configuration for `GroverModel`"""
29 |
30 | def __init__(self,
31 | vocab_size,
32 | hidden_size=768,
33 | num_hidden_layers=12,
34 | num_attention_heads=12,
35 | intermediate_size=3072,
36 | hidden_act="gelu",
37 | hidden_dropout_prob=0.1,
38 | attention_probs_dropout_prob=0.1,
39 | max_position_embeddings=512,
40 | initializer_range=0.02):
41 | """Constructs NewsConfig.
42 |
43 | Args:
44 | vocab_size: Vocabulary size of `inputs_ids` in `GroverModel`.
45 | hidden_size: Size of the layers
46 | num_hidden_layers: Number of hidden layers in the Transformer encoder.
47 | num_attention_heads: Number of attention heads for each attention layer in
48 | the Transformer encoder.
49 | intermediate_size: The size of the "intermediate" (i.e., feed-forward)
50 | layer in the Transformer encoder.
51 | hidden_act: The non-linear activation function (function or string) in the
52 | encoder and pooler.
53 | hidden_dropout_prob: The dropout probability for all fully connected
54 | layers in the embeddings, encoder, and pooler.
55 | attention_probs_dropout_prob: The dropout ratio for the attention
56 | probabilities.
57 | max_position_embeddings: The maximum sequence length that this model might
58 | ever be used with. Typically set this to something large just in case
59 | (e.g., 512 or 1024 or 2048).
60 | initializer_range: The stdev of the truncated_normal_initializer for
61 | initializing all weight matrices.
62 | """
63 | self.vocab_size = vocab_size
64 | self.hidden_size = hidden_size
65 | self.num_hidden_layers = num_hidden_layers
66 | self.num_attention_heads = num_attention_heads
67 | self.hidden_act = hidden_act
68 | self.intermediate_size = intermediate_size
69 | self.hidden_dropout_prob = hidden_dropout_prob
70 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
71 | self.max_position_embeddings = max_position_embeddings
72 | self.initializer_range = initializer_range
73 | self.pad_token_id = 0
74 |
75 | @classmethod
76 | def from_dict(cls, json_object):
77 | """Constructs a `NewsConfig` from a Python dictionary of parameters."""
78 | config = GroverConfig(vocab_size=None)
79 | for (key, value) in six.iteritems(json_object):
80 | config.__dict__[key] = value
81 | return config
82 |
83 | @classmethod
84 | def from_json_file(cls, json_file):
85 | """Constructs a `NewsConfig` from a json file of parameters."""
86 | with tf.gfile.GFile(json_file, "r") as reader:
87 | text = reader.read()
88 | return cls.from_dict(json.loads(text))
89 |
90 | def to_dict(self):
91 | """Serializes this instance to a Python dictionary."""
92 | output = copy.deepcopy(self.__dict__)
93 | return output
94 |
95 | def to_json_string(self):
96 | """Serializes this instance to a JSON string."""
97 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
98 |
99 |
100 | def mask_attention_for_ltr(attention_scores, attention_mask):
101 | """
102 | Mask attention so that we're only predicting going forward
103 | :param attention_scores: [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
104 | :param attention_mask [query_length, key_length]
105 | :return: masked attention
106 | """
107 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
108 | # masked positions, this operation will create a tensor which is 0.0 for
109 | # positions we want to attend and -10000.0 for masked positions.
110 | mask = attention_mask[None, None]
111 | return attention_scores * mask - tf.cast(1e10, attention_scores.dtype) * (1 - mask)
112 |
113 |
114 | def create_initializer(initializer_range=0.02):
115 | """Creates a `truncated_normal_initializer` with the given range."""
116 | return tf.truncated_normal_initializer(stddev=initializer_range)
117 |
118 |
119 | def _attention_projection_and_transpose(x_flat, batch_size, seq_length, num_attention_heads, size_per_head,
120 | name, initializer_range=0.02):
121 | """
122 | :param x_flat: [batch_size*seq_length, width]
123 | :return: A fixed up tensor of size [batch_size, num_attention_heads, seq_length, size_per_head]
124 | """
125 | batch_size_seq_length, dim = get_shape_list(x_flat, expected_rank=2)
126 |
127 | if dim != size_per_head * num_attention_heads:
128 | raise ValueError("passed in a tensor of shape {} when size_per_head={} and num_attention_heads={}".format(
129 | (batch_size_seq_length, dim), size_per_head, num_attention_heads
130 | ))
131 |
132 | projected = tf.layers.dense(
133 | x_flat,
134 | num_attention_heads * size_per_head,
135 | name=name,
136 | kernel_initializer=create_initializer(initializer_range))
137 |
138 | projected = tf.reshape(
139 | projected, [batch_size, seq_length, num_attention_heads, size_per_head])
140 | output_tensor = tf.transpose(projected, [0, 2, 1, 3])
141 | return output_tensor
142 |
143 |
144 | def attention_layer(x_flat, attention_mask, batch_size, seq_length, size_per_head=512, num_attention_heads=1, *,
145 | cache=None,
146 | initializer_range=0.02, hidden_dropout_prob=0.1,
147 | attention_probs_dropout_prob=0.1, do_cache=False):
148 | """
149 |
150 | :param x_flat: Tensor input, should be [batch_size*seq_length, dim]
151 | :param attention_mask: Attention mask to use of size [seq_length, seq_length+cached_length]
152 | :param size_per_head: dim = size_per_head * num_attention_heads
153 | :param num_attention_heads: dim = size_per_head * num_attention_heads
154 | :param cache: Optionally some past (cached) things of size
155 | [batch, 2, heads, sequence, features], where 2 is [k, v]
156 | :param do_cache: True if we should return cache
157 | :return: A new tensor of shape [batch_size, seq_length, dim]
158 | as well as a new cache "cached_keys_and_values" that will be of size
159 | [batch_size, 2, num_attention_heads, seq_length, dim]
160 | """
161 | batch_size_seq_length, dim = get_shape_list(x_flat, expected_rank=2)
162 |
163 | if dim != size_per_head * num_attention_heads:
164 | raise ValueError("passed in a tensor of shape {} when size_per_head={} and num_attention_heads={}".format(
165 | (batch_size_seq_length, dim), size_per_head, num_attention_heads
166 | ))
167 |
168 | query = _attention_projection_and_transpose(x_flat, batch_size=batch_size, seq_length=seq_length,
169 | num_attention_heads=num_attention_heads, size_per_head=size_per_head,
170 | name='query_layer',
171 | initializer_range=initializer_range)
172 | key = _attention_projection_and_transpose(x_flat, batch_size=batch_size, seq_length=seq_length,
173 | num_attention_heads=num_attention_heads, size_per_head=size_per_head,
174 | name='key_layer',
175 | initializer_range=initializer_range)
176 |
177 | value = _attention_projection_and_transpose(x_flat, batch_size=batch_size, seq_length=seq_length,
178 | num_attention_heads=num_attention_heads, size_per_head=size_per_head,
179 | name='value_layer',
180 | initializer_range=initializer_range)
181 |
182 | # Add to cache
183 | cached_keys_and_values = tf.stack([key, value], axis=1) if do_cache else None
184 |
185 | # Things that were relevant from the cache
186 | if cache is not None:
187 | pk, pv = tf.unstack(cache, axis=1)
188 | key = tf.concat([pk, key], axis=-2)
189 | value = tf.concat([pv, value], axis=-2)
190 |
191 | # Multiply [batch_size, num_attention_heads, seq_length, size_per_head] with
192 | # [batch_size, num_attention_heads, size_per_head, seq_length+cached_length] ->
193 | # [batch_size, num_attention_heads, seq_length, seq_length+cached_length]
194 | attention_scores = tf.matmul(query, key, transpose_b=True)
195 | attention_scores = tf.multiply(attention_scores,
196 | 1.0 / math.sqrt(float(size_per_head)))
197 | attention_scores = mask_attention_for_ltr(attention_scores, attention_mask)
198 | attention_probs = tf.nn.softmax(attention_scores)
199 |
200 | # This is actually dropping out entire tokens to attend to, which might
201 | # seem a bit unusual, but is taken from the original Transformer paper.
202 | # NOPENOPENOPENOPE
203 | # attention_probs = factoreddropout(attention_probs, attention_probs_dropout_prob)
204 |
205 | # Multiply [batch_size, num_attention_heads, seq_length, seq_length+cached_length] with
206 | # [batch_size, num_attention_heads, seq_length+cached_length, size_per_head] ->
207 | # [batch_size, num_attention_heads, seq_length, size_per_head] ->
208 | context_layer = tf.matmul(attention_probs, value)
209 |
210 | # `context_layer` = [batch_size, seq_length, num_attention_heads, size_per_head]
211 | context_layer = tf.transpose(context_layer, [0, 2, 1, 3])
212 | context_layer = tf.reshape(context_layer, [batch_size * seq_length, num_attention_heads * size_per_head])
213 |
214 | context_layer_projected = tf.layers.dense(
215 | context_layer,
216 | num_attention_heads * size_per_head,
217 | kernel_initializer=create_initializer(initializer_range),
218 | name='context_projection_layer'
219 | )
220 | context_layer_projected = dropout(context_layer_projected, hidden_dropout_prob)
221 |
222 | return context_layer_projected, cached_keys_and_values
223 |
224 |
225 | def residual_mlp_layer(x_flat, intermediate_size, initializer_range=0.02, hidden_dropout_prob=0.1):
226 | """
227 | :param x: The attention output. It should be [batch_size*seq_length, dim]
228 | :param intermediate_size: the hidden projection. By default this is the input_dim * 4.
229 |
230 | in the original GPT we would return layer_norm(x_norm + h1) rather than layer_norm(x + h1)
231 |
232 | :return:
233 | """
234 | batch_size_seq_length, hidden_size = get_shape_list(x_flat, expected_rank=2)
235 | x_norm = layer_norm(x_flat, name='mlp_ln0')
236 |
237 | intermediate_output = tf.layers.dense(
238 | x_norm,
239 | intermediate_size,
240 | activation=gelu,
241 | kernel_initializer=create_initializer(initializer_range),
242 | name='intermediate',
243 | )
244 |
245 | output_for_residual = tf.layers.dense(
246 | intermediate_output,
247 | hidden_size,
248 | name='output',
249 | kernel_initializer=create_initializer(initializer_range))
250 | output_for_residual = dropout(output_for_residual, hidden_dropout_prob)
251 |
252 | layer_output = layer_norm(x_flat + output_for_residual, name='mlp_ln1')
253 | return layer_output
254 |
255 |
256 | def embed(input_ids,
257 | vocab_size,
258 | embedding_size,
259 | position_offset=0,
260 | initializer_range=0.02,
261 | max_position_embeddings=512,
262 | use_one_hot_embeddings=True):
263 | """reur and position embeddings
264 | :param input_ids: int Tensor of shape [batch_size, seq_length].
265 | :param vocab_size: number of words in vocab
266 | :param embedding_size: dimensionality of the embedding
267 | :param position_offset: aka number of cached tokens.
268 | :param initializer_range: float. Range of the weight initialization.
269 | :param max_position_embeddings: int. Maximum sequence length.
270 | :param use_one_hot_embeddings: probably want this to be true
271 | :return: [batch_size, seq_length, embedding_size] embedded tensor
272 | """
273 | (batch_size, seq_length) = get_shape_list(input_ids, expected_rank=2)
274 |
275 | embedding_table = tf.get_variable(
276 | name='word_embed',
277 | shape=[vocab_size, embedding_size],
278 | initializer=create_initializer(initializer_range),
279 | )
280 |
281 | assert_op = tf.assert_less_equal(tf.reduce_max(input_ids), vocab_size - 1)
282 | with tf.control_dependencies([assert_op]):
283 | if use_one_hot_embeddings:
284 | flat_input_ids = tf.reshape(input_ids, [-1])
285 | one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size)
286 | output_flat = tf.matmul(one_hot_input_ids, embedding_table)
287 | else:
288 | output_flat = tf.nn.embedding_lookup(embedding_table, input_ids)
289 |
290 | embedded_input = tf.reshape(output_flat, [batch_size, seq_length, embedding_size])
291 |
292 | assert_op = tf.assert_less_equal(seq_length, max_position_embeddings)
293 |
294 | with tf.control_dependencies([assert_op]):
295 | full_position_embeddings = tf.get_variable(
296 | name='pos_embed',
297 | shape=[max_position_embeddings, embedding_size],
298 | initializer=create_initializer(initializer_range),
299 | )
300 | # Since the position embedding table is a learned variable, we create it
301 | # using a (long) sequence length `max_position_embeddings`. The actual
302 | # sequence length might be shorter than this, for faster training of
303 | # tasks that do not have long sequences.
304 | #
305 | # So `full_position_embeddings` is effectively an embedding table
306 | # for position [0, 1, 2, ..., max_position_embeddings-1], and the current
307 | # sequence has positions [0, 1, 2, ... seq_length-1], so we can just
308 | # perform a slice.
309 | if position_offset == 0:
310 | embedded_input += tf.slice(full_position_embeddings, [0, 0], [seq_length, -1])[None]
311 | else:
312 | # Tensorflow is too stupid to allow slicing
313 | flat_pos_ids = (tf.range(seq_length, dtype=tf.int32) + position_offset)
314 | one_hot_pos_ids = tf.one_hot(flat_pos_ids, depth=max_position_embeddings)
315 |
316 | # [seq_length, full_position_embeddings], [full_position_embeddings, dim]
317 | seq_embeds = tf.matmul(one_hot_pos_ids, full_position_embeddings)
318 | embedded_input += seq_embeds[None]
319 |
320 | # embedded_input += tf.slice(full_position_embeddings[position_offset:], [0, 0], [seq_length, -1])[None]
321 |
322 | return layer_norm(embedded_input, name='embed_norm'), embedding_table
323 |
324 |
325 | def _top_p_sample(logits, ignore_ids=None, num_samples=1, p=0.9):
326 | """
327 | Does top-p sampling. if ignore_ids is on, then we will zero out those logits.
328 | :param logits: [batch_size, vocab_size] tensor
329 | :param ignore_ids: [vocab_size] one-hot representation of the indices we'd like to ignore and never predict,
330 | like padding maybe
331 | :param p: topp threshold to use, either a float or a [batch_size] vector
332 | :return: [batch_size, num_samples] samples
333 |
334 | # TODO FIGURE OUT HOW TO DO THIS ON TPUS. IT'S HELLA SLOW RIGHT NOW, DUE TO ARGSORT I THINK
335 | """
336 | with tf.variable_scope('top_p_sample'):
337 | batch_size, vocab_size = get_shape_list(logits, expected_rank=2)
338 |
339 | probs = tf.nn.softmax(logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10,
340 | axis=-1)
341 |
342 | if isinstance(p, float) and p > 0.999999:
343 | # Don't do top-p sampling in this case
344 | print("Top-p sampling DISABLED", flush=True)
345 | return {
346 | 'probs': probs,
347 | 'sample': tf.random.categorical(
348 | logits=logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10,
349 | num_samples=num_samples, dtype=tf.int32),
350 | }
351 |
352 | # [batch_size, vocab_perm]
353 | indices = tf.argsort(probs, direction='DESCENDING')
354 | cumulative_probabilities = tf.math.cumsum(tf.batch_gather(probs, indices), axis=-1, exclusive=False)
355 |
356 | # find the top pth index to cut off. careful we don't want to cutoff everything!
357 | # result will be [batch_size, vocab_perm]
358 | p_expanded = p if isinstance(p, float) else p[:, None]
359 | exclude_mask = tf.logical_not(
360 | tf.logical_or(cumulative_probabilities < p_expanded, tf.range(vocab_size)[None] < 1))
361 |
362 | # OPTION A - sample in the sorted space, then unsort.
363 | logits_to_use = tf.batch_gather(logits, indices) - tf.cast(exclude_mask, tf.float32) * 1e10
364 | sample_perm = tf.random.categorical(logits=logits_to_use, num_samples=num_samples)
365 | sample = tf.batch_gather(indices, sample_perm)
366 |
367 | # OPTION B - unsort first - Indices need to go back to 0 -> N-1 -- then sample
368 | # unperm_indices = tf.argsort(indices, direction='ASCENDING')
369 | # include_mask_unperm = tf.batch_gather(include_mask, unperm_indices)
370 | # logits_to_use = logits - (1 - tf.cast(include_mask_unperm, tf.float32)) * 1e10
371 | # sample = tf.random.categorical(logits=logits_to_use, num_samples=num_samples, dtype=tf.int32)
372 |
373 | return {
374 | 'probs': probs,
375 | 'sample': sample,
376 | }
377 |
378 |
379 | def _top_k_sample(logits, ignore_ids=None, num_samples=1, k=10):
380 | """
381 | Does top-k sampling. if ignore_ids is on, then we will zero out those logits.
382 | :param logits: [batch_size, vocab_size] tensor
383 | :param ignore_ids: [vocab_size] one-hot representation of the indices we'd like to ignore and never predict,
384 | like padding maybe
385 | :param p: topp threshold to use, either a float or a [batch_size] vector
386 | :return: [batch_size, num_samples] samples
387 |
388 | # TODO FIGURE OUT HOW TO DO THIS ON TPUS. IT'S HELLA SLOW RIGHT NOW, DUE TO ARGSORT I THINK
389 | """
390 | with tf.variable_scope('top_p_sample'):
391 | batch_size, vocab_size = get_shape_list(logits, expected_rank=2)
392 |
393 | probs = tf.nn.softmax(logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10,
394 | axis=-1)
395 | # [batch_size, vocab_perm]
396 | indices = tf.argsort(probs, direction='DESCENDING')
397 |
398 | # find the top pth index to cut off. careful we don't want to cutoff everything!
399 | # result will be [batch_size, vocab_perm]
400 | k_expanded = k if isinstance(k, int) else k[:, None]
401 | exclude_mask = tf.range(vocab_size)[None] >= k_expanded
402 |
403 | # OPTION A - sample in the sorted space, then unsort.
404 | logits_to_use = tf.batch_gather(logits, indices) - tf.cast(exclude_mask, tf.float32) * 1e10
405 | sample_perm = tf.random.categorical(logits=logits_to_use, num_samples=num_samples)
406 | sample = tf.batch_gather(indices, sample_perm)
407 |
408 | return {
409 | 'probs': probs,
410 | 'sample': sample,
411 | }
412 |
413 |
414 | class GroverModel(object):
415 | def __init__(self,
416 | config: GroverConfig,
417 | is_training,
418 | input_ids,
419 | cache=None,
420 | do_cache=False,
421 | pad_token_id=0,
422 | chop_off_last_token=True,
423 | scope=None,
424 | reuse=False):
425 | """
426 | :param config:
427 | :param is_training:
428 | :param input_ids: Tensor thats of size [batch_size, seq_length]
429 | :param cache: Optionally, a tensor to use that will contain cached information of the size
430 | [batch_size, num_layers, 2, num_heads, cache_length, features]
431 | :param do_cache: Whether to cache again.
432 | :param pad_token_id: Which token will be used for padding (probably 0.)
433 | :param chop_off_last_token: True if we will end up using this for TRAINING only. False if we want to generate.
434 | it means the last token in input_ids will not be processed by the model as input
435 | :param scope: scope to run this on
436 | """
437 | self.config = copy.deepcopy(config)
438 | self.is_training = is_training
439 | self.pad_token_id = pad_token_id
440 |
441 | if not is_training:
442 | self.config.hidden_dropout_prob = 0.0
443 | self.config.attention_probs_dropout_prob = 0.0
444 |
445 | if chop_off_last_token:
446 | self.target_ids = input_ids[:, 1:]
447 | self.input_ids = input_ids[:, :-1]
448 | else:
449 | self.input_ids = input_ids
450 | self.target_ids = tf.concat((input_ids[:, 1:],
451 | tf.constant(self.pad_token_id, dtype=self.input_ids.dtype,
452 | shape=[get_shape_list(self.input_ids, 2)[0], 1])), 1)
453 |
454 | self.batch_size, self.seq_length = get_shape_list(self.input_ids, 2)
455 |
456 | if cache is None:
457 | caches = [None] * config.num_hidden_layers
458 | self.cache_length = 0
459 | else:
460 | batch_size_, num_layers_, two_, num_heads_, self.cache_length, features_ = get_shape_list(
461 | cache, expected_rank=6)
462 | assert batch_size_ == self.batch_size
463 | assert num_layers_ == config.num_hidden_layers
464 | assert two_ == 2
465 | assert num_heads_ == config.num_attention_heads
466 | assert features_ == (config.hidden_size // config.num_attention_heads)
467 | caches = tf.unstack(cache, axis=1)
468 |
469 | with tf.variable_scope(scope, default_name='newslm', reuse=reuse):
470 | with tf.variable_scope("embeddings"):
471 | embeddings, self.embedding_table = embed(self.input_ids, config.vocab_size,
472 | config.hidden_size,
473 | position_offset=self.cache_length,
474 | initializer_range=config.initializer_range,
475 | max_position_embeddings=config.max_position_embeddings,
476 | use_one_hot_embeddings=True)
477 |
478 | mask = get_attention_mask(self.seq_length, self.seq_length + self.cache_length, dtype=embeddings.dtype)
479 |
480 | # We keep the representation as a 2D tensor to avoid re-shaping it back and
481 | # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on
482 | # the GPU/CPU but may not be free on the TPU, so we want to minimize them to
483 | # help the optimizer.
484 | hidden_state = tf.reshape(embeddings, [self.batch_size * self.seq_length, self.config.hidden_size])
485 | new_kvs = []
486 | for layer_idx, layer_cache in enumerate(caches):
487 | with tf.variable_scope('layer{:02d}'.format(layer_idx)):
488 | # [batch_size * seq_length, hidden_size]
489 | attention_output, new_kv = attention_layer(
490 | hidden_state,
491 | mask,
492 | batch_size=self.batch_size,
493 | seq_length=self.seq_length,
494 | size_per_head=config.hidden_size // config.num_attention_heads,
495 | num_attention_heads=config.num_attention_heads,
496 | initializer_range=config.initializer_range,
497 | hidden_dropout_prob=self.config.hidden_dropout_prob,
498 | attention_probs_dropout_prob=self.config.attention_probs_dropout_prob,
499 | do_cache=do_cache,
500 | cache=layer_cache,
501 | )
502 | new_kvs.append(new_kv)
503 |
504 | # [batch_size * seq_length, hidden_size]
505 | hidden_state = residual_mlp_layer(hidden_state + attention_output,
506 | intermediate_size=config.intermediate_size,
507 | hidden_dropout_prob=self.config.hidden_dropout_prob)
508 | self.hidden_state = hidden_state
509 |
510 | self.new_kvs = tf.stack(new_kvs, axis=1) if do_cache else None
511 |
512 | # Note that the hidden state is still flat (batch_size*hidden_size)
513 | self.logits_flat = tf.matmul(self.hidden_state, self.embedding_table, transpose_b=True)
514 |
515 | # THE OUTPUT BIAS DOES NOT SPARK JOY
516 | # output_bias = tf.get_variable('output_bias', shape=[config.vocab_size], initializer=tf.zeros_initializer())
517 | # self.logits_flat = tf.nn.bias_add(self.logits_flat, output_bias)
518 |
519 | @property
520 | def log_probs(self):
521 | logprobs_flat = tf.nn.log_softmax(self.logits_flat, axis=-1)
522 | return tf.reshape(logprobs_flat, [self.batch_size, self.seq_length, -1])
523 |
524 | def lm_loss(self):
525 | """
526 | :return: stuff
527 | """
528 | target_ids_flat = tf.reshape(self.target_ids, [-1])
529 |
530 | # 1 if it's valid and 0 otherwise.
531 | label_weights = tf.cast(tf.not_equal(target_ids_flat, self.pad_token_id), dtype=self.logits_flat.dtype)
532 |
533 | # [batch_size * seq_length, vocab_size]
534 | one_hot_labels = tf.one_hot(target_ids_flat,
535 | depth=self.config.vocab_size,
536 | dtype=self.logits_flat.dtype)
537 |
538 | # [batch_size * seq_length, vocab_size]
539 | logprobs_flat = tf.nn.log_softmax(self.logits_flat, axis=-1)
540 |
541 | per_example_loss = -tf.reduce_sum(logprobs_flat * one_hot_labels, axis=[-1])
542 |
543 | # per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits_flat, labels=target_ids_flat)
544 |
545 | numerator = tf.reduce_sum(label_weights * per_example_loss)
546 | denominator = tf.reduce_sum(label_weights) + 1e-5
547 | loss = numerator / denominator
548 | return loss
549 |
550 | def pooled_output(self, clf_token):
551 | """
552 | Extract pooled output given a token that says where we should look
553 | :param clf_token:
554 | :return:
555 | """
556 | pool_idx = tf.cast(tf.argmax(tf.cast(tf.equal(self.input_ids, clf_token), tf.float32), 1), tf.int32)
557 | return tf.gather(self.hidden_state, tf.range(self.batch_size, dtype=tf.int32) * self.seq_length + pool_idx)
558 |
559 |
560 | def model_fn_builder(config: GroverConfig, init_checkpoint, learning_rate,
561 | num_train_steps, num_warmup_steps, use_tpu):
562 | """Returns `model_fn` closure for TPUEstimator."""
563 |
564 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
565 | """The `model_fn` for TPUEstimator."""
566 |
567 | tf.logging.info("*** Features ***")
568 | for name in sorted(features.keys()):
569 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
570 |
571 | input_ids = features["input_ids"]
572 |
573 | is_training = (mode == tf.estimator.ModeKeys.TRAIN)
574 |
575 | model = GroverModel(
576 | config=config,
577 | is_training=is_training,
578 | input_ids=input_ids,
579 | pad_token_id=config.pad_token_id,
580 | chop_off_last_token=True,
581 | )
582 |
583 | total_loss = model.lm_loss()
584 |
585 | if is_training:
586 | train_op, train_metrics = optimization_adafactor.create_optimizer(
587 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)
588 | tvars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
589 | else:
590 | train_op = None
591 | train_metrics = {}
592 | tvars = tf.trainable_variables()
593 |
594 | initialized_variable_names = {}
595 | scaffold_fn = None
596 | if init_checkpoint:
597 | (assignment_map, initialized_variable_names
598 | ) = get_assignment_map_from_checkpoint(tvars, init_checkpoint)
599 | if use_tpu:
600 | def tpu_scaffold():
601 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
602 | return tf.train.Scaffold()
603 |
604 | scaffold_fn = tpu_scaffold
605 | else:
606 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
607 |
608 | tf.logging.info("**** Trainable Variables ****")
609 | for var in tvars:
610 | init_string = ""
611 | if var.name in initialized_variable_names:
612 | init_string = ", *INIT_FROM_CKPT*"
613 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
614 | init_string)
615 |
616 | output_spec = None
617 | if mode == tf.estimator.ModeKeys.TRAIN:
618 | if use_tpu:
619 | output_spec = tf.contrib.tpu.TPUEstimatorSpec(
620 | mode=mode,
621 | loss=total_loss,
622 | train_op=train_op,
623 | host_call=construct_scalar_host_call(metric_dict=train_metrics, model_dir=params['model_dir'],
624 | prefix='training/'),
625 | scaffold_fn=scaffold_fn)
626 | else:
627 | output_spec = tf.contrib.tpu.TPUEstimatorSpec(
628 | mode=mode,
629 | loss=total_loss,
630 | train_op=train_op,
631 | training_hooks=[
632 | tf.train.LoggingTensorHook({'loss': tf.metrics.mean(total_loss)[1]}, every_n_iter=100)],
633 | scaffold_fn=scaffold_fn)
634 |
635 | elif mode == tf.estimator.ModeKeys.EVAL:
636 | def metric_fn(total_loss):
637 | loss = tf.metrics.mean(values=total_loss)
638 | return {
639 | "eval_loss": loss,
640 | }
641 |
642 | eval_metrics = (metric_fn,
643 | [total_loss])
644 | output_spec = tf.contrib.tpu.TPUEstimatorSpec(
645 | mode=mode,
646 | loss=total_loss,
647 | eval_metrics=eval_metrics,
648 | scaffold_fn=scaffold_fn)
649 | else:
650 | gt_logprobs = tf.squeeze(tf.batch_gather(model.log_probs, model.target_ids[:, :, None]), axis=2)
651 |
652 | # Need top-p required under topp sampling!
653 | better_than_gt = model.log_probs > gt_logprobs[:, :, None]
654 | top_p_required = tf.reduce_sum(tf.cast(better_than_gt, tf.float32) * tf.exp(model.log_probs), axis=2)
655 |
656 | # No top-p sampling for now, since this seems to be too slow on TPUs
657 | if use_tpu:
658 | predictions = tf.reshape(
659 | tf.random.categorical(logits=model.logits_flat, num_samples=1),
660 | get_shape_list(model.target_ids),
661 | )
662 | else:
663 | # Argmax
664 | # predictions = tf.math.argmax(model.log_probs, axis=-1, output_type=tf.int32)
665 | predictions = tf.reshape(
666 | _top_p_sample(model.logits_flat, num_samples=1, p=0.99)['sample'],
667 | get_shape_list(model.target_ids),
668 | )
669 | pred_logprobs = tf.squeeze(tf.batch_gather(model.log_probs, predictions[:, :, None]), axis=2)
670 |
671 | output_spec = tf.contrib.tpu.TPUEstimatorSpec(
672 | mode=mode,
673 | predictions={'gt_logprobs': gt_logprobs,
674 | 'top_p_required': top_p_required,
675 | 'predictions': predictions,
676 | 'pred_logprobs': pred_logprobs,
677 | 'labels': input_ids},
678 | scaffold_fn=scaffold_fn)
679 | return output_spec
680 |
681 | return model_fn
682 |
683 |
684 | def sample_step(tokens, ignore_ids, news_config, batch_size=1, p_for_topp=0.95, cache=None, do_topk=False):
685 | """
686 | Helper function that samples from grover for a single step
687 | :param tokens: [batch_size, n_ctx_b] tokens that we will predict from
688 | :param ignore_ids: [n_vocab] mask of the tokens we don't want to predict
689 | :param news_config: config for the GroverModel
690 | :param batch_size: batch size to use
691 | :param p_for_topp: top-p or top-k threshold
692 | :param cache: [batch_size, news_config.num_hidden_layers, 2,
693 | news_config.num_attention_heads, n_ctx_a,
694 | news_config.hidden_size // news_config.num_attention_heads] OR, None
695 | :return: new_tokens, size [batch_size]
696 | new_probs, also size [batch_size]
697 | new_cache, size [batch_size, news_config.num_hidden_layers, 2, n_ctx_b,
698 | news_config.num_attention_heads, news_config.hidden_size // news_config.num_attention_heads]
699 | """
700 | model = GroverModel(
701 | config=news_config,
702 | is_training=False,
703 | input_ids=tokens,
704 | reuse=tf.AUTO_REUSE,
705 | scope='newslm',
706 | chop_off_last_token=False,
707 | do_cache=True,
708 | cache=cache,
709 | )
710 |
711 | # Extract the FINAL SEQ LENGTH
712 | batch_size_times_seq_length, vocab_size = get_shape_list(model.logits_flat, expected_rank=2)
713 | next_logits = tf.reshape(model.logits_flat, [batch_size, -1, vocab_size])[:, -1]
714 |
715 | if do_topk:
716 | sample_info = _top_k_sample(next_logits, num_samples=1, k=tf.cast(p_for_topp, dtype=tf.int32))
717 | else:
718 | sample_info = _top_p_sample(next_logits, ignore_ids=ignore_ids, num_samples=1, p=p_for_topp)
719 |
720 | new_tokens = tf.squeeze(sample_info['sample'], 1)
721 | new_probs = tf.squeeze(tf.batch_gather(sample_info['probs'], sample_info['sample']), 1)
722 | return {
723 | 'new_tokens': new_tokens,
724 | 'new_probs': new_probs,
725 | 'new_cache': model.new_kvs,
726 | }
727 |
728 |
729 | def initialize_from_context(initial_context, ignore_ids, news_config, p_for_topp=0.95, do_topk=False):
730 | """ same signature as sample_step"""
731 | batch_size, _ = get_shape_list(initial_context, expected_rank=2)
732 |
733 | context_output = sample_step(tokens=initial_context, ignore_ids=ignore_ids, news_config=news_config,
734 | batch_size=batch_size, p_for_topp=p_for_topp, cache=None, do_topk=do_topk)
735 | return {
736 | 'tokens': tf.concat([initial_context, context_output['new_tokens'][:, None]], 1),
737 | 'cache': context_output['new_cache'],
738 | 'probs': context_output['new_probs'][:, None]
739 | }
740 |
741 |
742 | def sample(news_config: GroverConfig, initial_context, eos_token, min_len, ignore_ids=None, p_for_topp=0.95,
743 | do_topk=False):
744 | """
745 | V1 version of: sample outputs from a model, and do it all at once
746 | :param news_config: Configuration used to construct the model
747 | :param initial_context: [batch_size, seq_length] that we'll start generating with
748 | :param eos_token: Stop generating if you see this (tf scalar)
749 | :param min_len: min length of sample
750 | :param ignore_ids: NEVER GENERATE THESE [vocab_size]
751 | :return:
752 | """
753 | batch_size, _ = get_shape_list(initial_context, expected_rank=2)
754 |
755 | if ignore_ids is None:
756 | ignore_ids = tf.constant([x == 0 for x in range(news_config.vocab_size)], dtype=tf.bool)
757 |
758 | with tf.name_scope('sample_sequence'):
759 | # Initial call to get cache
760 | context_output = initialize_from_context(initial_context, ignore_ids=ignore_ids, news_config=news_config,
761 | p_for_topp=p_for_topp,
762 | do_topk=do_topk)
763 | ctx = context_output['tokens']
764 | cache = context_output['cache']
765 | probs = context_output['probs']
766 |
767 | def body(ctx, cache, probs):
768 | """ for whatever reason this didn't work when I ran it on more than one at once... ugh."""
769 | next_outputs = sample_step(ctx[:, -1][:, None], ignore_ids=ignore_ids, news_config=news_config,
770 | batch_size=batch_size, p_for_topp=p_for_topp, cache=cache,
771 | do_topk=do_topk)
772 |
773 | # Update everything
774 | new_cache = tf.concat([cache, next_outputs['new_cache']], axis=-2)
775 | new_ids = tf.concat([ctx, next_outputs['new_tokens'][:, None]], axis=1)
776 | new_probs = tf.concat([probs, next_outputs['new_probs'][:, None]], axis=1)
777 | return [new_ids, new_cache, new_probs]
778 |
779 | def cond(ctx, cache, probs):
780 | # ctx = tf.Print(ctx,[tf.shape(ctx)])
781 | is_eos = tf.reduce_all(tf.reduce_any(tf.equal(ctx[:,-1:], eos_token), axis=1))
782 | is_len = tf.greater(get_shape_list(ctx)[1], min_len)
783 | return tf.logical_not(tf.logical_and(is_eos, is_len))
784 |
785 | tokens, cache, probs = tf.while_loop(
786 | cond=cond, body=body, maximum_iterations=1025 - get_shape_list(ctx)[1],
787 | loop_vars=[ctx, cache, probs],
788 | shape_invariants=[tf.TensorShape([batch_size, None]),
789 | tf.TensorShape(
790 | [batch_size, news_config.num_hidden_layers, 2,
791 | news_config.num_attention_heads,
792 | None, news_config.hidden_size // news_config.num_attention_heads]),
793 | tf.TensorShape([batch_size, None]),
794 | ],
795 | back_prop=False,
796 | )
797 | return tokens, probs
798 |
--------------------------------------------------------------------------------
/train/optimization_adafactor.py:
--------------------------------------------------------------------------------
1 | # Original work Copyright 2018 The Google AI Language Team Authors.
2 | # Modified work Copyright 2019 Rowan Zellers
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | import re
16 | import tensorflow as tf
17 | from train.utils import get_shape_list
18 |
19 |
20 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu):
21 | """Creates an optimizer training op."""
22 | global_step = tf.train.get_or_create_global_step()
23 |
24 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)
25 |
26 | # Implements linear decay of the learning rate.
27 | learning_rate = tf.train.polynomial_decay(
28 | learning_rate,
29 | global_step,
30 | num_train_steps,
31 | end_learning_rate=0.0,
32 | power=1.0,
33 | cycle=False)
34 |
35 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the
36 | # learning rate will be `global_step/num_warmup_steps * init_lr`.
37 | if num_warmup_steps:
38 | global_steps_int = tf.cast(global_step, tf.int32)
39 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)
40 |
41 | global_steps_float = tf.cast(global_steps_int, tf.float32)
42 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)
43 |
44 | warmup_percent_done = global_steps_float / warmup_steps_float
45 | warmup_learning_rate = init_lr * warmup_percent_done
46 |
47 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
48 | learning_rate = (
49 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate)
50 |
51 | # It is recommended that you use this optimizer for fine tuning, since this
52 | # is how the model was trained (note that the Adam m/v variables are NOT
53 | # loaded from init_checkpoint.)
54 | optimizer = AdaFactorOptimizer(
55 | learning_rate=learning_rate,
56 | weight_decay_rate=0.01,
57 | beta_1=0.9,
58 | beta_2=0.999,
59 | epsilon=1e-6,
60 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
61 |
62 | if use_tpu:
63 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
64 |
65 | tvars = tf.trainable_variables()
66 | grads = tf.gradients(loss, tvars)
67 |
68 | # You could do this, but instead we don't because a) it's slow and b) we already did the 'update clipping'
69 | # (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
70 |
71 | train_op = optimizer.apply_gradients(
72 | zip(grads, tvars), global_step=global_step)
73 |
74 | # Normally the global step update is done inside of `apply_gradients`.
75 | # However, `AdaFactorOptimizer` doesn't do this. But if you use
76 | # a different optimizer, you should probably take this line out.
77 | new_global_step = global_step + 1
78 | train_op = tf.group(train_op, [global_step.assign(new_global_step)])
79 |
80 | train_metrics = {
81 | 'learning_rate': learning_rate,
82 | 'minibatch_loss': loss,
83 | # 'minibatch_ppl': tf.math.exp(loss),
84 | }
85 | return train_op, train_metrics
86 |
87 |
88 | class AdaFactorOptimizer(tf.compat.v1.train.Optimizer):
89 | """here's the optimizer we'll use"""
90 |
91 | def __init__(self,
92 | learning_rate,
93 | weight_decay_rate=0.0,
94 | beta_1=0.9,
95 | beta_2=0.999,
96 | epsilon=1e-6,
97 | exclude_from_weight_decay=None,
98 | clipping_rate=1.0,
99 | name="AdaFactorOptimizer"):
100 | """Constructs a AdaFactorOptimizer."""
101 | super(AdaFactorOptimizer, self).__init__(False, name)
102 |
103 | self.learning_rate = learning_rate
104 | self.weight_decay_rate = weight_decay_rate
105 | self.beta_1 = beta_1
106 | self.beta_2 = beta_2
107 | self.epsilon = epsilon
108 | self.epsilon1 = 1e-30
109 | self.epsilon2 = 0.001
110 | self.clipping_rate = clipping_rate
111 | self.exclude_from_weight_decay = exclude_from_weight_decay
112 | self.use_locking = False
113 |
114 | def _use_factored(self, shape):
115 | return len(shape) >= 2
116 |
117 | def _parameter_scale(self, var):
118 | """Estimate the scale of the parameters from the current values.
119 | We include a minimum value of 0.001 to give it a chance to escape 0
120 | if it was zero-initialized.
121 | Instead of using the value, we could impute the scale from the shape,
122 | as initializers do.
123 | Args:
124 | var: a variable or Tensor.
125 | Returns:
126 | a Scalar
127 | """
128 | return tf.maximum(reduce_rms(var), self.epsilon2)
129 |
130 | def apply_gradients(self, grads_and_vars, global_step=None, name=None):
131 | """See base class."""
132 | assignments = []
133 | for (grad, param) in grads_and_vars:
134 | if grad is None or param is None:
135 | continue
136 |
137 | param_name = self._get_variable_name(param.name)
138 | shape_list = get_shape_list(param, expected_rank=[1, 2])
139 |
140 | # decay_rate = 1 - tf.pow(tf.cast(tf.train.get_or_create_global_step(), tf.float32) + 1.0, -0.8)
141 | decay_rate = self.beta_2
142 | grad_squared = tf.square(grad) + self.epsilon1
143 |
144 | update_scale = self.learning_rate
145 | # update_scale = self.learning_rate * tf.cast(self._parameter_scale(param), dtype=tf.float32)
146 |
147 | # HACK: Make things dependent on grad.
148 | # This confounds the XLA rewriter and keeps it from fusing computations
149 | # across different variables. This fusion is a bad for HBM usage, since
150 | # it causes the gradients to persist in memory.
151 | grad_squared_mean = tf.reduce_mean(grad_squared)
152 | decay_rate += grad_squared_mean * 1e-30
153 | update_scale += grad_squared_mean * 1e-30
154 |
155 | # END HACK
156 |
157 | if self._use_factored(shape_list):
158 | num_rows, num_columns = shape_list
159 |
160 | vr = tf.get_variable(
161 | name=param_name + "/adafactor_vr",
162 | shape=[num_rows],
163 | dtype=tf.float32,
164 | trainable=False,
165 | initializer=tf.zeros_initializer())
166 | vc = tf.get_variable(
167 | name=param_name + "/adafactor_vc",
168 | shape=[num_columns],
169 | dtype=tf.float32,
170 | trainable=False,
171 | initializer=tf.zeros_initializer())
172 |
173 | next_vr = decay_rate * vr + (1 - decay_rate) * tf.reduce_mean(grad_squared, 1)
174 | next_vc = decay_rate * vc + (1 - decay_rate) * tf.reduce_mean(grad_squared, 0)
175 |
176 | long_term_mean = tf.reduce_mean(next_vr, -1, keepdims=True)
177 | r_factor = tf.rsqrt(next_vr / long_term_mean + self.epsilon1)
178 | c_factor = tf.rsqrt(next_vc + self.epsilon1)
179 | update = grad * tf.expand_dims(r_factor, -1) * tf.expand_dims(c_factor, -2)
180 |
181 | assignments.append(vr.assign(next_vr, use_locking=self.use_locking))
182 | assignments.append(vc.assign(next_vc, use_locking=self.use_locking))
183 | else:
184 | v = tf.get_variable(
185 | name=param_name + "/adafactor_v",
186 | shape=shape_list,
187 | dtype=tf.float32,
188 | trainable=False,
189 | initializer=tf.zeros_initializer())
190 | next_v = decay_rate * v + (1 - decay_rate) * grad_squared
191 |
192 | assignments.append(v.assign(next_v, use_locking=self.use_locking))
193 | update = grad * tf.rsqrt(next_v + self.epsilon1)
194 |
195 | clipping_denom = tf.maximum(1.0, reduce_rms(update) / self.clipping_rate)
196 | update /= clipping_denom
197 |
198 | # Do weight decay
199 | # Just adding the square of the weights to the loss function is *not*
200 | # the correct way of using L2 regularization/weight decay with Adam,
201 | # since that will interact with the m and v parameters in strange ways.
202 | #
203 | # Instead we want ot decay the weights in a manner that doesn't interact
204 | # with the m/v parameters. This is equivalent to adding the square
205 | # # of the weights to the loss with plain (non-momentum) SGD.
206 | if self._do_use_weight_decay(param_name):
207 | update += self.weight_decay_rate * param
208 |
209 | update_with_lr = update_scale * update
210 | next_param = param - update_with_lr
211 |
212 | assignments.append(param.assign(next_param, use_locking=self.use_locking))
213 | return tf.group(*assignments, name=name)
214 |
215 | def _do_use_weight_decay(self, param_name):
216 | """Whether to use L2 weight decay for `param_name`."""
217 | if not self.weight_decay_rate:
218 | return False
219 | if self.exclude_from_weight_decay:
220 | for r in self.exclude_from_weight_decay:
221 | if re.search(r, param_name) is not None:
222 | return False
223 | return True
224 |
225 | def _get_variable_name(self, param_name):
226 | """Get the variable name from the tensor name."""
227 | m = re.match("^(.*):\\d+$", param_name)
228 | if m is not None:
229 | param_name = m.group(1)
230 | return param_name
231 |
232 |
233 | def reduce_rms(x):
234 | return tf.sqrt(tf.reduce_mean(tf.square(x)))
235 |
--------------------------------------------------------------------------------
/train/train_tpu.py:
--------------------------------------------------------------------------------
1 | # Original work Copyright 2018 The Google AI Language Team Authors.
2 | # Modified work Copyright 2019 Rowan Zellers
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """ Training script! """
17 |
18 | import tensorflow.compat.v1 as tf
19 |
20 | from train.dataloader import input_fn_builder
21 | from train.modeling import model_fn_builder, GroverConfig
22 |
23 | flags = tf.flags
24 |
25 | FLAGS = flags.FLAGS
26 |
27 | ## Required parameters
28 | flags.DEFINE_string(
29 | "config_file", 'configs/base.json',
30 | "The config json file corresponding to the pre-trained news model. "
31 | "This specifies the model architecture.")
32 |
33 | flags.DEFINE_string(
34 | "input_file", None,
35 | "Input TF example files (can be a glob or comma separated).")
36 |
37 | flags.DEFINE_string(
38 | "output_dir", None,
39 | "The output directory where the model checkpoints will be written.")
40 |
41 | ## Other parameters
42 | flags.DEFINE_string(
43 | "init_checkpoint", None,
44 | "Initial checkpoint (usually from a pre-trained model).")
45 |
46 | flags.DEFINE_integer(
47 | "max_seq_length", 1024,
48 | "The maximum total input sequence length after BPE tokenization. "
49 | "Sequences longer than this will be truncated, and sequences shorter "
50 | "than this will be padded. Must match data generation.")
51 |
52 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
53 |
54 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for adafactor.")
55 |
56 | flags.DEFINE_integer("num_train_steps", 100000, "Number of training steps.")
57 |
58 | flags.DEFINE_integer("num_warmup_steps", 10000, "Number of warmup steps.")
59 |
60 | flags.DEFINE_integer("save_checkpoints_steps", 1000,
61 | "How often to save the model checkpoint.")
62 |
63 | flags.DEFINE_integer("iterations_per_loop", 1000,
64 | "How many steps to make in each estimator call.")
65 |
66 | flags.DEFINE_integer("max_eval_steps", 100, "Maximum number of eval steps.")
67 |
68 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
69 |
70 | flags.DEFINE_string(
71 | "tpu_name", None,
72 | "The Cloud TPU to use for training. This should be either the name "
73 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
74 | "url.")
75 |
76 | flags.DEFINE_string(
77 | "tpu_zone", None,
78 | "[Optional] GCE zone where the Cloud TPU is located in. If not "
79 | "specified, we will attempt to automatically detect the GCE project from "
80 | "metadata.")
81 |
82 | flags.DEFINE_string(
83 | "gcp_project", None,
84 | "[Optional] Project name for the Cloud TPU-enabled project. If not "
85 | "specified, we will attempt to automatically detect the GCE project from "
86 | "metadata.")
87 |
88 | flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
89 |
90 | flags.DEFINE_integer(
91 | "num_tpu_cores", 8,
92 | "Only used if `use_tpu` is True. Total number of TPU cores to use.")
93 |
94 |
95 | def main(_):
96 | tf.logging.set_verbosity(tf.logging.INFO)
97 |
98 | news_config = GroverConfig.from_json_file(FLAGS.config_file)
99 |
100 | tf.gfile.MakeDirs(FLAGS.output_dir)
101 |
102 | input_files = []
103 | for input_pattern in FLAGS.input_file.split(","):
104 | input_files.extend(tf.gfile.Glob(input_pattern))
105 |
106 | tf.logging.info("*** Input Files ***")
107 | for input_file in input_files:
108 | tf.logging.info(" %s" % input_file)
109 |
110 | tpu_cluster_resolver = None
111 | if FLAGS.use_tpu and FLAGS.tpu_name:
112 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
113 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
114 |
115 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
116 | run_config = tf.contrib.tpu.RunConfig(
117 | cluster=tpu_cluster_resolver,
118 | master=FLAGS.master,
119 | model_dir=FLAGS.output_dir,
120 | save_checkpoints_steps=FLAGS.save_checkpoints_steps,
121 | keep_checkpoint_max=None,
122 | tpu_config=tf.contrib.tpu.TPUConfig(
123 | iterations_per_loop=FLAGS.iterations_per_loop,
124 | num_shards=FLAGS.num_tpu_cores,
125 | per_host_input_for_training=is_per_host))
126 |
127 | model_fn = model_fn_builder(news_config, init_checkpoint=FLAGS.init_checkpoint,
128 | learning_rate=FLAGS.learning_rate,
129 | num_train_steps=FLAGS.num_train_steps,
130 | num_warmup_steps=FLAGS.num_warmup_steps,
131 | use_tpu=FLAGS.use_tpu,
132 | )
133 |
134 | # If TPU is not available, this will fall back to normal Estimator on CPU
135 | # or GPU.
136 | estimator = tf.contrib.tpu.TPUEstimator(
137 | use_tpu=FLAGS.use_tpu,
138 | model_fn=model_fn,
139 | config=run_config,
140 | train_batch_size=FLAGS.train_batch_size,
141 | eval_batch_size=FLAGS.train_batch_size,
142 | params={'model_dir': FLAGS.output_dir}
143 | )
144 |
145 | tf.logging.info("***** Running training *****")
146 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
147 | train_input_fn = input_fn_builder(
148 | input_files=input_files,
149 | seq_length=FLAGS.max_seq_length,
150 | is_training=True)
151 |
152 | estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps)
153 |
154 | if __name__ == "__main__":
155 | flags.mark_flag_as_required("input_file")
156 | flags.mark_flag_as_required("output_dir")
157 | tf.app.run()
158 |
--------------------------------------------------------------------------------
/train/train_tpu_adafactor.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | export PYTHONPATH=../
4 |
5 | learning_rate=1e-4
6 | init_checkpoint=""
7 | max_seq_length=1024
8 | save_checkpoint_steps=1000
9 |
10 | # You can customize the training here
11 | # mega, medium, or base
12 | model_type="mega"
13 | OUTPUT_DIR="gs://" # put your output directory here
14 | input_file="gs://" # put your input files here, it can also be something like "*.tfrecord"
15 |
16 | if [ ${model_type} == "base" ]; then
17 | num_tpu_cores=32
18 | batch_size_per_core=16
19 | elif [ ${model_type} == "medium" ]; then
20 | num_tpu_cores=128
21 | batch_size_per_core=4
22 | elif [ ${model_type} == "mega" ]; then
23 | num_tpu_cores=256
24 | batch_size_per_core=2
25 | fi
26 |
27 |
28 | # there are 20k * 1024 examples so this translates to 20 epochs. seems ok and i can run for more if needed
29 | num_train_steps=800000
30 |
31 | # Make sure batch size scales.
32 | let batch_size="$batch_size_per_core * $num_tpu_cores"
33 |
34 | python train_tpu.py \
35 | --config_file=configs/${model_type}.json \
36 | --input_file=${input_file} \
37 | --output_dir=${OUTPUT_DIR} \
38 | --max_seq_length=${max_seq_length} \
39 | --train_batch_size=${batch_size} \
40 | --learning_rate=${learning_rate} \
41 | --num_train_steps=${num_train_steps} \
42 | --num_warmup_steps=10000 \
43 | --save_checkpoints_steps=${save_checkpoint_steps} \
44 | --iterations_per_loop=${save_checkpoint_steps} \
45 | --use_tpu=True \
46 | --tpu_name=$(hostname) \
47 | --num_tpu_cores=$num_tpu_cores \
48 | --init_checkpoint=${init_checkpoint}
49 |
--------------------------------------------------------------------------------
/train/utils.py:
--------------------------------------------------------------------------------
1 | # Original work Copyright 2018 The Google AI Language Team Authors.
2 | # Modified work Copyright 2019 Rowan Zellers
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import collections
17 | import re
18 |
19 | import six
20 | import tensorflow.compat.v1 as tf
21 | import numpy as np
22 | from tensorflow.python.lib.io import file_io
23 |
24 |
25 | def _save_np(absolute_fn, array):
26 | if absolute_fn.startswith('gs://'):
27 | with file_io.FileIO(absolute_fn, 'w') as f:
28 | np.save(f, array)
29 | else:
30 | np.save(absolute_fn, array)
31 |
32 |
33 | def assert_rank(tensor, expected_rank, name=None):
34 | """Raises an exception if the tensor rank is not of the expected rank.
35 |
36 | Args:
37 | tensor: A tf.Tensor to check the rank of.
38 | expected_rank: Python integer or list of integers, expected rank.
39 | name: Optional name of the tensor for the error message.
40 |
41 | Raises:
42 | ValueError: If the expected shape doesn't match the actual shape.
43 | """
44 | if name is None:
45 | name = tensor.name
46 |
47 | expected_rank_dict = {}
48 | if isinstance(expected_rank, six.integer_types):
49 | expected_rank_dict[expected_rank] = True
50 | else:
51 | for x in expected_rank:
52 | expected_rank_dict[x] = True
53 |
54 | actual_rank = tensor.shape.ndims
55 | if actual_rank not in expected_rank_dict:
56 | scope_name = tf.get_variable_scope().name
57 | raise ValueError(
58 | "For the tensor `%s` in scope `%s`, the actual rank "
59 | "`%d` (shape = %s) is not equal to the expected rank `%s`" %
60 | (name, scope_name, actual_rank, str(tensor.shape), str(expected_rank)))
61 |
62 |
63 | def get_shape_list(tensor, expected_rank=None, name=None):
64 | """Returns a list of the shape of tensor, preferring static dimensions.
65 |
66 | Args:
67 | tensor: A tf.Tensor object to find the shape of.
68 | expected_rank: (optional) int. The expected rank of `tensor`. If this is
69 | specified and the `tensor` has a different rank, and exception will be
70 | thrown.
71 | name: Optional name of the tensor for the error message.
72 |
73 | Returns:
74 | A list of dimensions of the shape of tensor. All static dimensions will
75 | be returned as python integers, and dynamic dimensions will be returned
76 | as tf.Tensor scalars.
77 | """
78 | if name is None:
79 | name = tensor.name
80 |
81 | if expected_rank is not None:
82 | assert_rank(tensor, expected_rank, name)
83 |
84 | shape = tensor.shape.as_list()
85 |
86 | non_static_indexes = []
87 | for (index, dim) in enumerate(shape):
88 | if dim is None:
89 | non_static_indexes.append(index)
90 |
91 | if not non_static_indexes:
92 | return shape
93 |
94 | dyn_shape = tf.shape(tensor)
95 | for index in non_static_indexes:
96 | shape[index] = dyn_shape[index]
97 | return shape
98 |
99 |
100 | def gelu(input_tensor):
101 | """Gaussian Error Linear Unit.
102 |
103 | This is a smoother version of the RELU.
104 | Original paper: https://arxiv.org/abs/1606.08415
105 |
106 | Args:
107 | input_tensor: float Tensor to perform activation.
108 |
109 | Returns:
110 | `input_tensor` with the GELU activation applied.
111 | """
112 | cdf = 0.5 * (1.0 + tf.erf(input_tensor / tf.sqrt(2.0)))
113 | return input_tensor * cdf
114 |
115 |
116 | def layer_norm(input_tensor, name=None, epsilon=1e-5):
117 | """Run layer normalization on the last dimension of the tensor."""
118 | name2use = f'LayerNorm_{name}' if name is not None else name
119 | with tf.variable_scope(name2use, default_name='LayerNorm'):
120 | dim = input_tensor.shape[-1].value
121 | gamma = tf.get_variable('gamma', [dim], initializer=tf.constant_initializer(1))
122 | beta = tf.get_variable('beta', [dim], initializer=tf.constant_initializer(0))
123 | mean = tf.reduce_mean(input_tensor, axis=-1, keepdims=True)
124 | std = tf.reduce_mean(tf.square(input_tensor - mean), axis=-1, keepdims=True)
125 | input_tensor = (input_tensor - mean) * tf.rsqrt(std + epsilon)
126 | input_tensor = input_tensor * gamma + beta
127 | return input_tensor
128 |
129 |
130 | def dropout(input_tensor, dropout_prob):
131 | """Perform dropout.
132 |
133 | Args:
134 | input_tensor: float Tensor.
135 | dropout_prob: Python float. The probability of dropping out a value (NOT of
136 | *keeping* a dimension as in `tf.nn.dropout`).
137 |
138 | Returns:
139 | A version of `input_tensor` with dropout applied.
140 | """
141 | if dropout_prob is None or dropout_prob == 0.0:
142 | return input_tensor
143 | output = tf.nn.dropout(input_tensor, rate=dropout_prob)
144 | return output
145 |
146 |
147 | def get_attention_mask(nd, ns, *, dtype):
148 | """
149 | this is a TPU compatible version of tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd)
150 | where the lower right triangle contains 1s
151 | """
152 | i = tf.range(nd)[:, None]
153 | j = tf.range(ns)
154 | m = i >= j - ns + nd
155 | return tf.cast(m, dtype)
156 |
157 |
158 | def get_assignment_map_from_checkpoint(tvars, init_checkpoint):
159 | """Compute the union of the current variables and checkpoint variables."""
160 | assignment_map = {}
161 | initialized_variable_names = {}
162 |
163 | name_to_variable = collections.OrderedDict()
164 | for var in tvars:
165 | name = var.name
166 | m = re.match("^(.*):\\d+$", name)
167 | if m is not None:
168 | name = m.group(1)
169 | name_to_variable[name] = var
170 |
171 | init_vars = tf.train.list_variables(init_checkpoint)
172 |
173 | assignment_map = collections.OrderedDict()
174 | for x in init_vars:
175 | (name, var) = (x[0], x[1])
176 | if name not in name_to_variable:
177 | continue
178 | assignment_map[name] = name
179 | initialized_variable_names[name] = 1
180 | initialized_variable_names[name + ":0"] = 1
181 | return (assignment_map, initialized_variable_names)
182 |
183 |
184 | def construct_scalar_host_call(metric_dict, model_dir, prefix=""):
185 | """Construct a host call to log scalars when training on TPU.
186 |
187 | Args:
188 | metric_dict: A dict of the tensors to be logged.
189 | model_dir: The location to write the summary.
190 | prefix: The prefix (if any) to prepend to the metric names.
191 |
192 | Returns:
193 | A tuple of (function, args_to_be_passed_to_said_function)
194 | """
195 | metric_names = list(metric_dict.keys())
196 |
197 | def host_call_fn(global_step, *args):
198 | """Training host call. Creates scalar summaries for training metrics.
199 |
200 | This function is executed on the CPU and should not directly reference
201 | any Tensors in the rest of the `model_fn`. To pass Tensors from the
202 | model to the `metric_fn`, provide as part of the `host_call`. See
203 | https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
204 | for more information.
205 |
206 | Arguments should match the list of `Tensor` objects passed as the second
207 | element in the tuple passed to `host_call`.
208 |
209 | Args:
210 | global_step: `Tensor with shape `[batch]` for the global_step
211 | *args: Remaining tensors to log.
212 |
213 | Returns:
214 | List of summary ops to run on the CPU host.
215 | """
216 | step = global_step[0]
217 | with tf.contrib.summary.create_file_writer(
218 | logdir=model_dir, filename_suffix=".host_call").as_default():
219 | with tf.contrib.summary.always_record_summaries():
220 | for i, name in enumerate(metric_names):
221 | tf.contrib.summary.scalar(prefix + name, args[i][0], step=step)
222 |
223 | return tf.contrib.summary.all_summary_ops()
224 |
225 | # To log the current learning rate, and gradient norm for Tensorboard, the
226 | # summary op needs to be run on the host CPU via host_call. host_call
227 | # expects [batch_size, ...] Tensors, thus reshape to introduce a batch
228 | # dimension. These Tensors are implicitly concatenated to
229 | # [params['batch_size']].
230 | global_step_tensor = tf.reshape(
231 | tf.compat.v1.train.get_or_create_global_step(), [1])
232 | other_tensors = [tf.reshape(metric_dict[key], [1]) for key in metric_names]
233 |
234 | return host_call_fn, [global_step_tensor] + other_tensors
235 |
--------------------------------------------------------------------------------