├── Dockerfile ├── LICENSE ├── README.md ├── config.json ├── doc ├── Makefile ├── conf.py ├── index.rst ├── reference │ └── index.rst ├── references.bib ├── references.rst ├── require.txt └── start │ ├── index.rst │ ├── inference.rst │ ├── install.rst │ └── training.rst ├── environment.yml ├── evaluation.py ├── hlp ├── alphabet_helpers.py ├── csv_helpers.py ├── numbers_mnist_generator.py ├── prepare_iam.py └── string_data_manager.py ├── prediction.py ├── setup.py ├── tf_crnn ├── __init__.py ├── callbacks.py ├── config.py ├── data_handler.py ├── model.py └── preprocessing.py └── training.py /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM tensorflow/tensorflow:1.8.0-gpu 2 | 3 | # Python version 4 | RUN python -v 5 | 6 | # Additional requirements from Tensorflow 7 | RUN apt-get update && apt-get install -y python3 python3-dev 8 | 9 | # Clean up Python 3 install 10 | RUN curl -O https://bootstrap.pypa.io/get-pip.py && \ 11 | python3 get-pip.py && \ 12 | rm get-pip.py 13 | 14 | # Instal Notebook 15 | RUN pip3 install ipython notebook 16 | 17 | # Install tensorflow 1.8.0 (Does not actually work in 1.7.0) 18 | RUN pip3 install tensorflow-gpu==1.8.0 19 | 20 | # Copy and install TF-CRNN 21 | 22 | ADD . /script 23 | WORKDIR /script 24 | RUN python3 setup.py install 25 | 26 | # Add an additional sources directory 27 | # You should normalize the filepath in your data 28 | VOLUME /sources 29 | VOLUME /config 30 | 31 | # TensorBoard 32 | EXPOSE 6006 33 | # Allowing tensorflow to run and be read 34 | EXPOSE 8888 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | {one line to give the program's name and a brief idea of what it does.} 635 | Copyright (C) {year} {name of author} 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | {project} Copyright (C) {year} {fullname} 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Text recognition with Convolutional Recurrent Neural Network and TensorFlow 2.0 (tf2-crnn) 2 | 3 | [![Documentation Status](https://readthedocs.org/projects/tf-crnn/badge/?version=latest)](https://tf-crnn.readthedocs.io/en/latest/?badge=latest) 4 | 5 | Implementation of a Convolutional Recurrent Neural Network (CRNN) for image-based sequence recognition tasks, such as scene text recognition and OCR. 6 | 7 | This implementation is based on Tensorflow 2.0 and uses `tf.keras` and `tf.data` modules to build the model and to handle input data. 8 | 9 | To access the previous version implementing Shi et al. paper, go to the [v.0.5.2](https://github.com/solivr/tf-crnn/tree/v0.5.2) tag. 10 | 11 | 12 | ## Installation 13 | `tf_crnn` makes use of `tensorflow-gpu` package (so CUDA and cuDNN are needed). 14 | 15 | You can install it using the `environment.yml` file provided and use it within an environment. 16 | 17 | conda env create -f environment.yml 18 | 19 | See also the [docs](https://tf-crnn.readthedocs.io/en/latest/start/index.html#) for more information. 20 | 21 | 22 | ## Try it 23 | 24 | Train a model with [IAM dataset](http://www.fki.inf.unibe.ch/databases/iam-handwriting-database). 25 | 26 | **Create an account** 27 | 28 | Create an account on the official IAM dataset page in order to access the data. 29 | Export your credentials as enviornment variables, they will be used by the download script. 30 | 31 | export IAM_USER= 32 | export IAM_PWD= 33 | 34 | 35 | **Generate the data in the correct format** 36 | 37 | cd hlp 38 | python prepare_iam.py --download_dir ../data/iam --generated_data_dir ../data/iam/generated 39 | cd .. 40 | 41 | **Train the model** 42 | 43 | python training.py with config.json 44 | 45 | More details in the [documentation](https://tf-crnn.readthedocs.io/en/latest/start/training.html#example-of-training). 46 | -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "lookup_alphabet_file" : "./data/iam/generated/generated_alphabet/iam_alphabet_lookup.json", 3 | "csv_files_train" : "./data/iam/generated/generated_csv/lines_train.csv", 4 | "csv_files_eval" : "./data/iam/generated/generated_csv/lines_validation1.csv", 5 | "output_model_dir" : "./output_model", 6 | "num_beam_paths" : 1, 7 | "cnn_features_list" : [64, 128, 256, 512], 8 | "cnn_kernel_size" : [3, 3, 3, 3], 9 | "cnn_stride_size" : [[1, 1], [1, 1], [1, 1], [1, 1]], 10 | "cnn_pool_size" : [[2, 2], [2, 2], [2, 1], [2, 1]], 11 | "cnn_batch_norm" : [true, true, true, true], 12 | "max_chars_per_string" : 80, 13 | "n_epochs" : 200, 14 | "train_batch_size" : 64, 15 | "eval_batch_size" : 64, 16 | "learning_rate": 1e-3, 17 | "input_shape" : [64, 900], 18 | "rnn_units" : [128, 128, 128, 128], 19 | "restore_model" : false 20 | } -------------------------------------------------------------------------------- /doc/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = . 8 | BUILDDIR = _build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /doc/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Configuration file for the Sphinx documentation builder. 4 | # 5 | # This file does only contain a selection of the most common options. For a 6 | # full list see the documentation: 7 | # http://www.sphinx-doc.org/en/master/config 8 | 9 | # -- Path setup -------------------------------------------------------------- 10 | 11 | # If extensions (or modules to document with autodoc) are in another directory, 12 | # add these directories to sys.path here. If the directory is relative to the 13 | # documentation root, use os.path.abspath to make it absolute, like shown here. 14 | # 15 | import os 16 | import sys 17 | sys.path.insert(0, os.path.abspath('..')) 18 | 19 | 20 | # -- Project information ----------------------------------------------------- 21 | 22 | project = 'tf_crnn' 23 | copyright = '2019, Digital Humanities Lab - EPFL' 24 | author = 'Sofia ARES OLIVEIRA' 25 | 26 | # The short X.Y version 27 | version = '' 28 | # The full version, including alpha/beta/rc tags 29 | release = '' 30 | 31 | 32 | # -- General configuration --------------------------------------------------- 33 | 34 | # If your documentation needs a minimal Sphinx version, state it here. 35 | # 36 | # needs_sphinx = '1.0' 37 | 38 | # Add any Sphinx extension module names here, as strings. They can be 39 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 40 | # ones. 41 | extensions = [ 42 | 'sphinx.ext.autodoc', 43 | 'sphinx.ext.autosummary', 44 | 'sphinx.ext.coverage', 45 | 'sphinx.ext.viewcode', 46 | 'sphinx.ext.githubpages', 47 | 'sphinxcontrib.bibtex', # for bibtex 48 | 'sphinx_autodoc_typehints' 49 | ] 50 | 51 | # Add any paths that contain templates here, relative to this directory. 52 | templates_path = ['_templates'] 53 | 54 | # The suffix(es) of source filenames. 55 | # You can specify multiple suffix as a list of string: 56 | # 57 | # source_suffix = ['.rst', '.md'] 58 | source_suffix = '.rst' 59 | 60 | # The master toctree document. 61 | master_doc = 'index' 62 | 63 | # The language for content autogenerated by Sphinx. Refer to documentation 64 | # for a list of supported languages. 65 | # 66 | # This is also used if you do content translation via gettext catalogs. 67 | # Usually you set "language" from the command line for these cases. 68 | language = None 69 | 70 | # List of patterns, relative to source directory, that match files and 71 | # directories to ignore when looking for source files. 72 | # This pattern also affects html_static_path and html_extra_path. 73 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 74 | 75 | # The name of the Pygments (syntax highlighting) style to use. 76 | pygments_style = None 77 | 78 | 79 | # -- Options for HTML output ------------------------------------------------- 80 | 81 | # The theme to use for HTML and HTML Help pages. See the documentation for 82 | # a list of builtin themes. 83 | # 84 | html_theme = 'sphinx_rtd_theme' 85 | 86 | # Theme options are theme-specific and customize the look and feel of a theme 87 | # further. For a list of options available for each theme, see the 88 | # documentation. 89 | # 90 | # html_theme_options = {} 91 | 92 | # Add any paths that contain custom static files (such as style sheets) here, 93 | # relative to this directory. They are copied after the builtin static files, 94 | # so a file named "default.css" will overwrite the builtin "default.css". 95 | html_static_path = ['_static'] 96 | 97 | # Custom sidebar templates, must be a dictionary that maps document names 98 | # to template names. 99 | # 100 | # The default sidebars (for documents that don't match any pattern) are 101 | # defined by theme itself. Builtin themes are using these templates by 102 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 103 | # 'searchbox.html']``. 104 | # 105 | # html_sidebars = {} 106 | 107 | 108 | # -- Options for HTMLHelp output --------------------------------------------- 109 | 110 | # Output file base name for HTML help builder. 111 | htmlhelp_basename = 'tf_crnndoc' 112 | 113 | 114 | # -- Options for LaTeX output ------------------------------------------------ 115 | 116 | latex_elements = { 117 | # The paper size ('letterpaper' or 'a4paper'). 118 | # 119 | # 'papersize': 'letterpaper', 120 | 121 | # The font size ('10pt', '11pt' or '12pt'). 122 | # 123 | # 'pointsize': '10pt', 124 | 125 | # Additional stuff for the LaTeX preamble. 126 | # 127 | # 'preamble': '', 128 | 129 | # Latex figure (float) alignment 130 | # 131 | # 'figure_align': 'htbp', 132 | } 133 | 134 | # Grouping the document tree into LaTeX files. List of tuples 135 | # (source start file, target name, title, 136 | # author, documentclass [howto, manual, or own class]). 137 | latex_documents = [ 138 | (master_doc, 'tf_crnn.tex', 'tf\\_crnn Documentation', 139 | author, 'manual'), 140 | ] 141 | 142 | 143 | # -- Options for manual page output ------------------------------------------ 144 | 145 | # One entry per manual page. List of tuples 146 | # (source start file, name, description, authors, manual section). 147 | man_pages = [ 148 | (master_doc, 'tf_crnn', 'tf_crnn Documentation', 149 | [author], 1) 150 | ] 151 | 152 | 153 | # -- Options for Texinfo output ---------------------------------------------- 154 | 155 | # Grouping the document tree into Texinfo files. List of tuples 156 | # (source start file, target name, title, author, 157 | # dir menu entry, description, category) 158 | texinfo_documents = [ 159 | (master_doc, 'tf_crnn', 'tf_crnn Documentation', 160 | author, 'tf_crnn', 'One line description of project.', 161 | 'Miscellaneous'), 162 | ] 163 | 164 | 165 | # -- Options for Epub output ------------------------------------------------- 166 | 167 | # Bibliographic Dublin Core info. 168 | epub_title = project 169 | 170 | # The unique identifier of the text. This can be a ISBN number 171 | # or the project homepage. 172 | # 173 | # epub_identifier = '' 174 | 175 | # A unique identification for the text. 176 | # 177 | # epub_uid = '' 178 | 179 | # A list of files that should not be packed into the epub file. 180 | epub_exclude_files = ['search.html'] 181 | 182 | 183 | # -- Extension configuration ------------------------------------------------- 184 | autodoc_mock_imports = [ 185 | # 'numpy', 186 | 'tensorflow', 187 | 'tensorflow_addons', 188 | 'pandas', 189 | 'typing', 190 | 'cv2' 191 | ] -------------------------------------------------------------------------------- /doc/index.rst: -------------------------------------------------------------------------------- 1 | .. tf_crnn documentation master file, created by 2 | sphinx-quickstart on Mon Jan 7 14:43:48 2019. 3 | 4 | =============================================================================== 5 | TF-CRNN : A TensorFlow implementation of Convolutional Recurrent Neural Network 6 | =============================================================================== 7 | 8 | .. toctree:: 9 | :maxdepth: 2 10 | 11 | start/index 12 | reference/index 13 | references 14 | .. :caption: Contents: 15 | 16 | A TensorFlow implementation of the Convolutional Recurrent Neural Network (CRNN) for image-based sequence recognition 17 | tasks, such as scene text recognition and OCR. 18 | 19 | This implementation uses ``tf.keras`` to build the model and ``tf.data`` modules to handle input data. 20 | 21 | 22 | Indices and tables 23 | ================== 24 | 25 | * :ref:`genindex` 26 | * :ref:`modindex` 27 | * :ref:`search` -------------------------------------------------------------------------------- /doc/reference/index.rst: -------------------------------------------------------------------------------- 1 | =============== 2 | Reference guide 3 | =============== 4 | 5 | .. automodule:: tf_crnn 6 | 7 | 8 | .. automodule:: tf_crnn.data_handler 9 | :members: 10 | :undoc-members: 11 | 12 | .. automodule:: tf_crnn.config 13 | :members: 14 | :undoc-members: 15 | :exclude-members: CONST 16 | 17 | .. automodule:: tf_crnn.model 18 | :members: 19 | :undoc-members: 20 | 21 | .. automodule:: tf_crnn.callbacks 22 | :members: 23 | :undoc-members: 24 | 25 | .. automodule:: tf_crnn.preprocessing 26 | :members: 27 | :undoc-members: 28 | -------------------------------------------------------------------------------- /doc/references.bib: -------------------------------------------------------------------------------- 1 | @article{marti2002iam, 2 | title={The IAM-database: an English sentence database for offline handwriting recognition}, 3 | author={Marti, U-V and Bunke, Horst}, 4 | journal={International Journal on Document Analysis and Recognition}, 5 | volume={5}, 6 | number={1}, 7 | pages={39--46}, 8 | year={2002}, 9 | publisher={Springer} 10 | } 11 | -------------------------------------------------------------------------------- /doc/references.rst: -------------------------------------------------------------------------------- 1 | ========== 2 | References 3 | ========== 4 | 5 | .. bibliography:: references.bib 6 | :cited: 7 | :all: 8 | :style: alpha -------------------------------------------------------------------------------- /doc/require.txt: -------------------------------------------------------------------------------- 1 | sphinx 2 | sphinx-autodoc-typehints 3 | sphinx-rtd-theme 4 | sphinxcontrib-bibtex 5 | sphinxcontrib-websupport 6 | -------------------------------------------------------------------------------- /doc/start/index.rst: -------------------------------------------------------------------------------- 1 | Quickstart 2 | ========== 3 | 4 | .. toctree:: 5 | install 6 | training 7 | .. inference -------------------------------------------------------------------------------- /doc/start/inference.rst: -------------------------------------------------------------------------------- 1 | Using a saved model for prediction 2 | ---------------------------------- 3 | 4 | During the training, the model is exported every *n* epochs (you can set *n* in the config file, by default *n=5*). 5 | The exported models are ``SavedModel`` TensorFlow objects, which need to be loaded in order to be used. 6 | 7 | Assuming that the output folder is named ``output_dir``, the exported models will be saved in ``output_dir/export/`` 8 | with different timestamps for each export. Each ```` folder contains a ``saved_model.pb`` 9 | file and a ``variables`` folder. 10 | 11 | The ``saved_model.pb`` contains the graph definition of your model and the ``variables`` folder contains the 12 | saved variables (where the weights are stored). You can find more information about SavedModel 13 | on the `TensorFlow dedicated page `_. 14 | 15 | 16 | In order to easily handle the loading of the exported models, a ``PredictionModel`` class is provided and 17 | you can use the trained model to transcribe new image segments in the following way : 18 | 19 | .. code-block:: python 20 | 21 | import tensorflow as tf 22 | from tf_crnn.loader import PredictionModel 23 | 24 | model_directory = 'output/export//' 25 | image_filename = 'data/images/b04-034-04-04.png' 26 | 27 | with tf.Session() as session: 28 | model = PredictionModel(model_directory, signature='filename') 29 | prediction = model.predict(image_filename) 30 | 31 | -------------------------------------------------------------------------------- /doc/start/install.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ------------ 3 | 4 | ``tf_crnn`` uses ``tensorflow-gpu`` package, which needs CUDA and CuDNN libraries for GPU support. Tensorflow 5 | `GPU support page `_ lists the requirements. 6 | 7 | Using Anaconda 8 | ^^^^^^^^^^^^^^ 9 | 10 | When using Anaconda (or Miniconda), conda will install automatically the compatible versions of CUDA and CuDNN :: 11 | 12 | conda env create -f environment.yml 13 | 14 | 15 | From `this page `_: 16 | 17 | When the GPU accelerated version of TensorFlow is installed using conda, by the command 18 | “conda install tensorflow-gpu”, these libraries are installed automatically, with versions 19 | known to be compatible with the tensorflow-gpu package. Furthermore, conda installs these libraries 20 | into a location where they will not interfere with other instances of these libraries that may have 21 | been installed via another method. Regardless of using pip or conda-installed tensorflow-gpu, 22 | the NVIDIA driver must be installed separately. 23 | 24 | .. Using ``pip`` 25 | ^^^^^^^^^^^^^ 26 | 27 | Before using ``tf_crnn`` we recommend creating a virtual environment (python 3.5). 28 | Then, install the dependencies using Github repository's ``setup.py`` file. :: 29 | 30 | pip install git+https://github.com/solivr/tf-crnn 31 | 32 | You will then need to install CUDA and CuDNN libraries manually. 33 | 34 | 35 | .. Using Docker 36 | ^^^^^^^^^^^^ 37 | (thanks to `PonteIneptique `_) 38 | 39 | The ``Dockerfile`` in the root directory allows you to run the whole program as a Docker Nvidia Tensorflow GPU container. 40 | This is potentially helpful to deal with external dependencies like CUDA and the likes. 41 | 42 | You can follow installations processes here : 43 | 44 | - docker-ce : `Ubuntu `_ 45 | - nvidia-docker : `Ubuntu `_ 46 | 47 | Once this is installed, we will need to build the image of the container by doing : :: 48 | 49 | nvidia-docker build . --tag tf-crnn 50 | 51 | 52 | Our container model is now named ``tf-crnn``. 53 | We will be able to run it from ``nvidia-docker run -it tf-crnn:latest bash`` 54 | which will open a bash directory exactly where you are. Although, we recommend using :: 55 | 56 | nvidia-docker run -it -p 8888:8888 -p 6006:6006 -v /absolute/path/to/here/config:./config -v $INPUT_DATA:/sources tf-crnn:latest bash 57 | 58 | where ``$INPUT_DATA`` should be replaced by the directory where you have your training and testing data. 59 | This will get mounted on the ``sources`` folder. We propose to mount by default ``./config`` to the current ``./config`` directory. 60 | Path need to be absolute path. We also recommend to change :: 61 | 62 | //... 63 | "output_model_dir" : "/.output/" 64 | 65 | 66 | to :: 67 | 68 | //... 69 | "output_model_dir" : "/config/output" 70 | 71 | 72 | **Do not forget** to rename your training and testing file path, as well as renaming the path to their 73 | image by ``/sources/.../file.{png,jpg}`` 74 | 75 | 76 | .. note:: if you are uncomfortable with bash, you can always replace bash by ``ipython3 notebook --allow-root`` 77 | and go to your browser on ``http://localhost:8888/`` . A token will be shown in the terminal -------------------------------------------------------------------------------- /doc/start/training.rst: -------------------------------------------------------------------------------- 1 | How to train a model 2 | -------------------- 3 | 4 | ``sacred`` package is used to deal with experiments. 5 | If you are not yet familiar with it, have a quick look at the `documentation `_. 6 | 7 | Input data 8 | ^^^^^^^^^^ 9 | 10 | In order to train a model, you should input a csv file with each row containing the filename of the image (full path) 11 | and its label (plain text) separated by a delimiting character (let's say ``;``). 12 | Also, each character should be separated by a splitting character (let's say ``|``), this in order to deal with arbitrary 13 | alphabets (especially characters that cannot be encoded with ``utf-8`` format). 14 | 15 | An example of such csv file would look like : :: 16 | 17 | /full/path/to/image1.{jpg,png};|s|t|r|i|n|g|_|l|a|b|e|l|1| 18 | /full/path/to/image2.{jpg,png};|s|t|r|i|n|g|_|l|a|b|e|l|2| |w|i|t|h| |special_char| 19 | ... 20 | 21 | Input lookup alphabet file 22 | ^^^^^^^^^^^^^^^^^^^^^^^^^^ 23 | 24 | You also need to provide a lookup table for the *alphabet* that will be used. The term *alphabet* refers to all the 25 | symbols you want the network to learn, whether they are characters, digits, symbols, abbreviations, or any other graphical element. 26 | 27 | The lookup table is a dictionary mapping alphabet units to integer codes (i.e {'char' : }). 28 | Some lookup tables are already provided as examples in ``data/alphabet/``. 29 | 30 | For example to transcribe words that contain only the characters *'abcdefg'*, one possible lookup table would be : :: 31 | 32 | {'a': 1, 'b': 2, 'c': 3, 'd': 4. 'e': 5, 'f': 6, 'g': 7} 33 | 34 | The lookup table / dictionary needs to be saved in a json file. 35 | 36 | Config file (with ``sacred``) 37 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 38 | 39 | Set the parameters of the experiment in ``config.json``. The file looks like this : :: 40 | 41 | { 42 | "lookup_alphabet_file" : "./data/alphabet/lookup.json", 43 | "csv_files_train" : "./data/csv_experiments/train_data.csv", 44 | "csv_files_eval" : "./data/csv_experiments/validation_data.csv", 45 | "output_model_dir" : "./output_model", 46 | "num_beam_paths" : 1, 47 | "max_chars_per_string" : 80, 48 | "n_epochs" : 50, 49 | "train_batch_size" : 64, 50 | "eval_batch_size" : 64, 51 | "learning_rate": 1e-4, 52 | "input_shape" : [128, 1400], 53 | "restore_model" : false 54 | } 55 | 56 | In order to use your data, you should change the parameters ``csv_files_train``, ``csv_files_eval`` and ``lookup_alphabet_file``. 57 | 58 | All the configurable parameters can be found in class ``tf_crnn.config.Params``, which can be added to the config file if needed. 59 | 60 | Training 61 | ^^^^^^^^ 62 | 63 | Once you have your input csv and alphabet file completed, and the parameters set in ``config.json``, 64 | we will use ``sacred`` syntax to launch the training : :: 65 | 66 | python training.py with config.json 67 | 68 | The saved model and logs will then be exported to the folder specified in the config file (``output_model_dir``). 69 | 70 | 71 | Example of training 72 | ------------------- 73 | 74 | We will use the `IAM Database `_ :cite:`marti2002iam` 75 | as an example to generate the data in the correct input data and train a model. 76 | 77 | Go to the official page to download the dataset and create an account in order to access the data. 78 | You don't need to download the data yourself, the ``prepare_iam.py`` script will take care of that for you. 79 | 80 | Generating data 81 | ^^^^^^^^^^^^^^^ 82 | 83 | First create the ``IAM_USER`` and ``IAM_PWD`` environment variable to store your credentials, they will be used by the download script :: 84 | 85 | export IAM_USER= 86 | export IAM_PWD= 87 | 88 | 89 | Run the script ``hlp/prepare_iam.py`` in order to download the data, extract it and format it correctly to train a model. :: 90 | 91 | cd hlp 92 | python prepare_iam.py --download_dir ../data/iam --generated_data_dir ../data/iam/generated 93 | cd .. 94 | 95 | The images of the lines are extracted in ``data/iam/lines/`` and the folder ``data/generated/`` contains all the 96 | additional files necessary to run the experiment. The csv files are saved in ``data/generated/generated_csv`` and 97 | the alphabet is placed in ``data/generated/generated_alphabet``. 98 | 99 | Training the model 100 | ^^^^^^^^^^^^^^^^^^ 101 | 102 | Make sure the ``config.json`` file has the correct paths for training and validation data, as well as for the 103 | alphabet lookup file and run: :: 104 | 105 | python training.py with config.json 106 | 107 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: crnn-tf2 2 | dependencies: 3 | - python=3.6 4 | - imageio 5 | - numpy 6 | - tqdm 7 | - pandas 8 | - click 9 | - pip 10 | - pip: 11 | - sacred 12 | - opencv-python 13 | - tensorflow-gpu>=2.0 14 | - tensorflow-addons>=0.5 15 | - git+https://github.com/solivr/taputapu.git#egg=taputapu 16 | - sphinx 17 | - sphinx-autodoc-typehints 18 | - sphinx-rtd-theme 19 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = "solivr" 3 | __license__ = "GPL" 4 | 5 | import os 6 | from glob import glob 7 | 8 | import click 9 | from tf_crnn.callbacks import CustomLoaderCallback, FOLDER_SAVED_MODEL 10 | from tf_crnn.config import Params, CONST 11 | from tf_crnn.data_handler import dataset_generator 12 | from tf_crnn.preprocessing import preprocess_csv 13 | from tf_crnn.model import get_model_train 14 | 15 | 16 | @click.command() 17 | @click.option('--csv_filename') 18 | @click.option('--model_dir') 19 | def evaluation(csv_filename: str, 20 | model_dir: str): 21 | 22 | config_filename = os.path.join(model_dir, 'config.json') 23 | parameters = Params.from_json_file(config_filename) 24 | 25 | saving_dir = os.path.join(parameters.output_model_dir, FOLDER_SAVED_MODEL) 26 | 27 | # Callback for model weights loading 28 | last_time_stamp = max([int(p.split(os.path.sep)[-1].split('-')[0]) 29 | for p in glob(os.path.join(saving_dir, '*'))]) 30 | loading_dir = os.path.join(saving_dir, str(last_time_stamp)) 31 | ld_callback = CustomLoaderCallback(loading_dir) 32 | 33 | # Preprocess csv data 34 | csv_evaluation_file = os.path.join(parameters.output_model_dir, CONST.PREPROCESSING_FOLDER, 'evaluation_data.csv') 35 | n_samples = preprocess_csv(csv_filename, 36 | parameters, 37 | csv_evaluation_file) 38 | 39 | dataset_evaluation = dataset_generator([csv_evaluation_file], 40 | parameters, 41 | batch_size=parameters.eval_batch_size, 42 | num_epochs=1) 43 | 44 | # get model and evaluation 45 | model = get_model_train(parameters) 46 | eval_output = model.evaluate(dataset_evaluation, 47 | callbacks=[ld_callback]) 48 | print('-- Metrics: ', eval_output) 49 | 50 | 51 | if __name__ == '__main__': 52 | evaluation() 53 | -------------------------------------------------------------------------------- /hlp/alphabet_helpers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = "solivr" 3 | __license__ = "GPL" 4 | 5 | from typing import List, Union 6 | import csv 7 | import json 8 | import numpy as np 9 | import pandas as pd 10 | 11 | 12 | def get_alphabet_units_from_input_data(csv_filename: str, 13 | split_char: str='|'): 14 | """ 15 | Get alphabet units from the input_data csv file (which contains in each row the tuple 16 | (filename image segment, transcription formatted)) 17 | 18 | :param csv_filename: csv file containing the input data 19 | :param split_char: splitting character in input_data separting the alphabet units 20 | :return: 21 | """ 22 | df = pd.read_csv(csv_filename, sep=';', header=None, names=['image', 'labels'], 23 | encoding='utf8', escapechar="\\", quoting=3) 24 | transcriptions = list(df.labels.apply(lambda x: x.split(split_char))) 25 | 26 | unique_units = np.unique([chars for list_chars in transcriptions for chars in list_chars]) 27 | 28 | return unique_units 29 | 30 | 31 | def generate_alphabet_file(csv_filenames: List[str], 32 | alphabet_filename: str): 33 | """ 34 | 35 | :param csv_filenames: 36 | :param alphabet_filename: 37 | :return: 38 | """ 39 | symbols = list() 40 | for file in csv_filenames: 41 | symbols.append(get_alphabet_units_from_input_data(file)) 42 | 43 | alphabet_units = np.unique(np.concatenate(symbols)) 44 | 45 | alphabet_lookup = dict([(au, i+1)for i, au in enumerate(alphabet_units)]) 46 | 47 | with open(alphabet_filename, 'w') as f: 48 | json.dump(alphabet_lookup, f) 49 | 50 | 51 | def get_abbreviations_from_csv(csv_filename: str) -> List[str]: 52 | with open(csv_filename, 'r', encoding='utf8') as f: 53 | csvreader = csv.reader(f, delimiter='\n') 54 | alphabet_units = [row[0] for row in csvreader] 55 | return alphabet_units 56 | 57 | 58 | # def make_json_lookup_alphabet(string_chars: str=None) -> dict: 59 | # """ 60 | # 61 | # :param string_chars: for example string.ascii_letters, string.digits 62 | # :return: 63 | # """ 64 | # lookup = dict() 65 | # if string_chars: 66 | # # Add characters to lookup table 67 | # lookup.update({char: ord(char) for char in string_chars}) 68 | # 69 | # return map_lookup(lookup) 70 | 71 | 72 | # def load_lookup_from_json(json_filenames: Union[List[str], str])-> dict: 73 | # """ 74 | # Load a lookup table from a json file to a dictionnary 75 | # :param json_filenames: either a filename or a list of filenames 76 | # :return: 77 | # """ 78 | # 79 | # lookup = dict() 80 | # if isinstance(json_filenames, list): 81 | # for file in json_filenames: 82 | # with open(file, 'r', encoding='utf8') as f: 83 | # data_dict = json.load(f) 84 | # lookup.update(data_dict) 85 | # 86 | # elif isinstance(json_filenames, str): 87 | # with open(json_filenames, 'r', encoding='utf8') as f: 88 | # lookup = json.load(f) 89 | # 90 | # return map_lookup(lookup) 91 | 92 | 93 | # def map_lookup(lookup_table: dict, unique_entry: bool=True)-> dict: 94 | # """ 95 | # Converts an existing lookup table with minimal range code ([1, len(lookup_table)-1]) 96 | # and avoids multiple instances of the same code label (bijectivity) 97 | # 98 | # :param lookup_table: dictionary to be mapped {alphabet_unit : code label} 99 | # :param unique_entry: If each alphabet unit has a unique code and each code a unique alphabet unique ('bijective'), 100 | # only True is implemented for now 101 | # :return: a mapped dictionary 102 | # """ 103 | # 104 | # # Create tuple (alphabet unit, code) 105 | # tuple_char_code = list(zip(list(lookup_table.keys()), list(lookup_table.values()))) 106 | # # Sort by code 107 | # tuple_char_code.sort(key=lambda x: x[1]) 108 | # 109 | # # If each alphabet unit has a unique code and each code a unique alphabet unique ('bijective') 110 | # if unique_entry: 111 | # mapped_lookup = [[tp[0], i + 1] for i, tp in enumerate(tuple_char_code)] 112 | # else: 113 | # raise NotImplementedError 114 | # # Todo 115 | # 116 | # return dict(mapped_lookup) 117 | -------------------------------------------------------------------------------- /hlp/csv_helpers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = 'solivr' 3 | __license__ = "GPL" 4 | 5 | import csv 6 | import os 7 | import argparse 8 | from tqdm import tqdm, trange 9 | 10 | 11 | def csv_rel2abs_path_convertor(csv_filenames: str, delimiter: str=' ', encoding='utf8') -> None: 12 | """ 13 | Convert relative paths into absolute paths 14 | 15 | :param csv_filenames: filename of csv 16 | :param delimiter: character to delimit felds in csv 17 | :param encoding: encoding format of csv file 18 | :return: 19 | """ 20 | 21 | for filename in tqdm(csv_filenames): 22 | absolute_path, basename = os.path.split(os.path.abspath(filename)) 23 | relative_paths = list() 24 | labels = list() 25 | # Reading CSV 26 | with open(filename, 'r', encoding=encoding) as f: 27 | csvreader = csv.reader(f, delimiter=delimiter) 28 | for row in csvreader: 29 | relative_paths.append(row[0]) 30 | labels.append(row[1]) 31 | 32 | # Writing converted_paths CSV 33 | export_filename = os.path.join(absolute_path, '{}_abs{}'.format(*os.path.splitext(basename))) 34 | with open(export_filename, 'w', encoding=encoding) as f: 35 | csvwriter = csv.writer(f, delimiter=delimiter) 36 | for i in trange(0, len(relative_paths)): 37 | csvwriter.writerow([os.path.abspath(os.path.join(absolute_path, relative_paths[i])), labels[i]]) 38 | 39 | 40 | def csv_filtering_chars_from_labels(csv_filename: str, chars_to_remove: str, 41 | delimiter: str=' ', encoding='utf8') -> int: 42 | """ 43 | Remove labels containing chars_to_remove in csv_filename 44 | 45 | :param chars_to_remove: string (or list) with the undesired characters 46 | :param csv_filename: filenmae of csv 47 | :param delimiter: delimiter character 48 | :param encoding: encoding format of csv file 49 | :return: number of deleted labels 50 | """ 51 | 52 | if not isinstance(chars_to_remove, list): 53 | chars_to_remove = list(chars_to_remove) 54 | 55 | paths = list() 56 | labels = list() 57 | n_deleted = 0 58 | with open(csv_filename, 'r', encoding=encoding) as file: 59 | csvreader = csv.reader(file, delimiter=delimiter) 60 | for row in csvreader: 61 | if not any((d in chars_to_remove) for d in row[1]): 62 | paths.append(row[0]) 63 | labels.append(row[1]) 64 | else: 65 | n_deleted += 1 66 | 67 | with open(csv_filename, 'w', encoding=encoding) as file: 68 | csvwriter = csv.writer(file, delimiter=delimiter) 69 | for i in tqdm(range(len(paths)), total=len(paths)): 70 | csvwriter.writerow([paths[i], labels[i]]) 71 | 72 | return n_deleted 73 | 74 | 75 | if __name__ == '__main__': 76 | parser = argparse.ArgumentParser() 77 | parser.add_argument('-i', '--input_files', type=str, required=True, help='CSV filename to convert', nargs='*') 78 | parser.add_argument('-d', '--delimiter_char', type=str, help='CSV delimiter character', default=' ') 79 | 80 | args = vars(parser.parse_args()) 81 | 82 | csv_filenames = args.get('input_files') 83 | 84 | csv_rel2abs_path_convertor(csv_filenames, delimiter=args.get('delimiter_char')) 85 | 86 | -------------------------------------------------------------------------------- /hlp/numbers_mnist_generator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = 'solivr' 3 | __license__ = "GPL" 4 | 5 | from tensorflow.examples.tutorials.mnist import input_data 6 | import numpy as np 7 | import os 8 | import csv 9 | from imageio import imsave 10 | from tqdm import tqdm 11 | import random 12 | import argparse 13 | 14 | 15 | def generate_random_image_numbers(mnist_dir, dataset, output_dir, csv_filename, n_numbers): 16 | 17 | mnist = input_data.read_data_sets(mnist_dir, one_hot=False) 18 | 19 | output_dir_img = os.path.join(output_dir, 'images') 20 | if not os.path.exists(output_dir): 21 | os.mkdir(output_dir) 22 | if not os.path.exists(output_dir_img): 23 | os.mkdir(output_dir_img) 24 | 25 | if dataset == 'train': 26 | dataset = mnist.train 27 | elif dataset == 'validation': 28 | dataset = mnist.validation 29 | elif dataset == 'test': 30 | dataset = mnist.test 31 | 32 | list_paths = list() 33 | list_labels = list() 34 | 35 | for i in tqdm(range(n_numbers), total=n_numbers): 36 | n_digits = random.randint(3, 8) 37 | digits, labels = dataset.next_batch(n_digits) 38 | # Reshape to have 28x28 image 39 | square_digits = np.reshape(digits, [-1, 28, 28]) 40 | # White background 41 | square_digits = -(square_digits - 1) * 255 42 | stacked_number = np.hstack(square_digits[:, :, 4:-4]) 43 | stacked_label = ''.join(map(str, labels)) 44 | # chans3 = np.dstack([stacked_number]*3) 45 | 46 | # Save image number 47 | img_filename = '{:09}_{}.jpg'.format(i, stacked_label) 48 | img_path = os.path.join(output_dir_img, img_filename) 49 | imsave(img_path, stacked_number) 50 | 51 | # Add to list of paths and list of labels 52 | list_paths.append(img_filename) 53 | list_labels.append(stacked_label) 54 | 55 | root = './images' 56 | csv_path = os.path.join(output_dir, csv_filename) 57 | with open(csv_path, 'w') as csvfile: 58 | for i in tqdm(range(len(list_paths)), total=len(list_paths)): 59 | csvwriter = csv.writer(csvfile, delimiter=' ') 60 | csvwriter.writerow([os.path.join(root, list_paths[i]), list_labels[i]]) 61 | 62 | 63 | if __name__ == '__main__': 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument('-md', '--mnist_dir', type=str, help='Directory for MNIST data', default='./MNIST_data') 66 | parser.add_argument('-d', '--dataset', type=str, help='Dataset wanted (train, test, validation)', default='train') 67 | parser.add_argument('-csv', '--csv_filename', type=str, help='CSV filename to output paths and labels') 68 | parser.add_argument('-od', '--output_dir', type=str, help='Directory to output images and csv files', default='./output_numbers') 69 | parser.add_argument('-n', '--n_samples', type=int, help='Desired numbers of generated samples', default=1000) 70 | 71 | args = parser.parse_args() 72 | 73 | generate_random_image_numbers(args.mnist_dir, args.dataset, args.output_dir, args.csv_filename, args.n_samples) 74 | 75 | 76 | -------------------------------------------------------------------------------- /hlp/prepare_iam.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = "solivr" 3 | __license__ = "GPL" 4 | 5 | from taputapu.databases import iam 6 | import os 7 | from glob import glob 8 | from string_data_manager import tf_crnn_label_formatting 9 | from alphabet_helpers import generate_alphabet_file 10 | import click 11 | 12 | 13 | @click.command() 14 | @click.option('--download_dir') 15 | @click.option('--generated_data_dir') 16 | def prepare_iam_data(download_dir: str, 17 | generated_data_dir: str): 18 | 19 | # Download data 20 | print('Starting downloads...') 21 | iam.download(download_dir) 22 | 23 | # Extract archives 24 | print('Starting extractions...') 25 | iam.extract(download_dir) 26 | 27 | print('Generating files for the experiment...') 28 | # Generate splits (same format as ascii files) 29 | export_splits_dir = os.path.join(generated_data_dir, 'generated_splits') 30 | os.makedirs(export_splits_dir, exist_ok=True) 31 | 32 | iam.generate_splits_txt(os.path.join(download_dir, 'ascii', 'lines.txt'), 33 | os.path.join(download_dir, 'largeWriterIndependentTextLineRecognitionTask'), 34 | export_splits_dir) 35 | 36 | # Generate csv from .txt splits files 37 | export_csv_dir = os.path.join(generated_data_dir, 'generated_csv') 38 | os.makedirs(export_csv_dir, exist_ok=True) 39 | 40 | for file in glob(os.path.join(export_splits_dir, '*')): 41 | export_basename = os.path.basename(file).split('.')[0] 42 | iam.create_experiment_csv(file, 43 | os.path.join(download_dir, 'lines'), 44 | os.path.join(export_csv_dir, export_basename + '.csv'), 45 | False, 46 | True) 47 | 48 | # Format string label to tf_crnn formatting 49 | for csv_filename in glob(os.path.join(export_csv_dir, '*')): 50 | tf_crnn_label_formatting(csv_filename) 51 | 52 | # Generate alphabet 53 | alphabet_dir = os.path.join(generated_data_dir, 'generated_alphabet') 54 | os.makedirs(alphabet_dir, exist_ok=True) 55 | 56 | generate_alphabet_file(glob(os.path.join(export_csv_dir, '*')), 57 | os.path.join(alphabet_dir, 'iam_alphabet_lookup.json')) 58 | 59 | 60 | if __name__ == '__main__': 61 | prepare_iam_data() 62 | -------------------------------------------------------------------------------- /hlp/string_data_manager.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = 'solivr' 3 | __licence__ = 'GPL' 4 | 5 | import pandas as pd 6 | 7 | _accents_list = 'àéèìîóòù' 8 | _accent_mapping = {'à': 'a', 9 | 'é': 'e', 10 | 'è': 'e', 11 | 'ì': 'i', 12 | 'î': 'i', 13 | 'ó': 'o', 14 | 'ò': 'o', 15 | 'ù': 'u'} 16 | 17 | 18 | def map_accentuated_characters_in_dataframe(dataframe_transcriptions: pd.DataFrame, 19 | dict_mapping: dict=_accent_mapping) -> pd.DataFrame: 20 | """ 21 | 22 | :param dataframe_transcriptions: must have a field 'transcription' 23 | :param dict_mapping 24 | :return: 25 | """ 26 | items = dataframe_transcriptions.transcription.iteritems() 27 | 28 | for i in range(dataframe_transcriptions.transcription.count()): 29 | df_id, transcription = next(items) 30 | # https://stackoverflow.com/questions/30020184/how-to-find-the-first-index-of-any-of-a-set-of-characters-in-a-string 31 | ch_index = next((i for i, ch in enumerate(transcription) if ch in _accents_list), None) 32 | while ch_index is not None: 33 | transcription = list(transcription) 34 | ch = transcription[ch_index] 35 | transcription[ch_index] = dict_mapping[ch] 36 | transcription = ''.join(transcription) 37 | dataframe_transcriptions.at[df_id, 'transcription'] = transcription 38 | ch_index = next((i for i, ch in enumerate(transcription) if ch in _accents_list), None) 39 | 40 | return dataframe_transcriptions 41 | 42 | 43 | def map_accentuated_characters_in_string(string_to_format: str, dict_mapping: dict=_accent_mapping) -> str: 44 | """ 45 | 46 | :param string_to_format: 47 | :param dict_mapping: 48 | :return: 49 | """ 50 | # https://stackoverflow.com/questions/30020184/how-to-find-the-first-index-of-any-of-a-set-of-characters-in-a-string 51 | ch_index = next((i for i, ch in enumerate(string_to_format) if ch in _accents_list), None) 52 | while ch_index is not None: 53 | string_to_format = list(string_to_format) 54 | ch = string_to_format[ch_index] 55 | string_to_format[ch_index] = dict_mapping[ch] 56 | string_to_format = ''.join(string_to_format) 57 | ch_index = next((i for i, ch in enumerate(string_to_format) if ch in _accents_list), None) 58 | 59 | return string_to_format 60 | 61 | 62 | def format_string_for_tf_split(string_to_format: str, 63 | separator_character: str= '|', 64 | replace_brackets_abbreviations=True) -> str: 65 | """ 66 | Formats transcriptions to be split by tf.string_split using character separator "|" 67 | 68 | :param string_to_format: string to format 69 | :param separator_character: character that separates alphabet units 70 | :param replace_brackets_abbreviations: if True will replace '[' and ']' chars by separator character 71 | :return: 72 | """ 73 | 74 | if replace_brackets_abbreviations: 75 | # Replace "[]" chars by "|" 76 | string_to_format = string_to_format.replace("[", separator_character).replace("]", separator_character) 77 | 78 | splits = string_to_format.split(separator_character) 79 | 80 | final_string = separator_character 81 | # Case where string starts with a separator_character 82 | if splits[0] == '': 83 | for i, sp in enumerate(splits): 84 | if i % 2 > 0: # uneven -> abbreviation 85 | final_string += separator_character + sp + separator_character 86 | else: # even -> no abbreviation 87 | final_string += sp.replace('', separator_character)[1:-1] 88 | 89 | else: 90 | for i, sp in enumerate(splits): 91 | if i % 2 > 0: # uneven -> no abbreviation 92 | final_string += separator_character + sp + separator_character 93 | else: # even -> abbreviation 94 | final_string += sp.replace('', separator_character)[1:-1] 95 | 96 | # Add separator at beginning or end of string if it hasn't been added yet 97 | if final_string[1] == separator_character: 98 | final_string = final_string[1:] 99 | if final_string[-1] != separator_character: 100 | final_string += separator_character 101 | 102 | return final_string 103 | 104 | 105 | def tf_crnn_label_formatting(csv_filename: str): 106 | 107 | def _string_formatting(string_to_format: str, 108 | separator_character: str = '|'): 109 | chars = list(string_to_format) 110 | formated_string = separator_character + '{}'.format(separator_character).join(chars) + separator_character 111 | return formated_string 112 | 113 | df = pd.read_csv(csv_filename, sep=';', header=None, names=['image', 'labels'], encoding='utf8', 114 | escapechar="\\", quoting=3) 115 | 116 | df.labels = df.labels.apply(lambda x: _string_formatting(x)) 117 | 118 | df.to_csv(csv_filename, sep=';', encoding='utf-8', header=False, index=False, escapechar="\\", quoting=3) 119 | 120 | 121 | def lower_abbreviation_in_string(string_to_format: str): 122 | # Split with '[' 123 | tokens_opening = string_to_format.split('[') 124 | 125 | valid_string = True 126 | final_string = tokens_opening[0] 127 | for tok in tokens_opening[1:]: 128 | if len(tok) > 1: 129 | token_closing = tok.split(']') 130 | if len(token_closing) == 2: # checks if abbreviation starts with [ and ends with ] 131 | final_string += '[' + token_closing[0].lower() + ']' + token_closing[1] 132 | else: # No closing ']' 133 | valid_string = False 134 | else: 135 | final_string += ']' 136 | if valid_string: 137 | return final_string 138 | else: 139 | return '' 140 | 141 | 142 | def add_abbreviation_brackets(label: str): 143 | """ 144 | Adds brackets in formatted strings i.e label= '|B|e|n|e|t|t|a| |M|a|z|z|o|l|e|n|i| |quondam| |A|n|z|o|l|o|' 145 | turns to '|B|e|n|e|t|t|a| |M|a|z|z|o|l|e|n|i| |[quondam]| |A|n|z|o|l|o|' 146 | :param label: 147 | :return: 148 | """ 149 | splits = label.split('|') 150 | 151 | is_abbrev = [len(tok) > 1 for tok in splits] 152 | bracketing = ['[' + tok + ']' if abbrev else tok for (tok, abbrev) in zip(splits, is_abbrev)] 153 | 154 | return '|'.join(bracketing) -------------------------------------------------------------------------------- /prediction.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = "solivr" 3 | __license__ = "GPL" 4 | 5 | import os 6 | from glob import glob 7 | 8 | import click 9 | from tf_crnn.callbacks import CustomPredictionSaverCallback, FOLDER_SAVED_MODEL 10 | from tf_crnn.config import Params 11 | from tf_crnn.data_handler import dataset_generator 12 | from tf_crnn.model import get_model_inference 13 | 14 | 15 | @click.command() 16 | @click.option('--csv_filename', help='A csv file containing the path to the images to predict') 17 | @click.option('--output_model_dir', help='Directory where all the exported data related to an experiment has been saved') 18 | def prediction(csv_filename: str, 19 | output_model_dir: str): 20 | parameters = Params.from_json_file(os.path.join(output_model_dir, 'config.json')) 21 | 22 | saving_dir = os.path.join(output_model_dir, FOLDER_SAVED_MODEL) 23 | last_time_stamp = str(max([int(p.split(os.path.sep)[-1].split('-')[0]) 24 | for p in glob(os.path.join(saving_dir, '*'))])) 25 | model = get_model_inference(parameters, os.path.join(saving_dir, last_time_stamp, 'weights.h5')) 26 | 27 | dataset_test = dataset_generator([csv_filename], 28 | parameters, 29 | use_labels=False, 30 | batch_size=parameters.eval_batch_size, 31 | shuffle=False) 32 | 33 | ps_callback = CustomPredictionSaverCallback(output_model_dir, parameters) 34 | 35 | _, _, _ = model.predict(x=dataset_test, callbacks=[ps_callback]) 36 | 37 | 38 | if __name__ == '__main__': 39 | prediction() 40 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = "solivr" 3 | __license__ = "GPL" 4 | 5 | from setuptools import setup, find_packages 6 | 7 | setup(name='tf_crnn', 8 | version='0.6.0', 9 | license='GPL', 10 | author='Sofia Ares Oliveira', 11 | url='https://github.com/solivr/tf-crnn', 12 | description='TensorFlow Convolutional Recurrent Neural Network (CRNN)', 13 | install_requires=[ 14 | 'imageio', 15 | 'numpy', 16 | 'tqdm', 17 | 'sacred', 18 | 'opencv-python', 19 | 'pandas', 20 | 'click', 21 | #'tensorflow-addons', 22 | 'tensorflow-gpu', 23 | 'taputapu' 24 | ], 25 | dependency_links=['https://github.com/solivr/taputapu/tarball/master#egg=taputapu-1.0'], 26 | extras_require={ 27 | 'doc': [ 28 | 'sphinx', 29 | 'sphinx-autodoc-typehints', 30 | 'sphinx-rtd-theme', 31 | 'sphinxcontrib-bibtex', 32 | 'sphinxcontrib-websupport' 33 | ], 34 | }, 35 | packages=find_packages(where='.'), 36 | zip_safe=False) 37 | -------------------------------------------------------------------------------- /tf_crnn/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | 3 | 4 | Data handling for input function 5 | -------------------------------- 6 | .. currentmodule:: tf_crnn.data_handler 7 | 8 | .. autosummary:: 9 | dataset_generator 10 | padding_inputs_width 11 | augment_data 12 | random_rotation 13 | 14 | 15 | Model definitions 16 | ----------------- 17 | .. currentmodule:: tf_crnn.model 18 | 19 | .. autosummary:: 20 | ConvBlock 21 | get_model_train 22 | get_model_inference 23 | get_crnn_output 24 | 25 | 26 | Config for training 27 | ------------------- 28 | .. currentmodule:: tf_crnn.config 29 | 30 | .. autosummary:: 31 | Alphabet 32 | Params 33 | import_params_from_json 34 | 35 | 36 | Custom Callbacks 37 | ---------------- 38 | .. currentmodule:: tf_crnn.callbacks 39 | 40 | .. autosummary:: 41 | CustomSavingCallback 42 | LRTensorBoard 43 | CustomLoaderCallback 44 | CustomPredictionSaverCallback 45 | 46 | 47 | Preprocessing data 48 | ------------------ 49 | .. currentmodule:: tf_crnn.preprocessing 50 | 51 | .. autosummary:: 52 | data_preprocessing 53 | preprocess_csv 54 | 55 | 56 | ---- 57 | 58 | """ 59 | 60 | _DATA_HANDLING = [ 61 | 'dataset_generator', 62 | 'padding_inputs_width', 63 | 'augment_data', 64 | 'random_rotation' 65 | ] 66 | 67 | _CONFIG = [ 68 | 'Alphabet', 69 | 'Params', 70 | 'import_params_from_json' 71 | 72 | ] 73 | 74 | _MODEL = [ 75 | 'ConvBlock', 76 | 'get_model_train', 77 | 'get_model_inference' 78 | 'get_crnn_output' 79 | ] 80 | 81 | _CALLBACKS = [ 82 | 'CustomSavingCallback', 83 | 'CustomLoaderCallback', 84 | 'CustomPredictionSaverCallback', 85 | 'LRTensorBoard' 86 | ] 87 | 88 | _PREPROCESSING = [ 89 | 'data_preprocessing', 90 | 'preprocess_csv' 91 | ] 92 | 93 | __all__ = _DATA_HANDLING + _CONFIG + _MODEL + _CALLBACKS + _PREPROCESSING 94 | 95 | from tf_crnn.config import * 96 | from tf_crnn.model import * 97 | from tf_crnn.callbacks import * 98 | from tf_crnn.preprocessing import * 99 | from tf_crnn.data_handler import * -------------------------------------------------------------------------------- /tf_crnn/callbacks.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = "solivr" 3 | __license__ = "GPL" 4 | 5 | import tensorflow as tf 6 | from tensorflow.keras.callbacks import Callback, TensorBoard 7 | import os 8 | import shutil 9 | import pickle 10 | import json 11 | import time 12 | import numpy as np 13 | from .config import Params 14 | 15 | 16 | MODEL_WEIGHTS_FILENAME = 'weights.h5' 17 | OPTIMIZER_WEIGHTS_FILENAME = 'optimizer_weights.pkl' 18 | LEARNING_RATE_FILENAME = 'learning_rate.pkl' 19 | LAYERS_FILENAME = 'architecture.json' 20 | EPOCH_FILENAME = 'epoch.pkl' 21 | FOLDER_SAVED_MODEL = 'saving' 22 | 23 | 24 | class CustomSavingCallback(Callback): 25 | """ 26 | Callback to save weights, architecture, and optimizer at the end of training. 27 | Inspired by `ModelCheckpoint`. 28 | 29 | :ivar output_dir: path to the folder where files will be saved 30 | :vartype output_dir: str 31 | :ivar saving_freq: save every `n` epochs 32 | :vartype saving_freq: int 33 | :ivar save_best_only: wether to save a model if it is best thant the last saving 34 | :vartype save_best_only: bool 35 | :ivar keep_max_models: number of models to keep, the older ones will be deleted 36 | :vartype keep_max_models: int 37 | """ 38 | def __init__(self, 39 | output_dir: str, 40 | saving_freq:int, 41 | save_best_only: bool=False, 42 | keep_max_models:int=5): 43 | super(CustomSavingCallback, self).__init__() 44 | 45 | self.saving_dir = output_dir 46 | self.saving_freq = saving_freq 47 | self.save_best_only = save_best_only 48 | self.keep_max_models = keep_max_models 49 | 50 | self.epochs_since_last_save = 0 51 | 52 | self.monitor = 'val_loss' 53 | self.monitor_op = np.less 54 | self.best_value = np.Inf # todo: when restoring model we could also restore val_loss and metric 55 | 56 | def on_epoch_begin(self, 57 | epoch, 58 | logs=None): 59 | self._current_epoch = epoch 60 | 61 | def on_epoch_end(self, 62 | epoch, 63 | logs=None): 64 | 65 | self.logs = logs 66 | self.epochs_since_last_save += 1 67 | 68 | if self.epochs_since_last_save == self.saving_freq: 69 | self._export_model(logs) 70 | self.epochs_since_last_save = 0 71 | 72 | def on_train_end(self, 73 | logs=None): 74 | self._export_model(self.logs) 75 | self.epochs_since_last_save = 0 76 | 77 | 78 | def _export_model(self, logs): 79 | timestamp = str(int(time.time())) 80 | folder = os.path.join(self.saving_dir, timestamp) 81 | 82 | if self.save_best_only: 83 | current_value = logs.get(self.monitor) 84 | 85 | if self.monitor_op(current_value, self.best_value): 86 | print('\n{} improved from {:0.5f} to {:0.5f},' 87 | ' saving model to {}'.format(self.monitor, self.best_value, 88 | current_value, folder)) 89 | self.best_value = current_value 90 | 91 | else: 92 | print('\n{} did not improve from {:0.5f}'.format(self.monitor, self.best_value)) 93 | return 94 | 95 | os.makedirs(folder) 96 | 97 | # save architecture 98 | model_json = self.model.to_json() 99 | with open(os.path.join(folder, LAYERS_FILENAME), 'w') as f: 100 | json.dump(model_json, f) 101 | 102 | # model weights 103 | self.model.save_weights(os.path.join(folder, MODEL_WEIGHTS_FILENAME)) 104 | 105 | # optimizer weights 106 | optimizer_weights = tf.keras.backend.batch_get_value(self.model.optimizer.weights) 107 | with open(os.path.join(folder, OPTIMIZER_WEIGHTS_FILENAME), 'wb') as f: 108 | pickle.dump(optimizer_weights, f) 109 | 110 | # learning rate 111 | learning_rate = self.model.optimizer.learning_rate 112 | with open(os.path.join(folder, LEARNING_RATE_FILENAME), 'wb') as f: 113 | pickle.dump(learning_rate, f) 114 | 115 | # n epochs 116 | epoch = self._current_epoch + 1 117 | with open(os.path.join(folder, EPOCH_FILENAME), 'wb') as f: 118 | pickle.dump(epoch, f) 119 | 120 | self._clean_exports() 121 | 122 | def _clean_exports(self): 123 | timestamp_folders = [int(f) for f in os.listdir(self.saving_dir)] 124 | timestamp_folders.sort(reverse=True) 125 | 126 | if len(timestamp_folders) > self.keep_max_models: 127 | folders_to_remove = timestamp_folders[self.keep_max_models:] 128 | for f in folders_to_remove: 129 | shutil.rmtree(os.path.join(self.saving_dir, str(f))) 130 | 131 | 132 | 133 | class CustomLoaderCallback(Callback): 134 | """ 135 | Callback to load necessary weight and parameters for training, evaluation and prediction. 136 | 137 | :ivar loading_dir: path to directory to save logs 138 | :vartype loading_dir: str 139 | """ 140 | def __init__(self, 141 | loading_dir: str): 142 | super(CustomLoaderCallback, self).__init__() 143 | 144 | self.loading_dir = loading_dir 145 | 146 | def set_model(self, model): 147 | self.model = model 148 | 149 | print('-- Loading ', self.loading_dir) 150 | # Load model weights 151 | self.model.load_weights(os.path.join(self.loading_dir, MODEL_WEIGHTS_FILENAME)) 152 | 153 | # Load optimizer params 154 | with open(os.path.join(self.loading_dir, OPTIMIZER_WEIGHTS_FILENAME), 'rb') as f: 155 | optimizer_weights = pickle.load(f) 156 | with open(os.path.join(self.loading_dir, LEARNING_RATE_FILENAME), 'rb') as f: 157 | learning_rate = pickle.load(f) 158 | 159 | # Set optimizer params 160 | self.model.optimizer.learning_rate.assign(learning_rate) 161 | self.model._make_train_function() 162 | self.model.optimizer.set_weights(optimizer_weights) 163 | 164 | 165 | class CustomPredictionSaverCallback(Callback): 166 | """ 167 | Callback to save prediction decoded outputs. 168 | This will save the decoded outputs into a file. 169 | 170 | :ivar output_dir: path to directory to save logs 171 | :vartype output_dir: str 172 | :ivar parameters: parameters of the experiment (``Params``) 173 | :vartype parameters: Params 174 | """ 175 | def __init__(self, 176 | output_dir: str, 177 | parameters: Params): 178 | super(CustomPredictionSaverCallback, self).__init__() 179 | 180 | self.saving_dir = output_dir 181 | self.parameters = parameters 182 | 183 | # Inference 184 | def on_predict_begin(self, 185 | logs=None): 186 | # Create file to add predictions 187 | timestamp = str(int(time.time())) 188 | self._prediction_filename = os.path.join(self.saving_dir, 'predictions-{}.txt'.format(timestamp)) 189 | 190 | def on_predict_batch_end(self, 191 | batch, 192 | logs): 193 | logits, seq_len, filenames = logs['outputs'] 194 | 195 | codes = tf.keras.backend.ctc_decode(logits, tf.squeeze(seq_len), greedy=True)[0][0].numpy() 196 | strings = [''.join([self.parameters.alphabet.lookup_int2str[c] for c in lc if c != -1]) for lc in codes] 197 | 198 | with open(self._prediction_filename, 'ab') as f: 199 | for n, s in zip(filenames, strings): 200 | n = n[0] # n is a list of one element 201 | f.write((n.decode() + ';' + s + '\n').encode('utf8')) 202 | 203 | 204 | class LRTensorBoard(TensorBoard): 205 | """ 206 | Adds learning rate to TensorBoard scalars. 207 | 208 | :ivar logdir: path to directory to save logs 209 | :vartype logdir: str 210 | """ 211 | # From https://github.com/keras-team/keras/pull/9168#issuecomment-359901128 212 | def __init__(self, 213 | log_dir: str, 214 | **kwargs): # add other arguments to __init__ if you need 215 | super(LRTensorBoard, self).__init__(log_dir=log_dir, **kwargs) 216 | 217 | def on_epoch_end(self, 218 | epoch, 219 | logs=None): 220 | logs.update({'lr': tf.keras.backend.eval(self.model.optimizer.lr)}) 221 | super(LRTensorBoard, self).on_epoch_end(epoch, logs) -------------------------------------------------------------------------------- /tf_crnn/config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = 'solivr' 3 | __license__ = "GPL" 4 | 5 | import csv 6 | import json 7 | import os 8 | import string 9 | from functools import reduce 10 | from glob import glob 11 | from typing import List, Union 12 | import pandas as pd 13 | 14 | 15 | class CONST: 16 | DIMENSION_REDUCTION_W_POOLING = 2*2 # 2x2 pooling in dimension W on layer 1 and 2 17 | PREPROCESSING_FOLDER = 'preprocessed' 18 | 19 | 20 | class Alphabet: 21 | """ 22 | Class for alphabet / symbols units. 23 | 24 | :ivar _blank_symbol: Blank symbol used for CTC 25 | :vartype _blank_symbol: str 26 | :ivar _alphabet_units: list of elements composing the alphabet. The units may be a single character or multiple characters. 27 | :vartype _alphabet_units: List[str] 28 | :ivar _codes: Each alphabet unit has a unique corresponding code. 29 | :vartype _codes: List[int] 30 | :ivar _nclasses: number of alphabet units. 31 | :vartype _nclasses: int 32 | """ 33 | def __init__(self, lookup_alphabet_file: str=None, blank_symbol: str='$'): 34 | 35 | self._blank_symbol = blank_symbol 36 | 37 | if lookup_alphabet_file: 38 | lookup_alphabet = self.load_lookup_from_json(lookup_alphabet_file) 39 | # Blank symbol must have the largest value 40 | if self._blank_symbol in lookup_alphabet.keys(): 41 | 42 | # TODO : check if blank symbol is the last one 43 | assert lookup_alphabet[self._blank_symbol] == max(lookup_alphabet.values()), \ 44 | "Blank symbol should have the largest code integer" 45 | lookup_alphabet[self._blank_symbol] = max(lookup_alphabet.values()) + 1 46 | else: 47 | lookup_alphabet.update({self._blank_symbol: max(lookup_alphabet.values()) + 1}) 48 | 49 | self._alphabet_units = list(lookup_alphabet.keys()) 50 | self._codes = list(lookup_alphabet.values()) 51 | self._nclasses = len(self.codes) + 1 # n_classes should be + 1 of labels codes 52 | 53 | if 0 in self._codes: 54 | raise ValueError('0 code is in the lookup table, you should''nt use it.') 55 | 56 | self.lookup_int2str = dict(zip(self.codes, self.alphabet_units)) 57 | 58 | def check_input_file_alphabet(self, csv_filenames: List[str], 59 | discarded_chars: str=';|{}'.format(string.whitespace[1:]), 60 | csv_delimiter: str=";") -> None: 61 | """ 62 | Checks if labels of input files contains only characters that are in the Alphabet. 63 | 64 | :param csv_filenames: list of the csv filename 65 | :param discarded_chars: discarded characters 66 | :param csv_delimiter: character delimiting field in the csv file 67 | :return: 68 | """ 69 | assert isinstance(csv_filenames, list), 'csv_filenames argument is not a list' 70 | 71 | alphabet_set = set(self.alphabet_units) 72 | 73 | for filename in csv_filenames: 74 | input_chars_set = set() 75 | 76 | with open(filename, 'r', encoding='utf8') as f: 77 | csvreader = csv.reader(f, delimiter=csv_delimiter, escapechar='\\', quoting=0) 78 | for line in csvreader: 79 | input_chars_set.update(line[1]) 80 | 81 | # Discard all whitespaces except space ' ' 82 | for whitespace in discarded_chars: 83 | input_chars_set.discard(whitespace) 84 | 85 | extra_chars = input_chars_set - alphabet_set 86 | assert len(extra_chars) == 0, 'There are {} unknown chars in {} : {}'.format(len(extra_chars), 87 | filename, extra_chars) 88 | 89 | @classmethod 90 | def map_lookup(cls, lookup_table: dict, unique_entry: bool = True) -> dict: 91 | """ 92 | Converts an existing lookup table with minimal range code ([1, len(lookup_table)-1]) 93 | and avoids multiple instances of the same code label (bijectivity) 94 | 95 | :param lookup_table: dictionary to be mapped {alphabet_unit : code label} 96 | :param unique_entry: If each alphabet unit has a unique code and each code a unique alphabet unique ('bijective'), 97 | only True is implemented for now 98 | :return: a mapped dictionary 99 | """ 100 | 101 | # Create tuple (alphabet unit, code) 102 | tuple_char_code = list(zip(list(lookup_table.keys()), list(lookup_table.values()))) 103 | # Sort by code 104 | tuple_char_code.sort(key=lambda x: x[1]) 105 | 106 | # If each alphabet unit has a unique code and each code a unique alphabet unique ('bijective') 107 | if unique_entry: 108 | mapped_lookup = [[tp[0], i + 1] for i, tp in enumerate(tuple_char_code)] 109 | else: 110 | raise NotImplementedError 111 | # Todo 112 | 113 | return dict(mapped_lookup) 114 | 115 | @classmethod 116 | def create_lookup_from_labels(cls, csv_files: List[str], export_lookup_filename: str, 117 | original_lookup_filename: str=None): 118 | """ 119 | Create a lookup dictionary for csv files containing labels. Exports a json file with the Alphabet. 120 | 121 | :param csv_files: list of files to get the labels from (should be of format path;label) 122 | :param export_lookup_filename: filename to export alphabet lookup dictionary 123 | :param original_lookup_filename: original lookup filename to update (optional) 124 | :return: 125 | """ 126 | if original_lookup_filename: 127 | with open(original_lookup_filename, 'r') as f: 128 | lookup = json.load(f) 129 | set_chars = set(list(lookup.keys())) 130 | else: 131 | set_chars = set(list(string.ascii_letters) + list(string.digits)) 132 | lookup = dict() 133 | 134 | for filename in csv_files: 135 | data = pd.read_csv(filename, sep=';', encoding='utf8', error_bad_lines=False, header=None, 136 | names=['path', 'transcription'], escapechar='\\') 137 | for index, row in data.iterrows(): 138 | set_chars.update(row.transcription.split('|')) 139 | 140 | # Update (key, values) of lookup table 141 | for el in set_chars: 142 | if el not in lookup.keys(): 143 | lookup[el] = max(lookup.values()) + 1 if lookup.values() else 0 144 | 145 | lookup = cls.map_lookup(lookup) 146 | 147 | # Save new lookup 148 | with open(export_lookup_filename, 'w', encoding='utf8') as f: 149 | json.dump(lookup, f) 150 | 151 | @classmethod 152 | def load_lookup_from_json(cls, json_filenames: Union[List[str], str]) -> dict: 153 | """ 154 | Load a lookup table from a json file to a dictionnary 155 | :param json_filenames: either a filename or a list of filenames 156 | :return: 157 | """ 158 | 159 | lookup = dict() 160 | if isinstance(json_filenames, list): 161 | for file in json_filenames: 162 | with open(file, 'r', encoding='utf8') as f: 163 | data_dict = json.load(f) 164 | lookup.update(data_dict) 165 | 166 | elif isinstance(json_filenames, str): 167 | with open(json_filenames, 'r', encoding='utf8') as f: 168 | lookup = json.load(f) 169 | 170 | return cls.map_lookup(lookup) 171 | 172 | @classmethod 173 | def make_json_lookup_alphabet(cls, string_chars: str = None) -> dict: 174 | """ 175 | 176 | :param string_chars: for example string.ascii_letters, string.digits 177 | :return: 178 | """ 179 | lookup = dict() 180 | if string_chars: 181 | # Add characters to lookup table 182 | lookup.update({char: ord(char) for char in string_chars}) 183 | 184 | return cls.map_lookup(lookup) 185 | 186 | @property 187 | def n_classes(self): 188 | return self._nclasses 189 | 190 | @property 191 | def blank_symbol(self): 192 | return self._blank_symbol 193 | 194 | @property 195 | def codes(self): 196 | return self._codes 197 | 198 | @property 199 | def alphabet_units(self): 200 | return self._alphabet_units 201 | 202 | 203 | class Params: 204 | """ 205 | Class for parameters of the model and the experiment 206 | 207 | :ivar input_shape: input shape of the image to batch (this is the shape after data augmentation). 208 | The original will either be resized or pad depending on its original size 209 | :vartype input_shape: Tuple[int, int] 210 | :ivar input_channels: number of color channels for input image (default: 1) 211 | :vartype input_channels: int 212 | :ivar cnn_features_list: a list of length `n_layers` containing the number of features for each convolutionl layer 213 | (default: [16, 32, 64, 96, 128]) 214 | :vartype cnn_features_list: List(int) 215 | :ivar cnn_kernel_size: a list of length `n_layers` containing the size of the kernel for each convolutionl layer 216 | (default: [3, 3, 3, 3, 3]) 217 | :vartype cnn_kernel_size: List(int) 218 | :ivar cnn_stride_size: a list of length `n_layers` containing the stride size each convolutionl layer 219 | (default: [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) 220 | :vartype cnn_stride_size: List((int, int)) 221 | :ivar cnn_pool_size: a list of length `n_layers` containing the pool size each MaxPool layer 222 | default: ([(2, 2), (2, 2), (2, 2), (2, 2), (1, 1)]) 223 | :vartype cnn_pool_size: List((int, int)) 224 | :ivar cnn_batch_norm: a list of length `n_layers` containing a bool that indicated wether or not to use batch normalization 225 | (default: [False, False, False, False, False]) 226 | :vartype cnn_batch_norm: List(bool) 227 | :ivar rnn_units: a list containing the number of units per rnn layer (default: 256) 228 | :vartype rnn_units: List(int) 229 | :ivar num_beam_paths: number of paths (transcriptions) to return for ctc beam search (only used when predicting) 230 | :vartype num_beam_paths: int 231 | :ivar csv_delimiter: character to delimit csv input files (default: ';') 232 | :vartype csv_delimiter: str 233 | :ivar string_split_delimiter: character that delimits each alphabet unit in the labels (default: '|') 234 | :vartype string_split_delimiter: str 235 | :ivar csv_files_train: csv filename which contains the (path;label) of each training sample 236 | :vartype csv_files_train: str 237 | :ivar csv_files_eval: csv filename which contains the (path;label) of each eval sample 238 | :vartype csv_files_eval: str 239 | :ivar lookup_alphabet_file: json file that contains the mapping alphabet units <-> codes 240 | :vartype lookup_alphabet_file: str 241 | :ivar blank_symbol: symbol for to be considered as blank by the CTC decoder (default: '$') 242 | :vartype blank_symbol: str 243 | :ivar max_chars_per_string: maximum characters per sample (to avoid CTC decoder errors) (default: 75) 244 | :vartype max_chars_per_string: int 245 | :ivar data_augmentation: if True augments data on the fly (default: true) 246 | :vartype data_augmentation: bool 247 | :ivar data_augmentation_max_rotation: max permitted roation to apply to image during training in radians (default: 0.005) 248 | :vartype data_augmentation_max_rotation: float 249 | :ivar data_augmentation_max_slant: maximum angle for slant augmentation (default: 0.7) 250 | :vartype data_augmentation_max_slant: float 251 | :ivar n_epochs: numbers of epochs to run the training (default: 50) 252 | :vartype n_epochs: int 253 | :ivar train_batch_size: batch size during training (default: 64) 254 | :vartype train_batch_size: int 255 | :ivar eval_batch_size: batch size during evaluation (default: 128) 256 | :vartype eval_batch_size: int 257 | :ivar learning_rate: initial learning rate (default: 1e-4) 258 | :vartype learning_rate: float 259 | :ivar evaluate_every_epoch: evaluate every 'evaluate_every_epoch' epoch (default: 5) 260 | :vartype evaluate_every_epoch: int 261 | :ivar save_interval: save the model every 'save_interval' epoch (default: 20) 262 | :vartype save_interval: int 263 | :ivar optimizer: which optimizer to use ('adam', 'rms', 'ada') (default: 'adam') 264 | :vartype optimizer: str 265 | :ivar output_model_dir: output directory where the model will be saved and exported 266 | :vartype output_model_dir: str 267 | :ivar restore_model: boolean to continue training with saved weights (default: False) 268 | :vartype restore_model: bool 269 | """ 270 | def __init__(self, **kwargs): 271 | # model params 272 | self.input_shape = kwargs.get('input_shape', (96, 1400)) 273 | self.input_channels = kwargs.get('input_channels', 1) 274 | self.cnn_features_list = kwargs.get('cnn_features_list', [16, 32, 64, 96, 128]) 275 | self.cnn_kernel_size = kwargs.get('cnn_kernel_size', [3, 3, 3, 3, 3]) 276 | self.cnn_stride_size = kwargs.get('cnn_stride_size', [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) 277 | self.cnn_pool_size = kwargs.get('cnn_pool_size', [(2, 2), (2, 2), (2, 2), (2, 2), (1, 1)]) 278 | self.cnn_batch_norm = kwargs.get('cnn_batch_norm', [False, False, False, False, False]) 279 | self.rnn_units = kwargs.get('rnn_units', [256, 256]) 280 | # self._keep_prob_dropout = kwargs.get('keep_prob_dropout', 0.5) 281 | self.num_beam_paths = kwargs.get('num_beam_paths', 1) 282 | # csv params 283 | self.csv_delimiter = kwargs.get('csv_delimiter', ';') 284 | self.string_split_delimiter = kwargs.get('string_split_delimiter', '|') 285 | self.csv_files_train = kwargs.get('csv_files_train') 286 | self.csv_files_eval = kwargs.get('csv_files_eval') 287 | # alphabet params 288 | self.blank_symbol = kwargs.get('blank_symbol', '$') 289 | self.max_chars_per_string = kwargs.get('max_chars_per_string', 75) 290 | self.lookup_alphabet_file = kwargs.get('lookup_alphabet_file') 291 | # data augmentation params 292 | self.data_augmentation = kwargs.get('data_augmentation', True), 293 | self.data_augmentation_max_rotation = kwargs.get('data_augmentation_max_rotation', 0.005) 294 | self.data_augmentation_max_slant = kwargs.get('data_augmentation_max_slant', 0.7) 295 | # training params 296 | self.n_epochs = kwargs.get('n_epochs', 50) 297 | self.train_batch_size = kwargs.get('train_batch_size', 64) 298 | self.eval_batch_size = kwargs.get('eval_batch_size', 128) 299 | self.learning_rate = kwargs.get('learning_rate', 1e-4) 300 | self.optimizer = kwargs.get('optimizer', 'adam') 301 | self.output_model_dir = kwargs.get('output_model_dir', '') 302 | self.evaluate_every_epoch = kwargs.get('evaluate_every_epoch', 5) 303 | self.save_interval = kwargs.get('save_interval', 20) 304 | self.restore_model = kwargs.get('restore_model', False) 305 | 306 | self._assign_alphabet() 307 | 308 | cnn_params = zip(self.cnn_pool_size, self.cnn_stride_size) 309 | self.downscale_factor = reduce(lambda i, j: i * j, map(lambda k: k[0][1] * k[1][1], cnn_params)) 310 | 311 | # TODO add additional checks for the architecture 312 | assert len(self.cnn_features_list) == len(self.cnn_kernel_size) == len(self.cnn_stride_size) \ 313 | == len(self.cnn_pool_size) == len(self.cnn_batch_norm), \ 314 | "Length of parameters of model are not the same, check that all the layers parameters have the same length." 315 | 316 | max_input_width = (self.max_chars_per_string + 1) * self.downscale_factor 317 | assert max_input_width <= self.input_shape[1], "Maximum length of labels is {}, input width should be greater or " \ 318 | "equal to {} but is {}".format(self.max_chars_per_string, 319 | max_input_width, 320 | self.input_shape[1]) 321 | 322 | assert self.optimizer in ['adam', 'rms', 'ada'], 'Unknown optimizer {}'.format(self.optimizer) 323 | 324 | if os.path.isdir(self.output_model_dir): 325 | print('WARNING : The output directory {} already exists.'.format(self.output_model_dir)) 326 | 327 | def show_experiment_params(self) -> dict: 328 | """ 329 | Returns a dictionary with the variables of the class. 330 | 331 | :return: 332 | """ 333 | return vars(self) 334 | 335 | def _assign_alphabet(self): 336 | self.alphabet = Alphabet(lookup_alphabet_file=self.lookup_alphabet_file, blank_symbol=self.blank_symbol) 337 | 338 | # @property 339 | # def keep_prob_dropout(self): 340 | # return self._keep_prob_dropout 341 | # 342 | # @keep_prob_dropout.setter 343 | # def keep_prob_dropout(self, value): 344 | # assert (0.0 < value <= 1.0), 'Must be 0.0 < value <= 1.0' 345 | # self._keep_prob_dropout = value 346 | 347 | def to_dict(self) -> dict: 348 | """ 349 | Returns the parameters as a dictionary 350 | 351 | :return: 352 | """ 353 | new_dict = self.__dict__.copy() 354 | del new_dict['alphabet'] 355 | del new_dict['downscale_factor'] 356 | return new_dict 357 | 358 | @classmethod 359 | def from_json_file(cls, json_file: str): 360 | """ 361 | Given a json file, creates a ``Params`` object. 362 | 363 | :param json_file: path to the json file 364 | :return: ``Params`` object 365 | """ 366 | with open(json_file, 'r') as file: 367 | config = json.load(file) 368 | 369 | return cls(**config) 370 | 371 | 372 | def import_params_from_json(model_directory: str=None, json_filename: str=None) -> dict: 373 | """ 374 | Read the exported json file with parameters of the experiment. 375 | 376 | :param model_directory: Direcoty where the odel was exported 377 | :param json_filename: filename of the file 378 | :return: a dictionary containing the parameters of the experiment 379 | """ 380 | 381 | assert not all(p is None for p in [model_directory, json_filename]), 'One argument at least should not be None' 382 | 383 | if model_directory: 384 | # Import parameters from the json file 385 | try: 386 | json_filename = glob(os.path.join(model_directory, 'model_params*.json'))[-1] 387 | except IndexError: 388 | print('No json found in dir {}'.format(model_directory)) 389 | raise FileNotFoundError 390 | else: 391 | if not os.path.isfile(json_filename): 392 | print('No json found with filename {}'.format(json_filename)) 393 | raise FileNotFoundError 394 | 395 | with open(json_filename, 'r') as data_json: 396 | params_json = json.load(data_json) 397 | 398 | # Remove 'private' keys 399 | keys = list(params_json.keys()) 400 | for key in keys: 401 | if key[0] == '_': 402 | params_json.pop(key) 403 | 404 | return params_json 405 | -------------------------------------------------------------------------------- /tf_crnn/data_handler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = 'solivr' 3 | __license__ = "GPL" 4 | 5 | import tensorflow as tf 6 | from tensorflow_addons.image.transform_ops import rotate, transform 7 | from .config import Params, CONST 8 | from typing import Tuple, Union, List 9 | import collections 10 | 11 | 12 | @tf.function 13 | def random_rotation(img: tf.Tensor, 14 | max_rotation: float=0.1, 15 | crop: bool=True, 16 | minimum_width: int=0) -> tf.Tensor: # adapted from SeguinBe 17 | """ 18 | Rotates an image with a random angle. 19 | See https://stackoverflow.com/questions/16702966/rotate-image-and-crop-out-black-borders for formulae 20 | 21 | :param img: Tensor 22 | :param max_rotation: maximum angle to rotate (radians) 23 | :param crop: boolean to crop or not the image after rotation 24 | :param minimum_width: minimum width of image after data augmentation 25 | :return: 26 | """ 27 | with tf.name_scope('RandomRotation'): 28 | rotation = tf.random.uniform([], -max_rotation, max_rotation, name='pick_random_angle') 29 | # rotated_image = tf.contrib.image.rotate(img, rotation, interpolation='BILINEAR') 30 | rotated_image = rotate(tf.expand_dims(img, axis=0), rotation, interpolation='BILINEAR') 31 | rotated_image = tf.squeeze(rotated_image, axis=0) 32 | if crop: 33 | rotation = tf.abs(rotation) 34 | original_shape = tf.shape(rotated_image)[:2] 35 | h, w = original_shape[0], original_shape[1] 36 | old_l, old_s = tf.cond(h > w, lambda: [h, w], lambda: [w, h]) 37 | old_l, old_s = tf.cast(old_l, tf.float32), tf.cast(old_s, tf.float32) 38 | new_l = (old_l * tf.cos(rotation) - old_s * tf.sin(rotation)) / tf.cos(2*rotation) 39 | new_s = (old_s - tf.sin(rotation) * new_l) / tf.cos(rotation) 40 | new_h, new_w = tf.cond(h > w, lambda: [new_l, new_s], lambda: [new_s, new_l]) 41 | new_h, new_w = tf.cast(new_h, tf.int32), tf.cast(new_w, tf.int32) 42 | bb_begin = tf.cast(tf.math.ceil((h-new_h)/2), tf.int32), tf.cast(tf.math.ceil((w-new_w)/2), tf.int32) 43 | # Test sliced 44 | rotated_image_crop = tf.cond( 45 | tf.logical_and(bb_begin[0] < h - bb_begin[0], bb_begin[1] < w - bb_begin[1]), 46 | true_fn=lambda: rotated_image[bb_begin[0]:h - bb_begin[0], bb_begin[1]:w - bb_begin[1], :], 47 | false_fn=lambda: img, 48 | name='check_slices_indices' 49 | ) 50 | # rotated_image_crop = rotated_image[bb_begin[0]:h - bb_begin[0], bb_begin[1]:w - bb_begin[1], :] 51 | 52 | # If crop removes the entire image, keep the original image 53 | rotated_image = tf.cond(tf.less_equal(tf.shape(rotated_image_crop)[1], minimum_width), 54 | true_fn=lambda: img, 55 | false_fn=lambda: rotated_image_crop, 56 | name='check_size_crop') 57 | 58 | return rotated_image 59 | 60 | 61 | # def random_padding(image: tf.Tensor, max_pad_w: int=5, max_pad_h: int=10) -> tf.Tensor: 62 | # """ 63 | # Given an image will pad its border adding a random number of rows and columns 64 | # 65 | # :param image: image to pad 66 | # :param max_pad_w: maximum padding in width 67 | # :param max_pad_h: maximum padding in height 68 | # :return: a padded image 69 | # """ 70 | # # TODO specify image shape in doc 71 | # 72 | # w_pad = list(np.random.randint(0, max_pad_w, size=[2])) 73 | # h_pad = list(np.random.randint(0, max_pad_h, size=[2])) 74 | # paddings = [h_pad, w_pad, [0, 0]] 75 | # 76 | # return tf.pad(image, paddings, mode='REFLECT', name='random_padding') 77 | 78 | @tf.function 79 | def augment_data(image: tf.Tensor, 80 | max_rotation: float=0.1, 81 | minimum_width: int=0) -> tf.Tensor: 82 | """ 83 | Data augmentation on an image (padding, brightness, contrast, rotation) 84 | 85 | :param image: Tensor 86 | :param max_rotation: float, maximum permitted rotation (in radians) 87 | :param minimum_width: minimum width of image after data augmentation 88 | :return: Tensor 89 | """ 90 | with tf.name_scope('DataAugmentation'): 91 | 92 | # Random padding 93 | # image = random_padding(image) 94 | 95 | # TODO : add random scaling 96 | image = tf.image.random_brightness(image, max_delta=0.1) 97 | image = tf.image.random_contrast(image, 0.5, 1.5) 98 | image = random_rotation(image, max_rotation, crop=True, minimum_width=minimum_width) 99 | 100 | if image.shape[-1] >= 3: 101 | image = tf.image.random_hue(image, 0.2) 102 | image = tf.image.random_saturation(image, 0.5, 1.5) 103 | 104 | return image 105 | 106 | @tf.function 107 | def get_resized_width(image: tf.Tensor, 108 | target_height: int, 109 | increment: int): 110 | """ 111 | Resizes the image according to `target_height`. 112 | 113 | :param image: image to resize 114 | :param target_height: height of the resized image 115 | :param increment: reduction factor due to pooling between input width and output width, 116 | this makes sure that the final width will be a multiple of increment 117 | :return: resized image 118 | """ 119 | 120 | image_shape = tf.shape(image) 121 | image_ratio = tf.divide(image_shape[1], image_shape[0], name='ratio') 122 | 123 | new_width = tf.cast(tf.round((image_ratio * target_height) / increment) * increment, tf.int32) 124 | f1 = lambda: (new_width, image_ratio) 125 | f2 = lambda: (target_height, tf.constant(1.0, dtype=tf.float64)) 126 | if tf.math.less_equal(new_width, 0): 127 | return f2() 128 | else: 129 | return f1() 130 | 131 | 132 | @tf.function 133 | def padding_inputs_width(image: tf.Tensor, 134 | target_shape: Tuple[int, int], 135 | increment: int) -> Tuple[tf.Tensor, tf.Tensor]: 136 | """ 137 | Given an input image, will pad it to return a target_shape size padded image. 138 | There are 3 cases: 139 | - image width > target width : simple resizing to shrink the image 140 | - image width >= 0.5*target width : pad the image 141 | - image width < 0.5*target width : replicates the image segment and appends it 142 | 143 | :param image: Tensor of shape [H,W,C] 144 | :param target_shape: final shape after padding [H, W] 145 | :param increment: reduction factor due to pooling between input width and output width, 146 | this makes sure that the final width will be a multiple of increment 147 | :return: (image padded, output width) 148 | """ 149 | 150 | target_ratio = target_shape[1]/target_shape[0] 151 | target_w = target_shape[1] 152 | # Compute ratio to keep the same ratio in new image and get the size of padding 153 | # necessary to have the final desired shape 154 | new_h = target_shape[0] 155 | new_w, ratio = get_resized_width(image, new_h, increment) 156 | 157 | # Definitions for cases 158 | def pad_fn(): 159 | with tf.name_scope('mirror_padding'): 160 | pad = tf.subtract(target_w, new_w) 161 | 162 | img_resized = tf.image.resize(image, [new_h, new_w]) 163 | 164 | # Padding to have the desired width 165 | paddings = [[0, 0], [0, pad], [0, 0]] 166 | pad_image = tf.pad(img_resized, paddings, mode='SYMMETRIC', name=None) 167 | 168 | # Set manually the shape 169 | pad_image.set_shape([target_shape[0], target_shape[1], img_resized.get_shape()[2]]) 170 | 171 | return pad_image, (new_h, new_w) 172 | 173 | def replicate_fn(): 174 | with tf.name_scope('replication_padding'): 175 | img_resized = tf.image.resize(image, [new_h, new_w]) 176 | 177 | # If one symmetry is not enough to have a full width 178 | # Count number of replications needed 179 | n_replication = tf.cast(tf.math.ceil(target_shape[1]/new_w), tf.int32) 180 | img_replicated = tf.tile(img_resized, tf.stack([1, n_replication, 1])) 181 | pad_image = tf.image.crop_to_bounding_box(image=img_replicated, offset_height=0, offset_width=0, 182 | target_height=target_shape[0], target_width=target_shape[1]) 183 | 184 | # Set manually the shape 185 | pad_image.set_shape([target_shape[0], target_shape[1], img_resized.get_shape()[2]]) 186 | 187 | return pad_image, (new_h, new_w) 188 | 189 | def simple_resize(): 190 | with tf.name_scope('simple_resize'): 191 | img_resized = tf.image.resize(image, target_shape) 192 | 193 | img_resized.set_shape([target_shape[0], target_shape[1], img_resized.get_shape()[2]]) 194 | 195 | return img_resized, tuple(target_shape) 196 | 197 | # case 1 : new_w >= target_w 198 | if tf.logical_and(tf.greater_equal(ratio, target_ratio), tf.greater_equal(new_w, target_w)): 199 | pad_image, (new_h, new_w) = simple_resize() 200 | # case 2 : new_w >= target_w/2 & new_w < target_w & ratio < target_ratio 201 | elif tf.logical_and(tf.less(ratio, target_ratio), 202 | tf.logical_and(tf.greater_equal(new_w, tf.cast(tf.divide(target_w, 2), tf.int32)), 203 | tf.less(new_w, target_w))): 204 | pad_image, (new_h, new_w) = pad_fn() 205 | # case 3 : new_w < target_w/2 & new_w < target_w & ratio < target_ratio 206 | elif tf.logical_and(tf.less(ratio, target_ratio), 207 | tf.logical_and(tf.less(new_w, target_w), 208 | tf.less(new_w, tf.cast(tf.divide(target_w, 2), tf.int32)))): 209 | pad_image, (new_h, new_w) = replicate_fn() 210 | else: 211 | pad_image, (new_h, new_w) = simple_resize() 212 | 213 | return pad_image, new_w 214 | 215 | 216 | # def apply_slant(image: np.ndarray, alpha: np.ndarray) -> (np.ndarray, np.ndarray): 217 | # alpha = alpha[0] 218 | # 219 | # def _find_background_color(image: np.ndarray) -> int: 220 | # """ 221 | # Given a grayscale image, finds the background color value 222 | # :param image: grayscale image 223 | # :return: background color value (int) 224 | # """ 225 | # # Otsu's thresholding after Gaussian filtering 226 | # blur = cv2.GaussianBlur(image[:, :, 0].astype(np.uint8), (5, 5), 0) 227 | # thresh_value, thresholded_image = cv2.threshold(blur.astype(np.uint8), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) 228 | # 229 | # # Find which is the background (0 or 255). Supposing that the background color occurrence is higher 230 | # # than the writing color 231 | # counts, bin_edges = np.histogram(thresholded_image, bins=2) 232 | # background_color = int(np.median(image[thresholded_image == 255 * np.argmax(counts)])) 233 | # 234 | # return background_color 235 | # 236 | # shape_image = image.shape 237 | # shift = max(-alpha * shape_image[0], 0) 238 | # output_size = (int(shape_image[1] + np.ceil(abs(alpha * shape_image[0]))), int(shape_image[0])) 239 | # 240 | # warpM = np.array([[1, alpha, shift], [0, 1, 0]]) 241 | # 242 | # # Find color of background in order to replicate it in the borders 243 | # border_value = _find_background_color(image) 244 | # 245 | # image_warp = cv2.warpAffine(image, np.array(warpM), output_size, borderValue=border_value) 246 | # 247 | # return image_warp, np.array(output_size) 248 | 249 | 250 | def dataset_generator(csv_filename: Union[List[str], str], 251 | params: Params, 252 | use_labels: bool=True, 253 | batch_size: int=64, 254 | data_augmentation: bool=False, 255 | num_epochs: int=None, 256 | shuffle: bool=True): 257 | """ 258 | Generates the dataset for the experiment. 259 | 260 | 261 | :param csv_filename: Path to csv file containing the data 262 | :param params: parameters df the experiment (``Params``) 263 | :param use_labels: boolean to indicate dataset generation during training / evaluation (true) or prediction (false) 264 | :param batch_size: size of the generated batches 265 | :param data_augmentation: whether to use data augmentation strategies or not 266 | :param num_epochs: number of epochs to repeat the dataset generation 267 | :param shuffle: whether to suffle the data 268 | :return: ``tf.data.Dataset`` 269 | """ 270 | do_padding = True 271 | 272 | if use_labels: 273 | column_defaults = [['None'], ['None'], tf.int32] 274 | column_names = ['paths', 'label_codes', 'label_seq_length'] 275 | label_name = 'label_codes' 276 | else: 277 | column_defaults = [['None']] 278 | column_names = ['paths'] 279 | label_name = None 280 | 281 | num_parallel_reads = 1 282 | 283 | # ----- from data.experimental.make_csv_dataset 284 | def filename_to_dataset(filename): 285 | dataset = tf.data.experimental.CsvDataset(filename, 286 | record_defaults=column_defaults, 287 | field_delim=params.csv_delimiter, 288 | header=False) 289 | return dataset 290 | 291 | def map_fn(*columns): 292 | """Organizes columns into a features dictionary. 293 | Args: 294 | *columns: list of `Tensor`s corresponding to one csv record. 295 | Returns: 296 | An OrderedDict of feature names to values for that particular record. If 297 | label_name is provided, extracts the label feature to be returned as the 298 | second element of the tuple. 299 | """ 300 | features = collections.OrderedDict(zip(column_names, columns)) 301 | if label_name is not None: 302 | label = features.pop(label_name) 303 | return features, label 304 | 305 | return features 306 | 307 | dataset = tf.data.Dataset.from_tensor_slices(csv_filename) 308 | # Read files sequentially (if num_parallel_reads=1) or in parallel 309 | # dataset = dataset.apply(tf.data.experimental.parallel_interleave(filename_to_dataset, 310 | # cycle_length=num_parallel_reads)) 311 | dataset = dataset.interleave(filename_to_dataset, cycle_length=num_parallel_reads, 312 | num_parallel_calls=tf.data.experimental.AUTOTUNE) 313 | dataset = dataset.map(map_fn) 314 | # ----- 315 | 316 | def _load_image(features: dict, labels=None): 317 | path = features['paths'] 318 | image_content = tf.io.read_file(path) 319 | image = tf.io.decode_jpeg(image_content, channels=params.input_channels, 320 | try_recover_truncated=True, name='image_decoding_op') 321 | 322 | if use_labels: 323 | return {'input_images': image, 324 | 'label_seq_length': features['label_seq_length']}, labels 325 | else: 326 | return {'input_images': image, 327 | 'filename_images': path} 328 | 329 | def _apply_slant(features: dict, labels=None): 330 | image = features['input_images'] 331 | height_image = tf.cast(tf.shape(image)[0], dtype=tf.float32) 332 | 333 | with tf.name_scope('add_slant'): 334 | alpha = tf.random.uniform([], 335 | -params.data_augmentation_max_slant, 336 | params.data_augmentation_max_slant, 337 | name='pick_random_slant_angle') 338 | 339 | shiftx = tf.math.maximum(tf.math.multiply(-alpha, height_image), 0) 340 | 341 | # Pad in order not to loose image info when transformation is applied 342 | x_pad = 0 343 | y_pad = tf.math.round(tf.math.ceil(tf.math.abs(tf.math.multiply(alpha, height_image)))) 344 | y_pad = tf.cast(y_pad, dtype=tf.int32) 345 | paddings = [[x_pad, x_pad], [y_pad, 0], [0, 0]] 346 | transform_matrix = [1, alpha, shiftx, 0, 1, 0, 0, 0] 347 | 348 | # Apply transformation to image 349 | image_pad = tf.pad(image, paddings) 350 | image_transformed = transform(image_pad, transform_matrix, interpolation='BILINEAR') 351 | 352 | # Apply transformation to mask. The mask will be used to retrieve the pixels that have been filled 353 | # with zero during transformation and update their value with background value 354 | # TODO : Would be better to have some kind of binarization (i.e Otsu) and get the mean background value 355 | background_pixel_value = 255 356 | empty = background_pixel_value * tf.ones(tf.shape(image)) 357 | empty_pad = tf.pad(empty, paddings) 358 | empty_transformed = tf.subtract( 359 | tf.cast(background_pixel_value, dtype=tf.int32), 360 | tf.cast(transform(empty_pad, transform_matrix, interpolation='NEAREST'), dtype=tf.int32) 361 | ) 362 | 363 | # Update additional zeros values with background_pixel_value and cast result to uint8 364 | image = tf.add(tf.cast(image_transformed, dtype=tf.int32), empty_transformed) 365 | image = tf.cast(image, tf.uint8) 366 | 367 | features['input_images'] = image 368 | return features, labels if use_labels else features 369 | 370 | def _data_augment_fn(features: dict, labels=None) -> tf.data.Dataset: 371 | image = features['input_images'] 372 | image = augment_data(image, params.data_augmentation_max_rotation, minimum_width=params.max_chars_per_string) 373 | 374 | features.update({'input_images': image}) 375 | return features, labels if use_labels else features 376 | 377 | def _pad_image_or_resize(features: dict, labels=None): 378 | image = features['input_images'] 379 | if do_padding: 380 | with tf.name_scope('padding'): 381 | image, img_width = padding_inputs_width(image, target_shape=params.input_shape, 382 | increment=params.downscale_factor) # todo this needs to be updated 383 | # Resize 384 | else: 385 | image = tf.image.resize(image, size=params.input_shape) 386 | img_width = tf.shape(image)[1] 387 | 388 | input_seq_length = tf.cast(tf.floor(tf.divide(img_width, params.downscale_factor)), tf.int32) 389 | if use_labels: 390 | assert_op = tf.debugging.assert_greater_equal(input_seq_length, 391 | features['label_seq_length']) 392 | with tf.control_dependencies([assert_op]): 393 | return {'input_images': image, 394 | 'label_seq_length': features['label_seq_length'], 395 | 'input_seq_length': input_seq_length}, labels 396 | else: 397 | return {'input_images': image, 398 | 'input_seq_length': input_seq_length, 399 | 'filename_images': features['filename_images']} 400 | 401 | def _normalize_image(features: dict, labels=None): 402 | image = tf.cast(features['input_images'], tf.float32) 403 | image = tf.image.per_image_standardization(image) 404 | 405 | features['input_images'] = image 406 | return features, labels if use_labels else features 407 | 408 | def _format_label_codes(features: dict, string_label_codes): 409 | splits = tf.strings.split([string_label_codes], sep=' ') 410 | label_codes = tf.squeeze(tf.strings.to_number(splits, out_type=tf.int32), axis=0) 411 | 412 | features.update({'label_codes': label_codes}) 413 | return features, [0] 414 | 415 | 416 | num_parallel_calls = tf.data.experimental.AUTOTUNE 417 | # 1. load image 2. data augmentation 3. padding 418 | dataset = dataset.map(_load_image, num_parallel_calls=num_parallel_calls) 419 | # this causes problems when using the same cache for training, validation and prediction data... 420 | # dataset = dataset.cache(filename=os.path.join(params.output_model_dir, 'cache.tf-data')) 421 | if data_augmentation and params.data_augmentation_max_slant != 0: 422 | dataset = dataset.map(_apply_slant, num_parallel_calls=num_parallel_calls) 423 | if data_augmentation: 424 | dataset = dataset.map(_data_augment_fn, num_parallel_calls=num_parallel_calls) 425 | dataset = dataset.map(_normalize_image, num_parallel_calls=num_parallel_calls) 426 | dataset = dataset.map(_pad_image_or_resize, num_parallel_calls=num_parallel_calls) 427 | dataset = dataset.map(_format_label_codes, num_parallel_calls=num_parallel_calls) if use_labels else dataset 428 | dataset = dataset.shuffle(10 * batch_size, reshuffle_each_iteration=False) if shuffle else dataset 429 | dataset = dataset.repeat(num_epochs) if num_epochs is not None else dataset 430 | 431 | return dataset.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE) 432 | 433 | 434 | # def dataset_prediction(image_filenames: Union[List[str], str]=None, 435 | # csv_filename: str=None, 436 | # params: Params=None, 437 | # batch_size: int=64): 438 | # 439 | # assert params, 'params cannot be None' 440 | # assert image_filenames or csv_filename, 'You need to feed an input (image_filenames or csv_filename)' 441 | # 442 | # do_padding = True 443 | # 444 | # def _load_image(path): 445 | # image_content = tf.io.read_file(path) 446 | # image = tf.io.decode_jpeg(image_content, channels=params.input_channels, 447 | # try_recover_truncated=True, name='image_decoding_op') 448 | # 449 | # return {'input_images': image} 450 | # 451 | # def _normalize_image(features: dict): 452 | # image = tf.cast(features['input_images'], tf.float32) 453 | # image = tf.image.per_image_standardization(image) 454 | # 455 | # features['input_images'] = image 456 | # return features 457 | # 458 | # def _pad_image_or_resize(features: dict): 459 | # image = features['input_images'] 460 | # if do_padding: 461 | # with tf.name_scope('padding'): 462 | # image, img_width = padding_inputs_width(image, target_shape=params.input_shape, 463 | # increment=CONST.DIMENSION_REDUCTION_W_POOLING) 464 | # # Resize 465 | # else: 466 | # image = tf.image.resize(image, size=params.input_shape) 467 | # img_width = tf.shape(image)[1] 468 | # 469 | # input_seq_length = tf.cast(tf.floor(tf.math.divide(img_width, params.n_pool)), tf.int32) 470 | # 471 | # return {'input_images': image, 472 | # 'input_seq_length': input_seq_length} 473 | # if image_filenames is not None: 474 | # dataset = tf.data.Dataset.from_tensor_slices(image_filenames) 475 | # elif csv_filename is not None: 476 | # column_defaults = [['None']] 477 | # dataset = tf.data.experimental.CsvDataset(csv_filename, 478 | # record_defaults=column_defaults, 479 | # field_delim=params.csv_delimiter, 480 | # header=False) 481 | # # dataset = dataset.map(map_fn) 482 | # dataset = dataset.map(_load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE) 483 | # dataset = dataset.map(_normalize_image, num_parallel_calls=tf.data.experimental.AUTOTUNE) 484 | # dataset = dataset.map(_pad_image_or_resize, num_parallel_calls=tf.data.experimental.AUTOTUNE) 485 | # 486 | # return dataset.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE) 487 | -------------------------------------------------------------------------------- /tf_crnn/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = "solivr" 3 | __license__ = "GPL" 4 | 5 | import tensorflow as tf 6 | from tensorflow.keras import Model 7 | from tensorflow.keras.backend import ctc_batch_cost, ctc_decode 8 | from tensorflow.keras.layers import Layer, Conv2D, BatchNormalization, MaxPool2D, Input, Permute, \ 9 | Reshape, Bidirectional, LSTM, Dense, Softmax, Lambda 10 | from typing import List, Tuple 11 | from .config import Params 12 | 13 | 14 | class ConvBlock(Layer): 15 | """ 16 | Convolutional block class. 17 | It is composed of a `Conv2D` layer, a `BatchNormaization` layer (optional), 18 | a `MaxPool2D` layer (optional) and a `ReLu` activation. 19 | 20 | :ivar features: number of features of the convolutional layer 21 | :vartype features: int 22 | :ivar kernel_size: size of the convolutional kernel 23 | :vartype kernel_size: int 24 | :ivar stride: stride of the convolutional layer 25 | :vartype stride: int, int 26 | :ivar cnn_padding: padding of the convolution ('same' or 'valid') 27 | :vartype cnn_padding: 28 | :ivar pool_size: size of the maxpooling 29 | :vartype pool_size: int, int 30 | :ivar batchnorm: use batch norm or not 31 | :vartype batchnorm: bool 32 | """ 33 | def __init__(self, 34 | features: int, 35 | kernel_size: int, 36 | stride: Tuple[int, int], 37 | cnn_padding: str, 38 | pool_size: Tuple[int, int], 39 | batchnorm: bool, 40 | **kwargs): 41 | super(ConvBlock, self).__init__(**kwargs) 42 | self.conv = Conv2D(features, 43 | kernel_size, 44 | strides=stride, 45 | padding=cnn_padding) 46 | self.bn = BatchNormalization(renorm=True, 47 | renorm_clipping={'rmax': 1e2, 'rmin': 1e-1, 'dmax': 1e1}, 48 | trainable=True) if batchnorm else None 49 | self.pool = MaxPool2D(pool_size=pool_size, 50 | padding='same') if list(pool_size) > [1, 1] else None 51 | 52 | # for config purposes 53 | self._features = features 54 | self._kernel_size = kernel_size 55 | self._stride = stride 56 | self._cnn_padding = cnn_padding 57 | self._pool_size = pool_size 58 | self._batchnorm = batchnorm 59 | 60 | def call(self, inputs, training=False): 61 | x = self.conv(inputs) 62 | if self.bn is not None: 63 | x = self.bn(x, training=training) 64 | if self.pool is not None: 65 | x = self.pool(x) 66 | x = tf.nn.relu(x) 67 | return x 68 | 69 | def get_config(self) -> dict: 70 | """ 71 | Get a dictionary with all the necessary properties to recreate the same layer. 72 | 73 | :return: dictionary containing the properties of the layer 74 | """ 75 | super_config = super(ConvBlock, self).get_config() 76 | config = { 77 | 'features': self._features, 78 | 'kernel_size': self._kernel_size, 79 | 'stride': self._stride, 80 | 'cnn_padding': self._cnn_padding, 81 | 'pool_size': self._pool_size, 82 | 'batchnorm': self._batchnorm 83 | } 84 | return dict(list(super_config.items()) + list(config.items())) 85 | 86 | 87 | def get_crnn_output(input_images, 88 | parameters: Params=None) -> tf.Tensor: 89 | """ 90 | Creates the CRNN network and returns it's output. 91 | Passes the `input_images` through the network and returns its output 92 | 93 | :param input_images: images to process (B, H, W, C) 94 | :param parameters: parameters of the model (``Params``) 95 | :return: the output of the CRNN model 96 | """ 97 | 98 | # params of the architecture 99 | cnn_features_list = parameters.cnn_features_list 100 | cnn_kernel_size = parameters.cnn_kernel_size 101 | cnn_pool_size = parameters.cnn_pool_size 102 | cnn_stride_size = parameters.cnn_stride_size 103 | cnn_batch_norm = parameters.cnn_batch_norm 104 | rnn_units = parameters.rnn_units 105 | 106 | # CNN layers 107 | cnn_params = zip(cnn_features_list, cnn_kernel_size, cnn_stride_size, cnn_pool_size, cnn_batch_norm) 108 | conv_layers = [ConvBlock(ft, ks, ss, 'same', psz, bn) for ft, ks, ss, psz, bn in cnn_params] 109 | 110 | x = conv_layers[0](input_images) 111 | for conv in conv_layers[1:]: 112 | x = conv(x) 113 | 114 | # Permutation and reshape 115 | x = Permute((2, 1, 3))(x) 116 | shape = x.get_shape().as_list() 117 | x = Reshape((shape[1], shape[2] * shape[3]))(x) # [B, W, H*C] 118 | 119 | # RNN layers 120 | rnn_layers = [Bidirectional(LSTM(ru, dropout=0.5, return_sequences=True, time_major=False)) for ru in 121 | rnn_units] 122 | for rnn in rnn_layers: 123 | x = rnn(x) 124 | 125 | # Dense and softmax 126 | x = Dense(parameters.alphabet.n_classes)(x) 127 | net_output = Softmax()(x) 128 | 129 | return net_output 130 | 131 | 132 | def get_model_train(parameters: Params): 133 | """ 134 | Constructs the full model for training. 135 | Defines inputs and outputs, loss function and metric (CER). 136 | 137 | :param parameters: parameters of the model (``Params``) 138 | :return: the model (``tf.Keras.Model``) 139 | """ 140 | 141 | h, w = parameters.input_shape 142 | c = parameters.input_channels 143 | 144 | input_images = Input(shape=(h, w, c), name='input_images') 145 | input_seq_len = Input(shape=[1], dtype=tf.int32, name='input_seq_length') 146 | 147 | label_codes = Input(shape=(parameters.max_chars_per_string), dtype=tf.int32, name='label_codes') 148 | label_seq_length = Input(shape=[1], dtype=tf.int32, name='label_seq_length') 149 | 150 | net_output = get_crnn_output(input_images, parameters) 151 | 152 | # Loss function 153 | def warp_ctc_loss(y_true, y_pred): 154 | return ctc_batch_cost(label_codes, y_pred, input_seq_len, label_seq_length) 155 | 156 | # Metric function 157 | def warp_cer_metric(y_true, y_pred): 158 | pred_sequence_length, true_sequence_length = input_seq_len, label_seq_length 159 | 160 | # y_pred needs to be decoded (its the logits) 161 | pred_codes_dense = ctc_decode(y_pred, tf.squeeze(pred_sequence_length, axis=-1), greedy=True) 162 | pred_codes_dense = tf.squeeze(tf.cast(pred_codes_dense[0], tf.int64), axis=0) # only [0] if greedy=true 163 | 164 | # create sparse tensor 165 | idx = tf.where(tf.not_equal(pred_codes_dense, -1)) 166 | pred_codes_sparse = tf.SparseTensor(tf.cast(idx, tf.int64), 167 | tf.gather_nd(pred_codes_dense, idx), 168 | tf.cast(tf.shape(pred_codes_dense), tf.int64)) 169 | 170 | idx = tf.where(tf.not_equal(label_codes, 0)) 171 | label_sparse = tf.SparseTensor(tf.cast(idx, tf.int64), 172 | tf.gather_nd(label_codes, idx), 173 | tf.cast(tf.shape(label_codes), tf.int64)) 174 | label_sparse = tf.cast(label_sparse, tf.int64) 175 | 176 | # Compute edit distance and total chars count 177 | distance = tf.reduce_sum(tf.edit_distance(pred_codes_sparse, label_sparse, normalize=False)) 178 | count_chars = tf.reduce_sum(true_sequence_length) 179 | 180 | return tf.divide(distance, tf.cast(count_chars, tf.float32), name='CER') 181 | 182 | # Define model and compile it 183 | model = Model(inputs=[input_images, label_codes, input_seq_len, label_seq_length], outputs=net_output, name='CRNN') 184 | optimizer = tf.keras.optimizers.Adam(learning_rate=parameters.learning_rate) 185 | model.compile(loss=[warp_ctc_loss], 186 | optimizer=optimizer, 187 | metrics=[warp_cer_metric], 188 | experimental_run_tf_function=False) # TODO this is set to true by default but does not seem to work... 189 | 190 | return model 191 | 192 | 193 | def get_model_inference(parameters: Params, 194 | weights_path: str=None): 195 | """ 196 | Constructs the full model for prediction. 197 | Defines inputs and outputs, and loads the weights. 198 | 199 | 200 | :param parameters: parameters of the model (``Params``) 201 | :param weights_path: path to the weights (.h5 file) 202 | :return: the model (``tf.Keras.Model``) 203 | """ 204 | h, w = parameters.input_shape 205 | c = parameters.input_channels 206 | 207 | input_images = Input(shape=(h, w, c), name='input_images') 208 | input_seq_len = Input(shape=[1], dtype=tf.int32, name='input_seq_length') 209 | filename_images = Input(shape=[1], dtype=tf.string, name='filename_images') 210 | 211 | net_output = get_crnn_output(input_images, parameters) 212 | output_seq_len = tf.identity(input_seq_len) # need this op to pass it to output 213 | filenames = tf.identity(filename_images) 214 | 215 | model = Model(inputs=[input_images, input_seq_len, filename_images], outputs=[net_output, output_seq_len, filenames]) 216 | 217 | if weights_path: 218 | model.load_weights(weights_path) 219 | 220 | return model 221 | -------------------------------------------------------------------------------- /tf_crnn/preprocessing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = "solivr" 3 | __license__ = "GPL" 4 | 5 | import re 6 | import numpy as np 7 | import os 8 | from .config import Params, CONST 9 | import pandas as pd 10 | from typing import List, Tuple 11 | from taputapu.io.image import get_image_shape_without_loading 12 | 13 | 14 | def _convert_label_to_dense_codes(labels: List[str], 15 | split_char: str, 16 | max_width: int, 17 | table_str2int: dict): 18 | """ 19 | Converts a list of formatted string to a dense matrix of codes 20 | 21 | :param labels: list of strings containing formatted labels 22 | :param split_char: character to split the formatted label 23 | :param max_width: maximum length of string label (max_n_chars = max_width_dense_codes) 24 | :param table_str2int: mapping table between alphabet units and alphabet codes 25 | :return: dense matrix N x max_width, list of the lengths of each string (length N) 26 | """ 27 | labels_chars = [[c for c in label.split(split_char) if c] for label in labels] 28 | codes_list = [[table_str2int[c] for c in list_char] for list_char in labels_chars] 29 | 30 | seq_lengths = [len(cl) for cl in codes_list] 31 | 32 | dense_codes = list() 33 | for ls in codes_list: 34 | dense_codes.append(ls + np.maximum(0, (max_width - len(ls))) * [0]) 35 | 36 | return dense_codes, seq_lengths 37 | 38 | 39 | def _compute_length_inputs(path: str, 40 | target_shape: Tuple[int, int]): 41 | 42 | w, h = get_image_shape_without_loading(path) 43 | ratio = w / h 44 | 45 | new_h = target_shape[0] 46 | new_w = np.minimum(new_h * ratio, target_shape[1]) 47 | 48 | return new_w 49 | 50 | 51 | def preprocess_csv(csv_filename: str, 52 | parameters: Params, 53 | output_csv_filename: str) -> int: 54 | """ 55 | Converts the original csv data to the format required by the experiment. 56 | Removes the samples which labels have too many characters. Computes the widths of input images and removes the 57 | samples which have more characters per label than image width. Converts the string labels to dense codes. 58 | The output csv file contains the path to the image, the dense list of codes corresponding to the alphabets units 59 | (which are padded with 0 if `len(label)` < `max_len`) and the length of the label sequence. 60 | 61 | :param csv_filename: path to csv file 62 | :param parameters: parameters of the experiment (``Params``) 63 | :param output_csv_filename: path to the output csv file 64 | :return: number of samples in the output csv file 65 | """ 66 | 67 | # Conversion table 68 | table_str2int = dict(zip(parameters.alphabet.alphabet_units, parameters.alphabet.codes)) 69 | 70 | # Read file 71 | dataframe = pd.read_csv(csv_filename, 72 | sep=parameters.csv_delimiter, 73 | header=None, 74 | names=['paths', 'labels'], 75 | encoding='utf8', 76 | escapechar="\\", 77 | quoting=0) 78 | 79 | original_len = len(dataframe) 80 | 81 | dataframe['label_string'] = dataframe.labels.apply(lambda x: re.sub(re.escape(parameters.string_split_delimiter), '', x)) 82 | dataframe['label_len'] = dataframe.label_string.apply(lambda x: len(x)) 83 | 84 | # remove long labels 85 | dataframe = dataframe[dataframe.label_len <= parameters.max_chars_per_string] 86 | 87 | # Compute width images (after resizing) 88 | dataframe['input_length'] = dataframe.paths.apply(lambda x: _compute_length_inputs(x, parameters.input_shape)) 89 | dataframe.input_length = dataframe.input_length.apply(lambda x: np.floor(x / parameters.downscale_factor)) 90 | # Remove items with longer label than input 91 | dataframe = dataframe[dataframe.label_len < dataframe.input_length] 92 | 93 | final_length = len(dataframe) 94 | 95 | n_removed = original_len - final_length 96 | if n_removed > 0: 97 | print('-- Removed {} samples ({:.2f} %)'.format(n_removed, 98 | 100 * n_removed / original_len)) 99 | 100 | # Convert fields to list 101 | paths = dataframe.paths.to_list() 102 | labels = dataframe.labels.to_list() 103 | 104 | # Convert string labels to dense codes 105 | label_dense_codes, label_seq_length = _convert_label_to_dense_codes(labels, 106 | parameters.string_split_delimiter, 107 | parameters.max_chars_per_string, 108 | table_str2int) 109 | # format in string to be easily parsed by tf.data 110 | string_label_codes = [[str(ldc) for ldc in list_ldc] for list_ldc in label_dense_codes] 111 | string_label_codes = [' '.join(list_slc) for list_slc in string_label_codes] 112 | 113 | data = {'paths': paths, 'label_codes': string_label_codes, 'label_len': label_seq_length} 114 | new_dataframe = pd.DataFrame(data) 115 | 116 | new_dataframe.to_csv(output_csv_filename, 117 | sep=parameters.csv_delimiter, 118 | header=False, 119 | encoding='utf8', 120 | index=False, 121 | escapechar="\\", 122 | quoting=0) 123 | return len(new_dataframe) 124 | 125 | 126 | def data_preprocessing(params: Params) -> (str, str, int, int): 127 | """ 128 | Preporcesses the data for the experiment (training and evaluation data). 129 | Exports the updated csv files into `/preprocessed/updated_{eval,train}.csv` 130 | 131 | :param params: parameters of the experiment (``Params``) 132 | :return: output path files, number of samples (for train and evaluation data) 133 | """ 134 | output_dir = os.path.join(params.output_model_dir, CONST.PREPROCESSING_FOLDER) 135 | if not os.path.exists(output_dir): 136 | os.makedirs(output_dir) 137 | else: 138 | 'Output directory {} already exists'.format(output_dir) 139 | 140 | csv_train_output = os.path.join(output_dir, 'updated_train.csv') 141 | csv_eval_output = os.path.join(output_dir, 'updated_eval.csv') 142 | 143 | # Preprocess train csv 144 | n_samples_train = preprocess_csv(params.csv_files_train, params, csv_train_output) 145 | 146 | # Preprocess train csv 147 | n_samples_eval = preprocess_csv(params.csv_files_eval, params, csv_eval_output) 148 | 149 | return csv_train_output, csv_eval_output, n_samples_train, n_samples_eval 150 | 151 | -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = "solivr" 3 | __license__ = "GPL" 4 | 5 | import logging 6 | logging.getLogger("tensorflow").setLevel(logging.INFO) 7 | 8 | from tf_crnn.config import Params 9 | from tf_crnn.model import get_model_train 10 | from tf_crnn.preprocessing import data_preprocessing 11 | from tf_crnn.data_handler import dataset_generator 12 | from tf_crnn.callbacks import CustomLoaderCallback, CustomSavingCallback, LRTensorBoard, EPOCH_FILENAME, FOLDER_SAVED_MODEL 13 | import tensorflow as tf 14 | import numpy as np 15 | import os 16 | import json 17 | import pickle 18 | from glob import glob 19 | from sacred import Experiment, SETTINGS 20 | 21 | SETTINGS.CONFIG.READ_ONLY_CONFIG = False 22 | 23 | ex = Experiment('crnn') 24 | 25 | ex.add_config('config.json') 26 | 27 | @ex.automain 28 | def training(_config: dict): 29 | parameters = Params(**_config) 30 | 31 | export_config_filename = os.path.join(parameters.output_model_dir, 'config.json') 32 | saving_dir = os.path.join(parameters.output_model_dir, FOLDER_SAVED_MODEL) 33 | 34 | if not parameters.restore_model: 35 | # check if output folder already exists 36 | assert not os.path.isdir(parameters.output_model_dir), \ 37 | '{} already exists, you cannot use it as output directory.'.format(parameters.output_model_dir) 38 | # 'Set "restore_model=True" to continue training, or delete dir "rm -r {0}"'.format(parameters.output_model_dir) 39 | os.makedirs(parameters.output_model_dir) 40 | 41 | # data and csv preprocessing 42 | csv_train_file, csv_eval_file, \ 43 | n_samples_train, n_samples_eval = data_preprocessing(parameters) 44 | 45 | # export config file in model output dir 46 | with open(export_config_filename, 'w') as file: 47 | json.dump(parameters.to_dict(), file) 48 | 49 | # Create callbacks 50 | logdir = os.path.join(parameters.output_model_dir, 'logs') 51 | tb_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir, 52 | profile_batch=0) 53 | 54 | lrtb_callback = LRTensorBoard(log_dir=logdir, 55 | profile_batch=0) 56 | 57 | lr_callback = tf.keras.callbacks.ReduceLROnPlateau(factor=0.5, 58 | patience=10, 59 | cooldown=0, 60 | min_lr=1e-8, 61 | verbose=1) 62 | 63 | es_callback = tf.keras.callbacks.EarlyStopping(min_delta=0.005, 64 | patience=20, 65 | verbose=1) 66 | 67 | sv_callback = CustomSavingCallback(saving_dir, 68 | saving_freq=parameters.save_interval, 69 | save_best_only=True, 70 | keep_max_models=3) 71 | 72 | list_callbacks = [tb_callback, lrtb_callback, lr_callback, es_callback, sv_callback] 73 | 74 | if parameters.restore_model: 75 | last_time_stamp = max([int(p.split(os.path.sep)[-1].split('-')[0]) 76 | for p in glob(os.path.join(saving_dir, '*'))]) 77 | 78 | loading_dir = os.path.join(saving_dir, str(last_time_stamp)) 79 | ld_callback = CustomLoaderCallback(loading_dir) 80 | 81 | list_callbacks.append(ld_callback) 82 | 83 | with open(os.path.join(loading_dir, EPOCH_FILENAME), 'rb') as f: 84 | initial_epoch = pickle.load(f) 85 | 86 | epochs = initial_epoch + parameters.n_epochs 87 | else: 88 | initial_epoch = 0 89 | epochs = parameters.n_epochs 90 | 91 | # Get model 92 | model = get_model_train(parameters) 93 | 94 | # Get datasets 95 | dataset_train = dataset_generator([csv_train_file], 96 | parameters, 97 | batch_size=parameters.train_batch_size, 98 | data_augmentation=parameters.data_augmentation, 99 | num_epochs=parameters.n_epochs) 100 | 101 | dataset_eval = dataset_generator([csv_eval_file], 102 | parameters, 103 | batch_size=parameters.eval_batch_size, 104 | data_augmentation=False, 105 | num_epochs=parameters.n_epochs) 106 | 107 | # Train model 108 | model.fit(dataset_train, 109 | epochs=epochs, 110 | initial_epoch=initial_epoch, 111 | steps_per_epoch=np.floor(n_samples_train / parameters.train_batch_size), 112 | validation_data=dataset_eval, 113 | validation_steps=np.floor(n_samples_eval / parameters.eval_batch_size), 114 | callbacks=list_callbacks) 115 | --------------------------------------------------------------------------------