├── Dockerfile
├── LICENSE
├── README.md
├── config.json
├── doc
├── Makefile
├── conf.py
├── index.rst
├── reference
│ └── index.rst
├── references.bib
├── references.rst
├── require.txt
└── start
│ ├── index.rst
│ ├── inference.rst
│ ├── install.rst
│ └── training.rst
├── environment.yml
├── evaluation.py
├── hlp
├── alphabet_helpers.py
├── csv_helpers.py
├── numbers_mnist_generator.py
├── prepare_iam.py
└── string_data_manager.py
├── prediction.py
├── setup.py
├── tf_crnn
├── __init__.py
├── callbacks.py
├── config.py
├── data_handler.py
├── model.py
└── preprocessing.py
└── training.py
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM tensorflow/tensorflow:1.8.0-gpu
2 |
3 | # Python version
4 | RUN python -v
5 |
6 | # Additional requirements from Tensorflow
7 | RUN apt-get update && apt-get install -y python3 python3-dev
8 |
9 | # Clean up Python 3 install
10 | RUN curl -O https://bootstrap.pypa.io/get-pip.py && \
11 | python3 get-pip.py && \
12 | rm get-pip.py
13 |
14 | # Instal Notebook
15 | RUN pip3 install ipython notebook
16 |
17 | # Install tensorflow 1.8.0 (Does not actually work in 1.7.0)
18 | RUN pip3 install tensorflow-gpu==1.8.0
19 |
20 | # Copy and install TF-CRNN
21 |
22 | ADD . /script
23 | WORKDIR /script
24 | RUN python3 setup.py install
25 |
26 | # Add an additional sources directory
27 | # You should normalize the filepath in your data
28 | VOLUME /sources
29 | VOLUME /config
30 |
31 | # TensorBoard
32 | EXPOSE 6006
33 | # Allowing tensorflow to run and be read
34 | EXPOSE 8888
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 | {one line to give the program's name and a brief idea of what it does.}
635 | Copyright (C) {year} {name of author}
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | {project} Copyright (C) {year} {fullname}
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
675 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Text recognition with Convolutional Recurrent Neural Network and TensorFlow 2.0 (tf2-crnn)
2 |
3 | [](https://tf-crnn.readthedocs.io/en/latest/?badge=latest)
4 |
5 | Implementation of a Convolutional Recurrent Neural Network (CRNN) for image-based sequence recognition tasks, such as scene text recognition and OCR.
6 |
7 | This implementation is based on Tensorflow 2.0 and uses `tf.keras` and `tf.data` modules to build the model and to handle input data.
8 |
9 | To access the previous version implementing Shi et al. paper, go to the [v.0.5.2](https://github.com/solivr/tf-crnn/tree/v0.5.2) tag.
10 |
11 |
12 | ## Installation
13 | `tf_crnn` makes use of `tensorflow-gpu` package (so CUDA and cuDNN are needed).
14 |
15 | You can install it using the `environment.yml` file provided and use it within an environment.
16 |
17 | conda env create -f environment.yml
18 |
19 | See also the [docs](https://tf-crnn.readthedocs.io/en/latest/start/index.html#) for more information.
20 |
21 |
22 | ## Try it
23 |
24 | Train a model with [IAM dataset](http://www.fki.inf.unibe.ch/databases/iam-handwriting-database).
25 |
26 | **Create an account**
27 |
28 | Create an account on the official IAM dataset page in order to access the data.
29 | Export your credentials as enviornment variables, they will be used by the download script.
30 |
31 | export IAM_USER=
32 | export IAM_PWD=
33 |
34 |
35 | **Generate the data in the correct format**
36 |
37 | cd hlp
38 | python prepare_iam.py --download_dir ../data/iam --generated_data_dir ../data/iam/generated
39 | cd ..
40 |
41 | **Train the model**
42 |
43 | python training.py with config.json
44 |
45 | More details in the [documentation](https://tf-crnn.readthedocs.io/en/latest/start/training.html#example-of-training).
46 |
--------------------------------------------------------------------------------
/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "lookup_alphabet_file" : "./data/iam/generated/generated_alphabet/iam_alphabet_lookup.json",
3 | "csv_files_train" : "./data/iam/generated/generated_csv/lines_train.csv",
4 | "csv_files_eval" : "./data/iam/generated/generated_csv/lines_validation1.csv",
5 | "output_model_dir" : "./output_model",
6 | "num_beam_paths" : 1,
7 | "cnn_features_list" : [64, 128, 256, 512],
8 | "cnn_kernel_size" : [3, 3, 3, 3],
9 | "cnn_stride_size" : [[1, 1], [1, 1], [1, 1], [1, 1]],
10 | "cnn_pool_size" : [[2, 2], [2, 2], [2, 1], [2, 1]],
11 | "cnn_batch_norm" : [true, true, true, true],
12 | "max_chars_per_string" : 80,
13 | "n_epochs" : 200,
14 | "train_batch_size" : 64,
15 | "eval_batch_size" : 64,
16 | "learning_rate": 1e-3,
17 | "input_shape" : [64, 900],
18 | "rnn_units" : [128, 128, 128, 128],
19 | "restore_model" : false
20 | }
--------------------------------------------------------------------------------
/doc/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = sphinx-build
7 | SOURCEDIR = .
8 | BUILDDIR = _build
9 |
10 | # Put it first so that "make" without argument is like "make help".
11 | help:
12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
13 |
14 | .PHONY: help Makefile
15 |
16 | # Catch-all target: route all unknown targets to Sphinx using the new
17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
18 | %: Makefile
19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
--------------------------------------------------------------------------------
/doc/conf.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Configuration file for the Sphinx documentation builder.
4 | #
5 | # This file does only contain a selection of the most common options. For a
6 | # full list see the documentation:
7 | # http://www.sphinx-doc.org/en/master/config
8 |
9 | # -- Path setup --------------------------------------------------------------
10 |
11 | # If extensions (or modules to document with autodoc) are in another directory,
12 | # add these directories to sys.path here. If the directory is relative to the
13 | # documentation root, use os.path.abspath to make it absolute, like shown here.
14 | #
15 | import os
16 | import sys
17 | sys.path.insert(0, os.path.abspath('..'))
18 |
19 |
20 | # -- Project information -----------------------------------------------------
21 |
22 | project = 'tf_crnn'
23 | copyright = '2019, Digital Humanities Lab - EPFL'
24 | author = 'Sofia ARES OLIVEIRA'
25 |
26 | # The short X.Y version
27 | version = ''
28 | # The full version, including alpha/beta/rc tags
29 | release = ''
30 |
31 |
32 | # -- General configuration ---------------------------------------------------
33 |
34 | # If your documentation needs a minimal Sphinx version, state it here.
35 | #
36 | # needs_sphinx = '1.0'
37 |
38 | # Add any Sphinx extension module names here, as strings. They can be
39 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
40 | # ones.
41 | extensions = [
42 | 'sphinx.ext.autodoc',
43 | 'sphinx.ext.autosummary',
44 | 'sphinx.ext.coverage',
45 | 'sphinx.ext.viewcode',
46 | 'sphinx.ext.githubpages',
47 | 'sphinxcontrib.bibtex', # for bibtex
48 | 'sphinx_autodoc_typehints'
49 | ]
50 |
51 | # Add any paths that contain templates here, relative to this directory.
52 | templates_path = ['_templates']
53 |
54 | # The suffix(es) of source filenames.
55 | # You can specify multiple suffix as a list of string:
56 | #
57 | # source_suffix = ['.rst', '.md']
58 | source_suffix = '.rst'
59 |
60 | # The master toctree document.
61 | master_doc = 'index'
62 |
63 | # The language for content autogenerated by Sphinx. Refer to documentation
64 | # for a list of supported languages.
65 | #
66 | # This is also used if you do content translation via gettext catalogs.
67 | # Usually you set "language" from the command line for these cases.
68 | language = None
69 |
70 | # List of patterns, relative to source directory, that match files and
71 | # directories to ignore when looking for source files.
72 | # This pattern also affects html_static_path and html_extra_path.
73 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
74 |
75 | # The name of the Pygments (syntax highlighting) style to use.
76 | pygments_style = None
77 |
78 |
79 | # -- Options for HTML output -------------------------------------------------
80 |
81 | # The theme to use for HTML and HTML Help pages. See the documentation for
82 | # a list of builtin themes.
83 | #
84 | html_theme = 'sphinx_rtd_theme'
85 |
86 | # Theme options are theme-specific and customize the look and feel of a theme
87 | # further. For a list of options available for each theme, see the
88 | # documentation.
89 | #
90 | # html_theme_options = {}
91 |
92 | # Add any paths that contain custom static files (such as style sheets) here,
93 | # relative to this directory. They are copied after the builtin static files,
94 | # so a file named "default.css" will overwrite the builtin "default.css".
95 | html_static_path = ['_static']
96 |
97 | # Custom sidebar templates, must be a dictionary that maps document names
98 | # to template names.
99 | #
100 | # The default sidebars (for documents that don't match any pattern) are
101 | # defined by theme itself. Builtin themes are using these templates by
102 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html',
103 | # 'searchbox.html']``.
104 | #
105 | # html_sidebars = {}
106 |
107 |
108 | # -- Options for HTMLHelp output ---------------------------------------------
109 |
110 | # Output file base name for HTML help builder.
111 | htmlhelp_basename = 'tf_crnndoc'
112 |
113 |
114 | # -- Options for LaTeX output ------------------------------------------------
115 |
116 | latex_elements = {
117 | # The paper size ('letterpaper' or 'a4paper').
118 | #
119 | # 'papersize': 'letterpaper',
120 |
121 | # The font size ('10pt', '11pt' or '12pt').
122 | #
123 | # 'pointsize': '10pt',
124 |
125 | # Additional stuff for the LaTeX preamble.
126 | #
127 | # 'preamble': '',
128 |
129 | # Latex figure (float) alignment
130 | #
131 | # 'figure_align': 'htbp',
132 | }
133 |
134 | # Grouping the document tree into LaTeX files. List of tuples
135 | # (source start file, target name, title,
136 | # author, documentclass [howto, manual, or own class]).
137 | latex_documents = [
138 | (master_doc, 'tf_crnn.tex', 'tf\\_crnn Documentation',
139 | author, 'manual'),
140 | ]
141 |
142 |
143 | # -- Options for manual page output ------------------------------------------
144 |
145 | # One entry per manual page. List of tuples
146 | # (source start file, name, description, authors, manual section).
147 | man_pages = [
148 | (master_doc, 'tf_crnn', 'tf_crnn Documentation',
149 | [author], 1)
150 | ]
151 |
152 |
153 | # -- Options for Texinfo output ----------------------------------------------
154 |
155 | # Grouping the document tree into Texinfo files. List of tuples
156 | # (source start file, target name, title, author,
157 | # dir menu entry, description, category)
158 | texinfo_documents = [
159 | (master_doc, 'tf_crnn', 'tf_crnn Documentation',
160 | author, 'tf_crnn', 'One line description of project.',
161 | 'Miscellaneous'),
162 | ]
163 |
164 |
165 | # -- Options for Epub output -------------------------------------------------
166 |
167 | # Bibliographic Dublin Core info.
168 | epub_title = project
169 |
170 | # The unique identifier of the text. This can be a ISBN number
171 | # or the project homepage.
172 | #
173 | # epub_identifier = ''
174 |
175 | # A unique identification for the text.
176 | #
177 | # epub_uid = ''
178 |
179 | # A list of files that should not be packed into the epub file.
180 | epub_exclude_files = ['search.html']
181 |
182 |
183 | # -- Extension configuration -------------------------------------------------
184 | autodoc_mock_imports = [
185 | # 'numpy',
186 | 'tensorflow',
187 | 'tensorflow_addons',
188 | 'pandas',
189 | 'typing',
190 | 'cv2'
191 | ]
--------------------------------------------------------------------------------
/doc/index.rst:
--------------------------------------------------------------------------------
1 | .. tf_crnn documentation master file, created by
2 | sphinx-quickstart on Mon Jan 7 14:43:48 2019.
3 |
4 | ===============================================================================
5 | TF-CRNN : A TensorFlow implementation of Convolutional Recurrent Neural Network
6 | ===============================================================================
7 |
8 | .. toctree::
9 | :maxdepth: 2
10 |
11 | start/index
12 | reference/index
13 | references
14 | .. :caption: Contents:
15 |
16 | A TensorFlow implementation of the Convolutional Recurrent Neural Network (CRNN) for image-based sequence recognition
17 | tasks, such as scene text recognition and OCR.
18 |
19 | This implementation uses ``tf.keras`` to build the model and ``tf.data`` modules to handle input data.
20 |
21 |
22 | Indices and tables
23 | ==================
24 |
25 | * :ref:`genindex`
26 | * :ref:`modindex`
27 | * :ref:`search`
--------------------------------------------------------------------------------
/doc/reference/index.rst:
--------------------------------------------------------------------------------
1 | ===============
2 | Reference guide
3 | ===============
4 |
5 | .. automodule:: tf_crnn
6 |
7 |
8 | .. automodule:: tf_crnn.data_handler
9 | :members:
10 | :undoc-members:
11 |
12 | .. automodule:: tf_crnn.config
13 | :members:
14 | :undoc-members:
15 | :exclude-members: CONST
16 |
17 | .. automodule:: tf_crnn.model
18 | :members:
19 | :undoc-members:
20 |
21 | .. automodule:: tf_crnn.callbacks
22 | :members:
23 | :undoc-members:
24 |
25 | .. automodule:: tf_crnn.preprocessing
26 | :members:
27 | :undoc-members:
28 |
--------------------------------------------------------------------------------
/doc/references.bib:
--------------------------------------------------------------------------------
1 | @article{marti2002iam,
2 | title={The IAM-database: an English sentence database for offline handwriting recognition},
3 | author={Marti, U-V and Bunke, Horst},
4 | journal={International Journal on Document Analysis and Recognition},
5 | volume={5},
6 | number={1},
7 | pages={39--46},
8 | year={2002},
9 | publisher={Springer}
10 | }
11 |
--------------------------------------------------------------------------------
/doc/references.rst:
--------------------------------------------------------------------------------
1 | ==========
2 | References
3 | ==========
4 |
5 | .. bibliography:: references.bib
6 | :cited:
7 | :all:
8 | :style: alpha
--------------------------------------------------------------------------------
/doc/require.txt:
--------------------------------------------------------------------------------
1 | sphinx
2 | sphinx-autodoc-typehints
3 | sphinx-rtd-theme
4 | sphinxcontrib-bibtex
5 | sphinxcontrib-websupport
6 |
--------------------------------------------------------------------------------
/doc/start/index.rst:
--------------------------------------------------------------------------------
1 | Quickstart
2 | ==========
3 |
4 | .. toctree::
5 | install
6 | training
7 | .. inference
--------------------------------------------------------------------------------
/doc/start/inference.rst:
--------------------------------------------------------------------------------
1 | Using a saved model for prediction
2 | ----------------------------------
3 |
4 | During the training, the model is exported every *n* epochs (you can set *n* in the config file, by default *n=5*).
5 | The exported models are ``SavedModel`` TensorFlow objects, which need to be loaded in order to be used.
6 |
7 | Assuming that the output folder is named ``output_dir``, the exported models will be saved in ``output_dir/export/``
8 | with different timestamps for each export. Each ```` folder contains a ``saved_model.pb``
9 | file and a ``variables`` folder.
10 |
11 | The ``saved_model.pb`` contains the graph definition of your model and the ``variables`` folder contains the
12 | saved variables (where the weights are stored). You can find more information about SavedModel
13 | on the `TensorFlow dedicated page `_.
14 |
15 |
16 | In order to easily handle the loading of the exported models, a ``PredictionModel`` class is provided and
17 | you can use the trained model to transcribe new image segments in the following way :
18 |
19 | .. code-block:: python
20 |
21 | import tensorflow as tf
22 | from tf_crnn.loader import PredictionModel
23 |
24 | model_directory = 'output/export//'
25 | image_filename = 'data/images/b04-034-04-04.png'
26 |
27 | with tf.Session() as session:
28 | model = PredictionModel(model_directory, signature='filename')
29 | prediction = model.predict(image_filename)
30 |
31 |
--------------------------------------------------------------------------------
/doc/start/install.rst:
--------------------------------------------------------------------------------
1 | Installation
2 | ------------
3 |
4 | ``tf_crnn`` uses ``tensorflow-gpu`` package, which needs CUDA and CuDNN libraries for GPU support. Tensorflow
5 | `GPU support page `_ lists the requirements.
6 |
7 | Using Anaconda
8 | ^^^^^^^^^^^^^^
9 |
10 | When using Anaconda (or Miniconda), conda will install automatically the compatible versions of CUDA and CuDNN ::
11 |
12 | conda env create -f environment.yml
13 |
14 |
15 | From `this page `_:
16 |
17 | When the GPU accelerated version of TensorFlow is installed using conda, by the command
18 | “conda install tensorflow-gpu”, these libraries are installed automatically, with versions
19 | known to be compatible with the tensorflow-gpu package. Furthermore, conda installs these libraries
20 | into a location where they will not interfere with other instances of these libraries that may have
21 | been installed via another method. Regardless of using pip or conda-installed tensorflow-gpu,
22 | the NVIDIA driver must be installed separately.
23 |
24 | .. Using ``pip``
25 | ^^^^^^^^^^^^^
26 |
27 | Before using ``tf_crnn`` we recommend creating a virtual environment (python 3.5).
28 | Then, install the dependencies using Github repository's ``setup.py`` file. ::
29 |
30 | pip install git+https://github.com/solivr/tf-crnn
31 |
32 | You will then need to install CUDA and CuDNN libraries manually.
33 |
34 |
35 | .. Using Docker
36 | ^^^^^^^^^^^^
37 | (thanks to `PonteIneptique `_)
38 |
39 | The ``Dockerfile`` in the root directory allows you to run the whole program as a Docker Nvidia Tensorflow GPU container.
40 | This is potentially helpful to deal with external dependencies like CUDA and the likes.
41 |
42 | You can follow installations processes here :
43 |
44 | - docker-ce : `Ubuntu `_
45 | - nvidia-docker : `Ubuntu `_
46 |
47 | Once this is installed, we will need to build the image of the container by doing : ::
48 |
49 | nvidia-docker build . --tag tf-crnn
50 |
51 |
52 | Our container model is now named ``tf-crnn``.
53 | We will be able to run it from ``nvidia-docker run -it tf-crnn:latest bash``
54 | which will open a bash directory exactly where you are. Although, we recommend using ::
55 |
56 | nvidia-docker run -it -p 8888:8888 -p 6006:6006 -v /absolute/path/to/here/config:./config -v $INPUT_DATA:/sources tf-crnn:latest bash
57 |
58 | where ``$INPUT_DATA`` should be replaced by the directory where you have your training and testing data.
59 | This will get mounted on the ``sources`` folder. We propose to mount by default ``./config`` to the current ``./config`` directory.
60 | Path need to be absolute path. We also recommend to change ::
61 |
62 | //...
63 | "output_model_dir" : "/.output/"
64 |
65 |
66 | to ::
67 |
68 | //...
69 | "output_model_dir" : "/config/output"
70 |
71 |
72 | **Do not forget** to rename your training and testing file path, as well as renaming the path to their
73 | image by ``/sources/.../file.{png,jpg}``
74 |
75 |
76 | .. note:: if you are uncomfortable with bash, you can always replace bash by ``ipython3 notebook --allow-root``
77 | and go to your browser on ``http://localhost:8888/`` . A token will be shown in the terminal
--------------------------------------------------------------------------------
/doc/start/training.rst:
--------------------------------------------------------------------------------
1 | How to train a model
2 | --------------------
3 |
4 | ``sacred`` package is used to deal with experiments.
5 | If you are not yet familiar with it, have a quick look at the `documentation `_.
6 |
7 | Input data
8 | ^^^^^^^^^^
9 |
10 | In order to train a model, you should input a csv file with each row containing the filename of the image (full path)
11 | and its label (plain text) separated by a delimiting character (let's say ``;``).
12 | Also, each character should be separated by a splitting character (let's say ``|``), this in order to deal with arbitrary
13 | alphabets (especially characters that cannot be encoded with ``utf-8`` format).
14 |
15 | An example of such csv file would look like : ::
16 |
17 | /full/path/to/image1.{jpg,png};|s|t|r|i|n|g|_|l|a|b|e|l|1|
18 | /full/path/to/image2.{jpg,png};|s|t|r|i|n|g|_|l|a|b|e|l|2| |w|i|t|h| |special_char|
19 | ...
20 |
21 | Input lookup alphabet file
22 | ^^^^^^^^^^^^^^^^^^^^^^^^^^
23 |
24 | You also need to provide a lookup table for the *alphabet* that will be used. The term *alphabet* refers to all the
25 | symbols you want the network to learn, whether they are characters, digits, symbols, abbreviations, or any other graphical element.
26 |
27 | The lookup table is a dictionary mapping alphabet units to integer codes (i.e {'char' : }).
28 | Some lookup tables are already provided as examples in ``data/alphabet/``.
29 |
30 | For example to transcribe words that contain only the characters *'abcdefg'*, one possible lookup table would be : ::
31 |
32 | {'a': 1, 'b': 2, 'c': 3, 'd': 4. 'e': 5, 'f': 6, 'g': 7}
33 |
34 | The lookup table / dictionary needs to be saved in a json file.
35 |
36 | Config file (with ``sacred``)
37 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
38 |
39 | Set the parameters of the experiment in ``config.json``. The file looks like this : ::
40 |
41 | {
42 | "lookup_alphabet_file" : "./data/alphabet/lookup.json",
43 | "csv_files_train" : "./data/csv_experiments/train_data.csv",
44 | "csv_files_eval" : "./data/csv_experiments/validation_data.csv",
45 | "output_model_dir" : "./output_model",
46 | "num_beam_paths" : 1,
47 | "max_chars_per_string" : 80,
48 | "n_epochs" : 50,
49 | "train_batch_size" : 64,
50 | "eval_batch_size" : 64,
51 | "learning_rate": 1e-4,
52 | "input_shape" : [128, 1400],
53 | "restore_model" : false
54 | }
55 |
56 | In order to use your data, you should change the parameters ``csv_files_train``, ``csv_files_eval`` and ``lookup_alphabet_file``.
57 |
58 | All the configurable parameters can be found in class ``tf_crnn.config.Params``, which can be added to the config file if needed.
59 |
60 | Training
61 | ^^^^^^^^
62 |
63 | Once you have your input csv and alphabet file completed, and the parameters set in ``config.json``,
64 | we will use ``sacred`` syntax to launch the training : ::
65 |
66 | python training.py with config.json
67 |
68 | The saved model and logs will then be exported to the folder specified in the config file (``output_model_dir``).
69 |
70 |
71 | Example of training
72 | -------------------
73 |
74 | We will use the `IAM Database `_ :cite:`marti2002iam`
75 | as an example to generate the data in the correct input data and train a model.
76 |
77 | Go to the official page to download the dataset and create an account in order to access the data.
78 | You don't need to download the data yourself, the ``prepare_iam.py`` script will take care of that for you.
79 |
80 | Generating data
81 | ^^^^^^^^^^^^^^^
82 |
83 | First create the ``IAM_USER`` and ``IAM_PWD`` environment variable to store your credentials, they will be used by the download script ::
84 |
85 | export IAM_USER=
86 | export IAM_PWD=
87 |
88 |
89 | Run the script ``hlp/prepare_iam.py`` in order to download the data, extract it and format it correctly to train a model. ::
90 |
91 | cd hlp
92 | python prepare_iam.py --download_dir ../data/iam --generated_data_dir ../data/iam/generated
93 | cd ..
94 |
95 | The images of the lines are extracted in ``data/iam/lines/`` and the folder ``data/generated/`` contains all the
96 | additional files necessary to run the experiment. The csv files are saved in ``data/generated/generated_csv`` and
97 | the alphabet is placed in ``data/generated/generated_alphabet``.
98 |
99 | Training the model
100 | ^^^^^^^^^^^^^^^^^^
101 |
102 | Make sure the ``config.json`` file has the correct paths for training and validation data, as well as for the
103 | alphabet lookup file and run: ::
104 |
105 | python training.py with config.json
106 |
107 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: crnn-tf2
2 | dependencies:
3 | - python=3.6
4 | - imageio
5 | - numpy
6 | - tqdm
7 | - pandas
8 | - click
9 | - pip
10 | - pip:
11 | - sacred
12 | - opencv-python
13 | - tensorflow-gpu>=2.0
14 | - tensorflow-addons>=0.5
15 | - git+https://github.com/solivr/taputapu.git#egg=taputapu
16 | - sphinx
17 | - sphinx-autodoc-typehints
18 | - sphinx-rtd-theme
19 |
--------------------------------------------------------------------------------
/evaluation.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | __author__ = "solivr"
3 | __license__ = "GPL"
4 |
5 | import os
6 | from glob import glob
7 |
8 | import click
9 | from tf_crnn.callbacks import CustomLoaderCallback, FOLDER_SAVED_MODEL
10 | from tf_crnn.config import Params, CONST
11 | from tf_crnn.data_handler import dataset_generator
12 | from tf_crnn.preprocessing import preprocess_csv
13 | from tf_crnn.model import get_model_train
14 |
15 |
16 | @click.command()
17 | @click.option('--csv_filename')
18 | @click.option('--model_dir')
19 | def evaluation(csv_filename: str,
20 | model_dir: str):
21 |
22 | config_filename = os.path.join(model_dir, 'config.json')
23 | parameters = Params.from_json_file(config_filename)
24 |
25 | saving_dir = os.path.join(parameters.output_model_dir, FOLDER_SAVED_MODEL)
26 |
27 | # Callback for model weights loading
28 | last_time_stamp = max([int(p.split(os.path.sep)[-1].split('-')[0])
29 | for p in glob(os.path.join(saving_dir, '*'))])
30 | loading_dir = os.path.join(saving_dir, str(last_time_stamp))
31 | ld_callback = CustomLoaderCallback(loading_dir)
32 |
33 | # Preprocess csv data
34 | csv_evaluation_file = os.path.join(parameters.output_model_dir, CONST.PREPROCESSING_FOLDER, 'evaluation_data.csv')
35 | n_samples = preprocess_csv(csv_filename,
36 | parameters,
37 | csv_evaluation_file)
38 |
39 | dataset_evaluation = dataset_generator([csv_evaluation_file],
40 | parameters,
41 | batch_size=parameters.eval_batch_size,
42 | num_epochs=1)
43 |
44 | # get model and evaluation
45 | model = get_model_train(parameters)
46 | eval_output = model.evaluate(dataset_evaluation,
47 | callbacks=[ld_callback])
48 | print('-- Metrics: ', eval_output)
49 |
50 |
51 | if __name__ == '__main__':
52 | evaluation()
53 |
--------------------------------------------------------------------------------
/hlp/alphabet_helpers.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | __author__ = "solivr"
3 | __license__ = "GPL"
4 |
5 | from typing import List, Union
6 | import csv
7 | import json
8 | import numpy as np
9 | import pandas as pd
10 |
11 |
12 | def get_alphabet_units_from_input_data(csv_filename: str,
13 | split_char: str='|'):
14 | """
15 | Get alphabet units from the input_data csv file (which contains in each row the tuple
16 | (filename image segment, transcription formatted))
17 |
18 | :param csv_filename: csv file containing the input data
19 | :param split_char: splitting character in input_data separting the alphabet units
20 | :return:
21 | """
22 | df = pd.read_csv(csv_filename, sep=';', header=None, names=['image', 'labels'],
23 | encoding='utf8', escapechar="\\", quoting=3)
24 | transcriptions = list(df.labels.apply(lambda x: x.split(split_char)))
25 |
26 | unique_units = np.unique([chars for list_chars in transcriptions for chars in list_chars])
27 |
28 | return unique_units
29 |
30 |
31 | def generate_alphabet_file(csv_filenames: List[str],
32 | alphabet_filename: str):
33 | """
34 |
35 | :param csv_filenames:
36 | :param alphabet_filename:
37 | :return:
38 | """
39 | symbols = list()
40 | for file in csv_filenames:
41 | symbols.append(get_alphabet_units_from_input_data(file))
42 |
43 | alphabet_units = np.unique(np.concatenate(symbols))
44 |
45 | alphabet_lookup = dict([(au, i+1)for i, au in enumerate(alphabet_units)])
46 |
47 | with open(alphabet_filename, 'w') as f:
48 | json.dump(alphabet_lookup, f)
49 |
50 |
51 | def get_abbreviations_from_csv(csv_filename: str) -> List[str]:
52 | with open(csv_filename, 'r', encoding='utf8') as f:
53 | csvreader = csv.reader(f, delimiter='\n')
54 | alphabet_units = [row[0] for row in csvreader]
55 | return alphabet_units
56 |
57 |
58 | # def make_json_lookup_alphabet(string_chars: str=None) -> dict:
59 | # """
60 | #
61 | # :param string_chars: for example string.ascii_letters, string.digits
62 | # :return:
63 | # """
64 | # lookup = dict()
65 | # if string_chars:
66 | # # Add characters to lookup table
67 | # lookup.update({char: ord(char) for char in string_chars})
68 | #
69 | # return map_lookup(lookup)
70 |
71 |
72 | # def load_lookup_from_json(json_filenames: Union[List[str], str])-> dict:
73 | # """
74 | # Load a lookup table from a json file to a dictionnary
75 | # :param json_filenames: either a filename or a list of filenames
76 | # :return:
77 | # """
78 | #
79 | # lookup = dict()
80 | # if isinstance(json_filenames, list):
81 | # for file in json_filenames:
82 | # with open(file, 'r', encoding='utf8') as f:
83 | # data_dict = json.load(f)
84 | # lookup.update(data_dict)
85 | #
86 | # elif isinstance(json_filenames, str):
87 | # with open(json_filenames, 'r', encoding='utf8') as f:
88 | # lookup = json.load(f)
89 | #
90 | # return map_lookup(lookup)
91 |
92 |
93 | # def map_lookup(lookup_table: dict, unique_entry: bool=True)-> dict:
94 | # """
95 | # Converts an existing lookup table with minimal range code ([1, len(lookup_table)-1])
96 | # and avoids multiple instances of the same code label (bijectivity)
97 | #
98 | # :param lookup_table: dictionary to be mapped {alphabet_unit : code label}
99 | # :param unique_entry: If each alphabet unit has a unique code and each code a unique alphabet unique ('bijective'),
100 | # only True is implemented for now
101 | # :return: a mapped dictionary
102 | # """
103 | #
104 | # # Create tuple (alphabet unit, code)
105 | # tuple_char_code = list(zip(list(lookup_table.keys()), list(lookup_table.values())))
106 | # # Sort by code
107 | # tuple_char_code.sort(key=lambda x: x[1])
108 | #
109 | # # If each alphabet unit has a unique code and each code a unique alphabet unique ('bijective')
110 | # if unique_entry:
111 | # mapped_lookup = [[tp[0], i + 1] for i, tp in enumerate(tuple_char_code)]
112 | # else:
113 | # raise NotImplementedError
114 | # # Todo
115 | #
116 | # return dict(mapped_lookup)
117 |
--------------------------------------------------------------------------------
/hlp/csv_helpers.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | __author__ = 'solivr'
3 | __license__ = "GPL"
4 |
5 | import csv
6 | import os
7 | import argparse
8 | from tqdm import tqdm, trange
9 |
10 |
11 | def csv_rel2abs_path_convertor(csv_filenames: str, delimiter: str=' ', encoding='utf8') -> None:
12 | """
13 | Convert relative paths into absolute paths
14 |
15 | :param csv_filenames: filename of csv
16 | :param delimiter: character to delimit felds in csv
17 | :param encoding: encoding format of csv file
18 | :return:
19 | """
20 |
21 | for filename in tqdm(csv_filenames):
22 | absolute_path, basename = os.path.split(os.path.abspath(filename))
23 | relative_paths = list()
24 | labels = list()
25 | # Reading CSV
26 | with open(filename, 'r', encoding=encoding) as f:
27 | csvreader = csv.reader(f, delimiter=delimiter)
28 | for row in csvreader:
29 | relative_paths.append(row[0])
30 | labels.append(row[1])
31 |
32 | # Writing converted_paths CSV
33 | export_filename = os.path.join(absolute_path, '{}_abs{}'.format(*os.path.splitext(basename)))
34 | with open(export_filename, 'w', encoding=encoding) as f:
35 | csvwriter = csv.writer(f, delimiter=delimiter)
36 | for i in trange(0, len(relative_paths)):
37 | csvwriter.writerow([os.path.abspath(os.path.join(absolute_path, relative_paths[i])), labels[i]])
38 |
39 |
40 | def csv_filtering_chars_from_labels(csv_filename: str, chars_to_remove: str,
41 | delimiter: str=' ', encoding='utf8') -> int:
42 | """
43 | Remove labels containing chars_to_remove in csv_filename
44 |
45 | :param chars_to_remove: string (or list) with the undesired characters
46 | :param csv_filename: filenmae of csv
47 | :param delimiter: delimiter character
48 | :param encoding: encoding format of csv file
49 | :return: number of deleted labels
50 | """
51 |
52 | if not isinstance(chars_to_remove, list):
53 | chars_to_remove = list(chars_to_remove)
54 |
55 | paths = list()
56 | labels = list()
57 | n_deleted = 0
58 | with open(csv_filename, 'r', encoding=encoding) as file:
59 | csvreader = csv.reader(file, delimiter=delimiter)
60 | for row in csvreader:
61 | if not any((d in chars_to_remove) for d in row[1]):
62 | paths.append(row[0])
63 | labels.append(row[1])
64 | else:
65 | n_deleted += 1
66 |
67 | with open(csv_filename, 'w', encoding=encoding) as file:
68 | csvwriter = csv.writer(file, delimiter=delimiter)
69 | for i in tqdm(range(len(paths)), total=len(paths)):
70 | csvwriter.writerow([paths[i], labels[i]])
71 |
72 | return n_deleted
73 |
74 |
75 | if __name__ == '__main__':
76 | parser = argparse.ArgumentParser()
77 | parser.add_argument('-i', '--input_files', type=str, required=True, help='CSV filename to convert', nargs='*')
78 | parser.add_argument('-d', '--delimiter_char', type=str, help='CSV delimiter character', default=' ')
79 |
80 | args = vars(parser.parse_args())
81 |
82 | csv_filenames = args.get('input_files')
83 |
84 | csv_rel2abs_path_convertor(csv_filenames, delimiter=args.get('delimiter_char'))
85 |
86 |
--------------------------------------------------------------------------------
/hlp/numbers_mnist_generator.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | __author__ = 'solivr'
3 | __license__ = "GPL"
4 |
5 | from tensorflow.examples.tutorials.mnist import input_data
6 | import numpy as np
7 | import os
8 | import csv
9 | from imageio import imsave
10 | from tqdm import tqdm
11 | import random
12 | import argparse
13 |
14 |
15 | def generate_random_image_numbers(mnist_dir, dataset, output_dir, csv_filename, n_numbers):
16 |
17 | mnist = input_data.read_data_sets(mnist_dir, one_hot=False)
18 |
19 | output_dir_img = os.path.join(output_dir, 'images')
20 | if not os.path.exists(output_dir):
21 | os.mkdir(output_dir)
22 | if not os.path.exists(output_dir_img):
23 | os.mkdir(output_dir_img)
24 |
25 | if dataset == 'train':
26 | dataset = mnist.train
27 | elif dataset == 'validation':
28 | dataset = mnist.validation
29 | elif dataset == 'test':
30 | dataset = mnist.test
31 |
32 | list_paths = list()
33 | list_labels = list()
34 |
35 | for i in tqdm(range(n_numbers), total=n_numbers):
36 | n_digits = random.randint(3, 8)
37 | digits, labels = dataset.next_batch(n_digits)
38 | # Reshape to have 28x28 image
39 | square_digits = np.reshape(digits, [-1, 28, 28])
40 | # White background
41 | square_digits = -(square_digits - 1) * 255
42 | stacked_number = np.hstack(square_digits[:, :, 4:-4])
43 | stacked_label = ''.join(map(str, labels))
44 | # chans3 = np.dstack([stacked_number]*3)
45 |
46 | # Save image number
47 | img_filename = '{:09}_{}.jpg'.format(i, stacked_label)
48 | img_path = os.path.join(output_dir_img, img_filename)
49 | imsave(img_path, stacked_number)
50 |
51 | # Add to list of paths and list of labels
52 | list_paths.append(img_filename)
53 | list_labels.append(stacked_label)
54 |
55 | root = './images'
56 | csv_path = os.path.join(output_dir, csv_filename)
57 | with open(csv_path, 'w') as csvfile:
58 | for i in tqdm(range(len(list_paths)), total=len(list_paths)):
59 | csvwriter = csv.writer(csvfile, delimiter=' ')
60 | csvwriter.writerow([os.path.join(root, list_paths[i]), list_labels[i]])
61 |
62 |
63 | if __name__ == '__main__':
64 | parser = argparse.ArgumentParser()
65 | parser.add_argument('-md', '--mnist_dir', type=str, help='Directory for MNIST data', default='./MNIST_data')
66 | parser.add_argument('-d', '--dataset', type=str, help='Dataset wanted (train, test, validation)', default='train')
67 | parser.add_argument('-csv', '--csv_filename', type=str, help='CSV filename to output paths and labels')
68 | parser.add_argument('-od', '--output_dir', type=str, help='Directory to output images and csv files', default='./output_numbers')
69 | parser.add_argument('-n', '--n_samples', type=int, help='Desired numbers of generated samples', default=1000)
70 |
71 | args = parser.parse_args()
72 |
73 | generate_random_image_numbers(args.mnist_dir, args.dataset, args.output_dir, args.csv_filename, args.n_samples)
74 |
75 |
76 |
--------------------------------------------------------------------------------
/hlp/prepare_iam.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | __author__ = "solivr"
3 | __license__ = "GPL"
4 |
5 | from taputapu.databases import iam
6 | import os
7 | from glob import glob
8 | from string_data_manager import tf_crnn_label_formatting
9 | from alphabet_helpers import generate_alphabet_file
10 | import click
11 |
12 |
13 | @click.command()
14 | @click.option('--download_dir')
15 | @click.option('--generated_data_dir')
16 | def prepare_iam_data(download_dir: str,
17 | generated_data_dir: str):
18 |
19 | # Download data
20 | print('Starting downloads...')
21 | iam.download(download_dir)
22 |
23 | # Extract archives
24 | print('Starting extractions...')
25 | iam.extract(download_dir)
26 |
27 | print('Generating files for the experiment...')
28 | # Generate splits (same format as ascii files)
29 | export_splits_dir = os.path.join(generated_data_dir, 'generated_splits')
30 | os.makedirs(export_splits_dir, exist_ok=True)
31 |
32 | iam.generate_splits_txt(os.path.join(download_dir, 'ascii', 'lines.txt'),
33 | os.path.join(download_dir, 'largeWriterIndependentTextLineRecognitionTask'),
34 | export_splits_dir)
35 |
36 | # Generate csv from .txt splits files
37 | export_csv_dir = os.path.join(generated_data_dir, 'generated_csv')
38 | os.makedirs(export_csv_dir, exist_ok=True)
39 |
40 | for file in glob(os.path.join(export_splits_dir, '*')):
41 | export_basename = os.path.basename(file).split('.')[0]
42 | iam.create_experiment_csv(file,
43 | os.path.join(download_dir, 'lines'),
44 | os.path.join(export_csv_dir, export_basename + '.csv'),
45 | False,
46 | True)
47 |
48 | # Format string label to tf_crnn formatting
49 | for csv_filename in glob(os.path.join(export_csv_dir, '*')):
50 | tf_crnn_label_formatting(csv_filename)
51 |
52 | # Generate alphabet
53 | alphabet_dir = os.path.join(generated_data_dir, 'generated_alphabet')
54 | os.makedirs(alphabet_dir, exist_ok=True)
55 |
56 | generate_alphabet_file(glob(os.path.join(export_csv_dir, '*')),
57 | os.path.join(alphabet_dir, 'iam_alphabet_lookup.json'))
58 |
59 |
60 | if __name__ == '__main__':
61 | prepare_iam_data()
62 |
--------------------------------------------------------------------------------
/hlp/string_data_manager.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | __author__ = 'solivr'
3 | __licence__ = 'GPL'
4 |
5 | import pandas as pd
6 |
7 | _accents_list = 'àéèìîóòù'
8 | _accent_mapping = {'à': 'a',
9 | 'é': 'e',
10 | 'è': 'e',
11 | 'ì': 'i',
12 | 'î': 'i',
13 | 'ó': 'o',
14 | 'ò': 'o',
15 | 'ù': 'u'}
16 |
17 |
18 | def map_accentuated_characters_in_dataframe(dataframe_transcriptions: pd.DataFrame,
19 | dict_mapping: dict=_accent_mapping) -> pd.DataFrame:
20 | """
21 |
22 | :param dataframe_transcriptions: must have a field 'transcription'
23 | :param dict_mapping
24 | :return:
25 | """
26 | items = dataframe_transcriptions.transcription.iteritems()
27 |
28 | for i in range(dataframe_transcriptions.transcription.count()):
29 | df_id, transcription = next(items)
30 | # https://stackoverflow.com/questions/30020184/how-to-find-the-first-index-of-any-of-a-set-of-characters-in-a-string
31 | ch_index = next((i for i, ch in enumerate(transcription) if ch in _accents_list), None)
32 | while ch_index is not None:
33 | transcription = list(transcription)
34 | ch = transcription[ch_index]
35 | transcription[ch_index] = dict_mapping[ch]
36 | transcription = ''.join(transcription)
37 | dataframe_transcriptions.at[df_id, 'transcription'] = transcription
38 | ch_index = next((i for i, ch in enumerate(transcription) if ch in _accents_list), None)
39 |
40 | return dataframe_transcriptions
41 |
42 |
43 | def map_accentuated_characters_in_string(string_to_format: str, dict_mapping: dict=_accent_mapping) -> str:
44 | """
45 |
46 | :param string_to_format:
47 | :param dict_mapping:
48 | :return:
49 | """
50 | # https://stackoverflow.com/questions/30020184/how-to-find-the-first-index-of-any-of-a-set-of-characters-in-a-string
51 | ch_index = next((i for i, ch in enumerate(string_to_format) if ch in _accents_list), None)
52 | while ch_index is not None:
53 | string_to_format = list(string_to_format)
54 | ch = string_to_format[ch_index]
55 | string_to_format[ch_index] = dict_mapping[ch]
56 | string_to_format = ''.join(string_to_format)
57 | ch_index = next((i for i, ch in enumerate(string_to_format) if ch in _accents_list), None)
58 |
59 | return string_to_format
60 |
61 |
62 | def format_string_for_tf_split(string_to_format: str,
63 | separator_character: str= '|',
64 | replace_brackets_abbreviations=True) -> str:
65 | """
66 | Formats transcriptions to be split by tf.string_split using character separator "|"
67 |
68 | :param string_to_format: string to format
69 | :param separator_character: character that separates alphabet units
70 | :param replace_brackets_abbreviations: if True will replace '[' and ']' chars by separator character
71 | :return:
72 | """
73 |
74 | if replace_brackets_abbreviations:
75 | # Replace "[]" chars by "|"
76 | string_to_format = string_to_format.replace("[", separator_character).replace("]", separator_character)
77 |
78 | splits = string_to_format.split(separator_character)
79 |
80 | final_string = separator_character
81 | # Case where string starts with a separator_character
82 | if splits[0] == '':
83 | for i, sp in enumerate(splits):
84 | if i % 2 > 0: # uneven -> abbreviation
85 | final_string += separator_character + sp + separator_character
86 | else: # even -> no abbreviation
87 | final_string += sp.replace('', separator_character)[1:-1]
88 |
89 | else:
90 | for i, sp in enumerate(splits):
91 | if i % 2 > 0: # uneven -> no abbreviation
92 | final_string += separator_character + sp + separator_character
93 | else: # even -> abbreviation
94 | final_string += sp.replace('', separator_character)[1:-1]
95 |
96 | # Add separator at beginning or end of string if it hasn't been added yet
97 | if final_string[1] == separator_character:
98 | final_string = final_string[1:]
99 | if final_string[-1] != separator_character:
100 | final_string += separator_character
101 |
102 | return final_string
103 |
104 |
105 | def tf_crnn_label_formatting(csv_filename: str):
106 |
107 | def _string_formatting(string_to_format: str,
108 | separator_character: str = '|'):
109 | chars = list(string_to_format)
110 | formated_string = separator_character + '{}'.format(separator_character).join(chars) + separator_character
111 | return formated_string
112 |
113 | df = pd.read_csv(csv_filename, sep=';', header=None, names=['image', 'labels'], encoding='utf8',
114 | escapechar="\\", quoting=3)
115 |
116 | df.labels = df.labels.apply(lambda x: _string_formatting(x))
117 |
118 | df.to_csv(csv_filename, sep=';', encoding='utf-8', header=False, index=False, escapechar="\\", quoting=3)
119 |
120 |
121 | def lower_abbreviation_in_string(string_to_format: str):
122 | # Split with '['
123 | tokens_opening = string_to_format.split('[')
124 |
125 | valid_string = True
126 | final_string = tokens_opening[0]
127 | for tok in tokens_opening[1:]:
128 | if len(tok) > 1:
129 | token_closing = tok.split(']')
130 | if len(token_closing) == 2: # checks if abbreviation starts with [ and ends with ]
131 | final_string += '[' + token_closing[0].lower() + ']' + token_closing[1]
132 | else: # No closing ']'
133 | valid_string = False
134 | else:
135 | final_string += ']'
136 | if valid_string:
137 | return final_string
138 | else:
139 | return ''
140 |
141 |
142 | def add_abbreviation_brackets(label: str):
143 | """
144 | Adds brackets in formatted strings i.e label= '|B|e|n|e|t|t|a| |M|a|z|z|o|l|e|n|i| |quondam| |A|n|z|o|l|o|'
145 | turns to '|B|e|n|e|t|t|a| |M|a|z|z|o|l|e|n|i| |[quondam]| |A|n|z|o|l|o|'
146 | :param label:
147 | :return:
148 | """
149 | splits = label.split('|')
150 |
151 | is_abbrev = [len(tok) > 1 for tok in splits]
152 | bracketing = ['[' + tok + ']' if abbrev else tok for (tok, abbrev) in zip(splits, is_abbrev)]
153 |
154 | return '|'.join(bracketing)
--------------------------------------------------------------------------------
/prediction.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | __author__ = "solivr"
3 | __license__ = "GPL"
4 |
5 | import os
6 | from glob import glob
7 |
8 | import click
9 | from tf_crnn.callbacks import CustomPredictionSaverCallback, FOLDER_SAVED_MODEL
10 | from tf_crnn.config import Params
11 | from tf_crnn.data_handler import dataset_generator
12 | from tf_crnn.model import get_model_inference
13 |
14 |
15 | @click.command()
16 | @click.option('--csv_filename', help='A csv file containing the path to the images to predict')
17 | @click.option('--output_model_dir', help='Directory where all the exported data related to an experiment has been saved')
18 | def prediction(csv_filename: str,
19 | output_model_dir: str):
20 | parameters = Params.from_json_file(os.path.join(output_model_dir, 'config.json'))
21 |
22 | saving_dir = os.path.join(output_model_dir, FOLDER_SAVED_MODEL)
23 | last_time_stamp = str(max([int(p.split(os.path.sep)[-1].split('-')[0])
24 | for p in glob(os.path.join(saving_dir, '*'))]))
25 | model = get_model_inference(parameters, os.path.join(saving_dir, last_time_stamp, 'weights.h5'))
26 |
27 | dataset_test = dataset_generator([csv_filename],
28 | parameters,
29 | use_labels=False,
30 | batch_size=parameters.eval_batch_size,
31 | shuffle=False)
32 |
33 | ps_callback = CustomPredictionSaverCallback(output_model_dir, parameters)
34 |
35 | _, _, _ = model.predict(x=dataset_test, callbacks=[ps_callback])
36 |
37 |
38 | if __name__ == '__main__':
39 | prediction()
40 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | __author__ = "solivr"
3 | __license__ = "GPL"
4 |
5 | from setuptools import setup, find_packages
6 |
7 | setup(name='tf_crnn',
8 | version='0.6.0',
9 | license='GPL',
10 | author='Sofia Ares Oliveira',
11 | url='https://github.com/solivr/tf-crnn',
12 | description='TensorFlow Convolutional Recurrent Neural Network (CRNN)',
13 | install_requires=[
14 | 'imageio',
15 | 'numpy',
16 | 'tqdm',
17 | 'sacred',
18 | 'opencv-python',
19 | 'pandas',
20 | 'click',
21 | #'tensorflow-addons',
22 | 'tensorflow-gpu',
23 | 'taputapu'
24 | ],
25 | dependency_links=['https://github.com/solivr/taputapu/tarball/master#egg=taputapu-1.0'],
26 | extras_require={
27 | 'doc': [
28 | 'sphinx',
29 | 'sphinx-autodoc-typehints',
30 | 'sphinx-rtd-theme',
31 | 'sphinxcontrib-bibtex',
32 | 'sphinxcontrib-websupport'
33 | ],
34 | },
35 | packages=find_packages(where='.'),
36 | zip_safe=False)
37 |
--------------------------------------------------------------------------------
/tf_crnn/__init__.py:
--------------------------------------------------------------------------------
1 | r"""
2 |
3 |
4 | Data handling for input function
5 | --------------------------------
6 | .. currentmodule:: tf_crnn.data_handler
7 |
8 | .. autosummary::
9 | dataset_generator
10 | padding_inputs_width
11 | augment_data
12 | random_rotation
13 |
14 |
15 | Model definitions
16 | -----------------
17 | .. currentmodule:: tf_crnn.model
18 |
19 | .. autosummary::
20 | ConvBlock
21 | get_model_train
22 | get_model_inference
23 | get_crnn_output
24 |
25 |
26 | Config for training
27 | -------------------
28 | .. currentmodule:: tf_crnn.config
29 |
30 | .. autosummary::
31 | Alphabet
32 | Params
33 | import_params_from_json
34 |
35 |
36 | Custom Callbacks
37 | ----------------
38 | .. currentmodule:: tf_crnn.callbacks
39 |
40 | .. autosummary::
41 | CustomSavingCallback
42 | LRTensorBoard
43 | CustomLoaderCallback
44 | CustomPredictionSaverCallback
45 |
46 |
47 | Preprocessing data
48 | ------------------
49 | .. currentmodule:: tf_crnn.preprocessing
50 |
51 | .. autosummary::
52 | data_preprocessing
53 | preprocess_csv
54 |
55 |
56 | ----
57 |
58 | """
59 |
60 | _DATA_HANDLING = [
61 | 'dataset_generator',
62 | 'padding_inputs_width',
63 | 'augment_data',
64 | 'random_rotation'
65 | ]
66 |
67 | _CONFIG = [
68 | 'Alphabet',
69 | 'Params',
70 | 'import_params_from_json'
71 |
72 | ]
73 |
74 | _MODEL = [
75 | 'ConvBlock',
76 | 'get_model_train',
77 | 'get_model_inference'
78 | 'get_crnn_output'
79 | ]
80 |
81 | _CALLBACKS = [
82 | 'CustomSavingCallback',
83 | 'CustomLoaderCallback',
84 | 'CustomPredictionSaverCallback',
85 | 'LRTensorBoard'
86 | ]
87 |
88 | _PREPROCESSING = [
89 | 'data_preprocessing',
90 | 'preprocess_csv'
91 | ]
92 |
93 | __all__ = _DATA_HANDLING + _CONFIG + _MODEL + _CALLBACKS + _PREPROCESSING
94 |
95 | from tf_crnn.config import *
96 | from tf_crnn.model import *
97 | from tf_crnn.callbacks import *
98 | from tf_crnn.preprocessing import *
99 | from tf_crnn.data_handler import *
--------------------------------------------------------------------------------
/tf_crnn/callbacks.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | __author__ = "solivr"
3 | __license__ = "GPL"
4 |
5 | import tensorflow as tf
6 | from tensorflow.keras.callbacks import Callback, TensorBoard
7 | import os
8 | import shutil
9 | import pickle
10 | import json
11 | import time
12 | import numpy as np
13 | from .config import Params
14 |
15 |
16 | MODEL_WEIGHTS_FILENAME = 'weights.h5'
17 | OPTIMIZER_WEIGHTS_FILENAME = 'optimizer_weights.pkl'
18 | LEARNING_RATE_FILENAME = 'learning_rate.pkl'
19 | LAYERS_FILENAME = 'architecture.json'
20 | EPOCH_FILENAME = 'epoch.pkl'
21 | FOLDER_SAVED_MODEL = 'saving'
22 |
23 |
24 | class CustomSavingCallback(Callback):
25 | """
26 | Callback to save weights, architecture, and optimizer at the end of training.
27 | Inspired by `ModelCheckpoint`.
28 |
29 | :ivar output_dir: path to the folder where files will be saved
30 | :vartype output_dir: str
31 | :ivar saving_freq: save every `n` epochs
32 | :vartype saving_freq: int
33 | :ivar save_best_only: wether to save a model if it is best thant the last saving
34 | :vartype save_best_only: bool
35 | :ivar keep_max_models: number of models to keep, the older ones will be deleted
36 | :vartype keep_max_models: int
37 | """
38 | def __init__(self,
39 | output_dir: str,
40 | saving_freq:int,
41 | save_best_only: bool=False,
42 | keep_max_models:int=5):
43 | super(CustomSavingCallback, self).__init__()
44 |
45 | self.saving_dir = output_dir
46 | self.saving_freq = saving_freq
47 | self.save_best_only = save_best_only
48 | self.keep_max_models = keep_max_models
49 |
50 | self.epochs_since_last_save = 0
51 |
52 | self.monitor = 'val_loss'
53 | self.monitor_op = np.less
54 | self.best_value = np.Inf # todo: when restoring model we could also restore val_loss and metric
55 |
56 | def on_epoch_begin(self,
57 | epoch,
58 | logs=None):
59 | self._current_epoch = epoch
60 |
61 | def on_epoch_end(self,
62 | epoch,
63 | logs=None):
64 |
65 | self.logs = logs
66 | self.epochs_since_last_save += 1
67 |
68 | if self.epochs_since_last_save == self.saving_freq:
69 | self._export_model(logs)
70 | self.epochs_since_last_save = 0
71 |
72 | def on_train_end(self,
73 | logs=None):
74 | self._export_model(self.logs)
75 | self.epochs_since_last_save = 0
76 |
77 |
78 | def _export_model(self, logs):
79 | timestamp = str(int(time.time()))
80 | folder = os.path.join(self.saving_dir, timestamp)
81 |
82 | if self.save_best_only:
83 | current_value = logs.get(self.monitor)
84 |
85 | if self.monitor_op(current_value, self.best_value):
86 | print('\n{} improved from {:0.5f} to {:0.5f},'
87 | ' saving model to {}'.format(self.monitor, self.best_value,
88 | current_value, folder))
89 | self.best_value = current_value
90 |
91 | else:
92 | print('\n{} did not improve from {:0.5f}'.format(self.monitor, self.best_value))
93 | return
94 |
95 | os.makedirs(folder)
96 |
97 | # save architecture
98 | model_json = self.model.to_json()
99 | with open(os.path.join(folder, LAYERS_FILENAME), 'w') as f:
100 | json.dump(model_json, f)
101 |
102 | # model weights
103 | self.model.save_weights(os.path.join(folder, MODEL_WEIGHTS_FILENAME))
104 |
105 | # optimizer weights
106 | optimizer_weights = tf.keras.backend.batch_get_value(self.model.optimizer.weights)
107 | with open(os.path.join(folder, OPTIMIZER_WEIGHTS_FILENAME), 'wb') as f:
108 | pickle.dump(optimizer_weights, f)
109 |
110 | # learning rate
111 | learning_rate = self.model.optimizer.learning_rate
112 | with open(os.path.join(folder, LEARNING_RATE_FILENAME), 'wb') as f:
113 | pickle.dump(learning_rate, f)
114 |
115 | # n epochs
116 | epoch = self._current_epoch + 1
117 | with open(os.path.join(folder, EPOCH_FILENAME), 'wb') as f:
118 | pickle.dump(epoch, f)
119 |
120 | self._clean_exports()
121 |
122 | def _clean_exports(self):
123 | timestamp_folders = [int(f) for f in os.listdir(self.saving_dir)]
124 | timestamp_folders.sort(reverse=True)
125 |
126 | if len(timestamp_folders) > self.keep_max_models:
127 | folders_to_remove = timestamp_folders[self.keep_max_models:]
128 | for f in folders_to_remove:
129 | shutil.rmtree(os.path.join(self.saving_dir, str(f)))
130 |
131 |
132 |
133 | class CustomLoaderCallback(Callback):
134 | """
135 | Callback to load necessary weight and parameters for training, evaluation and prediction.
136 |
137 | :ivar loading_dir: path to directory to save logs
138 | :vartype loading_dir: str
139 | """
140 | def __init__(self,
141 | loading_dir: str):
142 | super(CustomLoaderCallback, self).__init__()
143 |
144 | self.loading_dir = loading_dir
145 |
146 | def set_model(self, model):
147 | self.model = model
148 |
149 | print('-- Loading ', self.loading_dir)
150 | # Load model weights
151 | self.model.load_weights(os.path.join(self.loading_dir, MODEL_WEIGHTS_FILENAME))
152 |
153 | # Load optimizer params
154 | with open(os.path.join(self.loading_dir, OPTIMIZER_WEIGHTS_FILENAME), 'rb') as f:
155 | optimizer_weights = pickle.load(f)
156 | with open(os.path.join(self.loading_dir, LEARNING_RATE_FILENAME), 'rb') as f:
157 | learning_rate = pickle.load(f)
158 |
159 | # Set optimizer params
160 | self.model.optimizer.learning_rate.assign(learning_rate)
161 | self.model._make_train_function()
162 | self.model.optimizer.set_weights(optimizer_weights)
163 |
164 |
165 | class CustomPredictionSaverCallback(Callback):
166 | """
167 | Callback to save prediction decoded outputs.
168 | This will save the decoded outputs into a file.
169 |
170 | :ivar output_dir: path to directory to save logs
171 | :vartype output_dir: str
172 | :ivar parameters: parameters of the experiment (``Params``)
173 | :vartype parameters: Params
174 | """
175 | def __init__(self,
176 | output_dir: str,
177 | parameters: Params):
178 | super(CustomPredictionSaverCallback, self).__init__()
179 |
180 | self.saving_dir = output_dir
181 | self.parameters = parameters
182 |
183 | # Inference
184 | def on_predict_begin(self,
185 | logs=None):
186 | # Create file to add predictions
187 | timestamp = str(int(time.time()))
188 | self._prediction_filename = os.path.join(self.saving_dir, 'predictions-{}.txt'.format(timestamp))
189 |
190 | def on_predict_batch_end(self,
191 | batch,
192 | logs):
193 | logits, seq_len, filenames = logs['outputs']
194 |
195 | codes = tf.keras.backend.ctc_decode(logits, tf.squeeze(seq_len), greedy=True)[0][0].numpy()
196 | strings = [''.join([self.parameters.alphabet.lookup_int2str[c] for c in lc if c != -1]) for lc in codes]
197 |
198 | with open(self._prediction_filename, 'ab') as f:
199 | for n, s in zip(filenames, strings):
200 | n = n[0] # n is a list of one element
201 | f.write((n.decode() + ';' + s + '\n').encode('utf8'))
202 |
203 |
204 | class LRTensorBoard(TensorBoard):
205 | """
206 | Adds learning rate to TensorBoard scalars.
207 |
208 | :ivar logdir: path to directory to save logs
209 | :vartype logdir: str
210 | """
211 | # From https://github.com/keras-team/keras/pull/9168#issuecomment-359901128
212 | def __init__(self,
213 | log_dir: str,
214 | **kwargs): # add other arguments to __init__ if you need
215 | super(LRTensorBoard, self).__init__(log_dir=log_dir, **kwargs)
216 |
217 | def on_epoch_end(self,
218 | epoch,
219 | logs=None):
220 | logs.update({'lr': tf.keras.backend.eval(self.model.optimizer.lr)})
221 | super(LRTensorBoard, self).on_epoch_end(epoch, logs)
--------------------------------------------------------------------------------
/tf_crnn/config.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | __author__ = 'solivr'
3 | __license__ = "GPL"
4 |
5 | import csv
6 | import json
7 | import os
8 | import string
9 | from functools import reduce
10 | from glob import glob
11 | from typing import List, Union
12 | import pandas as pd
13 |
14 |
15 | class CONST:
16 | DIMENSION_REDUCTION_W_POOLING = 2*2 # 2x2 pooling in dimension W on layer 1 and 2
17 | PREPROCESSING_FOLDER = 'preprocessed'
18 |
19 |
20 | class Alphabet:
21 | """
22 | Class for alphabet / symbols units.
23 |
24 | :ivar _blank_symbol: Blank symbol used for CTC
25 | :vartype _blank_symbol: str
26 | :ivar _alphabet_units: list of elements composing the alphabet. The units may be a single character or multiple characters.
27 | :vartype _alphabet_units: List[str]
28 | :ivar _codes: Each alphabet unit has a unique corresponding code.
29 | :vartype _codes: List[int]
30 | :ivar _nclasses: number of alphabet units.
31 | :vartype _nclasses: int
32 | """
33 | def __init__(self, lookup_alphabet_file: str=None, blank_symbol: str='$'):
34 |
35 | self._blank_symbol = blank_symbol
36 |
37 | if lookup_alphabet_file:
38 | lookup_alphabet = self.load_lookup_from_json(lookup_alphabet_file)
39 | # Blank symbol must have the largest value
40 | if self._blank_symbol in lookup_alphabet.keys():
41 |
42 | # TODO : check if blank symbol is the last one
43 | assert lookup_alphabet[self._blank_symbol] == max(lookup_alphabet.values()), \
44 | "Blank symbol should have the largest code integer"
45 | lookup_alphabet[self._blank_symbol] = max(lookup_alphabet.values()) + 1
46 | else:
47 | lookup_alphabet.update({self._blank_symbol: max(lookup_alphabet.values()) + 1})
48 |
49 | self._alphabet_units = list(lookup_alphabet.keys())
50 | self._codes = list(lookup_alphabet.values())
51 | self._nclasses = len(self.codes) + 1 # n_classes should be + 1 of labels codes
52 |
53 | if 0 in self._codes:
54 | raise ValueError('0 code is in the lookup table, you should''nt use it.')
55 |
56 | self.lookup_int2str = dict(zip(self.codes, self.alphabet_units))
57 |
58 | def check_input_file_alphabet(self, csv_filenames: List[str],
59 | discarded_chars: str=';|{}'.format(string.whitespace[1:]),
60 | csv_delimiter: str=";") -> None:
61 | """
62 | Checks if labels of input files contains only characters that are in the Alphabet.
63 |
64 | :param csv_filenames: list of the csv filename
65 | :param discarded_chars: discarded characters
66 | :param csv_delimiter: character delimiting field in the csv file
67 | :return:
68 | """
69 | assert isinstance(csv_filenames, list), 'csv_filenames argument is not a list'
70 |
71 | alphabet_set = set(self.alphabet_units)
72 |
73 | for filename in csv_filenames:
74 | input_chars_set = set()
75 |
76 | with open(filename, 'r', encoding='utf8') as f:
77 | csvreader = csv.reader(f, delimiter=csv_delimiter, escapechar='\\', quoting=0)
78 | for line in csvreader:
79 | input_chars_set.update(line[1])
80 |
81 | # Discard all whitespaces except space ' '
82 | for whitespace in discarded_chars:
83 | input_chars_set.discard(whitespace)
84 |
85 | extra_chars = input_chars_set - alphabet_set
86 | assert len(extra_chars) == 0, 'There are {} unknown chars in {} : {}'.format(len(extra_chars),
87 | filename, extra_chars)
88 |
89 | @classmethod
90 | def map_lookup(cls, lookup_table: dict, unique_entry: bool = True) -> dict:
91 | """
92 | Converts an existing lookup table with minimal range code ([1, len(lookup_table)-1])
93 | and avoids multiple instances of the same code label (bijectivity)
94 |
95 | :param lookup_table: dictionary to be mapped {alphabet_unit : code label}
96 | :param unique_entry: If each alphabet unit has a unique code and each code a unique alphabet unique ('bijective'),
97 | only True is implemented for now
98 | :return: a mapped dictionary
99 | """
100 |
101 | # Create tuple (alphabet unit, code)
102 | tuple_char_code = list(zip(list(lookup_table.keys()), list(lookup_table.values())))
103 | # Sort by code
104 | tuple_char_code.sort(key=lambda x: x[1])
105 |
106 | # If each alphabet unit has a unique code and each code a unique alphabet unique ('bijective')
107 | if unique_entry:
108 | mapped_lookup = [[tp[0], i + 1] for i, tp in enumerate(tuple_char_code)]
109 | else:
110 | raise NotImplementedError
111 | # Todo
112 |
113 | return dict(mapped_lookup)
114 |
115 | @classmethod
116 | def create_lookup_from_labels(cls, csv_files: List[str], export_lookup_filename: str,
117 | original_lookup_filename: str=None):
118 | """
119 | Create a lookup dictionary for csv files containing labels. Exports a json file with the Alphabet.
120 |
121 | :param csv_files: list of files to get the labels from (should be of format path;label)
122 | :param export_lookup_filename: filename to export alphabet lookup dictionary
123 | :param original_lookup_filename: original lookup filename to update (optional)
124 | :return:
125 | """
126 | if original_lookup_filename:
127 | with open(original_lookup_filename, 'r') as f:
128 | lookup = json.load(f)
129 | set_chars = set(list(lookup.keys()))
130 | else:
131 | set_chars = set(list(string.ascii_letters) + list(string.digits))
132 | lookup = dict()
133 |
134 | for filename in csv_files:
135 | data = pd.read_csv(filename, sep=';', encoding='utf8', error_bad_lines=False, header=None,
136 | names=['path', 'transcription'], escapechar='\\')
137 | for index, row in data.iterrows():
138 | set_chars.update(row.transcription.split('|'))
139 |
140 | # Update (key, values) of lookup table
141 | for el in set_chars:
142 | if el not in lookup.keys():
143 | lookup[el] = max(lookup.values()) + 1 if lookup.values() else 0
144 |
145 | lookup = cls.map_lookup(lookup)
146 |
147 | # Save new lookup
148 | with open(export_lookup_filename, 'w', encoding='utf8') as f:
149 | json.dump(lookup, f)
150 |
151 | @classmethod
152 | def load_lookup_from_json(cls, json_filenames: Union[List[str], str]) -> dict:
153 | """
154 | Load a lookup table from a json file to a dictionnary
155 | :param json_filenames: either a filename or a list of filenames
156 | :return:
157 | """
158 |
159 | lookup = dict()
160 | if isinstance(json_filenames, list):
161 | for file in json_filenames:
162 | with open(file, 'r', encoding='utf8') as f:
163 | data_dict = json.load(f)
164 | lookup.update(data_dict)
165 |
166 | elif isinstance(json_filenames, str):
167 | with open(json_filenames, 'r', encoding='utf8') as f:
168 | lookup = json.load(f)
169 |
170 | return cls.map_lookup(lookup)
171 |
172 | @classmethod
173 | def make_json_lookup_alphabet(cls, string_chars: str = None) -> dict:
174 | """
175 |
176 | :param string_chars: for example string.ascii_letters, string.digits
177 | :return:
178 | """
179 | lookup = dict()
180 | if string_chars:
181 | # Add characters to lookup table
182 | lookup.update({char: ord(char) for char in string_chars})
183 |
184 | return cls.map_lookup(lookup)
185 |
186 | @property
187 | def n_classes(self):
188 | return self._nclasses
189 |
190 | @property
191 | def blank_symbol(self):
192 | return self._blank_symbol
193 |
194 | @property
195 | def codes(self):
196 | return self._codes
197 |
198 | @property
199 | def alphabet_units(self):
200 | return self._alphabet_units
201 |
202 |
203 | class Params:
204 | """
205 | Class for parameters of the model and the experiment
206 |
207 | :ivar input_shape: input shape of the image to batch (this is the shape after data augmentation).
208 | The original will either be resized or pad depending on its original size
209 | :vartype input_shape: Tuple[int, int]
210 | :ivar input_channels: number of color channels for input image (default: 1)
211 | :vartype input_channels: int
212 | :ivar cnn_features_list: a list of length `n_layers` containing the number of features for each convolutionl layer
213 | (default: [16, 32, 64, 96, 128])
214 | :vartype cnn_features_list: List(int)
215 | :ivar cnn_kernel_size: a list of length `n_layers` containing the size of the kernel for each convolutionl layer
216 | (default: [3, 3, 3, 3, 3])
217 | :vartype cnn_kernel_size: List(int)
218 | :ivar cnn_stride_size: a list of length `n_layers` containing the stride size each convolutionl layer
219 | (default: [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)])
220 | :vartype cnn_stride_size: List((int, int))
221 | :ivar cnn_pool_size: a list of length `n_layers` containing the pool size each MaxPool layer
222 | default: ([(2, 2), (2, 2), (2, 2), (2, 2), (1, 1)])
223 | :vartype cnn_pool_size: List((int, int))
224 | :ivar cnn_batch_norm: a list of length `n_layers` containing a bool that indicated wether or not to use batch normalization
225 | (default: [False, False, False, False, False])
226 | :vartype cnn_batch_norm: List(bool)
227 | :ivar rnn_units: a list containing the number of units per rnn layer (default: 256)
228 | :vartype rnn_units: List(int)
229 | :ivar num_beam_paths: number of paths (transcriptions) to return for ctc beam search (only used when predicting)
230 | :vartype num_beam_paths: int
231 | :ivar csv_delimiter: character to delimit csv input files (default: ';')
232 | :vartype csv_delimiter: str
233 | :ivar string_split_delimiter: character that delimits each alphabet unit in the labels (default: '|')
234 | :vartype string_split_delimiter: str
235 | :ivar csv_files_train: csv filename which contains the (path;label) of each training sample
236 | :vartype csv_files_train: str
237 | :ivar csv_files_eval: csv filename which contains the (path;label) of each eval sample
238 | :vartype csv_files_eval: str
239 | :ivar lookup_alphabet_file: json file that contains the mapping alphabet units <-> codes
240 | :vartype lookup_alphabet_file: str
241 | :ivar blank_symbol: symbol for to be considered as blank by the CTC decoder (default: '$')
242 | :vartype blank_symbol: str
243 | :ivar max_chars_per_string: maximum characters per sample (to avoid CTC decoder errors) (default: 75)
244 | :vartype max_chars_per_string: int
245 | :ivar data_augmentation: if True augments data on the fly (default: true)
246 | :vartype data_augmentation: bool
247 | :ivar data_augmentation_max_rotation: max permitted roation to apply to image during training in radians (default: 0.005)
248 | :vartype data_augmentation_max_rotation: float
249 | :ivar data_augmentation_max_slant: maximum angle for slant augmentation (default: 0.7)
250 | :vartype data_augmentation_max_slant: float
251 | :ivar n_epochs: numbers of epochs to run the training (default: 50)
252 | :vartype n_epochs: int
253 | :ivar train_batch_size: batch size during training (default: 64)
254 | :vartype train_batch_size: int
255 | :ivar eval_batch_size: batch size during evaluation (default: 128)
256 | :vartype eval_batch_size: int
257 | :ivar learning_rate: initial learning rate (default: 1e-4)
258 | :vartype learning_rate: float
259 | :ivar evaluate_every_epoch: evaluate every 'evaluate_every_epoch' epoch (default: 5)
260 | :vartype evaluate_every_epoch: int
261 | :ivar save_interval: save the model every 'save_interval' epoch (default: 20)
262 | :vartype save_interval: int
263 | :ivar optimizer: which optimizer to use ('adam', 'rms', 'ada') (default: 'adam')
264 | :vartype optimizer: str
265 | :ivar output_model_dir: output directory where the model will be saved and exported
266 | :vartype output_model_dir: str
267 | :ivar restore_model: boolean to continue training with saved weights (default: False)
268 | :vartype restore_model: bool
269 | """
270 | def __init__(self, **kwargs):
271 | # model params
272 | self.input_shape = kwargs.get('input_shape', (96, 1400))
273 | self.input_channels = kwargs.get('input_channels', 1)
274 | self.cnn_features_list = kwargs.get('cnn_features_list', [16, 32, 64, 96, 128])
275 | self.cnn_kernel_size = kwargs.get('cnn_kernel_size', [3, 3, 3, 3, 3])
276 | self.cnn_stride_size = kwargs.get('cnn_stride_size', [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)])
277 | self.cnn_pool_size = kwargs.get('cnn_pool_size', [(2, 2), (2, 2), (2, 2), (2, 2), (1, 1)])
278 | self.cnn_batch_norm = kwargs.get('cnn_batch_norm', [False, False, False, False, False])
279 | self.rnn_units = kwargs.get('rnn_units', [256, 256])
280 | # self._keep_prob_dropout = kwargs.get('keep_prob_dropout', 0.5)
281 | self.num_beam_paths = kwargs.get('num_beam_paths', 1)
282 | # csv params
283 | self.csv_delimiter = kwargs.get('csv_delimiter', ';')
284 | self.string_split_delimiter = kwargs.get('string_split_delimiter', '|')
285 | self.csv_files_train = kwargs.get('csv_files_train')
286 | self.csv_files_eval = kwargs.get('csv_files_eval')
287 | # alphabet params
288 | self.blank_symbol = kwargs.get('blank_symbol', '$')
289 | self.max_chars_per_string = kwargs.get('max_chars_per_string', 75)
290 | self.lookup_alphabet_file = kwargs.get('lookup_alphabet_file')
291 | # data augmentation params
292 | self.data_augmentation = kwargs.get('data_augmentation', True),
293 | self.data_augmentation_max_rotation = kwargs.get('data_augmentation_max_rotation', 0.005)
294 | self.data_augmentation_max_slant = kwargs.get('data_augmentation_max_slant', 0.7)
295 | # training params
296 | self.n_epochs = kwargs.get('n_epochs', 50)
297 | self.train_batch_size = kwargs.get('train_batch_size', 64)
298 | self.eval_batch_size = kwargs.get('eval_batch_size', 128)
299 | self.learning_rate = kwargs.get('learning_rate', 1e-4)
300 | self.optimizer = kwargs.get('optimizer', 'adam')
301 | self.output_model_dir = kwargs.get('output_model_dir', '')
302 | self.evaluate_every_epoch = kwargs.get('evaluate_every_epoch', 5)
303 | self.save_interval = kwargs.get('save_interval', 20)
304 | self.restore_model = kwargs.get('restore_model', False)
305 |
306 | self._assign_alphabet()
307 |
308 | cnn_params = zip(self.cnn_pool_size, self.cnn_stride_size)
309 | self.downscale_factor = reduce(lambda i, j: i * j, map(lambda k: k[0][1] * k[1][1], cnn_params))
310 |
311 | # TODO add additional checks for the architecture
312 | assert len(self.cnn_features_list) == len(self.cnn_kernel_size) == len(self.cnn_stride_size) \
313 | == len(self.cnn_pool_size) == len(self.cnn_batch_norm), \
314 | "Length of parameters of model are not the same, check that all the layers parameters have the same length."
315 |
316 | max_input_width = (self.max_chars_per_string + 1) * self.downscale_factor
317 | assert max_input_width <= self.input_shape[1], "Maximum length of labels is {}, input width should be greater or " \
318 | "equal to {} but is {}".format(self.max_chars_per_string,
319 | max_input_width,
320 | self.input_shape[1])
321 |
322 | assert self.optimizer in ['adam', 'rms', 'ada'], 'Unknown optimizer {}'.format(self.optimizer)
323 |
324 | if os.path.isdir(self.output_model_dir):
325 | print('WARNING : The output directory {} already exists.'.format(self.output_model_dir))
326 |
327 | def show_experiment_params(self) -> dict:
328 | """
329 | Returns a dictionary with the variables of the class.
330 |
331 | :return:
332 | """
333 | return vars(self)
334 |
335 | def _assign_alphabet(self):
336 | self.alphabet = Alphabet(lookup_alphabet_file=self.lookup_alphabet_file, blank_symbol=self.blank_symbol)
337 |
338 | # @property
339 | # def keep_prob_dropout(self):
340 | # return self._keep_prob_dropout
341 | #
342 | # @keep_prob_dropout.setter
343 | # def keep_prob_dropout(self, value):
344 | # assert (0.0 < value <= 1.0), 'Must be 0.0 < value <= 1.0'
345 | # self._keep_prob_dropout = value
346 |
347 | def to_dict(self) -> dict:
348 | """
349 | Returns the parameters as a dictionary
350 |
351 | :return:
352 | """
353 | new_dict = self.__dict__.copy()
354 | del new_dict['alphabet']
355 | del new_dict['downscale_factor']
356 | return new_dict
357 |
358 | @classmethod
359 | def from_json_file(cls, json_file: str):
360 | """
361 | Given a json file, creates a ``Params`` object.
362 |
363 | :param json_file: path to the json file
364 | :return: ``Params`` object
365 | """
366 | with open(json_file, 'r') as file:
367 | config = json.load(file)
368 |
369 | return cls(**config)
370 |
371 |
372 | def import_params_from_json(model_directory: str=None, json_filename: str=None) -> dict:
373 | """
374 | Read the exported json file with parameters of the experiment.
375 |
376 | :param model_directory: Direcoty where the odel was exported
377 | :param json_filename: filename of the file
378 | :return: a dictionary containing the parameters of the experiment
379 | """
380 |
381 | assert not all(p is None for p in [model_directory, json_filename]), 'One argument at least should not be None'
382 |
383 | if model_directory:
384 | # Import parameters from the json file
385 | try:
386 | json_filename = glob(os.path.join(model_directory, 'model_params*.json'))[-1]
387 | except IndexError:
388 | print('No json found in dir {}'.format(model_directory))
389 | raise FileNotFoundError
390 | else:
391 | if not os.path.isfile(json_filename):
392 | print('No json found with filename {}'.format(json_filename))
393 | raise FileNotFoundError
394 |
395 | with open(json_filename, 'r') as data_json:
396 | params_json = json.load(data_json)
397 |
398 | # Remove 'private' keys
399 | keys = list(params_json.keys())
400 | for key in keys:
401 | if key[0] == '_':
402 | params_json.pop(key)
403 |
404 | return params_json
405 |
--------------------------------------------------------------------------------
/tf_crnn/data_handler.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | __author__ = 'solivr'
3 | __license__ = "GPL"
4 |
5 | import tensorflow as tf
6 | from tensorflow_addons.image.transform_ops import rotate, transform
7 | from .config import Params, CONST
8 | from typing import Tuple, Union, List
9 | import collections
10 |
11 |
12 | @tf.function
13 | def random_rotation(img: tf.Tensor,
14 | max_rotation: float=0.1,
15 | crop: bool=True,
16 | minimum_width: int=0) -> tf.Tensor: # adapted from SeguinBe
17 | """
18 | Rotates an image with a random angle.
19 | See https://stackoverflow.com/questions/16702966/rotate-image-and-crop-out-black-borders for formulae
20 |
21 | :param img: Tensor
22 | :param max_rotation: maximum angle to rotate (radians)
23 | :param crop: boolean to crop or not the image after rotation
24 | :param minimum_width: minimum width of image after data augmentation
25 | :return:
26 | """
27 | with tf.name_scope('RandomRotation'):
28 | rotation = tf.random.uniform([], -max_rotation, max_rotation, name='pick_random_angle')
29 | # rotated_image = tf.contrib.image.rotate(img, rotation, interpolation='BILINEAR')
30 | rotated_image = rotate(tf.expand_dims(img, axis=0), rotation, interpolation='BILINEAR')
31 | rotated_image = tf.squeeze(rotated_image, axis=0)
32 | if crop:
33 | rotation = tf.abs(rotation)
34 | original_shape = tf.shape(rotated_image)[:2]
35 | h, w = original_shape[0], original_shape[1]
36 | old_l, old_s = tf.cond(h > w, lambda: [h, w], lambda: [w, h])
37 | old_l, old_s = tf.cast(old_l, tf.float32), tf.cast(old_s, tf.float32)
38 | new_l = (old_l * tf.cos(rotation) - old_s * tf.sin(rotation)) / tf.cos(2*rotation)
39 | new_s = (old_s - tf.sin(rotation) * new_l) / tf.cos(rotation)
40 | new_h, new_w = tf.cond(h > w, lambda: [new_l, new_s], lambda: [new_s, new_l])
41 | new_h, new_w = tf.cast(new_h, tf.int32), tf.cast(new_w, tf.int32)
42 | bb_begin = tf.cast(tf.math.ceil((h-new_h)/2), tf.int32), tf.cast(tf.math.ceil((w-new_w)/2), tf.int32)
43 | # Test sliced
44 | rotated_image_crop = tf.cond(
45 | tf.logical_and(bb_begin[0] < h - bb_begin[0], bb_begin[1] < w - bb_begin[1]),
46 | true_fn=lambda: rotated_image[bb_begin[0]:h - bb_begin[0], bb_begin[1]:w - bb_begin[1], :],
47 | false_fn=lambda: img,
48 | name='check_slices_indices'
49 | )
50 | # rotated_image_crop = rotated_image[bb_begin[0]:h - bb_begin[0], bb_begin[1]:w - bb_begin[1], :]
51 |
52 | # If crop removes the entire image, keep the original image
53 | rotated_image = tf.cond(tf.less_equal(tf.shape(rotated_image_crop)[1], minimum_width),
54 | true_fn=lambda: img,
55 | false_fn=lambda: rotated_image_crop,
56 | name='check_size_crop')
57 |
58 | return rotated_image
59 |
60 |
61 | # def random_padding(image: tf.Tensor, max_pad_w: int=5, max_pad_h: int=10) -> tf.Tensor:
62 | # """
63 | # Given an image will pad its border adding a random number of rows and columns
64 | #
65 | # :param image: image to pad
66 | # :param max_pad_w: maximum padding in width
67 | # :param max_pad_h: maximum padding in height
68 | # :return: a padded image
69 | # """
70 | # # TODO specify image shape in doc
71 | #
72 | # w_pad = list(np.random.randint(0, max_pad_w, size=[2]))
73 | # h_pad = list(np.random.randint(0, max_pad_h, size=[2]))
74 | # paddings = [h_pad, w_pad, [0, 0]]
75 | #
76 | # return tf.pad(image, paddings, mode='REFLECT', name='random_padding')
77 |
78 | @tf.function
79 | def augment_data(image: tf.Tensor,
80 | max_rotation: float=0.1,
81 | minimum_width: int=0) -> tf.Tensor:
82 | """
83 | Data augmentation on an image (padding, brightness, contrast, rotation)
84 |
85 | :param image: Tensor
86 | :param max_rotation: float, maximum permitted rotation (in radians)
87 | :param minimum_width: minimum width of image after data augmentation
88 | :return: Tensor
89 | """
90 | with tf.name_scope('DataAugmentation'):
91 |
92 | # Random padding
93 | # image = random_padding(image)
94 |
95 | # TODO : add random scaling
96 | image = tf.image.random_brightness(image, max_delta=0.1)
97 | image = tf.image.random_contrast(image, 0.5, 1.5)
98 | image = random_rotation(image, max_rotation, crop=True, minimum_width=minimum_width)
99 |
100 | if image.shape[-1] >= 3:
101 | image = tf.image.random_hue(image, 0.2)
102 | image = tf.image.random_saturation(image, 0.5, 1.5)
103 |
104 | return image
105 |
106 | @tf.function
107 | def get_resized_width(image: tf.Tensor,
108 | target_height: int,
109 | increment: int):
110 | """
111 | Resizes the image according to `target_height`.
112 |
113 | :param image: image to resize
114 | :param target_height: height of the resized image
115 | :param increment: reduction factor due to pooling between input width and output width,
116 | this makes sure that the final width will be a multiple of increment
117 | :return: resized image
118 | """
119 |
120 | image_shape = tf.shape(image)
121 | image_ratio = tf.divide(image_shape[1], image_shape[0], name='ratio')
122 |
123 | new_width = tf.cast(tf.round((image_ratio * target_height) / increment) * increment, tf.int32)
124 | f1 = lambda: (new_width, image_ratio)
125 | f2 = lambda: (target_height, tf.constant(1.0, dtype=tf.float64))
126 | if tf.math.less_equal(new_width, 0):
127 | return f2()
128 | else:
129 | return f1()
130 |
131 |
132 | @tf.function
133 | def padding_inputs_width(image: tf.Tensor,
134 | target_shape: Tuple[int, int],
135 | increment: int) -> Tuple[tf.Tensor, tf.Tensor]:
136 | """
137 | Given an input image, will pad it to return a target_shape size padded image.
138 | There are 3 cases:
139 | - image width > target width : simple resizing to shrink the image
140 | - image width >= 0.5*target width : pad the image
141 | - image width < 0.5*target width : replicates the image segment and appends it
142 |
143 | :param image: Tensor of shape [H,W,C]
144 | :param target_shape: final shape after padding [H, W]
145 | :param increment: reduction factor due to pooling between input width and output width,
146 | this makes sure that the final width will be a multiple of increment
147 | :return: (image padded, output width)
148 | """
149 |
150 | target_ratio = target_shape[1]/target_shape[0]
151 | target_w = target_shape[1]
152 | # Compute ratio to keep the same ratio in new image and get the size of padding
153 | # necessary to have the final desired shape
154 | new_h = target_shape[0]
155 | new_w, ratio = get_resized_width(image, new_h, increment)
156 |
157 | # Definitions for cases
158 | def pad_fn():
159 | with tf.name_scope('mirror_padding'):
160 | pad = tf.subtract(target_w, new_w)
161 |
162 | img_resized = tf.image.resize(image, [new_h, new_w])
163 |
164 | # Padding to have the desired width
165 | paddings = [[0, 0], [0, pad], [0, 0]]
166 | pad_image = tf.pad(img_resized, paddings, mode='SYMMETRIC', name=None)
167 |
168 | # Set manually the shape
169 | pad_image.set_shape([target_shape[0], target_shape[1], img_resized.get_shape()[2]])
170 |
171 | return pad_image, (new_h, new_w)
172 |
173 | def replicate_fn():
174 | with tf.name_scope('replication_padding'):
175 | img_resized = tf.image.resize(image, [new_h, new_w])
176 |
177 | # If one symmetry is not enough to have a full width
178 | # Count number of replications needed
179 | n_replication = tf.cast(tf.math.ceil(target_shape[1]/new_w), tf.int32)
180 | img_replicated = tf.tile(img_resized, tf.stack([1, n_replication, 1]))
181 | pad_image = tf.image.crop_to_bounding_box(image=img_replicated, offset_height=0, offset_width=0,
182 | target_height=target_shape[0], target_width=target_shape[1])
183 |
184 | # Set manually the shape
185 | pad_image.set_shape([target_shape[0], target_shape[1], img_resized.get_shape()[2]])
186 |
187 | return pad_image, (new_h, new_w)
188 |
189 | def simple_resize():
190 | with tf.name_scope('simple_resize'):
191 | img_resized = tf.image.resize(image, target_shape)
192 |
193 | img_resized.set_shape([target_shape[0], target_shape[1], img_resized.get_shape()[2]])
194 |
195 | return img_resized, tuple(target_shape)
196 |
197 | # case 1 : new_w >= target_w
198 | if tf.logical_and(tf.greater_equal(ratio, target_ratio), tf.greater_equal(new_w, target_w)):
199 | pad_image, (new_h, new_w) = simple_resize()
200 | # case 2 : new_w >= target_w/2 & new_w < target_w & ratio < target_ratio
201 | elif tf.logical_and(tf.less(ratio, target_ratio),
202 | tf.logical_and(tf.greater_equal(new_w, tf.cast(tf.divide(target_w, 2), tf.int32)),
203 | tf.less(new_w, target_w))):
204 | pad_image, (new_h, new_w) = pad_fn()
205 | # case 3 : new_w < target_w/2 & new_w < target_w & ratio < target_ratio
206 | elif tf.logical_and(tf.less(ratio, target_ratio),
207 | tf.logical_and(tf.less(new_w, target_w),
208 | tf.less(new_w, tf.cast(tf.divide(target_w, 2), tf.int32)))):
209 | pad_image, (new_h, new_w) = replicate_fn()
210 | else:
211 | pad_image, (new_h, new_w) = simple_resize()
212 |
213 | return pad_image, new_w
214 |
215 |
216 | # def apply_slant(image: np.ndarray, alpha: np.ndarray) -> (np.ndarray, np.ndarray):
217 | # alpha = alpha[0]
218 | #
219 | # def _find_background_color(image: np.ndarray) -> int:
220 | # """
221 | # Given a grayscale image, finds the background color value
222 | # :param image: grayscale image
223 | # :return: background color value (int)
224 | # """
225 | # # Otsu's thresholding after Gaussian filtering
226 | # blur = cv2.GaussianBlur(image[:, :, 0].astype(np.uint8), (5, 5), 0)
227 | # thresh_value, thresholded_image = cv2.threshold(blur.astype(np.uint8), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
228 | #
229 | # # Find which is the background (0 or 255). Supposing that the background color occurrence is higher
230 | # # than the writing color
231 | # counts, bin_edges = np.histogram(thresholded_image, bins=2)
232 | # background_color = int(np.median(image[thresholded_image == 255 * np.argmax(counts)]))
233 | #
234 | # return background_color
235 | #
236 | # shape_image = image.shape
237 | # shift = max(-alpha * shape_image[0], 0)
238 | # output_size = (int(shape_image[1] + np.ceil(abs(alpha * shape_image[0]))), int(shape_image[0]))
239 | #
240 | # warpM = np.array([[1, alpha, shift], [0, 1, 0]])
241 | #
242 | # # Find color of background in order to replicate it in the borders
243 | # border_value = _find_background_color(image)
244 | #
245 | # image_warp = cv2.warpAffine(image, np.array(warpM), output_size, borderValue=border_value)
246 | #
247 | # return image_warp, np.array(output_size)
248 |
249 |
250 | def dataset_generator(csv_filename: Union[List[str], str],
251 | params: Params,
252 | use_labels: bool=True,
253 | batch_size: int=64,
254 | data_augmentation: bool=False,
255 | num_epochs: int=None,
256 | shuffle: bool=True):
257 | """
258 | Generates the dataset for the experiment.
259 |
260 |
261 | :param csv_filename: Path to csv file containing the data
262 | :param params: parameters df the experiment (``Params``)
263 | :param use_labels: boolean to indicate dataset generation during training / evaluation (true) or prediction (false)
264 | :param batch_size: size of the generated batches
265 | :param data_augmentation: whether to use data augmentation strategies or not
266 | :param num_epochs: number of epochs to repeat the dataset generation
267 | :param shuffle: whether to suffle the data
268 | :return: ``tf.data.Dataset``
269 | """
270 | do_padding = True
271 |
272 | if use_labels:
273 | column_defaults = [['None'], ['None'], tf.int32]
274 | column_names = ['paths', 'label_codes', 'label_seq_length']
275 | label_name = 'label_codes'
276 | else:
277 | column_defaults = [['None']]
278 | column_names = ['paths']
279 | label_name = None
280 |
281 | num_parallel_reads = 1
282 |
283 | # ----- from data.experimental.make_csv_dataset
284 | def filename_to_dataset(filename):
285 | dataset = tf.data.experimental.CsvDataset(filename,
286 | record_defaults=column_defaults,
287 | field_delim=params.csv_delimiter,
288 | header=False)
289 | return dataset
290 |
291 | def map_fn(*columns):
292 | """Organizes columns into a features dictionary.
293 | Args:
294 | *columns: list of `Tensor`s corresponding to one csv record.
295 | Returns:
296 | An OrderedDict of feature names to values for that particular record. If
297 | label_name is provided, extracts the label feature to be returned as the
298 | second element of the tuple.
299 | """
300 | features = collections.OrderedDict(zip(column_names, columns))
301 | if label_name is not None:
302 | label = features.pop(label_name)
303 | return features, label
304 |
305 | return features
306 |
307 | dataset = tf.data.Dataset.from_tensor_slices(csv_filename)
308 | # Read files sequentially (if num_parallel_reads=1) or in parallel
309 | # dataset = dataset.apply(tf.data.experimental.parallel_interleave(filename_to_dataset,
310 | # cycle_length=num_parallel_reads))
311 | dataset = dataset.interleave(filename_to_dataset, cycle_length=num_parallel_reads,
312 | num_parallel_calls=tf.data.experimental.AUTOTUNE)
313 | dataset = dataset.map(map_fn)
314 | # -----
315 |
316 | def _load_image(features: dict, labels=None):
317 | path = features['paths']
318 | image_content = tf.io.read_file(path)
319 | image = tf.io.decode_jpeg(image_content, channels=params.input_channels,
320 | try_recover_truncated=True, name='image_decoding_op')
321 |
322 | if use_labels:
323 | return {'input_images': image,
324 | 'label_seq_length': features['label_seq_length']}, labels
325 | else:
326 | return {'input_images': image,
327 | 'filename_images': path}
328 |
329 | def _apply_slant(features: dict, labels=None):
330 | image = features['input_images']
331 | height_image = tf.cast(tf.shape(image)[0], dtype=tf.float32)
332 |
333 | with tf.name_scope('add_slant'):
334 | alpha = tf.random.uniform([],
335 | -params.data_augmentation_max_slant,
336 | params.data_augmentation_max_slant,
337 | name='pick_random_slant_angle')
338 |
339 | shiftx = tf.math.maximum(tf.math.multiply(-alpha, height_image), 0)
340 |
341 | # Pad in order not to loose image info when transformation is applied
342 | x_pad = 0
343 | y_pad = tf.math.round(tf.math.ceil(tf.math.abs(tf.math.multiply(alpha, height_image))))
344 | y_pad = tf.cast(y_pad, dtype=tf.int32)
345 | paddings = [[x_pad, x_pad], [y_pad, 0], [0, 0]]
346 | transform_matrix = [1, alpha, shiftx, 0, 1, 0, 0, 0]
347 |
348 | # Apply transformation to image
349 | image_pad = tf.pad(image, paddings)
350 | image_transformed = transform(image_pad, transform_matrix, interpolation='BILINEAR')
351 |
352 | # Apply transformation to mask. The mask will be used to retrieve the pixels that have been filled
353 | # with zero during transformation and update their value with background value
354 | # TODO : Would be better to have some kind of binarization (i.e Otsu) and get the mean background value
355 | background_pixel_value = 255
356 | empty = background_pixel_value * tf.ones(tf.shape(image))
357 | empty_pad = tf.pad(empty, paddings)
358 | empty_transformed = tf.subtract(
359 | tf.cast(background_pixel_value, dtype=tf.int32),
360 | tf.cast(transform(empty_pad, transform_matrix, interpolation='NEAREST'), dtype=tf.int32)
361 | )
362 |
363 | # Update additional zeros values with background_pixel_value and cast result to uint8
364 | image = tf.add(tf.cast(image_transformed, dtype=tf.int32), empty_transformed)
365 | image = tf.cast(image, tf.uint8)
366 |
367 | features['input_images'] = image
368 | return features, labels if use_labels else features
369 |
370 | def _data_augment_fn(features: dict, labels=None) -> tf.data.Dataset:
371 | image = features['input_images']
372 | image = augment_data(image, params.data_augmentation_max_rotation, minimum_width=params.max_chars_per_string)
373 |
374 | features.update({'input_images': image})
375 | return features, labels if use_labels else features
376 |
377 | def _pad_image_or_resize(features: dict, labels=None):
378 | image = features['input_images']
379 | if do_padding:
380 | with tf.name_scope('padding'):
381 | image, img_width = padding_inputs_width(image, target_shape=params.input_shape,
382 | increment=params.downscale_factor) # todo this needs to be updated
383 | # Resize
384 | else:
385 | image = tf.image.resize(image, size=params.input_shape)
386 | img_width = tf.shape(image)[1]
387 |
388 | input_seq_length = tf.cast(tf.floor(tf.divide(img_width, params.downscale_factor)), tf.int32)
389 | if use_labels:
390 | assert_op = tf.debugging.assert_greater_equal(input_seq_length,
391 | features['label_seq_length'])
392 | with tf.control_dependencies([assert_op]):
393 | return {'input_images': image,
394 | 'label_seq_length': features['label_seq_length'],
395 | 'input_seq_length': input_seq_length}, labels
396 | else:
397 | return {'input_images': image,
398 | 'input_seq_length': input_seq_length,
399 | 'filename_images': features['filename_images']}
400 |
401 | def _normalize_image(features: dict, labels=None):
402 | image = tf.cast(features['input_images'], tf.float32)
403 | image = tf.image.per_image_standardization(image)
404 |
405 | features['input_images'] = image
406 | return features, labels if use_labels else features
407 |
408 | def _format_label_codes(features: dict, string_label_codes):
409 | splits = tf.strings.split([string_label_codes], sep=' ')
410 | label_codes = tf.squeeze(tf.strings.to_number(splits, out_type=tf.int32), axis=0)
411 |
412 | features.update({'label_codes': label_codes})
413 | return features, [0]
414 |
415 |
416 | num_parallel_calls = tf.data.experimental.AUTOTUNE
417 | # 1. load image 2. data augmentation 3. padding
418 | dataset = dataset.map(_load_image, num_parallel_calls=num_parallel_calls)
419 | # this causes problems when using the same cache for training, validation and prediction data...
420 | # dataset = dataset.cache(filename=os.path.join(params.output_model_dir, 'cache.tf-data'))
421 | if data_augmentation and params.data_augmentation_max_slant != 0:
422 | dataset = dataset.map(_apply_slant, num_parallel_calls=num_parallel_calls)
423 | if data_augmentation:
424 | dataset = dataset.map(_data_augment_fn, num_parallel_calls=num_parallel_calls)
425 | dataset = dataset.map(_normalize_image, num_parallel_calls=num_parallel_calls)
426 | dataset = dataset.map(_pad_image_or_resize, num_parallel_calls=num_parallel_calls)
427 | dataset = dataset.map(_format_label_codes, num_parallel_calls=num_parallel_calls) if use_labels else dataset
428 | dataset = dataset.shuffle(10 * batch_size, reshuffle_each_iteration=False) if shuffle else dataset
429 | dataset = dataset.repeat(num_epochs) if num_epochs is not None else dataset
430 |
431 | return dataset.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
432 |
433 |
434 | # def dataset_prediction(image_filenames: Union[List[str], str]=None,
435 | # csv_filename: str=None,
436 | # params: Params=None,
437 | # batch_size: int=64):
438 | #
439 | # assert params, 'params cannot be None'
440 | # assert image_filenames or csv_filename, 'You need to feed an input (image_filenames or csv_filename)'
441 | #
442 | # do_padding = True
443 | #
444 | # def _load_image(path):
445 | # image_content = tf.io.read_file(path)
446 | # image = tf.io.decode_jpeg(image_content, channels=params.input_channels,
447 | # try_recover_truncated=True, name='image_decoding_op')
448 | #
449 | # return {'input_images': image}
450 | #
451 | # def _normalize_image(features: dict):
452 | # image = tf.cast(features['input_images'], tf.float32)
453 | # image = tf.image.per_image_standardization(image)
454 | #
455 | # features['input_images'] = image
456 | # return features
457 | #
458 | # def _pad_image_or_resize(features: dict):
459 | # image = features['input_images']
460 | # if do_padding:
461 | # with tf.name_scope('padding'):
462 | # image, img_width = padding_inputs_width(image, target_shape=params.input_shape,
463 | # increment=CONST.DIMENSION_REDUCTION_W_POOLING)
464 | # # Resize
465 | # else:
466 | # image = tf.image.resize(image, size=params.input_shape)
467 | # img_width = tf.shape(image)[1]
468 | #
469 | # input_seq_length = tf.cast(tf.floor(tf.math.divide(img_width, params.n_pool)), tf.int32)
470 | #
471 | # return {'input_images': image,
472 | # 'input_seq_length': input_seq_length}
473 | # if image_filenames is not None:
474 | # dataset = tf.data.Dataset.from_tensor_slices(image_filenames)
475 | # elif csv_filename is not None:
476 | # column_defaults = [['None']]
477 | # dataset = tf.data.experimental.CsvDataset(csv_filename,
478 | # record_defaults=column_defaults,
479 | # field_delim=params.csv_delimiter,
480 | # header=False)
481 | # # dataset = dataset.map(map_fn)
482 | # dataset = dataset.map(_load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
483 | # dataset = dataset.map(_normalize_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
484 | # dataset = dataset.map(_pad_image_or_resize, num_parallel_calls=tf.data.experimental.AUTOTUNE)
485 | #
486 | # return dataset.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
487 |
--------------------------------------------------------------------------------
/tf_crnn/model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | __author__ = "solivr"
3 | __license__ = "GPL"
4 |
5 | import tensorflow as tf
6 | from tensorflow.keras import Model
7 | from tensorflow.keras.backend import ctc_batch_cost, ctc_decode
8 | from tensorflow.keras.layers import Layer, Conv2D, BatchNormalization, MaxPool2D, Input, Permute, \
9 | Reshape, Bidirectional, LSTM, Dense, Softmax, Lambda
10 | from typing import List, Tuple
11 | from .config import Params
12 |
13 |
14 | class ConvBlock(Layer):
15 | """
16 | Convolutional block class.
17 | It is composed of a `Conv2D` layer, a `BatchNormaization` layer (optional),
18 | a `MaxPool2D` layer (optional) and a `ReLu` activation.
19 |
20 | :ivar features: number of features of the convolutional layer
21 | :vartype features: int
22 | :ivar kernel_size: size of the convolutional kernel
23 | :vartype kernel_size: int
24 | :ivar stride: stride of the convolutional layer
25 | :vartype stride: int, int
26 | :ivar cnn_padding: padding of the convolution ('same' or 'valid')
27 | :vartype cnn_padding:
28 | :ivar pool_size: size of the maxpooling
29 | :vartype pool_size: int, int
30 | :ivar batchnorm: use batch norm or not
31 | :vartype batchnorm: bool
32 | """
33 | def __init__(self,
34 | features: int,
35 | kernel_size: int,
36 | stride: Tuple[int, int],
37 | cnn_padding: str,
38 | pool_size: Tuple[int, int],
39 | batchnorm: bool,
40 | **kwargs):
41 | super(ConvBlock, self).__init__(**kwargs)
42 | self.conv = Conv2D(features,
43 | kernel_size,
44 | strides=stride,
45 | padding=cnn_padding)
46 | self.bn = BatchNormalization(renorm=True,
47 | renorm_clipping={'rmax': 1e2, 'rmin': 1e-1, 'dmax': 1e1},
48 | trainable=True) if batchnorm else None
49 | self.pool = MaxPool2D(pool_size=pool_size,
50 | padding='same') if list(pool_size) > [1, 1] else None
51 |
52 | # for config purposes
53 | self._features = features
54 | self._kernel_size = kernel_size
55 | self._stride = stride
56 | self._cnn_padding = cnn_padding
57 | self._pool_size = pool_size
58 | self._batchnorm = batchnorm
59 |
60 | def call(self, inputs, training=False):
61 | x = self.conv(inputs)
62 | if self.bn is not None:
63 | x = self.bn(x, training=training)
64 | if self.pool is not None:
65 | x = self.pool(x)
66 | x = tf.nn.relu(x)
67 | return x
68 |
69 | def get_config(self) -> dict:
70 | """
71 | Get a dictionary with all the necessary properties to recreate the same layer.
72 |
73 | :return: dictionary containing the properties of the layer
74 | """
75 | super_config = super(ConvBlock, self).get_config()
76 | config = {
77 | 'features': self._features,
78 | 'kernel_size': self._kernel_size,
79 | 'stride': self._stride,
80 | 'cnn_padding': self._cnn_padding,
81 | 'pool_size': self._pool_size,
82 | 'batchnorm': self._batchnorm
83 | }
84 | return dict(list(super_config.items()) + list(config.items()))
85 |
86 |
87 | def get_crnn_output(input_images,
88 | parameters: Params=None) -> tf.Tensor:
89 | """
90 | Creates the CRNN network and returns it's output.
91 | Passes the `input_images` through the network and returns its output
92 |
93 | :param input_images: images to process (B, H, W, C)
94 | :param parameters: parameters of the model (``Params``)
95 | :return: the output of the CRNN model
96 | """
97 |
98 | # params of the architecture
99 | cnn_features_list = parameters.cnn_features_list
100 | cnn_kernel_size = parameters.cnn_kernel_size
101 | cnn_pool_size = parameters.cnn_pool_size
102 | cnn_stride_size = parameters.cnn_stride_size
103 | cnn_batch_norm = parameters.cnn_batch_norm
104 | rnn_units = parameters.rnn_units
105 |
106 | # CNN layers
107 | cnn_params = zip(cnn_features_list, cnn_kernel_size, cnn_stride_size, cnn_pool_size, cnn_batch_norm)
108 | conv_layers = [ConvBlock(ft, ks, ss, 'same', psz, bn) for ft, ks, ss, psz, bn in cnn_params]
109 |
110 | x = conv_layers[0](input_images)
111 | for conv in conv_layers[1:]:
112 | x = conv(x)
113 |
114 | # Permutation and reshape
115 | x = Permute((2, 1, 3))(x)
116 | shape = x.get_shape().as_list()
117 | x = Reshape((shape[1], shape[2] * shape[3]))(x) # [B, W, H*C]
118 |
119 | # RNN layers
120 | rnn_layers = [Bidirectional(LSTM(ru, dropout=0.5, return_sequences=True, time_major=False)) for ru in
121 | rnn_units]
122 | for rnn in rnn_layers:
123 | x = rnn(x)
124 |
125 | # Dense and softmax
126 | x = Dense(parameters.alphabet.n_classes)(x)
127 | net_output = Softmax()(x)
128 |
129 | return net_output
130 |
131 |
132 | def get_model_train(parameters: Params):
133 | """
134 | Constructs the full model for training.
135 | Defines inputs and outputs, loss function and metric (CER).
136 |
137 | :param parameters: parameters of the model (``Params``)
138 | :return: the model (``tf.Keras.Model``)
139 | """
140 |
141 | h, w = parameters.input_shape
142 | c = parameters.input_channels
143 |
144 | input_images = Input(shape=(h, w, c), name='input_images')
145 | input_seq_len = Input(shape=[1], dtype=tf.int32, name='input_seq_length')
146 |
147 | label_codes = Input(shape=(parameters.max_chars_per_string), dtype=tf.int32, name='label_codes')
148 | label_seq_length = Input(shape=[1], dtype=tf.int32, name='label_seq_length')
149 |
150 | net_output = get_crnn_output(input_images, parameters)
151 |
152 | # Loss function
153 | def warp_ctc_loss(y_true, y_pred):
154 | return ctc_batch_cost(label_codes, y_pred, input_seq_len, label_seq_length)
155 |
156 | # Metric function
157 | def warp_cer_metric(y_true, y_pred):
158 | pred_sequence_length, true_sequence_length = input_seq_len, label_seq_length
159 |
160 | # y_pred needs to be decoded (its the logits)
161 | pred_codes_dense = ctc_decode(y_pred, tf.squeeze(pred_sequence_length, axis=-1), greedy=True)
162 | pred_codes_dense = tf.squeeze(tf.cast(pred_codes_dense[0], tf.int64), axis=0) # only [0] if greedy=true
163 |
164 | # create sparse tensor
165 | idx = tf.where(tf.not_equal(pred_codes_dense, -1))
166 | pred_codes_sparse = tf.SparseTensor(tf.cast(idx, tf.int64),
167 | tf.gather_nd(pred_codes_dense, idx),
168 | tf.cast(tf.shape(pred_codes_dense), tf.int64))
169 |
170 | idx = tf.where(tf.not_equal(label_codes, 0))
171 | label_sparse = tf.SparseTensor(tf.cast(idx, tf.int64),
172 | tf.gather_nd(label_codes, idx),
173 | tf.cast(tf.shape(label_codes), tf.int64))
174 | label_sparse = tf.cast(label_sparse, tf.int64)
175 |
176 | # Compute edit distance and total chars count
177 | distance = tf.reduce_sum(tf.edit_distance(pred_codes_sparse, label_sparse, normalize=False))
178 | count_chars = tf.reduce_sum(true_sequence_length)
179 |
180 | return tf.divide(distance, tf.cast(count_chars, tf.float32), name='CER')
181 |
182 | # Define model and compile it
183 | model = Model(inputs=[input_images, label_codes, input_seq_len, label_seq_length], outputs=net_output, name='CRNN')
184 | optimizer = tf.keras.optimizers.Adam(learning_rate=parameters.learning_rate)
185 | model.compile(loss=[warp_ctc_loss],
186 | optimizer=optimizer,
187 | metrics=[warp_cer_metric],
188 | experimental_run_tf_function=False) # TODO this is set to true by default but does not seem to work...
189 |
190 | return model
191 |
192 |
193 | def get_model_inference(parameters: Params,
194 | weights_path: str=None):
195 | """
196 | Constructs the full model for prediction.
197 | Defines inputs and outputs, and loads the weights.
198 |
199 |
200 | :param parameters: parameters of the model (``Params``)
201 | :param weights_path: path to the weights (.h5 file)
202 | :return: the model (``tf.Keras.Model``)
203 | """
204 | h, w = parameters.input_shape
205 | c = parameters.input_channels
206 |
207 | input_images = Input(shape=(h, w, c), name='input_images')
208 | input_seq_len = Input(shape=[1], dtype=tf.int32, name='input_seq_length')
209 | filename_images = Input(shape=[1], dtype=tf.string, name='filename_images')
210 |
211 | net_output = get_crnn_output(input_images, parameters)
212 | output_seq_len = tf.identity(input_seq_len) # need this op to pass it to output
213 | filenames = tf.identity(filename_images)
214 |
215 | model = Model(inputs=[input_images, input_seq_len, filename_images], outputs=[net_output, output_seq_len, filenames])
216 |
217 | if weights_path:
218 | model.load_weights(weights_path)
219 |
220 | return model
221 |
--------------------------------------------------------------------------------
/tf_crnn/preprocessing.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | __author__ = "solivr"
3 | __license__ = "GPL"
4 |
5 | import re
6 | import numpy as np
7 | import os
8 | from .config import Params, CONST
9 | import pandas as pd
10 | from typing import List, Tuple
11 | from taputapu.io.image import get_image_shape_without_loading
12 |
13 |
14 | def _convert_label_to_dense_codes(labels: List[str],
15 | split_char: str,
16 | max_width: int,
17 | table_str2int: dict):
18 | """
19 | Converts a list of formatted string to a dense matrix of codes
20 |
21 | :param labels: list of strings containing formatted labels
22 | :param split_char: character to split the formatted label
23 | :param max_width: maximum length of string label (max_n_chars = max_width_dense_codes)
24 | :param table_str2int: mapping table between alphabet units and alphabet codes
25 | :return: dense matrix N x max_width, list of the lengths of each string (length N)
26 | """
27 | labels_chars = [[c for c in label.split(split_char) if c] for label in labels]
28 | codes_list = [[table_str2int[c] for c in list_char] for list_char in labels_chars]
29 |
30 | seq_lengths = [len(cl) for cl in codes_list]
31 |
32 | dense_codes = list()
33 | for ls in codes_list:
34 | dense_codes.append(ls + np.maximum(0, (max_width - len(ls))) * [0])
35 |
36 | return dense_codes, seq_lengths
37 |
38 |
39 | def _compute_length_inputs(path: str,
40 | target_shape: Tuple[int, int]):
41 |
42 | w, h = get_image_shape_without_loading(path)
43 | ratio = w / h
44 |
45 | new_h = target_shape[0]
46 | new_w = np.minimum(new_h * ratio, target_shape[1])
47 |
48 | return new_w
49 |
50 |
51 | def preprocess_csv(csv_filename: str,
52 | parameters: Params,
53 | output_csv_filename: str) -> int:
54 | """
55 | Converts the original csv data to the format required by the experiment.
56 | Removes the samples which labels have too many characters. Computes the widths of input images and removes the
57 | samples which have more characters per label than image width. Converts the string labels to dense codes.
58 | The output csv file contains the path to the image, the dense list of codes corresponding to the alphabets units
59 | (which are padded with 0 if `len(label)` < `max_len`) and the length of the label sequence.
60 |
61 | :param csv_filename: path to csv file
62 | :param parameters: parameters of the experiment (``Params``)
63 | :param output_csv_filename: path to the output csv file
64 | :return: number of samples in the output csv file
65 | """
66 |
67 | # Conversion table
68 | table_str2int = dict(zip(parameters.alphabet.alphabet_units, parameters.alphabet.codes))
69 |
70 | # Read file
71 | dataframe = pd.read_csv(csv_filename,
72 | sep=parameters.csv_delimiter,
73 | header=None,
74 | names=['paths', 'labels'],
75 | encoding='utf8',
76 | escapechar="\\",
77 | quoting=0)
78 |
79 | original_len = len(dataframe)
80 |
81 | dataframe['label_string'] = dataframe.labels.apply(lambda x: re.sub(re.escape(parameters.string_split_delimiter), '', x))
82 | dataframe['label_len'] = dataframe.label_string.apply(lambda x: len(x))
83 |
84 | # remove long labels
85 | dataframe = dataframe[dataframe.label_len <= parameters.max_chars_per_string]
86 |
87 | # Compute width images (after resizing)
88 | dataframe['input_length'] = dataframe.paths.apply(lambda x: _compute_length_inputs(x, parameters.input_shape))
89 | dataframe.input_length = dataframe.input_length.apply(lambda x: np.floor(x / parameters.downscale_factor))
90 | # Remove items with longer label than input
91 | dataframe = dataframe[dataframe.label_len < dataframe.input_length]
92 |
93 | final_length = len(dataframe)
94 |
95 | n_removed = original_len - final_length
96 | if n_removed > 0:
97 | print('-- Removed {} samples ({:.2f} %)'.format(n_removed,
98 | 100 * n_removed / original_len))
99 |
100 | # Convert fields to list
101 | paths = dataframe.paths.to_list()
102 | labels = dataframe.labels.to_list()
103 |
104 | # Convert string labels to dense codes
105 | label_dense_codes, label_seq_length = _convert_label_to_dense_codes(labels,
106 | parameters.string_split_delimiter,
107 | parameters.max_chars_per_string,
108 | table_str2int)
109 | # format in string to be easily parsed by tf.data
110 | string_label_codes = [[str(ldc) for ldc in list_ldc] for list_ldc in label_dense_codes]
111 | string_label_codes = [' '.join(list_slc) for list_slc in string_label_codes]
112 |
113 | data = {'paths': paths, 'label_codes': string_label_codes, 'label_len': label_seq_length}
114 | new_dataframe = pd.DataFrame(data)
115 |
116 | new_dataframe.to_csv(output_csv_filename,
117 | sep=parameters.csv_delimiter,
118 | header=False,
119 | encoding='utf8',
120 | index=False,
121 | escapechar="\\",
122 | quoting=0)
123 | return len(new_dataframe)
124 |
125 |
126 | def data_preprocessing(params: Params) -> (str, str, int, int):
127 | """
128 | Preporcesses the data for the experiment (training and evaluation data).
129 | Exports the updated csv files into `/preprocessed/updated_{eval,train}.csv`
130 |
131 | :param params: parameters of the experiment (``Params``)
132 | :return: output path files, number of samples (for train and evaluation data)
133 | """
134 | output_dir = os.path.join(params.output_model_dir, CONST.PREPROCESSING_FOLDER)
135 | if not os.path.exists(output_dir):
136 | os.makedirs(output_dir)
137 | else:
138 | 'Output directory {} already exists'.format(output_dir)
139 |
140 | csv_train_output = os.path.join(output_dir, 'updated_train.csv')
141 | csv_eval_output = os.path.join(output_dir, 'updated_eval.csv')
142 |
143 | # Preprocess train csv
144 | n_samples_train = preprocess_csv(params.csv_files_train, params, csv_train_output)
145 |
146 | # Preprocess train csv
147 | n_samples_eval = preprocess_csv(params.csv_files_eval, params, csv_eval_output)
148 |
149 | return csv_train_output, csv_eval_output, n_samples_train, n_samples_eval
150 |
151 |
--------------------------------------------------------------------------------
/training.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | __author__ = "solivr"
3 | __license__ = "GPL"
4 |
5 | import logging
6 | logging.getLogger("tensorflow").setLevel(logging.INFO)
7 |
8 | from tf_crnn.config import Params
9 | from tf_crnn.model import get_model_train
10 | from tf_crnn.preprocessing import data_preprocessing
11 | from tf_crnn.data_handler import dataset_generator
12 | from tf_crnn.callbacks import CustomLoaderCallback, CustomSavingCallback, LRTensorBoard, EPOCH_FILENAME, FOLDER_SAVED_MODEL
13 | import tensorflow as tf
14 | import numpy as np
15 | import os
16 | import json
17 | import pickle
18 | from glob import glob
19 | from sacred import Experiment, SETTINGS
20 |
21 | SETTINGS.CONFIG.READ_ONLY_CONFIG = False
22 |
23 | ex = Experiment('crnn')
24 |
25 | ex.add_config('config.json')
26 |
27 | @ex.automain
28 | def training(_config: dict):
29 | parameters = Params(**_config)
30 |
31 | export_config_filename = os.path.join(parameters.output_model_dir, 'config.json')
32 | saving_dir = os.path.join(parameters.output_model_dir, FOLDER_SAVED_MODEL)
33 |
34 | if not parameters.restore_model:
35 | # check if output folder already exists
36 | assert not os.path.isdir(parameters.output_model_dir), \
37 | '{} already exists, you cannot use it as output directory.'.format(parameters.output_model_dir)
38 | # 'Set "restore_model=True" to continue training, or delete dir "rm -r {0}"'.format(parameters.output_model_dir)
39 | os.makedirs(parameters.output_model_dir)
40 |
41 | # data and csv preprocessing
42 | csv_train_file, csv_eval_file, \
43 | n_samples_train, n_samples_eval = data_preprocessing(parameters)
44 |
45 | # export config file in model output dir
46 | with open(export_config_filename, 'w') as file:
47 | json.dump(parameters.to_dict(), file)
48 |
49 | # Create callbacks
50 | logdir = os.path.join(parameters.output_model_dir, 'logs')
51 | tb_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir,
52 | profile_batch=0)
53 |
54 | lrtb_callback = LRTensorBoard(log_dir=logdir,
55 | profile_batch=0)
56 |
57 | lr_callback = tf.keras.callbacks.ReduceLROnPlateau(factor=0.5,
58 | patience=10,
59 | cooldown=0,
60 | min_lr=1e-8,
61 | verbose=1)
62 |
63 | es_callback = tf.keras.callbacks.EarlyStopping(min_delta=0.005,
64 | patience=20,
65 | verbose=1)
66 |
67 | sv_callback = CustomSavingCallback(saving_dir,
68 | saving_freq=parameters.save_interval,
69 | save_best_only=True,
70 | keep_max_models=3)
71 |
72 | list_callbacks = [tb_callback, lrtb_callback, lr_callback, es_callback, sv_callback]
73 |
74 | if parameters.restore_model:
75 | last_time_stamp = max([int(p.split(os.path.sep)[-1].split('-')[0])
76 | for p in glob(os.path.join(saving_dir, '*'))])
77 |
78 | loading_dir = os.path.join(saving_dir, str(last_time_stamp))
79 | ld_callback = CustomLoaderCallback(loading_dir)
80 |
81 | list_callbacks.append(ld_callback)
82 |
83 | with open(os.path.join(loading_dir, EPOCH_FILENAME), 'rb') as f:
84 | initial_epoch = pickle.load(f)
85 |
86 | epochs = initial_epoch + parameters.n_epochs
87 | else:
88 | initial_epoch = 0
89 | epochs = parameters.n_epochs
90 |
91 | # Get model
92 | model = get_model_train(parameters)
93 |
94 | # Get datasets
95 | dataset_train = dataset_generator([csv_train_file],
96 | parameters,
97 | batch_size=parameters.train_batch_size,
98 | data_augmentation=parameters.data_augmentation,
99 | num_epochs=parameters.n_epochs)
100 |
101 | dataset_eval = dataset_generator([csv_eval_file],
102 | parameters,
103 | batch_size=parameters.eval_batch_size,
104 | data_augmentation=False,
105 | num_epochs=parameters.n_epochs)
106 |
107 | # Train model
108 | model.fit(dataset_train,
109 | epochs=epochs,
110 | initial_epoch=initial_epoch,
111 | steps_per_epoch=np.floor(n_samples_train / parameters.train_batch_size),
112 | validation_data=dataset_eval,
113 | validation_steps=np.floor(n_samples_eval / parameters.eval_batch_size),
114 | callbacks=list_callbacks)
115 |
--------------------------------------------------------------------------------