├── .gitignore ├── LICENSE ├── README.md ├── environment.yml ├── flatting_server.spec ├── pyproject.toml ├── requirements.txt └── src ├── flatting ├── __init__.py ├── __main__.py ├── app.py ├── client_debug.py ├── demo.py ├── dice_loss.py ├── eval.py ├── flatting_api.py ├── flatting_api_async.py ├── hubconf.py ├── predict.py ├── resources │ ├── __init__.py │ ├── flatting.icns │ ├── flatting.ico │ └── flatting.png ├── submit.py ├── tkapp.py ├── train.py ├── trapped_ball │ ├── adjacency_matrix.pyx │ ├── examples │ │ ├── 01.png │ │ ├── 01_sim.png │ │ ├── 02.png │ │ ├── tiny.png │ │ └── tiny_sim.png │ ├── run.py │ ├── thinning.py │ ├── thinning_zhang.py │ └── trappedball_fill.py ├── unet │ ├── __init__.py │ ├── unet_model.py │ └── unet_parts.py └── utils │ ├── add_white_background.py │ ├── data_vis.py │ ├── dataset.py │ ├── ground_truth_creation.py │ ├── move_to_duplicate.py │ ├── polyvector │ └── run_all_examples.py │ └── preprocessing.py └── flatting_server.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Specific to this project 2 | /data 3 | /src/flatting/checkpoints 4 | /src/flatting.dist-info 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # OSX useful to ignore 12 | *.DS_Store 13 | .AppleDouble 14 | .LSOverride 15 | 16 | # Thumbnails 17 | ._* 18 | 19 | # Files that might appear in the root of a volume 20 | .DocumentRevisions-V100 21 | .fseventsd 22 | .Spotlight-V100 23 | .TemporaryItems 24 | .Trashes 25 | .VolumeIcon.icns 26 | .com.apple.timemachine.donotpresent 27 | 28 | # Directories potentially created on remote AFP share 29 | .AppleDB 30 | .AppleDesktop 31 | Network Trash Folder 32 | Temporary Items 33 | .apdisk 34 | 35 | # C extensions 36 | *.so 37 | 38 | # Distribution / packaging 39 | .Python 40 | env/ 41 | build/ 42 | develop-eggs/ 43 | dist/ 44 | downloads/ 45 | eggs/ 46 | .eggs/ 47 | lib/ 48 | lib64/ 49 | parts/ 50 | sdist/ 51 | var/ 52 | *.egg-info/ 53 | .installed.cfg 54 | *.egg 55 | 56 | # IntelliJ Idea family of suites 57 | .idea 58 | *.iml 59 | ## File-based project format: 60 | *.ipr 61 | *.iws 62 | ## mpeltonen/sbt-idea plugin 63 | .idea_modules/ 64 | 65 | # Briefcase build directories 66 | iOS/ 67 | macOS/ 68 | windows/ 69 | android/ 70 | linux/ 71 | django/ 72 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Flatting 2 | This project is based on [U-net](https://github.com/milesial/Pytorch-UNet) 3 | 4 | ## Install 5 | ### 1. Install required packages 6 | To run color filling, you need the following module installed: 7 | 8 | - numpy 9 | - opencv-python 10 | - tqdm 11 | - pillow 12 | - cython 13 | - aiohttp 14 | - scikit-image 15 | - torch 16 | - torchvision 17 | 18 | You can install these dependencies via [Anaconda](https://www.anaconda.com/products/individual) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html). 19 | Miniconda is faster to install. (On Windows, choose the 64-bit Python 3.x version. Launch the Anaconda shell from the Start menu and navigate to this directory.) 20 | Then: 21 | 22 | conda env create -f environment.yml 23 | conda activate flatting 24 | 25 | To update an already created environment if the `environment.yml` file changes or to change environments, activate and then run `conda env update --file environment.yml --prune`. 26 | 27 | ### 2. Download pretrained models 28 | Download the [pretrained network model](https://drive.google.com/file/d/1NLooRQ8uZ3ZwQnAYjQAiGhOJqit5Q2_J/view?usp=sharing) and unzip `checkpoints.zip` into `./src/flatting/`. 29 | 30 | ### 3. Run 31 | You can run our backend directly by: 32 | 33 | cd src 34 | python -m flatting 35 | 36 | 37 | 38 | ### 4. Package 39 | If you just want to run the backend only and don't want to touch the code. We provide a [portable backend (Windows only)](https://drive.google.com/file/d/1s9Z5Qgc9siWMu45iOetEUhuzNfJbjbGw/view?usp=sharing) which packaged by the pyinstaller (see sec 4b.) You can download it and unzip to any place, then run: 40 | 41 | cd flatting_server 42 | flatting_server.exe 43 | 44 | ### 4a. Packaging with Briefcase 45 | **Issues:** Although briefcase can output a cleaner package of our backend but it seems also hide the running log as well, we currently don't have a good solution for this issue yet. 46 | 47 | Use `briefcase` [commands](https://docs.beeware.org/en/latest/tutorial/tutorial-1.html) for packaging. Briefcase can't compile Cython modules, so you must first do that. There is only one. Compile it via `cythonize -i src/flatting/trapped_ball/adjacency_matrix.pyx`. 48 | 49 | To start the process, run: 50 | 51 | briefcase create 52 | briefcase build 53 | 54 | To run the standalone program: 55 | 56 | briefcase run 57 | 58 | To create an installer: 59 | 60 | briefcase package 61 | 62 | To update the standalone program when your code or dependencies change: 63 | 64 | briefcase update -r -d 65 | 66 | You can also simply run `briefcase run -u`. 67 | 68 | To debug this process, you can run your code from the entrypoint briefcase uses: 69 | 70 | briefcase dev 71 | 72 | This reveals some issues important to debug. It doesn't reveal dependency issues, because it's not using briefcase's python installation. 73 | 74 | On my setup, I have to manually edit `edit macOS/app/Flatting/Flatting.app/Contents/Resources/app_packages/torch/distributed/rpc/api.py` to insert a line `if docstring is None: continue` after line 443: 75 | 76 | assert docstring is not None, "RRef user-facing methods should all have docstrings." 77 | 78 | ### 4b. Packaging with pyinstaller 79 | 80 | If briefcase doesn't work, you can use [pyinstaller](https://www.pyinstaller.org/): 81 | 82 | pyinstaller --noconfirm flatting_server.spec 83 | 84 | ### 5. Install Photoshop plugin 85 | Download the [flatting plugin](https://drive.google.com/file/d/1HivdqU2Z2dIL2MvqzEYmCLO2_nDL2Cnk/view?usp=sharing) and unzip it to any place. 86 | Download the backend server by following the instructions inside the "flatting plugin.zip" 87 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: flatting 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - numpy 8 | - tqdm 9 | - pillow 10 | - py-opencv 11 | - aiohttp 12 | - scikit-image 13 | - pytorch 14 | - torchvision 15 | - appdirs 16 | - pyinstaller 17 | - pip 18 | - pip: 19 | - briefcase 20 | prefix: /Users/anonymous/opt/miniconda3/envs/flatting 21 | -------------------------------------------------------------------------------- /flatting_server.spec: -------------------------------------------------------------------------------- 1 | # -*- mode: python ; coding: utf-8 -*- 2 | import os.path 3 | block_cipher = None 4 | a = Analysis(['src/flatting_server.py'], 5 | ## pyinstaller iheartla.spec must be run from 6 | ## ??? 7 | pathex=[os.path.abspath(os.getcwd())], 8 | binaries=[], 9 | datas=[('src/flatting/checkpoints','checkpoints')], 10 | hiddenimports=[], 11 | hookspath=[], 12 | runtime_hooks=[], 13 | excludes=[], 14 | win_no_prefer_redirects=False, 15 | win_private_assemblies=False, 16 | cipher=block_cipher, 17 | noarchive=False) 18 | pyz = PYZ(a.pure, a.zipped_data, 19 | cipher=block_cipher) 20 | exe = EXE(pyz, 21 | a.scripts, 22 | [], 23 | exclude_binaries=True, 24 | name='flatting_server', 25 | debug=False, 26 | bootloader_ignore_signals=False, 27 | strip=False, 28 | upx=True, 29 | console=True, 30 | icon='src/flatting/resources/flatting.ico' ) 31 | coll = COLLECT(exe, 32 | a.binaries, 33 | a.zipfiles, 34 | a.datas, 35 | strip=False, 36 | upx=True, 37 | upx_exclude=[], 38 | name='flatting_server') 39 | app = BUNDLE(coll, 40 | name='flatting_server.app', 41 | icon='src/flatting/resources/flatting.icns', 42 | bundle_identifier=None, 43 | info_plist={'NSHighResolutionCapable': 'True'} 44 | ) 45 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.briefcase] 2 | project_name = "Flatting" 3 | bundle = "edu.gmu.cs.cragl.flatting" 4 | version = "0.0.1" 5 | url = "https://cragl.cs.gmu.edu/flatting" 6 | license = "Proprietary" 7 | author = 'Chuan Yan' 8 | author_email = "cyan3@gmu.edu" 9 | 10 | [tool.briefcase.app.flatting] 11 | formal_name = "Flatting" 12 | description = "Back-end for Photoshop flatting plugin" 13 | icon = "src/flatting/resources/flatting" 14 | sources = ['src/flatting'] 15 | requires = [ 16 | "numpy", 17 | "tqdm", 18 | "pillow", 19 | "opencv-python-headless", 20 | "aiohttp", 21 | "scikit-image", 22 | "torch", 23 | "torchvision" 24 | ] 25 | 26 | [tool.briefcase.app.flatting.macOS] 27 | requires = [] 28 | 29 | [tool.briefcase.app.flatting.linux] 30 | requires = [] 31 | system_requires = [] 32 | 33 | [tool.briefcase.app.flatting.windows] 34 | requires = [] 35 | 36 | # Mobile deployments 37 | [tool.briefcase.app.flatting.iOS] 38 | requires = [] 39 | 40 | [tool.briefcase.app.flatting.android] 41 | requires = [] 42 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | numpy 3 | Pillow 4 | torch 5 | torchvision 6 | tensorboard 7 | future 8 | tqdm 9 | -------------------------------------------------------------------------------- /src/flatting/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nauhcnay/flat_magic_backend/5344f11c7a50c0a5b0d0876dcf68aa45b5a84687/src/flatting/__init__.py -------------------------------------------------------------------------------- /src/flatting/__main__.py: -------------------------------------------------------------------------------- 1 | from . import app 2 | 3 | if __name__ == '__main__': 4 | ## https://docs.python.org/3/library/multiprocessing.html#multiprocessing.freeze_support 5 | if app.MULTIPROCESS: app.multiprocessing.freeze_support() 6 | app.main() 7 | -------------------------------------------------------------------------------- /src/flatting/app.py: -------------------------------------------------------------------------------- 1 | from aiohttp import web 2 | from PIL import Image 3 | from io import BytesIO 4 | from datetime import datetime 5 | from os.path import join, exists 6 | 7 | import appdirs 8 | import numpy as np 9 | from . import flatting_api 10 | import base64 11 | import os 12 | import io 13 | import json 14 | import asyncio 15 | import multiprocessing 16 | 17 | MULTIPROCESS = True 18 | LOG = True 19 | 20 | if MULTIPROCESS: 21 | # Importing this module creates multiprocessing pools, which is problematic 22 | # in Briefcase and PyInstaller on macOS. 23 | from . import flatting_api_async 24 | 25 | routes = web.RouteTableDef() 26 | 27 | @routes.get('/') 28 | # seems the function name is not that important? 29 | async def hello(request): 30 | return web.Response(text="Flatting API server is running") 31 | 32 | ## Add more API entry points 33 | @routes.post('/flatsingle') 34 | async def flatsingle( request ): 35 | data = await request.json() 36 | try: 37 | data = json.loads(data) 38 | except: 39 | print("got dict directly") 40 | 41 | # convert to json 42 | img = to_pil(data['image']) 43 | net = str(data['net']) 44 | radii = int(data['radius']) 45 | resize = data['resize'] 46 | if 'userName' in data: 47 | user = data['userName'] 48 | img_name = data['fileName'] 49 | if resize: 50 | w_new, h_new = data["newSize"] 51 | else: 52 | w_new = None 53 | h_new = None 54 | 55 | if MULTIPROCESS: 56 | flatted = await flatting_api_async.run_single(img, net, radii, resize, w_new, h_new, img_name) 57 | else: 58 | flatted = flatting_api.run_single(img, net, radii, resize, w_new, h_new, img_name) 59 | 60 | result = {} 61 | result['line_artist'] = to_base64(flatted['line_artist']) 62 | result['line_hint'] = to_base64(flatted['line_hint']) 63 | result['line_simplified'] = to_base64(flatted['line_simplified']) 64 | result['image'] = to_base64(flatted['fill_color']) 65 | result['fill_artist'] = to_base64(flatted['components_color']) 66 | 67 | if LOG: 68 | now = datetime.now() 69 | save_to_log(now, flatted['line_artist'], user, img_name, "line_artist", "flat") 70 | save_to_log(now, flatted['line_hint'], user, img_name, "line_hint", "flat") 71 | save_to_log(now, flatted['line_simplified'], user, img_name, "line_simplified", "flat") 72 | save_to_log(now, flatted['fill_color'], user, img_name, "fill_color", "flat") 73 | save_to_log(now, flatted['components_color'], user, img_name, "fill_color_floodfill", "flat") 74 | save_to_log(now, flatted['fill_color_neural'], user, img_name, "fill_color_neural", "flat") 75 | save_to_log(now, flatted['line_neural'], user, img_name, "line_neural", "flat") 76 | print("Log:\tlogs saved") 77 | return web.json_response( result ) 78 | 79 | @routes.post('/merge') 80 | async def merge( request ): 81 | data = await request.json() 82 | try: 83 | data = json.loads(data) 84 | except: 85 | print("got dict directly") 86 | 87 | line_artist = to_pil(data['line_artist']) 88 | fill_neural = np.array(to_pil(data['fill_neural'])) 89 | fill_artist = np.array(to_pil(data['fill_artist'])) 90 | stroke = to_pil(data['stroke']) 91 | if 'userName' in data: 92 | user = data['userName'] 93 | img_name = data['fileName'] 94 | # palette = np.array(data['palette']) 95 | 96 | if MULTIPROCESS: 97 | merged = await flatting_api_async.merge(fill_neural, fill_artist, stroke, line_artist) 98 | else: 99 | merged = flatting_api.merge(fill_neural, fill_artist, stroke, line_artist) 100 | 101 | result = {} 102 | result['image'] = to_base64(merged['fill_color']) 103 | result['line_simplified'] = to_base64(merged['line_simplified']) 104 | if LOG: 105 | now = datetime.now() 106 | save_to_log(now, merged['line_simplified'], user, img_name, "line_simplified", "merge") 107 | save_to_log(now, merged['fill_color'], user, img_name, "fill_color", "merge") 108 | save_to_log(now, stroke, user, img_name, "merge_stroke", "merge") 109 | save_to_log(now, fill_artist, user, img_name, "fill_color_floodfill", "merge") 110 | print("Log:\tlogs saved") 111 | 112 | return web.json_response(result) 113 | 114 | @routes.post('/splitmanual') 115 | async def split_manual( request ): 116 | data = await request.json() 117 | try: 118 | data = json.loads(data) 119 | except: 120 | print("got dict directly") 121 | 122 | fill_neural = np.array(to_pil(data['fill_neural'])) 123 | fill_artist = np.array(to_pil(data['fill_artist'])) 124 | stroke = np.array(to_pil(data['stroke'])) 125 | line_artist = to_pil(data['line_artist']) 126 | add_only = data['mode'] 127 | if 'userName' in data: 128 | user = data['userName'] 129 | img_name = data['fileName'] 130 | 131 | if MULTIPROCESS: 132 | splited = await flatting_api_async.split_manual(fill_neural, fill_artist, stroke, line_artist, add_only) 133 | else: 134 | splited = flatting_api.split_manual(fill_neural, fill_artist, stroke, line_artist, add_only) 135 | 136 | result = {} 137 | result['line_artist'] = to_base64(splited['line_artist']) 138 | result['line_simplified'] = to_base64(splited['line_neural']) 139 | result['image'] = to_base64(splited['fill_color']) 140 | result['fill_artist'] = to_base64(splited['fill_artist']) 141 | result['line_hint'] = to_base64(splited['line_hint']) 142 | 143 | if LOG: 144 | now = datetime.now() 145 | save_to_log(now, splited['line_neural'], user, img_name, "line_simplified", "split_%s"%str(add_only)) 146 | save_to_log(now, splited['line_artist'], user, img_name, "line_artist", "split_%s"%str(add_only)) 147 | save_to_log(now, splited['fill_color'], user, img_name, "fill_color", "split_%s"%str(add_only)) 148 | save_to_log(now, stroke, user, img_name, "split_stroke", "split_%s"%str(add_only)) 149 | save_to_log(now, splited['fill_artist'], user, img_name, "fill_color_floodfill", "split_%s"%str(add_only)) 150 | save_to_log(now, splited['line_hint'], user, img_name, "line_hint", "split_%s"%str(add_only)) 151 | print("Log:\tlogs saved") 152 | return web.json_response(result) 153 | 154 | @routes.post('/flatlayers') 155 | async def export_fill_to_layers( request ): 156 | data = await request.json() 157 | try: 158 | data = json.loads(data) 159 | except: 160 | print("got dict directly") 161 | 162 | img = np.array(to_pil(data['fill_neural'])) 163 | layers = flatting_api.export_layers(img) 164 | layers_base64 = [] 165 | for layer in layers["layers"]: 166 | layers_base64.append(to_base64(layer)) 167 | 168 | result = {} 169 | result["layersImage"] = layers_base64 170 | 171 | return web.json_response(result) 172 | 173 | def to_base64(array): 174 | ''' 175 | A helper function to convert numpy array to png in base64 format 176 | ''' 177 | with io.BytesIO() as output: 178 | if type(array) == np.ndarray: 179 | Image.fromarray(array).save(output, format='png') 180 | else: 181 | array.save(output, format='png') 182 | img = output.getvalue() 183 | img = base64.encodebytes(img).decode("utf-8") 184 | return img 185 | 186 | def to_pil(byte): 187 | ''' 188 | A helper function to convert byte png to PIL.Image 189 | ''' 190 | byte = base64.b64decode(byte) 191 | return Image.open(BytesIO(byte)) 192 | 193 | def save_to_log(date, data, user, img_name, save_name, op): 194 | log_dir = appdirs.user_log_dir( "Flatting Server", "CraGL" ) 195 | save_folder = "[%s][%s][%s_%s]"%(user, str(date.strftime("%d-%b-%Y %H-%M-%S")), img_name, op) 196 | save_folder = join( log_dir, save_folder) 197 | try: 198 | if exists(save_folder) == False: 199 | os.makedirs(save_folder) 200 | if type(data) == np.ndarray: 201 | Image.fromarray(data).save(join(save_folder, "%s.png"%save_name)) 202 | else: 203 | data.save(join(save_folder, "%s.png"%save_name)) 204 | except: 205 | print("Warning:\tsave log failed!") 206 | 207 | def main(): 208 | ''' 209 | import traceback 210 | with open('/Users/yotam/Work/GMU/flatting/code/log.txt','a') as f: 211 | f.write('=================================================================\n') 212 | f.write('__name__: %s\n' % __name__) 213 | traceback.print_stack(file=f) 214 | ''' 215 | app = web.Application(client_max_size = 1024 * 1024 ** 2) 216 | app.add_routes(routes) 217 | web.run_app(app) 218 | 219 | ## From JavaScript: 220 | # let result = await fetch( url_of_server.py, { method: 'POST', body: JSON.stringify(data) } ).json(); 221 | 222 | if __name__ == '__main__': 223 | main() 224 | -------------------------------------------------------------------------------- /src/flatting/client_debug.py: -------------------------------------------------------------------------------- 1 | # need to write some test case 2 | import requests 3 | import base64 4 | import io 5 | import json 6 | import os 7 | 8 | from os.path import * 9 | from io import BytesIO 10 | from PIL import Image 11 | 12 | url = "http://jixuanzhi.asuscomm.com:8080/" 13 | image = "./trapped_ball/examples/01.png" 14 | # image = "test1.png" 15 | run_single_test = "flatsingle" 16 | run_multi_test = "flatmultiple" 17 | merge_test = "merge" 18 | split_auto_test = "splitauto" 19 | split_manual_test = "splitmanual" 20 | show_fillmap_test = "showfillmap" 21 | 22 | def png_to_base64(path_to_img): 23 | with open(path_to_img, 'rb') as f: 24 | return base64.encodebytes(f.read()).decode("utf-8") 25 | 26 | def test_case1(): 27 | # case for run single test 28 | data = {} 29 | data['image'] = png_to_base64(image) 30 | data['net'] = '1024_base' 31 | data['radius'] = 1 32 | data['preview'] = False 33 | 34 | # convert to json 35 | result = requests.post(url+run_single_test, json = json.dumps(data)) 36 | if result.status_code == 200: 37 | result = result.json() 38 | import pdb 39 | pdb.set_trace() 40 | line_sim = to_pil(result['line_artist']) 41 | line_sim.show() 42 | os.system("pause") 43 | 44 | line_sim = to_pil(result['image']) 45 | line_sim.show() 46 | os.system("pause") 47 | line_sim = to_pil(result['image_c']) 48 | line_sim.show() 49 | os.system("pause") 50 | line_sim = to_pil(result['line_simplified']) 51 | line_sim.show() 52 | os.system("pause") 53 | 54 | else: 55 | raise ValueError("Test failed") 56 | 57 | print("Done") 58 | 59 | def to_pil(byte): 60 | ''' 61 | A helper function to convert byte png to PIL.Image 62 | ''' 63 | byte = base64.b64decode(byte) 64 | return Image.open(BytesIO(byte)) 65 | 66 | test_case1() -------------------------------------------------------------------------------- /src/flatting/demo.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import argparse 3 | from os.path import * 4 | sys.path.append(join(dirname(abspath(__file__)), "trapped_ball")) 5 | 6 | 7 | import gradio as gr 8 | import numpy as np 9 | import torch 10 | import random 11 | 12 | from PIL import Image 13 | from torchvision import transforms as T 14 | from torchvision import utils 15 | 16 | # import model 17 | from unet import UNet 18 | from predict import predict_img 19 | 20 | # import trapped ball filling func 21 | from run import region_get_map 22 | 23 | from functools import partial 24 | from zipfile import ZipFile 25 | 26 | def to_t(array): 27 | return torch.Tensor(array).cuda().unsqueeze(0) 28 | 29 | 30 | def to_tensor(img): 31 | 32 | img_t = ( 33 | torch.from_numpy(img).unsqueeze(-1) 34 | .to(torch.float32) 35 | .div(255) 36 | .add_(-0.5) 37 | .mul_(2) 38 | .permute(2, 0, 1) 39 | ) 40 | return img_t.unsqueeze(0).cuda() 41 | 42 | # def denormalize(img): 43 | # # denormalize 44 | # inv_normalize = T.Normalize( mean=[-1, -1, -1], std=[2, 2, 2]) 45 | 46 | # img_np = inv_normalize(img.repeat(3,1,1)) 47 | # img_np = (img_np * 255).clamp(0, 255) 48 | 49 | # # to numpy 50 | # img_np = img_np.cpu().numpy().transpose((1,2,0)) 51 | 52 | # return Image.fromarray(img_np.astype(np.uint8)) 53 | 54 | def zip_files(files): 55 | with ZipFile("./flatting/gradio/all.zip", 'w') as zipObj: 56 | for f in files: 57 | zipObj.write(f) 58 | return "./flatting/gradio/all.zip" 59 | 60 | def split_to_4(img): 61 | 62 | # now I just write a simple code to split images into 4 evenly 63 | w, h = img.size 64 | h1 = h // 2 65 | w1 = w // 2 66 | img = np.array(img) 67 | 68 | # top left 69 | img1 = Image.fromarray(img[:h1, :w1]) 70 | 71 | # top right 72 | img2 = Image.fromarray(img[:h1, w1:]) 73 | 74 | # bottom left 75 | img3 = Image.fromarray(img[h1:, :w1]) 76 | 77 | # bottom right 78 | img4 = Image.fromarray(img[h1:, w1:]) 79 | 80 | return img1, img2, img3, img4 81 | 82 | def merge_to_1(imgs): 83 | 84 | img1, img2, img3, img4 = imgs 85 | img_top = np.concatenate((img1, img2), axis = 1) 86 | img_bottom = np.concatenate((img3, img4), axis = 1) 87 | 88 | return np.concatenate((img_top, img_bottom), axis = 0) 89 | 90 | def pred_and_fill(img, op, radius, patch, nets, outputs="./flatting/gradio"): 91 | 92 | # initail out files 93 | outs = [] 94 | outs.append(join(outputs, "%s_input.png"%op)) 95 | outs.append(join(outputs, "%s_fill.png"%op)) 96 | outs.append(join(outputs, "%s_fill_edge.png"%op)) 97 | outs.append(join(outputs, "%s_fill_line.png"%op)) 98 | outs.append(join(outputs, "%s_fill_line_full.png"%op)) 99 | 100 | 101 | # predict full image 102 | # img = cv2.threshold(img, 240, 255, cv2.THRESH_BINARY) 103 | if patch == "False": 104 | # img_w = Image.new("RGBA", img.size, "WHITE") 105 | # try: 106 | # img_w.paste(img, None, img) 107 | # img = img_w.convert("L") 108 | # except: 109 | # print("Log:\tfailed to add white background") 110 | 111 | edge = predict_img(net=nets[op][0], 112 | full_img=img, 113 | device=nets[op][1], 114 | size = int(op.replace("_rand", ""))) 115 | else: 116 | print("Log:\tsplit input into 4 patch with model %s"%(op)) 117 | # cut image into non-overlapping patches 118 | imgs = split_to_4(img) 119 | 120 | edges = [] 121 | for patch in imgs: 122 | edge = predict_img(net=nets[op][0], 123 | full_img=patch, 124 | device=nets[op][1], 125 | size = int(op)) 126 | 127 | edges.append(np.array(edge)) 128 | 129 | edge = Image.fromarray(merge_to_1(edges)) 130 | 131 | print("Log:\ttrapping ball filling with radius %s"%radius) 132 | fill = region_get_map(edge.convert("L"), 133 | radius_set=[int(radius)], percentiles=[0], 134 | path_to_line_artist=img, 135 | return_numpy=True, 136 | preview = True) 137 | 138 | return edge, fill 139 | 140 | def initial_models(path_to_ckpt): 141 | 142 | # find the lastest model 143 | ckpt_list = [] 144 | 145 | if ".pth" not in path_to_ckpt: 146 | for c in os.listdir(path_to_ckpt): 147 | if ".pth" in c: 148 | ckpt_list.append(c) 149 | ckpt_list.sort() 150 | path_to_ckpt = join(path_to_ckpt, ckpt_list[-1]) 151 | 152 | assert exists(path_to_ckpt) 153 | 154 | # init model 155 | net = UNet(in_channels=1, out_channels=1, bilinear=True) 156 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 157 | net.to(device=device) 158 | 159 | # load model 160 | print("Log:\tload %s"%path_to_ckpt) 161 | try: 162 | net.load_state_dict(torch.load(path_to_ckpt, map_location=device)) 163 | except: 164 | net = torch.nn.DataParallel(net) 165 | net.load_state_dict(torch.load(path_to_ckpt, map_location=device)) 166 | net.eval() 167 | 168 | return net, device 169 | 170 | def initial_flatting_input(): 171 | 172 | # inputs 173 | img = gr.inputs.Image(image_mode='L', 174 | invert_colors=False, source="upload", label="Input Image", 175 | type = "pil") 176 | # resize = gr.inputs.Radio(choices=["1024", "512", "256"], label="Resize") 177 | model = gr.inputs.Radio(choices=["1024", "1024_rand", "512", "512_rand"], label="Model") 178 | # split = gr.inputs.Radio(choices=["True", "False"], label="Split") 179 | radius = gr.inputs.Slider(minimum=1, maximum=10, step=1, default=7, label="kernel radius") 180 | 181 | # outputs 182 | out1 = gr.outputs.Image(type='pil', label='line prediction') 183 | out2 = gr.outputs.Image(type='pil', label='fill') 184 | # out5 = gr.outputs.File(label="all results") 185 | 186 | return [img, model, radius], [out1, out2] 187 | # return [img, resize], [out1, out2, out3, out4, out5] 188 | 189 | def start_demo(fn, inputs, outputs, examples): 190 | iface = gr.Interface(fn = fn, inputs = inputs, outputs = outputs, examples = examples, layout = "unaligned") 191 | iface.launch() 192 | 193 | def main(): 194 | 195 | # get base tcode number 196 | parser = argparse.ArgumentParser() 197 | parser.add_argument("--ckpt-1024", type=str, default = "./checkpoints/base_1024/") 198 | parser.add_argument("--ckpt-512", type=str, default = "./checkpoints/base_512/") 199 | # parser.add_argument("--ckpt-256", type=str, default = "./checkpoints/base_256/") 200 | parser.add_argument("--ckpt-512-rand", type=str, default = "./checkpoints/rand_512/") 201 | # parser.add_argument("--ckpt-256-rand", type=str, default = "./checkpoints/rand_256/") 202 | parser.add_argument("--ckpt-1024-rand", type=str, default = "./checkpoints/rand_1024/") 203 | 204 | args = parser.parse_args() 205 | 206 | # initailize 207 | nets = {} 208 | nets["1024"] = initial_models(args.ckpt_1024) 209 | nets["1024_rand"] = initial_models(args.ckpt_1024_rand) 210 | nets["512"] = initial_models(args.ckpt_512) 211 | nets["512_rand"] = initial_models(args.ckpt_512_rand) 212 | # nets["256"] = initial_models(args.ckpt_256) 213 | # nets["256_rand"] = initial_models(args.ckpt_256_rand) 214 | 215 | 216 | # construct exmaples 217 | example_path = "./flatting/validation" 218 | example_list = os.listdir(example_path) 219 | example_list.sort() 220 | 221 | examples = [] 222 | 223 | for file in example_list: 224 | print("find %s"%file) 225 | img = os.path.join(example_path, file) 226 | model = random.choice(["512_rand"]) 227 | radius = 2 228 | examples.append([img, model, radius]) 229 | 230 | # initial pred func 231 | fn = partial(pred_and_fill, nets=nets, patch="False", outputs="./flatting/gradio") 232 | 233 | # bug fix 234 | fn.__name__ = fn.func.__name__ 235 | 236 | # start 237 | inputs, outputs = initial_flatting_input() 238 | start_demo(fn=fn, inputs=inputs, outputs=outputs, examples=examples) 239 | 240 | def debug(): 241 | # get base tcode number 242 | parser = argparse.ArgumentParser() 243 | parser.add_argument("--ckpt-1024", type=str, default = "./checkpoints/base_1024/") 244 | parser.add_argument("--ckpt-512", type=str, default = "./checkpoints/base_512/") 245 | # parser.add_argument("--ckpt-256", type=str, default = "./checkpoints/base_256/") 246 | parser.add_argument("--ckpt-512-rand", type=str, default = "./checkpoints/rand_512/") 247 | # parser.add_argument("--ckpt-256-rand", type=str, default = "./checkpoints/rand_256/") 248 | parser.add_argument("--ckpt-1024-rand", type=str, default = "./checkpoints/rand_1024/") 249 | args = parser.parse_args() 250 | 251 | # initailize 252 | nets = {} 253 | nets["1024"] = initial_models(args.ckpt_1024) 254 | nets["1024_rand"] = initial_models(args.ckpt_1024_rand) 255 | nets["512"] = initial_models(args.ckpt_512) 256 | nets["512_rand"] = initial_models(args.ckpt_512_rand) 257 | # nets["256"] = initial_models(args.ckpt_256) 258 | # nets["256_rand"] = initial_models(args.ckpt_256_rand) 259 | 260 | 261 | 262 | img = Image.open("./flatting/validation/train_008.png").convert("L") 263 | pred_and_fill(img, radius=2, op='512_rand', patch="False", nets=nets, outputs="./flatting/gradio") 264 | 265 | if __name__ == '__main__': 266 | main() 267 | # debug() 268 | -------------------------------------------------------------------------------- /src/flatting/dice_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | 4 | 5 | class DiceCoeff(Function): 6 | """Dice coeff for individual examples""" 7 | 8 | def forward(self, input, target): 9 | self.save_for_backward(input, target) 10 | eps = 0.0001 11 | self.inter = torch.dot(input.view(-1), target.view(-1)) 12 | self.union = torch.sum(input) + torch.sum(target) + eps 13 | 14 | t = (2 * self.inter.float() + eps) / self.union.float() 15 | return t 16 | 17 | # This function has only a single output, so it gets only one gradient 18 | def backward(self, grad_output): 19 | 20 | input, target = self.saved_variables 21 | grad_input = grad_target = None 22 | 23 | if self.needs_input_grad[0]: 24 | grad_input = grad_output * 2 * (target * self.union - self.inter) \ 25 | / (self.union * self.union) 26 | if self.needs_input_grad[1]: 27 | grad_target = None 28 | 29 | return grad_input, grad_target 30 | 31 | 32 | def dice_coeff(input, target): 33 | """Dice coeff for batches""" 34 | if input.is_cuda: 35 | s = torch.FloatTensor(1).cuda().zero_() 36 | else: 37 | s = torch.FloatTensor(1).zero_() 38 | 39 | for i, c in enumerate(zip(input, target)): 40 | s = s + DiceCoeff().forward(c[0], c[1]) 41 | 42 | return s / (i + 1) 43 | -------------------------------------------------------------------------------- /src/flatting/eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from tqdm import tqdm 4 | 5 | from dice_loss import dice_coeff 6 | 7 | 8 | # 对,这个函数也要修改 9 | 10 | def eval_net(net, loader, device): 11 | """Evaluation without the densecrf with the dice coefficient""" 12 | net.eval() 13 | mask_type = torch.float32 if net.n_classes == 1 else torch.long 14 | n_val = len(loader) # the number of batch 15 | tot = 0 16 | 17 | with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar: 18 | for batch in loader: 19 | imgs, true_masks = batch['image'], batch['mask'] 20 | imgs = imgs.to(device=device, dtype=torch.float32) 21 | true_masks = true_masks.to(device=device, dtype=mask_type) 22 | 23 | with torch.no_grad(): 24 | mask_pred = net(imgs) 25 | 26 | if net.n_classes > 1: 27 | tot += F.cross_entropy(mask_pred, true_masks).item() 28 | else: 29 | pred = torch.sigmoid(mask_pred) 30 | pred = (pred > 0.5).float() 31 | tot += dice_coeff(pred, true_masks).item() 32 | pbar.update() 33 | 34 | net.train() 35 | return tot / n_val 36 | -------------------------------------------------------------------------------- /src/flatting/flatting_api_async.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from concurrent.futures import ProcessPoolExecutor 3 | import functools 4 | 5 | from . import flatting_api 6 | 7 | ## This controls the number of parallel processes. 8 | ## Keep in mind that parallel processes will load duplicate networks 9 | ## and compete for the same RAM, which could lead to thrashing. 10 | ## Pass `max_workers = N` for exactly `N` parallel processes. 11 | # import multiprocessing 12 | # HALF_CORES = max( multiprocessing.cpu_count()//2, 1 ) ) 13 | executor_batch = ProcessPoolExecutor(4) 14 | executor_interactive = ProcessPoolExecutor(4) 15 | 16 | async def run_async( executor, f ): 17 | ## We expect this to be called from inside an existing loop. 18 | ## As a result, we call `get_running_loop()` instead of `get_event_loop()` so that 19 | ## it raises an error if our assumption is false, rather than creating a new loop. 20 | loop = asyncio.get_running_loop() 21 | data = await loop.run_in_executor( executor, f ) 22 | return data 23 | 24 | async def run_single( *args, **kwargs ): 25 | return await run_async( executor_batch, functools.partial( flatting_api.run_single, *args, **kwargs ) ) 26 | 27 | async def merge( *args, **kwargs ): 28 | return await run_async( executor_interactive, functools.partial( flatting_api.merge, *args, **kwargs ) ) 29 | 30 | async def split_manual( *args, **kwargs ): 31 | return await run_async( executor_interactive, functools.partial( flatting_api.split_manual, *args, **kwargs ) ) 32 | -------------------------------------------------------------------------------- /src/flatting/hubconf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from unet import UNet as _UNet 3 | 4 | def unet_carvana(pretrained=False): 5 | """ 6 | UNet model trained on the Carvana dataset ( https://www.kaggle.com/c/carvana-image-masking-challenge/data ). 7 | Set the scale to 1 (100%) when predicting. 8 | """ 9 | net = _UNet(n_channels=3, n_classes=1, bilinear=True) 10 | if pretrained: 11 | checkpoint = 'https://github.com/milesial/Pytorch-UNet/releases/download/v1.0/unet_carvana_scale1_epoch5.pth' 12 | net.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=True)) 13 | 14 | return net 15 | 16 | -------------------------------------------------------------------------------- /src/flatting/predict.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | 4 | import os 5 | # import os, sys 6 | from os.path import join 7 | # sys.path.append(join(dirname(abspath(__file__)), "trapped_ball")) 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn.functional as F 12 | import cv2 13 | 14 | from PIL import Image 15 | from torchvision import transforms as T 16 | 17 | from .unet import UNet 18 | from .utils.preprocessing import to_point_list, find_bbox, crop_img 19 | from .trapped_ball.run import region_get_map 20 | 21 | 22 | def to_tensor(img): 23 | 24 | transforms = T.Compose( 25 | [ 26 | # to tensor will change the channel order and divide 255 if necessary 27 | T.ToTensor(), 28 | T.Normalize(0.5, 0.5, inplace = True) 29 | ] 30 | ) 31 | 32 | return transforms(img) 33 | 34 | def denormalize(img): 35 | # denormalize 36 | inv_normalize = T.Normalize( mean=-1, std=2) 37 | 38 | img_np = inv_normalize(img.repeat(3,1,1)).clamp(0, 1) 39 | img_np = img_np * 255 40 | 41 | # to numpy 42 | img_np = img_np.cpu().numpy().transpose((1,2,0)) 43 | 44 | return Image.fromarray(img_np.astype(np.uint8)).convert("L") 45 | 46 | def to_numpy(f, size, bbox=None): 47 | 48 | if type(f) == str: 49 | img = np.array(Image.open(f).convert("L")) 50 | 51 | else: 52 | img = np.array(f.convert("L")) 53 | 54 | if bbox != None: 55 | img = crop_img(bbox, img) 56 | 57 | h, w = img.shape 58 | ratio = size/w if w < h else size/h 59 | 60 | return cv2.resize(img, (int(w*ratio+0.5), int(h*ratio+0.5)), interpolation=cv2.INTER_AREA) 61 | 62 | def predict_img(net, 63 | full_img, 64 | device, 65 | size): 66 | net.eval() 67 | 68 | # corp image 69 | bbox = find_bbox(to_point_list(np.array(full_img))) 70 | 71 | # read image 72 | print("Log:\tpredict image at size %d"%size) 73 | img = to_tensor(to_numpy(full_img, size, bbox = bbox)) 74 | img = img.unsqueeze(0) 75 | img = img.to(device=device, dtype=torch.float32) 76 | 77 | with torch.no_grad(): 78 | output = net(img) 79 | 80 | output = denormalize(output[0]) 81 | 82 | return output, bbox 83 | 84 | 85 | def get_args(): 86 | parser = argparse.ArgumentParser(description='Predict edge from line art', 87 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 88 | 89 | parser.add_argument('--model', '-m', default='./checkpoints/exp1/CP_epoch2001.pth', 90 | metavar='FILE', 91 | help="Specify the file in which the model is stored") 92 | 93 | parser.add_argument('--input', '-i', type=str, 94 | help='filename of single input image, include path') 95 | 96 | parser.add_argument('--output', '-o', type=str, 97 | help='filename of single ouput image, include path') 98 | 99 | parser.add_argument('--input-path', type=str, default="./flatting/validation", 100 | help='path to input images') 101 | 102 | parser.add_argument('--output-path', type=str, default="./results/val", 103 | help='path to ouput images') 104 | 105 | return parser.parse_args() 106 | 107 | 108 | if __name__ == "__main__": 109 | 110 | args = get_args() 111 | 112 | in_files = args.input 113 | 114 | net = UNet(in_channels=1, out_channels=1, bilinear=True) 115 | 116 | logging.info("Loading model {}".format(args.model)) 117 | 118 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 119 | logging.info(f'Using device {device}') 120 | 121 | net.to(device=device) 122 | net.load_state_dict(torch.load(args.model, map_location=device)) 123 | 124 | logging.info("Model loaded !") 125 | 126 | for f in os.listdir(args.input_path): 127 | name, _ = splitext(f) 128 | 129 | logging.info("\nPredicting image {} ...".format(join(args.input_path, f))) 130 | 131 | 132 | # predict edge and save image 133 | edge = predict_img(net=net, 134 | full_img=join(args.input_path, f), 135 | device=device, 136 | size = 1024) 137 | 138 | edge.save(join(args.output_path, name + "_pred.png")) 139 | 140 | # trapped ball fill and save image 141 | region_get_map(join(args.output_path, name + "_pred.png"), args.output_path, 142 | radius_set=[1], percentiles=[0], 143 | path_to_line = join(args.input_path, f), 144 | save_org_size = True) 145 | -------------------------------------------------------------------------------- /src/flatting/resources/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nauhcnay/flat_magic_backend/5344f11c7a50c0a5b0d0876dcf68aa45b5a84687/src/flatting/resources/__init__.py -------------------------------------------------------------------------------- /src/flatting/resources/flatting.icns: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nauhcnay/flat_magic_backend/5344f11c7a50c0a5b0d0876dcf68aa45b5a84687/src/flatting/resources/flatting.icns -------------------------------------------------------------------------------- /src/flatting/resources/flatting.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nauhcnay/flat_magic_backend/5344f11c7a50c0a5b0d0876dcf68aa45b5a84687/src/flatting/resources/flatting.ico -------------------------------------------------------------------------------- /src/flatting/resources/flatting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nauhcnay/flat_magic_backend/5344f11c7a50c0a5b0d0876dcf68aa45b5a84687/src/flatting/resources/flatting.png -------------------------------------------------------------------------------- /src/flatting/submit.py: -------------------------------------------------------------------------------- 1 | """ Submit code specific to the kaggle challenge""" 2 | 3 | import os 4 | 5 | import torch 6 | from PIL import Image 7 | import numpy as np 8 | 9 | from predict import predict_img 10 | from unet import UNet 11 | 12 | # credits to https://stackoverflow.com/users/6076729/manuel-lagunas 13 | def rle_encode(mask_image): 14 | pixels = mask_image.flatten() 15 | # We avoid issues with '1' at the start or end (at the corners of 16 | # the original image) by setting those pixels to '0' explicitly. 17 | # We do not expect these to be non-zero for an accurate mask, 18 | # so this should not harm the score. 19 | pixels[0] = 0 20 | pixels[-1] = 0 21 | runs = np.where(pixels[1:] != pixels[:-1])[0] + 2 22 | runs[1::2] = runs[1::2] - runs[:-1:2] 23 | return runs 24 | 25 | 26 | def submit(net): 27 | """Used for Kaggle submission: predicts and encode all test images""" 28 | dir = 'data/test/' 29 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 30 | N = len(list(os.listdir(dir))) 31 | with open('SUBMISSION.csv', 'a') as f: 32 | f.write('img,rle_mask\n') 33 | for index, i in enumerate(os.listdir(dir)): 34 | print('{}/{}'.format(index, N)) 35 | 36 | img = Image.open(dir + i) 37 | 38 | mask = predict_img(net, img, device) 39 | enc = rle_encode(mask) 40 | f.write('{},{}\n'.format(i, ' '.join(map(str, enc)))) 41 | 42 | 43 | if __name__ == '__main__': 44 | net = UNet(3, 1).cuda() 45 | net.load_state_dict(torch.load('MODEL.pth')) 46 | submit(net) 47 | -------------------------------------------------------------------------------- /src/flatting/tkapp.py: -------------------------------------------------------------------------------- 1 | from flatting.app import main as start_server 2 | 3 | ## Briefcase doesn't support tkinter 4 | 5 | import asyncio 6 | import tkinter as tk 7 | 8 | async def start_server_in_thread(): 9 | def start_server_wrapper(): 10 | asyncio.set_event_loop(asyncio.new_event_loop()) 11 | start_server() 12 | 13 | # Run the server in a thread. 14 | import threading 15 | server = threading.Thread( name='flatting_server', target=start_server_wrapper ) 16 | server.setDaemon( True ) 17 | server.start() 18 | 19 | def start_gui(): 20 | import tkinter as tk 21 | root = tk.Tk() 22 | 23 | ## Adapting: https://stackoverflow.com/questions/47895765/use-asyncio-and-tkinter-or-another-gui-lib-together-without-freezing-the-gui 24 | loop = asyncio.get_event_loop() 25 | 26 | INTERVAL = 1/30 27 | async def guiloop(): 28 | while True: 29 | root.update() 30 | await asyncio.sleep( INTERVAL ) 31 | task = loop.create_task( guiloop() ) 32 | 33 | def shutdown(): 34 | task.cancel() 35 | loop.stop() 36 | # loop.close() 37 | root.destroy() 38 | 39 | root.title("Flatting Backend") 40 | ## The port changes if we pass a port argument to `web.run_app`. 41 | tk.Label( root, text="Serving at http://127.0.0.1:8080" ).pack() 42 | tk.Button( root, text="Quit", command=shutdown ).pack() 43 | 44 | # tk.mainloop() 45 | 46 | 47 | def main(): 48 | start_gui() 49 | start_server() 50 | 51 | if __name__ == '__main__': 52 | main() 53 | -------------------------------------------------------------------------------- /src/flatting/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import sys 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from torch import optim 10 | from tqdm import tqdm 11 | 12 | from eval import eval_net 13 | from unet import UNet 14 | 15 | from torch.utils.tensorboard import SummaryWriter 16 | from utils.dataset import BasicDataset 17 | from torch.utils.data import DataLoader, random_split 18 | from torchvision import transforms as T 19 | from torchvision import utils 20 | from PIL import Image 21 | from io import BytesIO 22 | 23 | dir_line = './flatting/size_512/line_croped' 24 | dir_edge = './flatting/size_512/line_detection_croped' 25 | dir_checkpoint = './checkpoints' 26 | 27 | def denormalize(img): 28 | # denormalize 29 | inv_normalize = T.Normalize( mean=[-1], std=[2]) 30 | 31 | img_np = inv_normalize(img) 32 | img_np = img_np.clamp(0, 1) 33 | # to numpy 34 | return img_np 35 | 36 | def train_net(net, 37 | device, 38 | epochs=100, 39 | batch_size=1, 40 | lr=0.001, 41 | val_percent=0.1, 42 | save_cp=True, 43 | crop_size = None): 44 | 45 | # dataset = BasicDataset(dir_line, dir_edge, crop_size = crop_size) 46 | logging.info("Loading training set to memory") 47 | lines_bytes, edges_bytes = load_to_ram(dir_line, dir_edge) 48 | 49 | dataset = BasicDataset(lines_bytes, edges_bytes, crop_size = crop_size) 50 | 51 | n_train = len(dataset) 52 | 53 | train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True) 54 | 55 | # we don't need valiation currently 56 | # val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True, drop_last=True) 57 | 58 | writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}') 59 | 60 | global_step = 0 61 | 62 | # logging.info(f'''Starting training: 63 | # Epochs: {epochs} 64 | # Batch size: {batch_size} 65 | # Learning rate: {lr} 66 | # Training size: {n_train} 67 | # Validation size: {n_val} 68 | # Checkpoints: {save_cp} 69 | # Device: {device.type} 70 | # Images scaling: {img_scale} 71 | # ''') 72 | 73 | logging.info(f'''Starting training: 74 | Epochs: {epochs} 75 | Batch size: {batch_size} 76 | Learning rate: {lr} 77 | Training size: {n_train} 78 | Checkpoints: {save_cp} 79 | Device: {device.type} 80 | Crop size: {crop_size} 81 | ''') 82 | 83 | #optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9) 84 | optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=1e-8) 85 | 86 | # how to use scheduler? 87 | # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min' if net.n_classes > 1 else 'max', patience=2) 88 | 89 | # since now we are trying to generate images, so we use l1 loss 90 | # if net.n_classes > 1: 91 | # criterion = nn.CrossEntropyLoss() 92 | # else: 93 | # criterion = nn.BCEWithLogitsLoss() 94 | 95 | criterion = nn.L1Loss() 96 | 97 | for epoch in range(epochs): 98 | net.train() 99 | 100 | epoch_loss = 0 101 | with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar: 102 | for imgs, gts, mask1, mask2 in train_loader: 103 | 104 | # assert imgs.shape[1] == net.in_channels, \ 105 | # f'Network has been defined with {net.in_channels} input channels, ' \ 106 | # f'but loaded images have {imgs.shape[1]} channels. Please check that ' \ 107 | # 'the images are loaded correctly.' 108 | 109 | imgs = imgs.to(device=device, dtype=torch.float32) 110 | gts = gts.to(device=device, dtype=torch.float32) 111 | 112 | # forward 113 | pred = net(imgs) 114 | 115 | ''' 116 | baseline 117 | ''' 118 | # loss1 = criterion(pred, gts) 119 | 120 | ''' 121 | weighted loss 122 | ''' 123 | mask_1 = (1-mask1) 124 | mask_2 = 100 * (1-mask2) 125 | mask_3 = 0.5 * mask2 126 | mask_w = mask_1 + mask_2 + mask_3 127 | mask_w = mask_w.to(device=device, dtype=torch.float32) 128 | loss1 = criterion(pred * mask_w, gts * mask_w) 129 | 130 | ''' 131 | point number loss 132 | the point number of the perdiction and gt should close, too 133 | 134 | ''' 135 | loss2 = criterion( 136 | ((denormalize(gts)==0).sum()).float(), 137 | ((denormalize(pred)==0).sum()).float() 138 | ) 139 | 140 | # total loss 141 | loss = loss1 + 0.5 * torch.log(torch.abs(loss2 + 1)) 142 | 143 | epoch_loss += loss.item() 144 | writer.add_scalar('Loss/total', loss.item(), global_step) 145 | writer.add_scalar('Loss/l1', loss1.item(), global_step) 146 | writer.add_scalar('Loss/point', loss2.item(), global_step) 147 | 148 | pbar.set_postfix(**{'loss (batch)': loss.item()}) 149 | 150 | # back propagate 151 | optimizer.zero_grad() 152 | loss.backward() 153 | nn.utils.clip_grad_value_(net.parameters(), 0.1) 154 | optimizer.step() 155 | 156 | pbar.update(imgs.shape[0]) 157 | 158 | global_step += 1 159 | 160 | # if global_step % (n_train // (10 * batch_size)) == 0: 161 | # for tag, value in net.named_parameters(): 162 | # tag = tag.replace('.', '/') 163 | # writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step) 164 | # writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), global_step) 165 | # val_score = eval_net(net, val_loader, device) 166 | # scheduler.step(val_score) 167 | # writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step) 168 | 169 | # if net.n_classes > 1: 170 | # logging.info('Validation cross entropy: {}'.format(val_score)) 171 | # writer.add_scalar('Loss/test', val_score, global_step) 172 | # else: 173 | # logging.info('Validation Dice Coeff: {}'.format(val_score)) 174 | # writer.add_scalar('Dice/test', val_score, global_step) 175 | 176 | # writer.add_images('images', imgs, global_step) 177 | # if net.n_classes == 1: 178 | # writer.add_images('masks/true', true_masks, global_step) 179 | # writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step) 180 | 181 | if global_step % 1000 == 0: 182 | sample = torch.cat((imgs, pred, gts), dim = 0) 183 | if os.path.exists("./results/train/") is False: 184 | logging.info("Creating ./results/train/") 185 | os.makedirs("./results/train/") 186 | 187 | utils.save_image( 188 | sample, 189 | f"./results/train/{str(global_step).zfill(6)}.png", 190 | nrow=int(batch_size), 191 | # nrow=int(sample.shape[0] ** 0.5), 192 | normalize=True, 193 | range=(-1, 1), 194 | ) 195 | 196 | if save_cp and epoch % 100 == 0: 197 | try: 198 | os.mkdir(dir_checkpoint) 199 | logging.info('Created checkpoint directory') 200 | except OSError: 201 | pass 202 | torch.save(net.state_dict(), 203 | dir_checkpoint + f'/CP_epoch{epoch + 1}.pth') 204 | logging.info(f'Checkpoint {epoch + 1} saved !') 205 | 206 | writer.close() 207 | 208 | def save_to_ram(path_to_img): 209 | 210 | img = Image.open(path_to_img).convert("L") 211 | buffer = BytesIO() 212 | img.save(buffer, format='png') 213 | 214 | return buffer.getvalue() 215 | 216 | def load_to_ram(path_to_line, path_to_edge): 217 | lines = os.listdir(path_to_line) 218 | lines.sort() 219 | 220 | edges = os.listdir(path_to_edge) 221 | edges.sort() 222 | 223 | assert len(lines) == len(edges) 224 | 225 | lines_bytes = [] 226 | edges_bytes = [] 227 | 228 | # read everything into memory 229 | for img in tqdm(lines): 230 | assert img.replace("webp", "png") in edges 231 | 232 | lines_bytes.append(save_to_ram(os.path.join(path_to_line, img))) 233 | edges_bytes.append(save_to_ram(os.path.join(path_to_edge, img.replace("webp", "png")))) 234 | 235 | return lines_bytes, edges_bytes 236 | 237 | 238 | def get_args(): 239 | parser = argparse.ArgumentParser(description='Train the UNet on images and target masks', 240 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 241 | parser.add_argument('-e', '--epochs', metavar='E', type=int, default=90000, 242 | help='Number of epochs', dest='epochs') 243 | parser.add_argument('-m', '--multi-gpu', action='store_true') 244 | parser.add_argument('-c', '--crop-size', metavar='C', type=int, default=512, 245 | help='the size of random cropping') 246 | parser.add_argument('-b', '--batch-size', metavar='B', type=int, nargs='?', default=1, 247 | help='Batch size', dest='batchsize') 248 | parser.add_argument('-l', '--learning-rate', metavar='LR', type=float, nargs='?', default=0.0001, 249 | help='Learning rate', dest='lr') 250 | parser.add_argument('-f', '--load', dest='load', type=str, default=False, 251 | help='Load model from a .pth file') 252 | 253 | return parser.parse_args() 254 | 255 | 256 | if __name__ == '__main__': 257 | 258 | __spec__ = None 259 | 260 | logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') 261 | args = get_args() 262 | 263 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 264 | logging.info(f'Using device {device}') 265 | 266 | # Change here to adapt to your data 267 | # n_channels=3 for RGB images 268 | # n_classes is the number of probabilities you want to get per pixel 269 | # - For 1 class and background, use n_classes=1 270 | # - For 2 classes, use n_classes=1 271 | # - For N > 2 classes, use n_classes=N 272 | 273 | net = UNet(in_channels=1, out_channels=1, bilinear=True) 274 | 275 | if args.multi_gpu: 276 | logging.info("using data parallel") 277 | net = nn.DataParallel(net).cuda() 278 | else: 279 | net.to(device=device) 280 | 281 | # logging.info(f'Network:\n' 282 | # f'\t{net.in_channels} input channels\n' 283 | # f'\t{net.out_channels} output channels\n' 284 | # f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling' 285 | # ) 286 | 287 | if args.load: 288 | net.load_state_dict( 289 | torch.load(args.load, map_location=device) 290 | ) 291 | logging.info(f'Model loaded from {args.load}') 292 | 293 | 294 | # faster convolutions, but more memory 295 | # cudnn.benchmark = True 296 | 297 | try: 298 | train_net(net=net, 299 | epochs=args.epochs, 300 | batch_size=args.batchsize, 301 | lr=args.lr, 302 | device=device, 303 | crop_size=args.crop_size) 304 | 305 | # this is interesting, save model when keyborad interrupt 306 | except KeyboardInterrupt: 307 | torch.save(net.state_dict(), './checkpoints/INTERRUPTED.pth') 308 | # logging.info('Saved interrupt') 309 | try: 310 | sys.exit(0) 311 | except SystemExit: 312 | os._exit(0) 313 | -------------------------------------------------------------------------------- /src/flatting/trapped_ball/adjacency_matrix.pyx: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | # cython: boundscheck=False 3 | # cython: wraparound=False 4 | 5 | #import numpy as np 6 | #cimport numpy as np 7 | 8 | ## TODO: Allocate memory internally. See: https://stackoverflow.com/questions/18462785/what-is-the-recommended-way-of-allocating-memory-for-a-typed-memory-view 9 | 10 | def adjacency_matrix( image, num_regions ): 11 | ''' 12 | Given: 13 | image: A 2D image of integer labels in the range [0,num_regions]. 14 | num_regions: The number of regions in `image`. 15 | Returns: 16 | A: The adjacency matrix such that A[i,j] is 1 if region i is 17 | connected to region j and 0 otherwise. 18 | ''' 19 | 20 | import numpy as np 21 | A = np.zeros( ( num_regions, num_regions ), dtype = int ) 22 | adjacency_matrix_internal( image, A ) 23 | return A 24 | 25 | cpdef long[:,:] adjacency_matrix_internal( long[:,:] image, long[:,:] A ) nogil: 26 | ''' 27 | Given: 28 | image: A 2D image of integer labels in the range [0,num_regions]. 29 | Returns: 30 | A: The adjacency matrix such that A[i,j] is 1 if region i is 31 | connected to region j and 0 otherwise. 32 | 33 | Note: `A` is an output parameter. Allocate space and pass it in. 34 | ''' 35 | 36 | # A = np.zeros( ( num_regions, num_regions ), dtype = int ) 37 | A[:] = 0 38 | 39 | cdef long nrow = image.shape[0] 40 | cdef long ncol = image.shape[1] 41 | cdef long i,j,region0,region1 42 | 43 | ## Sweep with left-right neighbors. Skip the right-most column. 44 | for i in range(nrow): 45 | for j in range(ncol-1): 46 | region0 = image[i,j] 47 | region1 = image[i,j+1] 48 | A[region0,region1] = 1 49 | A[region1,region0] = 1 50 | 51 | ## Sweep with top-bottom neighbors. Skip the bottom-most row. 52 | for i in range(nrow-1): 53 | for j in range(ncol): 54 | region0 = image[i,j] 55 | region1 = image[i+1,j] 56 | A[region0,region1] = 1 57 | A[region1,region0] = 1 58 | 59 | ## Sweep with top-left-to-bottom-right neighbors. Skip the bottom row and right column. 60 | for i in range(nrow-1): 61 | for j in range(ncol-1): 62 | region0 = image[i,j] 63 | region1 = image[i+1,j+1] 64 | A[region0,region1] = 1 65 | A[region1,region0] = 1 66 | 67 | ## Sweep with top-right-to-bottom-left neighbors. Skip the bottom row and left column. 68 | for i in range(nrow-1): 69 | for j in range(1,ncol): 70 | region0 = image[i,j] 71 | region1 = image[i+1,j-1] 72 | A[region0,region1] = 1 73 | A[region1,region0] = 1 74 | 75 | ## region will not connect to itself 76 | for i in range(len(A)): 77 | A[i, i] = 0 78 | 79 | return A 80 | 81 | def region_sizes( image, num_regions ): 82 | ''' 83 | Given: 84 | image: A 2D image of integer labels in the range [0,num_regions]. 85 | num_regions: The number of regions in `image`. 86 | Returns: 87 | sizes: An array of length `num_regions`. Each element stores the 88 | number of pixels with the corresponding region number. 89 | That is, region i has `region_sizes[i]` pixels. 90 | ''' 91 | 92 | import numpy as np 93 | sizes = np.zeros( num_regions, dtype = int ) 94 | region_sizes_internal( image, sizes ) 95 | return sizes 96 | 97 | cpdef long[:] region_sizes_internal( long[:,:] image, long[:] sizes ) nogil: 98 | ''' 99 | Given: 100 | image: A 2D image of integer labels in the range [0,num_regions]. 101 | Returns: 102 | sizes: An array of length `num_regions`. Each element stores the 103 | number of pixels with the corresponding region number. 104 | That is, region i has `region_sizes[i]` pixels. 105 | 106 | Note: `sizes` is an output parameter. Allocate space and pass it in. 107 | ''' 108 | 109 | # sizes = np.zeros( num_regions, dtype = int ) 110 | sizes[:] = 0 111 | 112 | cdef long nrow = image.shape[0] 113 | cdef long ncol = image.shape[1] 114 | cdef long i,j,region 115 | 116 | ## Sweep with left-right neighbors. Skip the right-most column. 117 | for i in range(nrow): 118 | for j in range(ncol): 119 | region = image[i,j] 120 | sizes[region] += 1 121 | 122 | return sizes 123 | 124 | def remap_labels( image, remaps ): 125 | ''' 126 | Given: 127 | image: A 2D image of integer labels in the range [0,len(remaps)]. 128 | remaps: A 1D array of integer remappings that occurred. 129 | Returns: 130 | image_remapped: An array the same size as `image` except labels have been 131 | remapped according to `remaps`. 132 | ''' 133 | 134 | import numpy as np 135 | image_remapped = image.copy().astype( int ) 136 | remaps = remaps.astype( int ) 137 | remap_labels_internal( image_remapped, remaps ) 138 | return image_remapped 139 | 140 | cpdef void remap_labels_internal( long[:,:] image, long[:] remaps ) nogil: 141 | ''' 142 | Given: 143 | image: A 2D image of integer labels in the range [0,len(remaps)]. 144 | remaps: A 1D array of integer remappings that occurred. 145 | Modifies in-place: 146 | image: `image` with labels remapped according to `remaps`. 147 | ''' 148 | 149 | cdef long nrow = image.shape[0] 150 | cdef long ncol = image.shape[1] 151 | cdef long i,j,region 152 | 153 | ## Sweep with left-right neighbors. Skip the right-most column. 154 | for i in range(nrow): 155 | for j in range(ncol): 156 | region = image[i,j] 157 | # I see, the merge operation likes a chain, we need to apply all of them one by one 158 | # or we can apply this on remaps frist, I'm not sure which one will be faster. 159 | while remaps[region] != region: 160 | image[i,j] = remaps[region] 161 | region = remaps[region] 162 | 163 | return 164 | -------------------------------------------------------------------------------- /src/flatting/trapped_ball/examples/01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nauhcnay/flat_magic_backend/5344f11c7a50c0a5b0d0876dcf68aa45b5a84687/src/flatting/trapped_ball/examples/01.png -------------------------------------------------------------------------------- /src/flatting/trapped_ball/examples/01_sim.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nauhcnay/flat_magic_backend/5344f11c7a50c0a5b0d0876dcf68aa45b5a84687/src/flatting/trapped_ball/examples/01_sim.png -------------------------------------------------------------------------------- /src/flatting/trapped_ball/examples/02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nauhcnay/flat_magic_backend/5344f11c7a50c0a5b0d0876dcf68aa45b5a84687/src/flatting/trapped_ball/examples/02.png -------------------------------------------------------------------------------- /src/flatting/trapped_ball/examples/tiny.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nauhcnay/flat_magic_backend/5344f11c7a50c0a5b0d0876dcf68aa45b5a84687/src/flatting/trapped_ball/examples/tiny.png -------------------------------------------------------------------------------- /src/flatting/trapped_ball/examples/tiny_sim.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nauhcnay/flat_magic_backend/5344f11c7a50c0a5b0d0876dcf68aa45b5a84687/src/flatting/trapped_ball/examples/tiny_sim.png -------------------------------------------------------------------------------- /src/flatting/trapped_ball/run.py: -------------------------------------------------------------------------------- 1 | 2 | from .trappedball_fill import trapped_ball_fill_multi, flood_fill_multi, mark_fill, build_fill_map, merge_fill, show_fill_map, merger_fill_2nd 3 | from .trappedball_fill import get_ball_structuring_element, extract_line, to_masked_line 4 | from .thinning import thinning 5 | # from skimage.morphology import skeletonize 6 | from PIL import Image 7 | from tqdm import tqdm 8 | 9 | import argparse 10 | import cv2 11 | # import matplotlib.pyplot as plt 12 | import os 13 | import numpy as np 14 | from os.path import * 15 | 16 | # use cython adjacency matrix 17 | try: 18 | from . import adjacency_matrix 19 | ## If it's not already compiled, compile it. 20 | except: 21 | import pyximport 22 | pyximport.install() 23 | from . import adjacency_matrix 24 | 25 | def extract_skeleton(img): 26 | 27 | size = np.size(img) 28 | skel = np.zeros(img.shape,np.uint8) 29 | element = cv2.getStructuringElement(cv2.MORPH_CROSS,(3,3)) 30 | done = False 31 | 32 | while done is False: 33 | eroded = cv2.erode(img,element) 34 | temp = cv2.dilate(eroded,element) 35 | temp = cv2.subtract(img,temp) 36 | skel = cv2.bitwise_or(skel,temp) 37 | img = eroded.copy() 38 | 39 | zeros = size - cv2.countNonZero(img) 40 | if zeros==size: 41 | done = True 42 | 43 | return skel 44 | 45 | def generate_masked_line(line_simplify, line_artist, line_artist_fullsize): 46 | line_masked = to_masked_line(line_simplify, line_artist, rk1=1, rk2=1, tn=1) 47 | 48 | # remove isolate points 49 | # it is not safe to do that at down scaled size 50 | # _, result = cv2.connectedComponents(255 - line_masked, connectivity=8) 51 | 52 | # up scale masked line to full size 53 | line_masked_fullsize_t = cv2.resize(line_masked.astype(np.uint8), 54 | (line_artist_fullsize.shape[1], line_artist_fullsize.shape[0]), 55 | interpolation = cv2.INTER_NEAREST) 56 | 57 | # maske with fullsize artist line again 58 | line_masked_fullsize = to_masked_line(line_masked_fullsize_t, line_artist_fullsize, rk1=7, rk2=1, tn=2) 59 | 60 | # remove isolate points 61 | _, temp = cv2.connectedComponents(255 - line_masked_fullsize, connectivity=4) 62 | 63 | def remove_stray_points(fillmap, drop_thres = 32): 64 | ids = np.unique(fillmap) 65 | result = np.ones(fillmap.shape) * 255 66 | 67 | for i in tqdm(ids): 68 | if i == 0: continue 69 | if len(np.where(fillmap == i)[0]) < drop_thres: 70 | # set them as background 71 | result[fillmap == i] = 255 72 | else: 73 | # set them as line 74 | result[fillmap == i] = 0 75 | 76 | return result 77 | 78 | line_masked_fullsize = remove_stray_points(temp, 16) 79 | 80 | return line_masked_fullsize 81 | 82 | 83 | def region_get_map(path_to_line_sim, 84 | path_to_line_artist=None, 85 | output_path=None, 86 | radius_set=[3,2,1], 87 | percentiles=[90, 0, 0], 88 | visualize_steps=False, 89 | return_numpy=False, 90 | preview=False): 91 | ''' 92 | Given: 93 | the path to input png file 94 | Return: 95 | the initial region map as a numpy matrix 96 | ''' 97 | def read_png(path_to_png, to_grayscale=True): 98 | ''' 99 | Given: 100 | path_to_png, it accept be any type of input, path, numpy array or PIL Image 101 | Return: 102 | the numpy array of a image 103 | ''' 104 | 105 | # if it is png file, open it 106 | if isinstance(path_to_png, str): 107 | # get file name 108 | _, file = os.path.split(path_to_png) 109 | name, _ = os.path.splitext(file) 110 | 111 | print("Log:\topen %s"%path_to_png) 112 | img_org = cv2.imread(path_to_png, cv2.IMREAD_COLOR) 113 | if to_grayscale: 114 | img = cv2.cvtColor(img_org, cv2.COLOR_BGR2GRAY) 115 | else: 116 | img = img_org 117 | 118 | elif isinstance(path_to_png, Image.Image): 119 | if to_grayscale: 120 | path_to_png = path_to_png.convert("L") 121 | img = np.array(path_to_png) 122 | name = "result" 123 | 124 | elif isinstance(path_to_png, np.ndarray): 125 | img = path_to_png 126 | if len(img.shape) > 2 and to_grayscale: 127 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 128 | name = "result" 129 | 130 | else: 131 | raise ValueError("The input data type %s is not supported"%str(type(path_to_png))) 132 | 133 | return img, name 134 | 135 | # read files 136 | img, name = read_png(path_to_line_sim) 137 | line_artist_fullsize, _ = read_png(path_to_line_artist) 138 | if len(line_artist_fullsize.shape) == 3: 139 | line_artist_fullsize = cv2.cvtColor(line_artist_fullsize, cv2.COLOR_BGR2GRAY) 140 | # line_artist_fullsize = cv2.adaptiveThreshold(line_artist_fullsize, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY,11,2) 141 | 142 | print("Log:\ttrapped ball filling") 143 | # threshold line arts before filling 144 | _, line_simplify = cv2.threshold(img, 200, 255, cv2.THRESH_BINARY) 145 | _, line_artist_fullsize = cv2.threshold(line_artist_fullsize, 200, 255, cv2.THRESH_BINARY) 146 | fills = [] 147 | result = line_simplify # this should be line_simplify numpu array 148 | line = line_artist_fullsize.copy() # change to a shorter name 149 | 150 | # may be resize the original line is not a good idea 151 | if line.shape[:2] != line_simplify.shape[:2]: 152 | line = cv2.resize(line, (line_simplify.shape[1],line_simplify.shape[0]), 153 | interpolation = cv2.INTER_AREA) 154 | _, line_artist = cv2.threshold(line, 200, 255, cv2.THRESH_BINARY) 155 | assert len(radius_set) == len(percentiles) 156 | 157 | # trapped ball fillilng 158 | for i in range(len(radius_set)): 159 | fill = trapped_ball_fill_multi(result, radius_set[i], percentile=percentiles[i]) 160 | fills += fill 161 | result = mark_fill(result, fill) 162 | if visualize_steps: 163 | cv2.imwrite("%d.r%d_per%.2f.png"%(i+1, radius_set[i], percentiles[i]), 164 | show_fill_map(build_fill_map(result, fills))) 165 | 166 | # fill up remaining regions if there still have 167 | fill = flood_fill_multi(result) 168 | fills += fill 169 | 170 | # convert fill mask to fill map 171 | fillmap_neural = build_fill_map(result, fills) 172 | if visualize_steps: 173 | i+=1 174 | cv2.imwrite("%d.final_fills.png"%i, show_fill_map(fillmap_neural)) 175 | 176 | # final refine, remove tiny regions in the fill map 177 | fillmap_neural = merge_fill(fillmap_neural) 178 | if visualize_steps: 179 | i+=1 180 | cv2.imwrite("%d.merged.png"%i, show_fill_map(fillmap_neural)) 181 | 182 | # remove the line art region 183 | fillmap_neural = thinning(fillmap_neural) 184 | if visualize_steps: 185 | i+=1 186 | cv2.imwrite("%d.fills_final.png"%i, show_fill_map(fillmap_neural)) 187 | 188 | # upscale neural fill map back to original size 189 | fillmap_neural_fullsize = cv2.resize(fillmap_neural.astype(np.uint8), 190 | (line_artist_fullsize.shape[1], line_artist_fullsize.shape[0]), 191 | interpolation = cv2.INTER_NEAREST) 192 | fillmap_neural_fullsize = fillmap_neural_fullsize.astype(np.int32) 193 | 194 | if preview: 195 | fill_neural_fullsize = show_fill_map(fillmap_neural_fullsize) 196 | fill_neural_fullsize[line_artist_fullsize < 125] = 0 197 | return Image.fromarray(fill_neural_fullsize.astype(np.uint8)) 198 | 199 | ## bleeding removal 200 | # prepare nerual map and FF map with full size 201 | fill_neural = show_fill_map(fillmap_neural) 202 | fill_neural_line = fill_neural.copy() 203 | fill_neural_line[line_simplify < 200] = 0 204 | fillmap_artist_fullsize = np.ones(fillmap_neural_fullsize.shape, dtype=np.uint8) * 255 205 | fillmap_artist_fullsize[line_artist_fullsize < 125] = 0 206 | _, fillmap_artist_fullsize_c = cv2.connectedComponents(fillmap_artist_fullsize, connectivity=8) 207 | 208 | print("Log:\tcompute cartesian product") 209 | fillmap_neural_fullsize_c = fillmap_neural_fullsize.copy() 210 | 211 | fillmap_neural_fullsize[line_artist_fullsize < 125] = 0 212 | fillmap_neural_fullsize = verify_region(fillmap_neural_fullsize) 213 | 214 | fillmap_artist_fullsize = fillmap_cartesian_product(fillmap_artist_fullsize_c, fillmap_neural_fullsize) 215 | fillmap_artist_fullsize[line_artist_fullsize < 125] = 0 216 | 217 | # re-order both fillmaps 218 | fillmap_artist_fullsize = verify_region(fillmap_artist_fullsize, True) 219 | # fillmap_neural_fullsize = verify_region(fillmap_neural_fullsize, True) 220 | 221 | fillmap_neural_fullsize = bleeding_removal_yotam(fillmap_neural_fullsize_c, fillmap_artist_fullsize, th=0.0002) 222 | fillmap_neural_fullsize[line_artist_fullsize < 125] = 0 223 | fillmap_neural_fullsize = verify_region(fillmap_neural_fullsize, True) 224 | 225 | # convert final result to graph 226 | # we have adjacency matrix, we have fillmap, do we really need another graph for it? 227 | fillmap_artist_fullsize_c = thinning(fillmap_artist_fullsize_c) 228 | fillmap_neural_fullsize = thinning(fillmap_neural_fullsize) 229 | 230 | fill_artist_fullsize = show_fill_map(fillmap_artist_fullsize_c) 231 | fill_neural_fullsize = show_fill_map(fillmap_neural_fullsize) 232 | # fill_neural_fullsize[line_artist_fullsize < 125] = 0 233 | 234 | # if output_path is not None: 235 | 236 | # print("Log:\tsave final fill at %s"%os.path.join(output_path, str(name)+"_fill.png")) 237 | # cv2.imwrite(os.path.join(output_path, str(name)+"_fill.png"), fill_neural_fullsize) 238 | 239 | # print("Log:\tsave neural fill at %s"%os.path.join(output_path, str(name)+"_neural.png")) 240 | # cv2.imwrite(os.path.join(output_path, str(name)+"_neural.png"), fill_neural) 241 | 242 | # print("Log:\tsave fine fill at %s"%os.path.join(output_path, str(name)+"_fine.png")) 243 | # cv2.imwrite(os.path.join(output_path, str(name)+"_fine.png"), 244 | # show_fill_map(fillmap_artist_fullsize_c)) 245 | 246 | print("Log:\tdone") 247 | if return_numpy: 248 | return fill_neural, fill_neural_line, fill_artist_fullsize, fill_neural_fullsize 249 | else: 250 | return fillmap_neural_fullsize, fillmap_artist_fullsize_c,\ 251 | fill_neural_fullsize, fill_neural, fill_artist_fullsize 252 | 253 | def fillmap_cartesian_product(fill1, fill2): 254 | ''' 255 | Given: 256 | fill1, fillmap 1 257 | fill2, fillmap 2 258 | Return: 259 | A new fillmap based on its cartesian_product 260 | ''' 261 | assert fill1.shape == fill2.shape 262 | 263 | if len(fill1.shape)==2: 264 | fill1 = np.expand_dims(fill1, axis=-1) 265 | 266 | if len(fill2.shape)==2: 267 | fill2 = np.expand_dims(fill2, axis=-1) 268 | 269 | # cat along channel 270 | fill_c = np.concatenate((fill1, fill2), axis=-1) 271 | 272 | # regnerate all region labels 273 | labels, inv = np.unique(fill_c.reshape(-1, 2), return_inverse=True, axis=0) 274 | labels = tuple(map(tuple, labels)) # convert array to tuple 275 | 276 | # assign a number lable to each cartesian product tuple 277 | l_to_r = {} 278 | for i in range(len(labels)): 279 | l_to_r[labels[i]] = i+1 280 | 281 | # assign new labels back to fillmap 282 | # https://stackoverflow.com/questions/16992713/translate-every-element-in-numpy-array-according-to-key 283 | fill_c = np.array(list(map(l_to_r.get, labels)))[inv] 284 | fill_c = fill_c.reshape(fill1.shape[0:2]) 285 | 286 | return fill_c 287 | 288 | 289 | # verify if there is no isolate sub-region in each region, if yes, split it and assign a new region id 290 | # Yotam: Can this function be replaced with a single call to cv2.connectedComponents()? 291 | # Chuan: I think no, to find the bleeding regions on the bounderay, iteratively flood fill each region is necessary 292 | def verify_region(fillmap, reorder_only=False): 293 | fillmap = fillmap.copy().astype(np.int32) 294 | labels = np.unique(fillmap) 295 | h, w = fillmap.shape 296 | # split region 297 | # is this really necessary? 298 | # yes, without this snippet, the result will be bad at line boundary 299 | # intuitively, this is like an "alingment" of smaller neural fill map to the large original line art 300 | # is it possible to crop the image before connectedComponents filling? 301 | next_label = labels.max() + 1 302 | if reorder_only == False: 303 | print("Log:\tsplit isolate regions in fillmap") 304 | for r in tqdm(labels): 305 | if r == 0: continue 306 | # inital input fill map 307 | region = np.ones(fillmap.shape, dtype=np.uint8) 308 | region[fillmap != r] = 0 309 | ''' 310 | seems this get the speed even slower, sad 311 | need to find a better way 312 | ''' 313 | # # try to split region 314 | # def find_bounding_box(region): 315 | # # find the pixel coordination of this region 316 | # points = np.array(np.where(region == 1)).T 317 | # t = points[:,0].min() # top 318 | # l = points[:,1].min() # left 319 | # b = points[:,0].max() # bottom 320 | # r = points[:,1].max() # right 321 | # return t, l, b, r 322 | # t, l, b, r = find_bounding_box(region) 323 | # region_cropped = region[t:b+1, l:r+1] 324 | # # fill_map_corpped = fill_map[t:b+1, l:r+1] 325 | 326 | _, region_verify = cv2.connectedComponents(region, connectivity=8) 327 | 328 | ''' 329 | seems this get the speed even slower, sad 330 | ''' 331 | # padding 0 back to the region 332 | # region_padded = cv2.copyMakeBorder(region_verify, t, h-b-1, l, w-r-1, cv2.BORDER_CONSTANT, 0) 333 | # assert region_padded.shape == fillmap.shape 334 | # region_verify = region_padded 335 | 336 | 337 | # split region if necessary 338 | label_verify = np.unique(region_verify) 339 | if len(label_verify) > 2: # skip 0 and the first region 340 | for j in range(2, len(label_verify)): 341 | fillmap[region_verify == label_verify[j]] = next_label 342 | next_label += 1 343 | 344 | # re-order regions 345 | assert np.unique(fillmap).max() == next_label - 1 346 | old_to_new = [0] * next_label 347 | idx = 1 348 | l = len(old_to_new) 349 | labels = np.unique(fillmap) 350 | for i in range(l): 351 | if i in labels and i != 0: 352 | old_to_new[i] = idx 353 | idx += 1 354 | else: 355 | old_to_new[i] = 0 356 | old_to_new = np.array(old_to_new) 357 | fillmap_out = old_to_new[fillmap] 358 | 359 | # assert np.unique(fillmap_out).max()+1 == len(np.unique(fillmap_out)) 360 | return fillmap_out 361 | 362 | def update_adj_matrix(A, source, target): 363 | 364 | # update A, region s and max is not neigbor any more 365 | # assert A[source, target] == 1 366 | A[source, target] = 0 367 | 368 | # assert A[target, source] == 1 369 | A[target, source] = 0 370 | 371 | # neighbors of s should become neighbor of max 372 | s_neighbors_x = np.where(A[source,:] == 1) 373 | s_neighbors_y = np.where(A[:,source] == 1) 374 | A[source, s_neighbors_x] = 0 375 | A[s_neighbors_y, source] = 0 376 | 377 | # neighbor of neighbors of s should use max instead of s 378 | A[s_neighbors_x, target] = 1 379 | A[target, s_neighbors_y] = 1 380 | 381 | return A 382 | 383 | def merge_to_ref(fill_map_ref, fill_map_source, r_idx, result): 384 | 385 | # this could be imporved as well 386 | # r_idx is the region labels 387 | F = {} #mapping of large region to ref region 388 | for i in range(len(r_idx)): 389 | r = r_idx[i] 390 | 391 | if r == 0: continue 392 | label_mask = fill_map_source == r 393 | idx, count = np.unique(fill_map_ref[label_mask], return_counts=True) 394 | most_common = idx[np.argmax(count)] 395 | F[r] = most_common 396 | 397 | for r in r_idx: 398 | if r == 0: continue 399 | label_mask = fill_map_source == r 400 | result[label_mask] = F[r] 401 | 402 | return result 403 | 404 | def merge_small_fast(fill_map_ref, fill_map_source, th): 405 | ''' 406 | OK let's understand the improved version 407 | 408 | ''' 409 | 410 | fill_map_source = fill_map_source.copy() 411 | fill_map_ref = fill_map_ref.copy() 412 | 413 | num_regions = len(np.unique(fill_map_source)) 414 | 415 | # the definition of long int is different on windows and linux 416 | try: 417 | A = adjacency_matrix.adjacency_matrix(fill_map_source.astype(np.int32), num_regions) 418 | except: 419 | A = adjacency_matrix.adjacency_matrix(fill_map_source.astype(np.int64), num_regions) 420 | 421 | r_idx_source, r_count_source = np.unique(fill_map_source, return_counts=True) 422 | 423 | 424 | 425 | ## Labels should be contiguous. 426 | assert len(r_idx_source) == max(r_idx_source)+1 427 | ## A should have the same dimensions as number of labels. 428 | assert A.shape[0] == A.shape[1] 429 | assert A.shape[0] == len( r_idx_source ) 430 | ARTIST_LINE_LABEL = 0 431 | def get_small_region(r_idx_source, r_count_source, th): 432 | return set( 433 | # 1. size less that threshold 434 | r_idx_source[ r_count_source < th ] 435 | ) | set( 436 | # 2. not the neighbor of artist line 437 | ## Which of `r_idx_source` have a 0 in the adjacency position for `ARTIST_LINE_LABEL`? 438 | r_idx_source[ A[r_idx_source,ARTIST_LINE_LABEL] == 0 ] 439 | ) 440 | 441 | r_idx_source_small = get_small_region(r_idx_source, r_count_source, th) 442 | 443 | stop = False 444 | 445 | while len(r_idx_source_small) > 0 and stop == False: 446 | 447 | stop = True 448 | 449 | for s in r_idx_source_small: 450 | if s == ARTIST_LINE_LABEL: continue 451 | 452 | neighbors = np.where(A[s,:] == 1)[0] 453 | 454 | # remove line regions 455 | neighbors = neighbors[neighbors != ARTIST_LINE_LABEL] 456 | 457 | # skip if this region doesn't have neighbors 458 | if len(neighbors) == 0: continue 459 | 460 | # find region size 461 | # sizes = np.array([get_size(r_idx_source, r_count_source, n) for n in neighbors]).flatten() 462 | sizes = r_count_source[ neighbors ] 463 | 464 | # merge regions if necessary 465 | largest_index = np.argmax(sizes) 466 | if neighbors[largest_index] == ARTIST_LINE_LABEL and len(neighbors) > 1: 467 | # if its largest neighbor is line skip it 468 | del neighbors[ largest_index ] 469 | del sizes[ largest_index ] 470 | 471 | if len(neighbors) >= 1: 472 | label_mask = fill_map_source == s 473 | max_neighbor = neighbors[np.argmax(sizes)] 474 | A = update_adj_matrix(A, s, max_neighbor) 475 | fill_map_source[label_mask] = max_neighbor 476 | stop = False 477 | else: 478 | continue 479 | 480 | r_idx_source, r_count_source = np.unique(fill_map_source, return_counts=True) 481 | r_idx_source_small = get_small_region(r_idx_source, r_count_source, th) 482 | 483 | ''' 484 | for debug 485 | after the first for loop, these 3 variable should have exactly same value compare to the merge_small_fast2's result 486 | ''' 487 | # return r_idx_source_small, r_idx_source, r_count_source 488 | # return fill_map_source 489 | 490 | return fill_map_source 491 | 492 | def merge_small_fast2(fill_map_ref, fill_map_source, th): 493 | ''' 494 | 495 | ''' 496 | 497 | fill_map_source = fill_map_source.copy() 498 | fill_map_ref = fill_map_ref.copy() 499 | 500 | num_regions = len(np.unique(fill_map_source)) 501 | 502 | # the definition of long int is different on windows and linux 503 | try: 504 | A = adjacency_matrix.adjacency_matrix(fill_map_source.astype(np.int32), num_regions) 505 | except: 506 | A = adjacency_matrix.adjacency_matrix(fill_map_source.astype(np.int64), num_regions) 507 | 508 | r_idx_source, r_count_source = np.unique(fill_map_source, return_counts=True) 509 | ## Convert them to masked arrays 510 | # why? 511 | r_idx_source = np.ma.masked_array( r_idx_source ) 512 | r_count_source = np.ma.masked_array( r_count_source ) 513 | 514 | 515 | ## Labels should be contiguous. 516 | assert len(r_idx_source) == max(r_idx_source)+1 517 | ## A should have the same dimensions as number of labels. 518 | assert A.shape[0] == A.shape[1] 519 | assert A.shape[0] == len( r_idx_source ) 520 | ARTIST_LINE_LABEL = 0 521 | def get_small_region(r_idx_source, r_count_source, th): 522 | return set( 523 | # 1. size less that threshold 524 | r_idx_source[ r_count_source < th ].compressed() 525 | ) | set( 526 | # 2. not the neighbor of artist line 527 | ## Which of `r_idx_source` have a 0 in the adjacency position for `ARTIST_LINE_LABEL`? 528 | r_idx_source[ A[r_idx_source,ARTIST_LINE_LABEL] == 0 ].compressed() 529 | ) 530 | 531 | r_idx_source_small = get_small_region(r_idx_source, r_count_source, th) 532 | 533 | # since the region labels are always continous numbers, so it is safe to create a remap array like this 534 | # in other word, r_idx_source.max() + 1 == len(r_idx_source) 535 | remap = np.arange(len(r_idx_source)) 536 | 537 | stop = False 538 | 539 | while len(r_idx_source_small) > 0 and stop == False: 540 | 541 | stop = True 542 | 543 | for s in r_idx_source_small: 544 | if s == ARTIST_LINE_LABEL: continue 545 | 546 | neighbors = np.where(A[s,:] == 1)[0] 547 | 548 | # remove line regions 549 | neighbors = neighbors[neighbors != ARTIST_LINE_LABEL] 550 | 551 | # skip if this region doesn't have neighbors 552 | if len(neighbors) == 0: continue 553 | 554 | # find region size 555 | # sizes = np.array([get_size(r_idx_source, r_count_source, n) for n in neighbors]).flatten() 556 | sizes = r_count_source[ neighbors ] 557 | 558 | # merge regions if necessary 559 | largest_index = np.argmax(sizes) 560 | if neighbors[largest_index] == ARTIST_LINE_LABEL and len(neighbors) > 1: 561 | # if its largest neighbor is line skip it 562 | del neighbors[ largest_index ] 563 | del sizes[ largest_index ] 564 | 565 | if len(neighbors) >= 1: 566 | max_neighbor = neighbors[np.argmax(sizes)] 567 | A = update_adj_matrix(A, s, max_neighbor) 568 | # record the operation of merge 569 | remap[s] = max_neighbor 570 | # update the region size 571 | r_count_source[max_neighbor] = r_count_source[max_neighbor] + r_count_source[s] 572 | # remove the merged region, however, we should keep the index unchanged 573 | r_count_source[s] = np.ma.masked 574 | r_idx_source[s] = np.ma.masked 575 | stop = False 576 | else: 577 | continue 578 | 579 | r_idx_source_small = get_small_region(r_idx_source, r_count_source, th) 580 | 581 | ''' 582 | for debug 583 | after the first for loop, these 3 variable should have exactly same value compare to the merge_small_fast2's result 584 | ''' 585 | # return r_idx_source_small, r_idx_source, r_count_source 586 | # adjacency_matrix.remap_labels( fill_map_source, remap ) 587 | # return fill_map_source 588 | 589 | fill_map_source = adjacency_matrix.remap_labels( fill_map_source, remap ) 590 | 591 | return fill_map_source 592 | 593 | def merge_small(fill_map_ref, fill_map_source, th): 594 | ''' 595 | Given: 596 | fill_map_ref: 2D numpy array as neural fill map on neural line 597 | fill_map_source: Connected commponent fill map on artist line 598 | th: A threshold to identify small regions 599 | Returns: 600 | 601 | ''' 602 | 603 | # result_fast1 = merge_small_fast(fill_map_ref, fill_map_source, th) 604 | # result_fast2 = merge_small_fast2(fill_map_ref, fill_map_source, th) 605 | # assert ( result_fast1 == result_fast2 ).all() 606 | # r1, r2, r3 = merge_small_fast(fill_map_ref, fill_map_source, th) 607 | # s1, s2, s3 = merge_small_fast2(fill_map_ref, fill_map_source, th) 608 | # return result_fast1 609 | 610 | # make a copy of input, we don't want to affect the array outside of this function 611 | fill_map_source = fill_map_source.copy() 612 | fill_map_ref = fill_map_ref.copy() 613 | 614 | num_regions = len(np.unique(fill_map_source)) 615 | 616 | # the definition of long int is different on windows and linux 617 | try: 618 | A = adjacency_matrix.adjacency_matrix(fill_map_source.astype(np.int32), num_regions) 619 | except: 620 | A = adjacency_matrix.adjacency_matrix(fill_map_source.astype(np.int64), num_regions) 621 | 622 | # find the label and size of each region 623 | r_idx_source, r_count_source = np.unique(fill_map_source, return_counts=True) 624 | 625 | def get_small_region(r_idx_source, r_count_source, th): 626 | ''' 627 | Find the 'small' region that need to be merged to its neighbor 628 | ''' 629 | r_idx_source_small = [] 630 | for i in range(len(r_idx_source)): 631 | # there are two kinds of region should be identified as small region: 632 | # 1. size less the threshold 633 | if r_count_source[i] < th: 634 | r_idx_source_small.append(r_idx_source[i]) 635 | # 2. not the neighbor of artist line, this type of region is not adjecent to any stroke lines, 636 | # so it need to be merged to a neighbor which touch the strokes no matter how big it is 637 | n = np.where(A[r_idx_source[i],:] == 1)[0] 638 | if 0 not in n: 639 | r_idx_source_small.append(r_idx_source[i]) 640 | return r_idx_source_small 641 | 642 | # find the small regions that need to be merged 643 | r_idx_source_small = get_small_region(r_idx_source, r_count_source, th) 644 | 645 | # early stop sign 646 | stop = False 647 | 648 | # main loop to iteratively merge all small regions into its largest neighbor 649 | while len(r_idx_source_small) > 0 and stop == False: 650 | 651 | stop = True 652 | # each time process small regions in the list sequentially 653 | for s in r_idx_source_small: 654 | if s == 0: continue 655 | 656 | # get the pixel mask of region s 657 | label_mask = fill_map_source == s 658 | 659 | # find all neighbors of region s 660 | neighbors = np.where(A[s,:] == 1)[0] 661 | 662 | # remove line regions 663 | neighbors = neighbors[neighbors != 0] 664 | 665 | # skip if this region doesn't have neighbors 666 | if len(neighbors) == 0: continue 667 | 668 | # find region size of s's neighbors 669 | sizes = np.array([get_size(r_idx_source, r_count_source, n) for n in neighbors]).flatten() 670 | 671 | # merge regions s to its largest neighbor 672 | if neighbors[np.argsort(sizes)[-1]] == 0 and len(neighbors) > 1: 673 | # if its largest neighbor is line skip it 674 | max_neighbor = neighbors[np.argsort(sizes)[-2]] 675 | A = update_adj_matrix(A, s, max_neighbor) 676 | fill_map_source[label_mask] = max_neighbor 677 | stop = False 678 | elif len(neighbors) >= 1: 679 | # esle return its largest nerighbor 680 | max_neighbor = neighbors[np.argsort(sizes)[-1]] 681 | A = update_adj_matrix(A, s, max_neighbor) 682 | fill_map_source[label_mask] = max_neighbor 683 | stop = False 684 | else: 685 | continue 686 | 687 | # re-search the small regions for next loop 688 | r_idx_source, r_count_source = np.unique(fill_map_source, return_counts=True) 689 | r_idx_source_small = get_small_region(r_idx_source, r_count_source, th) 690 | 691 | # assert ( fill_map_source == result_fast2 ).all() 692 | return fill_map_source 693 | 694 | def get_size(idx, count, r): 695 | assert r in idx 696 | assert r != 0 697 | 698 | return count[np.where(idx==r)] 699 | 700 | def bleeding_removal_yotam(fill_map_ref, fill_map_source, th): 701 | 702 | fill_map_ref = fill_map_ref.copy() # connected compoenent fill map 703 | fill_map_source = fill_map_source.copy() # the cartesian product of connected component and neural fill map 704 | 705 | w, h = fill_map_ref.shape 706 | th = int(w * h * th) 707 | 708 | result = np.zeros(fill_map_ref.shape, dtype=np.int32) 709 | # 1. merge small regions which has neighbors 710 | # the int64 means long on linux but long long on windows, sad 711 | print("Log:\tmerge small regions") 712 | fill_map_source = merge_small_fast2(fill_map_ref, fill_map_source, th) 713 | 714 | # 2. merge large regions 715 | # now the fill_map_source is clean, no bleeding. but it still contains many "broken" pieces which 716 | # should belong to the same semantical regions. So, we can merge these "large but still broken" region 717 | # together by the neural fill map. 718 | print("Log:\tmerge large regions") 719 | r_idx_source= np.unique(fill_map_source) 720 | result = merge_to_ref(fill_map_ref, fill_map_source, r_idx_source, result) 721 | 722 | return result 723 | 724 | def sweep_line_merge(fillmap_neural_fullsize, fillmap_artist_fullsize, add_th, keep_th): 725 | 726 | assert fillmap_neural_fullsize.shape == fillmap_artist_fullsize.shape 727 | 728 | result = np.zeros(fillmap_neural_fullsize.shape) 729 | 730 | def to_sweep_list(fillmap): 731 | sweep_dict = {} 732 | sweep_ml = [] # most left position, which is also the sweep line's anchor 733 | sweep_list, sweep_count = np.unique(fillmap, return_counts=True) 734 | for i in range(len(sweep_list)): 735 | idx = sweep_list[i] 736 | if idx == 0: continue 737 | # 1. point sets 2. if have been merged 3. region area 738 | points = np.where(fillmap == idx) 739 | sweep_dict[idx] = [points, False, sweep_count[i]] 740 | sweep_ml.append(points[0].min()) 741 | 742 | sweep_list = sweep_list[np.argsort(np.array(sweep_ml))] 743 | 744 | return sweep_list, sweep_dict 745 | 746 | # turn fill map to sweep list 747 | r_idx_ref, r_dict_ref = to_sweep_list(fillmap_neural_fullsize) 748 | r_idx_source, r_dict_artist = to_sweep_list(fillmap_artist_fullsize) 749 | 750 | skip = [] 751 | for rn in tqdm(r_idx_ref): 752 | 753 | if rn == 0: continue 754 | 755 | r1 = np.zeros(fillmap_neural_fullsize.shape) 756 | r1[fillmap_neural_fullsize == rn] = 1 757 | 758 | for ra in r_idx_source: 759 | if ra == 0: continue 760 | 761 | # skip if this region has been merged 762 | if r_dict_artist[ra][1]: continue 763 | 764 | # compute iou of this two regions 765 | r2 = np.zeros(r1.shape) 766 | r2[fillmap_artist_fullsize == ra] = 1 767 | iou = (r1 * r2).sum() 768 | 769 | # compute the precentage of iou/region area 770 | c1 = iou/r_dict_ref[rn][2] 771 | c2 = iou/r_dict_artist[ra][2] 772 | 773 | # merge 774 | # r1 and r2 are quite similar, then use r2 instead of r1 775 | if c1 > 0.9 and c2 > 0.9: 776 | result[r_dict_artist[ra][0]] = rn 777 | r_dict_artist[ra][1] = True 778 | continue 779 | 780 | # # r1 is almost contained by r2, the keep r1 781 | # elif c1 > 0.9 and c2 < 0.6: 782 | # result[r_dict_ref[rn][0]] = rn 783 | # # todo: 784 | # # then we need refinement! 785 | 786 | # r2 is almost covered by r1, then merge r2 into r1 787 | elif c1 < 0.6 and c2 > 0.9: 788 | result[r_dict_artist[ra][0]] = rn 789 | r_dict_artist[ra][1] = True 790 | 791 | # r1 and r2 are not close, do nothing then 792 | else: 793 | # we probably could record the c1 and c2, see what the parameter looks like 794 | if c1 != 0 and c2 != 0: 795 | skip.append((c1,c2)) 796 | 797 | return result.astype(np.uint8), skip 798 | 799 | 800 | def show_region(region_bit): 801 | plt.imshow(show_fill_map(region_bit)) 802 | plt.show() 803 | 804 | def get_figsize(img_num, row_num, img_size, dpi = 100): 805 | # inches = resolution / dpi 806 | # assume all image have the same resolution 807 | width = row_num * (img_size[0] + 200) 808 | height = round(img_num / row_num + 0.5) * (img_size[1] + 300) 809 | return width / dpi, height / dpi 810 | 811 | def visualize(test_folder, row_num = 3): 812 | ''' 813 | 814 | ''' 815 | 816 | img_list = [] 817 | for img in os.listdir(test_folder): 818 | img_list.append(img) 819 | 820 | img = Image.open(os.path.join(test_folder, img_list[0])) 821 | 822 | 823 | # visualize collected centers 824 | plt.rcParams["figure.figsize"] = get_figsize(len(img_list), row_num, img.size) 825 | 826 | 827 | i = 0 828 | for i in range((len(img_list)//row_num + 1 if len(img_list)%row_num != 0 else len(img_list)//row_num)): 829 | for j in range(row_num): 830 | plt.subplot(len(img_list)//row_num + 1 , row_num , i*row_num+j+1) 831 | if i*row_num+j < len(img_list): 832 | # the whole function contains two steps 833 | # 1. get region map 834 | img = region_get_map(os.path.join(test_folder, img_list[i*row_num+j])) 835 | # 2. fill the region map 836 | plt.imshow(show_fill_map(img)) 837 | plt.title(img_list[i]) 838 | return plt 839 | 840 | def radius_percentile_explor(radius_set, method_set, input, output): 841 | for radius in radius_set: 842 | for method in method_set: 843 | print("Log:\ttrying radius %d with percentile %s"%(radius, method)) 844 | 845 | # get file name 846 | _, file = os.path.split(input) 847 | name, _ = os.path.splitext(file) 848 | 849 | # open image 850 | img_org = cv2.imread(input, cv2.IMREAD_COLOR) 851 | img = cv2.cvtColor(img_org, cv2.COLOR_BGR2GRAY) 852 | 853 | ret, binary = cv2.threshold(img, 220, 255, cv2.THRESH_BINARY) 854 | fills = [] 855 | result = binary # this should be binary numpu array 856 | 857 | # save result 858 | fill = trapped_ball_fill_multi(result, radius, percentile=method) 859 | 860 | outpath = os.path.join(output, name+"_%d"%radius+"_percentail %s.png"%str(method)) 861 | out_map = show_fill_map(build_fill_map(result, fill)) 862 | out_map[np.where(binary == 0)]=0 863 | cv2.imwrite(outpath, out_map) 864 | 865 | 866 | # outpath = os.path.join(output, name+"_%d"%radius+"_percentail%s_final.png"%str(method)) 867 | # out_map = show_fill_map(thinning(build_fill_map(result, fill))) 868 | # cv2.imwrite(outpath, out_map) 869 | def radius_percentile_explor_repeat(radius_set, input, output, percentile_set = [100], repeat = 20): 870 | for r in radius_set: 871 | for p in percentile_set: 872 | _, file = os.path.split(input) 873 | name, _ = os.path.splitext(file) 874 | 875 | # open image 876 | img_org = cv2.imread(input, cv2.IMREAD_COLOR) 877 | img = cv2.cvtColor(img_org, cv2.COLOR_BGR2GRAY) 878 | 879 | ret, binary = cv2.threshold(img, 220, 255, cv2.THRESH_BINARY) 880 | fills = [] 881 | result = binary # this should be binary numpu array 882 | 883 | for i in range(1, repeat+1): 884 | print("Log:\ttrying radius %d with percentile %s, No.%d"%(r, str(p), i)) 885 | 886 | # get file name 887 | 888 | 889 | # save result 890 | fill = trapped_ball_fill_multi(result, r, percentile=p) 891 | fills+=fill 892 | 893 | outpath = os.path.join(output, name+"_%d"%r+"_percentail %s_No %d.png"%(str(p), i)) 894 | out_map = show_fill_map(build_fill_map(result, fills)) 895 | out_map[np.where(binary == 0)]=0 896 | 897 | cv2.imwrite(outpath, out_map) 898 | result = mark_fill(result, fill) 899 | 900 | def trappedball_2pass_exp(path_line, path_line_sim, save_file=False): 901 | 902 | # open image 903 | line = cv2.imread(path_line, cv2.IMREAD_COLOR) 904 | line = cv2.cvtColor(line, cv2.COLOR_BGR2GRAY) 905 | 906 | line_sim = cv2.imread(path_line_sim, cv2.IMREAD_COLOR) 907 | line_sim = cv2.cvtColor(line_sim, cv2.COLOR_BGR2GRAY) 908 | 909 | _, line = cv2.threshold(line, 220, 255, cv2.THRESH_BINARY) 910 | _, binary = cv2.threshold(line_sim, 220, 255, cv2.THRESH_BINARY) 911 | 912 | result = binary 913 | fills = [] 914 | 915 | # filling 916 | fill = trapped_ball_fill_multi(result, 1, percentile=0) 917 | fills += fill 918 | result = mark_fill(result, fill) 919 | 920 | # fill rest region 921 | fill = flood_fill_multi(result) 922 | fills += fill 923 | 924 | # merge 925 | fillmap = build_fill_map(result, fills) 926 | fillmap = merge_fill(fillmap) 927 | 928 | # thin 929 | fillmap = thinning(fillmap) 930 | 931 | # let's do 2nd pass merge! 932 | fillmap_full = cv2.resize(fillmap.astype(np.uint8), 933 | (line.shape[1], line.shape[0]), 934 | interpolation = cv2.INTER_NEAREST) 935 | 936 | # construct a full mask 937 | line_sim_scaled = cv2.resize(line_sim.astype(np.uint8), 938 | (line.shape[1], line.shape[0]), 939 | interpolation = cv2.INTER_NEAREST) 940 | # line_full = cv2.bitwise_and(line, line_sim_scaled) 941 | line_full = line 942 | 943 | # fillmap_full[np.where(line_full<220)] = 0 944 | fillmap_full[line_full<220] = 0 945 | fillmap_full = merger_fill_2nd(fillmap_full)[0] 946 | # fillmap_full = thinning(fillmap_full) 947 | 948 | ''' 949 | save results 950 | ''' 951 | if save_file: 952 | # show fill map 953 | fill_scaled = show_fill_map(fillmap) 954 | fill_scaled_v1 = show_fill_map(fillmap_full) 955 | fill_full = cv2.resize(fill_scaled.astype(np.uint8), 956 | (line.shape[1], line.shape[0]), 957 | interpolation = cv2.INTER_NEAREST) 958 | line_scaled = cv2.resize(line.astype(np.uint8), 959 | (line_sim.shape[1], line_sim.shape[0]), 960 | interpolation = cv2.INTER_NEAREST) 961 | 962 | # overlay strokes 963 | fill_scaled[np.where(line_scaled<220)] = 0 964 | fill_scaled_v1[np.where(line<220)] = 0 965 | fill_full[np.where(line<220)]=0 966 | 967 | # save result 968 | cv2.imwrite("fill_sacled.png", fill_scaled) 969 | cv2.imwrite("fill_scaled_v1.png", fill_scaled_v1) 970 | cv2.imwrite("fill_full.png", fill_full) 971 | 972 | return fillmap_full 973 | if __name__ == '__main__': 974 | 975 | __spec__ = None 976 | 977 | parser = argparse.ArgumentParser() 978 | 979 | parser.add_argument("--single", action = 'store_true', help="process and save a single image to output") 980 | parser.add_argument("--show-intermediate", action = 'store_true', help="save intermediate results") 981 | parser.add_argument("--visualize", action = 'store_true') 982 | parser.add_argument("--exp1", action = 'store_true', help="experiment of exploring the parameter") 983 | parser.add_argument("--exp3", action = 'store_true', help="experiment of exploring the parameter") 984 | parser.add_argument("--exp4", action = 'store_true', help="experiment of exploring the parameter") 985 | parser.add_argument("--exp5", action = 'store_true', help="experiment of exploring the parameter") 986 | parser.add_argument('--input', type=str, default="./flatting/line_white_background/image0001_line.png", 987 | help = "path to input image, support png file only") 988 | parser.add_argument('--output', type=str, default="./exp1", 989 | help = "the path to result saving folder") 990 | 991 | args = parser.parse_args() 992 | 993 | if args.single: 994 | bit_map = region_get_map(args.input, args.output) 995 | if args.visualize: 996 | show_region(bit_map) 997 | elif args.exp1: 998 | # define the range of parameters 999 | radius_set1 = list(range(7, 15)) 1000 | method_set = list(range(0, 101, 5)) + ["mean"] 1001 | radius_percentile_explor(radius_set1, method_set, args.input, args.output) 1002 | elif args.exp3: 1003 | radius_set2 = list(range(1, 15)) 1004 | radius_percentile_explor_repeat(radius_set2, args.input, "./exp3") 1005 | elif args.exp4: 1006 | # let's test 2 pass merge 1007 | line = "./examples/01.png" 1008 | line_sim = "./examples/01_sim.png" 1009 | # trappedball_2pass_exp(line, line_sim) 1010 | region_get_map(line_sim, 1011 | path_to_line_artist=line, 1012 | output_path='./', 1013 | radius_set=[1], 1014 | percentiles=[0], 1015 | visualize_steps=False, 1016 | return_numpy=False) 1017 | elif args.exp5: 1018 | # let's test 2 pass merge 1019 | line = "./examples/tiny.png" 1020 | line_sim = "./examples/tiny_sim.png" 1021 | # trappedball_2pass_exp(line, line_sim) 1022 | region_get_map(line_sim, 1023 | path_to_line_artist=line, 1024 | output_path='./', 1025 | radius_set=[1], 1026 | percentiles=[0], 1027 | visualize_steps=False, 1028 | return_numpy=False) 1029 | else: 1030 | in_path = "./flatting/size_2048/line_detection_croped" 1031 | out_path = "./exp4" 1032 | for img in os.listdir(in_path): 1033 | region_get_map(join(in_path, img), out_path, radius_set=[1], percentiles=[0]) 1034 | -------------------------------------------------------------------------------- /src/flatting/trapped_ball/thinning.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | 5 | def thinning(fillmap, max_iter=100): 6 | """Fill area of line with surrounding fill color. 7 | 8 | # Arguments 9 | fillmap: an image. 10 | max_iter: max iteration number. 11 | 12 | # Returns 13 | an image. 14 | """ 15 | line_id = 0 16 | h, w = fillmap.shape[:2] 17 | result = fillmap.copy() 18 | 19 | for iterNum in range(max_iter): 20 | # Get points of line. if there is not point, stop. 21 | line_points = np.where(result == line_id) 22 | if not len(line_points[0]) > 0: 23 | break 24 | 25 | # Get points between lines and fills. 26 | line_mask = np.full((h, w), 255, np.uint8) 27 | line_mask[line_points] = 0 28 | line_border_mask = cv2.morphologyEx(line_mask, cv2.MORPH_DILATE, 29 | cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)), anchor=(-1, -1), 30 | iterations=1) - line_mask 31 | line_border_points = np.where(line_border_mask == 255) 32 | 33 | result_tmp = result.copy() 34 | # Iterate over points, fill each point with nearest fill's id. 35 | for i, _ in enumerate(line_border_points[0]): 36 | x, y = line_border_points[1][i], line_border_points[0][i] 37 | 38 | if x - 1 > 0 and result[y][x - 1] != line_id: 39 | result_tmp[y][x] = result[y][x - 1] 40 | continue 41 | 42 | if x - 1 > 0 and y - 1 > 0 and result[y - 1][x - 1] != line_id: 43 | result_tmp[y][x] = result[y - 1][x - 1] 44 | continue 45 | 46 | if y - 1 > 0 and result[y - 1][x] != line_id: 47 | result_tmp[y][x] = result[y - 1][x] 48 | continue 49 | 50 | if y - 1 > 0 and x + 1 < w and result[y - 1][x + 1] != line_id: 51 | result_tmp[y][x] = result[y - 1][x + 1] 52 | continue 53 | 54 | if x + 1 < w and result[y][x + 1] != line_id: 55 | result_tmp[y][x] = result[y][x + 1] 56 | continue 57 | 58 | if x + 1 < w and y + 1 < h and result[y + 1][x + 1] != line_id: 59 | result_tmp[y][x] = result[y + 1][x + 1] 60 | continue 61 | 62 | if y + 1 < h and result[y + 1][x] != line_id: 63 | result_tmp[y][x] = result[y + 1][x] 64 | continue 65 | 66 | if y + 1 < h and x - 1 > 0 and result[y + 1][x - 1] != line_id: 67 | result_tmp[y][x] = result[y + 1][x - 1] 68 | continue 69 | 70 | result = result_tmp.copy() 71 | 72 | return result 73 | -------------------------------------------------------------------------------- /src/flatting/trapped_ball/thinning_zhang.py: -------------------------------------------------------------------------------- 1 | """ 2 | =========================== 3 | @Author : Linbo 4 | @Version: 1.0 25/10/2014 5 | This is the implementation of the 6 | Zhang-Suen Thinning Algorithm for skeletonization. 7 | =========================== 8 | """ 9 | 10 | def neighbours(x,y,image): 11 | "Return 8-neighbours of image point P1(x,y), in a clockwise order" 12 | img = image 13 | x_1, y_1, x1, y1 = x-1, y-1, x+1, y+1 14 | return [ img[x_1][y], img[x_1][y1], img[x][y1], img[x1][y1], # P2,P3,P4,P5 15 | img[x1][y], img[x1][y_1], img[x][y_1], img[x_1][y_1] ] # P6,P7,P8,P9 16 | 17 | def transitions(neighbours): 18 | "No. of 0,1 patterns (transitions from 0 to 1) in the ordered sequence" 19 | n = neighbours + neighbours[0:1] # P2, P3, ... , P8, P9, P2 20 | return sum( (n1, n2) == (0, 1) for n1, n2 in zip(n, n[1:]) ) # (P2,P3), (P3,P4), ... , (P8,P9), (P9,P2) 21 | 22 | def zhangSuen(image): 23 | "the Zhang-Suen Thinning Algorithm" 24 | Image_Thinned = image.copy() # deepcopy to protect the original image 25 | changing1 = changing2 = 1 # the points to be removed (set as 0) 26 | while changing1 or changing2: # iterates until no further changes occur in the image 27 | # Step 1 28 | changing1 = [] 29 | rows, columns = Image_Thinned.shape # x for rows, y for columns 30 | for x in range(1, rows - 1): # No. of rows 31 | for y in range(1, columns - 1): # No. of columns 32 | P2,P3,P4,P5,P6,P7,P8,P9 = n = neighbours(x, y, Image_Thinned) 33 | if (Image_Thinned[x][y] == 1 and # Condition 0: Point P1 in the object regions 34 | 2 <= sum(n) <= 6 and # Condition 1: 2<= N(P1) <= 6 35 | transitions(n) == 1 and # Condition 2: S(P1)=1 36 | P2 * P4 * P6 == 0 and # Condition 3 37 | P4 * P6 * P8 == 0): # Condition 4 38 | changing1.append((x,y)) 39 | for x, y in changing1: 40 | Image_Thinned[x][y] = 0 41 | # Step 2 42 | changing2 = [] 43 | for x in range(1, rows - 1): 44 | for y in range(1, columns - 1): 45 | P2,P3,P4,P5,P6,P7,P8,P9 = n = neighbours(x, y, Image_Thinned) 46 | if (Image_Thinned[x][y] == 1 and # Condition 0 47 | 2 <= sum(n) <= 6 and # Condition 1 48 | transitions(n) == 1 and # Condition 2 49 | P2 * P4 * P8 == 0 and # Condition 3 50 | P2 * P6 * P8 == 0): # Condition 4 51 | changing2.append((x,y)) 52 | for x, y in changing2: 53 | Image_Thinned[x][y] = 0 54 | return Image_Thinned 55 | 56 | import matplotlib 57 | import matplotlib.pyplot as plt 58 | import skimage.io as io 59 | 60 | 61 | if __name__ == "__main__": 62 | "load image data" 63 | Img_Original = io.imread( './data/test1.bmp') # Gray image, rgb images need pre-conversion 64 | 65 | "Convert gray images to binary images using Otsu's method" 66 | from skimage import filters 67 | Otsu_Threshold = filters.threshold_otsu(Img_Original) 68 | BW_Original = Img_Original < Otsu_Threshold # must set object region as 1, background region as 0 ! 69 | 70 | "Apply the algorithm on images" 71 | BW_Skeleton = zhangSuen(BW_Original) 72 | # BW_Skeleton = BW_Original 73 | "Display the results" 74 | fig, ax = plt.subplots(1, 2) 75 | ax1, ax2 = ax.ravel() 76 | ax1.imshow(BW_Original, cmap=plt.cm.gray) 77 | ax1.set_title('Original binary image') 78 | ax1.axis('off') 79 | ax2.imshow(BW_Skeleton, cmap=plt.cm.gray) 80 | ax2.set_title('Skeleton of the image') 81 | ax2.axis('off') 82 | plt.show() 83 | -------------------------------------------------------------------------------- /src/flatting/trapped_ball/trappedball_fill.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import pdb 4 | import pickle 5 | import time 6 | # use cython adjacency matrix 7 | try: 8 | from . import adjacency_matrix 9 | ## If it's not already compiled, compile it. 10 | except: 11 | import pyximport 12 | pyximport.install() 13 | from . import adjacency_matrix 14 | # it seemed that multi thread will not help to reduce running time 15 | # https://medium.com/python-experiments/parallelising-in-python-mutithreading-and-mutiprocessing-with-practical-templates-c81d593c1c49 16 | from multiprocessing import Pool 17 | from multiprocessing import freeze_support 18 | from functools import partial 19 | # from skimage.morphology import skeletonize, thin 20 | 21 | 22 | 23 | from tqdm import tqdm 24 | from PIL import Image 25 | 26 | def save_obj(fill_graph, save_path='fill_map.pickle'): 27 | 28 | with open(save_path, 'wb') as f: 29 | pickle.dump(fill_graph, f, protocol=pickle.HIGHEST_PROTOCOL) 30 | 31 | def load_obj(load_path='fill_map.pickle'): 32 | 33 | with open(load_path, 'rb') as f: 34 | fill_graph = pickle.load(f) 35 | 36 | return fill_graph 37 | 38 | def extract_line(fills_result): 39 | 40 | img = cv2.blur(fills_result,(5,5)) 41 | 42 | # analyaze the gradient of flat image 43 | grad = cv2.Laplacian(img,cv2.CV_64F) 44 | grad = abs(grad).sum(axis = -1) 45 | grad_v, grad_c = np.unique(grad, return_counts=True) 46 | 47 | # remove the majority grad, which is 0 48 | assert np.where(grad_v==0) == np.where(grad_c==grad_c.max()) 49 | grad_v = np.delete(grad_v, np.where(grad_v==0)) 50 | grad_c = np.delete(grad_c, np.where(grad_c==grad_c.max())) 51 | print("Log:\tlen of grad_v %d"%len(grad_v)) 52 | grad_c_cum = np.cumsum(grad_c) 53 | 54 | # if grad number is greater than 100, then this probably means the current 55 | # image exists pretty similar colors, then we should apply 56 | # another set of parameter to detect edge 57 | # this could be better if we can find the realtion between them 58 | if len(grad_v) < 100: 59 | min_val = grad_v[np.where(grad_c_cum<=np.percentile(grad_c_cum, 25))[0].max()] 60 | max_val = grad_v[np.where(grad_c_cum<=np.percentile(grad_c_cum, 40))[0].max()] 61 | else: 62 | min_val = grad_v[np.where(grad_c_cum<=np.percentile(grad_c_cum, 1))[0].max()] 63 | max_val = grad_v[np.where(grad_c_cum<=np.percentile(grad_c_cum, 10))[0].max()] 64 | 65 | edges = cv2.Canny(img, min_val, max_val, L2gradient=True) 66 | return 255-edges 67 | 68 | def to_masked_line(line_sim, line_artist, rk1=None, rk2=None, ak=None, tn=1): 69 | ''' 70 | Given: 71 | line_sim, simplified line, which is also the neural networks output 72 | line_artist, artist line, the original input 73 | rk, remove kernel, thicken kernel. if the neural network ouput too thin line, use this option 74 | ak, add kernel, thinning kernel. if the neural network output too thick line, use this option 75 | Return: 76 | the masked line for filling 77 | ''' 78 | # 1. generate lines removed unecessary strokes 79 | if rk1 != None: 80 | kernel_remove1 = get_ball_structuring_element(rk1) 81 | # make the simplified line to cover the artist's line 82 | mask_remove = cv2.morphologyEx(line_sim, cv2.MORPH_ERODE, kernel_remove1) 83 | else: 84 | mask_remove = line_sim 85 | 86 | mask_remove = np.logical_and(line_artist==0, mask_remove==0) 87 | 88 | # 2. generate lines that added by line_sim 89 | if ak != None: 90 | kernel_add = get_ball_structuring_element(ak) 91 | # try to make the artist's line cover the simplified line 92 | mask_add = cv2.morphologyEx(line_sim, cv2.MORPH_DILATE, kernel_add) 93 | else: 94 | mask_add = line_sim 95 | 96 | # may be we don't need that skeleton 97 | # mask_add = 255 - skeletonize((255 - mask_add)/255, method='lee') 98 | 99 | # let's try just thin it 100 | mask_add = 255 - thin(255 - mask_add, max_iter=tn).astype(np.uint8)*255 101 | 102 | if rk2 != None: 103 | kernel_remove2 = get_ball_structuring_element(rk2) 104 | line_artist = cv2.morphologyEx(line_artist, cv2.MORPH_ERODE, kernel_remove2) 105 | mask_add = np.logical_and(mask_add==0, np.logical_xor(mask_add==0, line_artist==0)) 106 | 107 | # 3. combine and return the result 108 | mask = np.logical_or(mask_remove, mask_add).astype(np.uint8)*255 109 | 110 | # # 4. connect dot lines if exists 111 | # if connect != None: 112 | # kernel_con = get_ball_structuring_element(1) 113 | # for _ in range(connect): 114 | # mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel_con) 115 | 116 | return 255 - mask 117 | 118 | 119 | def get_ball_structuring_element(radius): 120 | """Get a ball shape structuring element with specific radius for morphology operation. 121 | The radius of ball usually equals to (leaking_gap_size / 2). 122 | 123 | # Arguments 124 | radius: radius of ball shape. 125 | 126 | # Returns 127 | an array of ball structuring element. 128 | """ 129 | return cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2 * radius + 1, 2 * radius + 1)) 130 | 131 | 132 | def get_unfilled_point(image): 133 | """Get points belong to unfilled(value==255) area. 134 | 135 | # Arguments 136 | image: an image. 137 | 138 | # Returns 139 | an array of points. 140 | """ 141 | y, x = np.where(image == 255) 142 | 143 | return np.stack((x.astype(int), y.astype(int)), axis=-1) 144 | 145 | 146 | def exclude_area(image, radius): 147 | """Perform erosion on image to exclude points near the boundary. 148 | We want to pick part using floodfill from the seed point after dilation. 149 | When the seed point is near boundary, it might not stay in the fill, and would 150 | not be a valid point for next floodfill operation. So we ignore these points with erosion. 151 | 152 | # Arguments 153 | image: an image. 154 | radius: radius of ball shape. 155 | 156 | # Returns 157 | an image after dilation. 158 | """ 159 | # https://docs.opencv.org/3.4/d4/d86/group__imgproc__filter.html 160 | return cv2.morphologyEx(image, cv2.MORPH_ERODE, get_ball_structuring_element(radius), anchor=(-1, -1), iterations=1) 161 | 162 | 163 | def trapped_ball_fill_single(image, seed_point, radius): 164 | """Perform a single trapped ball fill operation. 165 | 166 | # Arguments 167 | image: an image. the image should consist of white background, black lines and black fills. 168 | the white area is unfilled area, and the black area is filled area. 169 | seed_point: seed point for trapped-ball fill, a tuple (integer, integer). 170 | radius: radius of ball shape. 171 | # Returns 172 | an image after filling. 173 | """ 174 | 175 | ball = get_ball_structuring_element(radius) 176 | 177 | pass1 = np.full(image.shape, 255, np.uint8) 178 | pass2 = np.full(image.shape, 255, np.uint8) 179 | 180 | im_inv = cv2.bitwise_not(image) # why inverse image? 181 | 182 | # Floodfill the image 183 | mask1 = cv2.copyMakeBorder(im_inv, 1, 1, 1, 1, cv2.BORDER_CONSTANT, 0) 184 | 185 | # retval, image, mask, rect = cv.floodFill(image, mask, seedPoint, newVal[, loDiff[, upDiff[, flags]]]) 186 | # fill back pixles, Flood-filling cannot go across non-zero pixels in the input mask. 187 | _, pass1, _, _ = cv2.floodFill(pass1, mask1, seed_point, 0, 0, 0, 4) #seed point is the first unfilled point 188 | 189 | # Perform dilation on image. The fill areas between gaps became disconnected. 190 | # close any possible gaps that could be coverd by the ball 191 | pass1 = cv2.morphologyEx(pass1, cv2.MORPH_DILATE, ball, anchor=(-1, -1), iterations=1) 192 | mask2 = cv2.copyMakeBorder(pass1, 1, 1, 1, 1, cv2.BORDER_CONSTANT, 0) 193 | 194 | # Floodfill with seed point again to select one fill area. 195 | _, pass2, _, rect = cv2.floodFill(pass2, mask2, seed_point, 0, 0, 0, 4) 196 | 197 | # Perform erosion on the fill result leaking-proof fill. 198 | 199 | pass2 = cv2.morphologyEx(pass2, cv2.MORPH_ERODE, ball, anchor=(-1, -1), iterations=1) 200 | 201 | return pass2 202 | 203 | 204 | def trapped_ball_fill_multi(image, radius, percentile='mean', max_iter=1000, verbo=False): 205 | """Perform multi trapped ball fill operations until all valid areas are filled. 206 | 207 | # Arguments 208 | image: an image. The image should consist of white background, black lines and black fills. 209 | the white area is unfilled area, and the black area is filled area. 210 | radius: radius of ball shape. 211 | method: method for filtering the fills. 212 | 'max' is usually with large radius for select large area such as background. 213 | max_iter: max iteration number. 214 | # Returns 215 | an array of fills' points. 216 | """ 217 | if verbo: 218 | print('trapped-ball ' + str(radius)) 219 | 220 | unfill_area = image # so unfill_area is the binary numpy array (but contain 0, 255 only), 0 means filled region I guess 221 | 222 | h, w = image.shape 223 | 224 | filled_area, filled_area_size, result = [], [], [] 225 | 226 | for _ in range(max_iter): 227 | 228 | # get the point list of unfilled regions 229 | points = get_unfilled_point(exclude_area(unfill_area, radius)) 230 | # points = get_unfilled_point(unfill_area) 231 | 232 | # terminate if all points have been filled 233 | if not len(points) > 0: 234 | break 235 | 236 | # perform a single flood fill 237 | fill = trapped_ball_fill_single(unfill_area, (points[0][0], points[0][1]), radius) 238 | 239 | # update filled region 240 | unfill_area = cv2.bitwise_and(unfill_area, fill) 241 | 242 | # record filled region of each iter 243 | filled_area.append(np.where(fill == 0)) 244 | filled_area_size.append(len(np.where(fill == 0)[0])) 245 | 246 | filled_area_size = np.asarray(filled_area_size) 247 | 248 | # a filter to remove the "half" filed regions 249 | if percentile == "mean": 250 | area_size_filter = np.mean(filled_area_size) 251 | 252 | elif type(percentile)==int: 253 | assert percentile>=0 and percentile<=100 254 | area_size_filter = np.percentile(filled_area_size, percentile) 255 | else: 256 | print("wrong percentile %s"%percentile) 257 | raise ValueError 258 | 259 | result_idx = np.where(filled_area_size >= area_size_filter)[0] 260 | 261 | # filter out all region that is less than the area_size_filter 262 | for i in result_idx: 263 | result.append(filled_area[i]) 264 | 265 | # result is a list of point list for each filled region 266 | return result 267 | 268 | 269 | def flood_fill_single(im, seed_point): 270 | """Perform a single flood fill operation. 271 | 272 | # Arguments 273 | image: an image. the image should consist of white background, black lines and black fills. 274 | the white area is unfilled area, and the black area is filled area. 275 | seed_point: seed point for trapped-ball fill, a tuple (integer, integer). 276 | # Returns 277 | an image after filling. 278 | """ 279 | pass1 = np.full(im.shape, 255, np.uint8) 280 | 281 | im_inv = cv2.bitwise_not(im) 282 | 283 | mask1 = cv2.copyMakeBorder(im_inv, 1, 1, 1, 1, cv2.BORDER_CONSTANT, 0) 284 | _, pass1, _, _ = cv2.floodFill(pass1, mask1, seed_point, 0, 0, 0, 4) 285 | 286 | return pass1 287 | 288 | 289 | def flood_fill_multi(image, max_iter=20000, verbo=False): 290 | 291 | """Perform multi flood fill operations until all valid areas are filled. 292 | This operation will fill all rest areas, which may result large amount of fills. 293 | 294 | # Arguments 295 | image: an image. the image should contain white background, black lines and black fills. 296 | the white area is unfilled area, and the black area is filled area. 297 | max_iter: max iteration number. 298 | # Returns 299 | an array of fills' points. 300 | """ 301 | if verbo: 302 | print('floodfill') 303 | 304 | unfill_area = image 305 | filled_area = [] 306 | 307 | for _ in range(max_iter): 308 | points = get_unfilled_point(unfill_area) 309 | 310 | if not len(points) > 0: 311 | break 312 | 313 | fill = flood_fill_single(unfill_area, (points[0][0], points[0][1])) 314 | unfill_area = cv2.bitwise_and(unfill_area, fill) 315 | 316 | filled_area.append(np.where(fill == 0)) 317 | 318 | return filled_area 319 | 320 | 321 | def mark_fill(image, fills): 322 | """Mark filled areas with 0. 323 | 324 | # Arguments 325 | image: an image. 326 | fills: an array of fills' points. 327 | # Returns 328 | an image. 329 | """ 330 | result = image.copy() 331 | 332 | for fill in fills: 333 | result[fill] = 0 334 | 335 | return result 336 | 337 | 338 | def build_fill_map(image, fills): 339 | """Make an image(array) with each pixel(element) marked with fills' id. id of line is 0. 340 | 341 | # Arguments 342 | image: an image. 343 | fills: an array of fills' points. 344 | # Returns 345 | an array. 346 | """ 347 | result = np.zeros(image.shape[:2], np.int) 348 | 349 | for index, fill in enumerate(fills): 350 | result[fill] = index + 1 351 | 352 | return result 353 | 354 | 355 | def show_fill_map(fillmap): 356 | """Mark filled areas with colors. It is useful for visualization. 357 | 358 | # Arguments 359 | image: an image. 360 | fills: an array of fills' points. 361 | # Returns 362 | an image. 363 | """ 364 | # Generate color for each fill randomly. 365 | colors = np.random.randint(0, 255, (np.max(fillmap) + 1, 3), dtype=np.uint8) 366 | # Id of line is 0, and its color is black. 367 | colors[0] = [0, 0, 0] 368 | 369 | return colors[fillmap] 370 | 371 | 372 | def get_bounding_rect(points): 373 | """Get a bounding rect of points. 374 | 375 | # Arguments 376 | points: array of points. 377 | # Returns 378 | rect coord 379 | """ 380 | x1, y1, x2, y2 = np.min(points[1]), np.min(points[0]), np.max(points[1]), np.max(points[0]) 381 | return x1, y1, x2, y2 382 | 383 | 384 | def get_border_bounding_rect(h, w, p1, p2, r): 385 | """Get a valid bounding rect in the image with border of specific size. 386 | 387 | # Arguments 388 | h: image max height. 389 | w: image max width. 390 | p1: start point of rect. 391 | p2: end point of rect. 392 | r: border radius. 393 | # Returns 394 | rect coord 395 | """ 396 | x1, y1, x2, y2 = p1[0], p1[1], p2[0], p2[1] 397 | 398 | x1 = x1 - r if 0 < x1 - r else 0 399 | y1 = y1 - r if 0 < y1 - r else 0 400 | x2 = x2 + r + 1 if x2 + r + 1 < w else w # why here plus 1? 401 | y2 = y2 + r + 1 if y2 + r + 1 < h else h 402 | 403 | return x1, y1, x2, y2 404 | 405 | 406 | def get_border_point(points, rect, max_height, max_width): 407 | """Get border points of a fill area 408 | 409 | # Arguments 410 | points: points of fill . 411 | rect: bounding rect of fill. 412 | max_height: image max height. 413 | max_width: image max width. 414 | # Returns 415 | points , convex shape of points 416 | """ 417 | 418 | # Get a local bounding rect. 419 | # what this function used for? 420 | border_rect = get_border_bounding_rect(max_height, max_width, rect[:2], rect[2:], 2) 421 | 422 | # Get fill in rect, all 0s 423 | fill = np.zeros((border_rect[3] - border_rect[1], border_rect[2] - border_rect[0]), np.uint8) 424 | 425 | # Move points to the rect. 426 | # offset points into the fill 427 | fill[(points[0] - border_rect[1], points[1] - border_rect[0])] = 255 428 | 429 | # Get shape. 430 | # pdb.set_trace() 431 | contours, _ = cv2.findContours(fill, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 432 | approx_shape = cv2.approxPolyDP(contours[0], 0.02 * cv2.arcLength(contours[0], True), True) 433 | 434 | # Get border pixel. 435 | # Structuring element in cross shape is used instead of box to get 4-connected border. 436 | ''' 437 | # Cross-shaped Kernel 438 | >>> cv2.getStructuringElement(cv2.MORPH_CROSS,(5,5)) 439 | array([[0, 0, 1, 0, 0], 440 | [0, 0, 1, 0, 0], 441 | [1, 1, 1, 1, 1], 442 | [0, 0, 1, 0, 0], 443 | [0, 0, 1, 0, 0]], dtype=uint8) 444 | ''' 445 | cross = cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)) # this is a ball shape kernel 446 | border_pixel_mask = cv2.morphologyEx(fill, cv2.MORPH_DILATE, cross, anchor=(-1, -1), iterations=1) - fill 447 | border_pixel_points = np.where(border_pixel_mask == 255) 448 | 449 | # Transform points back to fillmap. 450 | border_pixel_points = (border_pixel_points[0] + border_rect[1], border_pixel_points[1] + border_rect[0]) 451 | 452 | return border_pixel_points, approx_shape 453 | 454 | 455 | def merge_fill(fillmap, max_iter=10, verbo=False): 456 | """Merge fill areas. 457 | 458 | # Arguments 459 | fillmap: an image. 460 | max_iter: max iteration number. 461 | # Returns 462 | an image. 463 | """ 464 | max_height, max_width = fillmap.shape[:2] 465 | result = fillmap.copy() 466 | 467 | for i in range(max_iter): 468 | if verbo: 469 | print('merge ' + str(i + 1)) 470 | 471 | # set stroke as black 472 | result[np.where(fillmap == 0)] = 0 473 | 474 | # get list of fill id 475 | fill_id = np.unique(result.flatten()) 476 | fills = [] 477 | 478 | for j in fill_id: 479 | 480 | # select one region each time 481 | point = np.where(result == j) 482 | 483 | fills.append({ 484 | 'id': j, 485 | 'point': point, 486 | 'area': len(point[0]), 487 | 'rect': get_bounding_rect(point) 488 | }) 489 | 490 | for j, f in enumerate(fills): 491 | 492 | # ignore lines 493 | if f['id'] == 0: 494 | continue 495 | 496 | # get border shape of a region, but that may contains many nosiy segementation? 497 | border_points, approx_shape = get_border_point(f['point'], f['rect'], max_height, max_width) 498 | border_pixels = result[border_points] # pixel values or seg index of that region 499 | pixel_ids, counts = np.unique(border_pixels, return_counts=True) 500 | 501 | # remove id that equal 0 502 | ids = pixel_ids[np.nonzero(pixel_ids)] 503 | new_id = f['id'] 504 | if len(ids) == 0: 505 | # points with lines around color change to line color 506 | # regions surrounded by line remain the same 507 | if f['area'] < 5: 508 | # if f['area'] < 32: 509 | new_id = 0 510 | else: 511 | # region id may be set to region with largest contact 512 | new_id = ids[0] 513 | 514 | # a point, because the convex shape only contains 1 point 515 | if len(approx_shape) == 1 or f['area'] == 1: 516 | result[f['point']] = new_id 517 | 518 | # so this means 519 | if len(approx_shape) in [2, 3, 4, 5] and f['area'] < 500: 520 | # if len(approx_shape) in [2, 3, 4, 5] and f['area'] < 10000: 521 | result[f['point']] = new_id 522 | 523 | if f['area'] < 250 and len(ids) == 1: 524 | # if f['area'] < 5000 and len(ids) == 1: 525 | result[f['point']] = new_id 526 | 527 | if f['area'] < 50: 528 | # if f['area'] < 100: 529 | result[f['point']] = new_id 530 | 531 | # if no merge happen, stop this process 532 | if len(fill_id) == len(np.unique(result.flatten())): 533 | break 534 | 535 | return result 536 | 537 | def search_point(points, point): 538 | 539 | idx = np.where((points == point).all(axis = 1))[0] 540 | 541 | return idx 542 | 543 | def extract_region_obsolete(points, point, width, height): 544 | 545 | # unfortunately, this function is too costly to run 546 | 547 | # get 8-connectivity neighbors 548 | point_list = [] 549 | 550 | # search top left 551 | # point[0] is height 552 | # point[1] is width 553 | if point[0] > 0 and point[1] > 0: 554 | tl = np.array([point[0]-1, point[1]-1]) 555 | idx = search_point(points, tl) 556 | if len(idx) == 0: 557 | pass 558 | elif len(idx) == 1: 559 | point_list.append(tl) 560 | # pop out current point 561 | points = np.delete(points, idx, axis=0) 562 | point_list += extract_region(points, tl, width, height) 563 | else: 564 | raise ValueError("There should not exist two identical points in the list!") 565 | # search top 566 | if point[0] > 0: 567 | t = np.array([point[0], point[1]-1]) 568 | idx = search_point(points, t) 569 | if len(idx) == 0: 570 | pass 571 | elif len(idx) == 1: 572 | point_list.append(t) 573 | # pop out current point 574 | points = np.delete(points, idx, axis=0) 575 | point_list += extract_region(points, t, width, height) 576 | else: 577 | raise ValueError("There should not exist two identical points in the list!") 578 | 579 | # search top right 580 | if point[0] > 0 and point[1] < width: 581 | tr = np.array([point[0]-1, point[1]+1]) 582 | idx = search_point(points, tr) 583 | if len(idx) == 0: 584 | pass 585 | elif len(idx) == 1: 586 | point_list.append(tr) 587 | # pop out current point 588 | points = np.delete(points, idx, axis=0) 589 | point_list += extract_region(points, tr, width, height) 590 | else: 591 | raise ValueError("There should not exist two identical points in the list!") 592 | 593 | # search mid left 594 | if point[1] > 0: 595 | ml = np.array([point[0], point[1]-1]) 596 | idx = search_point(points, ml) 597 | if len(idx) == 0: 598 | pass 599 | elif len(idx) == 1: 600 | point_list.append(ml) 601 | # pop out current point 602 | points = np.delete(points, idx, axis=0) 603 | point_list += extract_region(points, ml, width, height) 604 | else: 605 | raise ValueError("There should not exist two identical points in the list!") 606 | 607 | # search mid right 608 | if point[1] < width: 609 | mr = np.array([point[0], point[1]+1]) 610 | idx = search_point(points, mr) 611 | if len(idx) == 0: 612 | pass 613 | elif len(idx) == 1: 614 | point_list.append(mr) 615 | # pop out current point 616 | points = np.delete(points, idx, axis=0) 617 | point_list += extract_region(points, mr, width, height) 618 | else: 619 | raise ValueError("There should not exist two identical points in the list!") 620 | 621 | # search bottom left 622 | if point[0] < height and point[1] > 0: 623 | bl = np.array([point[0]+1, point[1]-1]) 624 | idx = search_point(points, bl) 625 | if len(idx) == 0: 626 | pass 627 | elif len(idx) == 1: 628 | point_list.append(bl) 629 | # pop out current point 630 | points = np.delete(points, idx, axis=0) 631 | point_list += extract_region(points, bl, width, height) 632 | else: 633 | raise ValueError("There should not exist two identical points in the list!") 634 | 635 | # search bottom 636 | if point[0] < height: 637 | b = np.array([point[0]+1, point[1]]) 638 | idx = search_point(points, b) 639 | if len(idx) == 0: 640 | pass 641 | elif len(idx) == 1: 642 | point_list.append(b) 643 | # pop out current point 644 | points = np.delete(points, idx, axis=0) 645 | point_list += extract_region(points, b, width, height) 646 | else: 647 | raise ValueError("There should not exist two identical points in the list!") 648 | 649 | # search bottom right 650 | if point[0] < height and point[1] < width: 651 | br = np.array([point[0]+1, point[1]+1]) 652 | idx = search_point(points, br) 653 | if len(idx) == 0: 654 | pass 655 | elif len(idx) == 1: 656 | point_list.append(br) 657 | # pop out current point 658 | points = np.delete(points, idx, axis=0) 659 | point_list += extract_region(points, br, width, height) 660 | else: 661 | raise ValueError("There should not exist two identical points in the list!") 662 | 663 | return point_list 664 | 665 | def extract_region(): 666 | 667 | # let's try flood fill 668 | 669 | flood_fill_multi(image, max_iter=20000) 670 | 671 | def to_graph(fillmap, fillid): 672 | 673 | # how to speed up this part? 674 | # use another graph data structure 675 | # or maybe use list instead of dict 676 | 677 | fills = {} 678 | for j in tqdm(fillid): 679 | 680 | # select one region each time 681 | point = np.where(fillmap == j) 682 | 683 | fills[j] = {"point":point, 684 | "area": len(point[0]), 685 | "rect": get_bounding_rect(point), 686 | "neighbor":[]} 687 | return fills 688 | 689 | def to_fillmap(fillmap, fills): 690 | 691 | for j in fills: 692 | if fills[j] == None: 693 | continue 694 | fillmap[fills[j]['point']] = j 695 | 696 | return fillmap 697 | 698 | def merge_list_ordered(list1, list2, idx): 699 | 700 | for value in list2: 701 | if value not in list1 and value != idx and value != None: 702 | list1.append(value) 703 | 704 | return list1 705 | 706 | def merge_region(fills, source_idx, target_idx, result): 707 | 708 | # merge from source to target 709 | assert fills[source_idx] != None 710 | assert fills[target_idx] != None 711 | 712 | # update target region 713 | fills[target_idx]['point'] = (np.concatenate((fills[target_idx]['point'][0], fills[source_idx]['point'][0])), 714 | np.concatenate((fills[target_idx]['point'][1], fills[source_idx]['point'][1]))) 715 | 716 | fills[target_idx]['area'] += fills[source_idx]['area'] 717 | assert len(fills[target_idx]['point'][0]) + len(fills[source_idx]['point'][0]) == fills[target_idx]['area'] + fills[source_idx]['area'] 718 | 719 | fills[target_idx]['neighbor'] = merge_list_ordered(fills[target_idx]['neighbor'], 720 | fills[source_idx]['neighbor'], target_idx) 721 | 722 | # update source's neighbor 723 | for n in fills[source_idx]['neighbor']: 724 | if n != None: 725 | if source_idx in fills[n]['neighbor']: 726 | t = fills[n]['neighbor'].index(source_idx) 727 | fills[n]['neighbor'][t] = None 728 | else: 729 | print("find one side neighbor") 730 | 731 | if target_idx not in fills[n]['neighbor'] and target_idx != n: 732 | fills[n]['neighbor'].append(target_idx) 733 | 734 | # remove source region 735 | fills[source_idx] = None 736 | 737 | return fills 738 | 739 | def split(): 740 | # this might be a different function 741 | pass 742 | 743 | def list_region(fill_graph, th = None, verbo = True): 744 | 745 | regions = 0 746 | small_regions = [] 747 | for key in fill_graph: 748 | if fill_graph[key] != None: 749 | 750 | if verbo: 751 | print("Log:\tregion %d with size %d"%(key, fill_graph[key]['area'])) 752 | regions += 1 753 | 754 | if th == None: 755 | continue 756 | 757 | # collect small regions 758 | if fill_graph[key]['area'] < th: 759 | small_regions.append(key) 760 | 761 | if verbo: 762 | print("Log:\ttotal regions %d"%regions) 763 | 764 | return small_regions 765 | 766 | def visualize_graph(fills_graph, result, region=None): 767 | if region == None: 768 | Image.fromarray(show_fill_map(to_fillmap(result, fills_graph)).astype(np.uint8)).show() 769 | else: 770 | assert region in fills_graph 771 | show_map = np.zeros(result.shape, np.uint8) 772 | show_map[fills_graph[region]['point']] = 255 773 | Image.fromarray(show_map).show() 774 | 775 | def visualize_result(result, region=None): 776 | if region == None: 777 | Image.fromarray(show_fill_map(result).astype(np.uint8)).show() 778 | else: 779 | assert region in result 780 | show_map = np.zeros(result.shape, np.uint8) 781 | show_map[np.where(result == region)] = 255 782 | Image.fromarray(show_map).show() 783 | 784 | def graph_self_check(fill_graph): 785 | 786 | for key in fill_graph: 787 | if fill_graph[key] != None: 788 | if len(fill_graph[key]['neighbor']) > 0: 789 | if len(fill_graph[key]['neighbor']) != len(set(fill_graph[key]['neighbor'])): 790 | print("Log:\tfind duplicate neighbor!") 791 | for n in fill_graph[key]['neighbor']: 792 | if key not in fill_graph[n]['neighbor']: 793 | print("Log:\tfind missing neighbor") 794 | # print("Log:\tregion %d has %d points"%(key, fill_graph[key]['area'])) 795 | 796 | def flood_fill_single_proc(region_id, img): 797 | 798 | # construct fill region 799 | fill_region = np.full(img.shape, 0, np.uint8) 800 | fill_region[np.where(img == region_id)] = 255 801 | return flood_fill_multi(fill_region, verbo=False) 802 | 803 | def flood_fill_multi_proc(func, fill_id, result, n_proc): 804 | print("Log:\tmulti process spliting bleeding regions") 805 | with Pool(processes=n_proc) as p: 806 | return p.map(partial(func, img=result), fill_id) 807 | 808 | def split_region(result, multi_proc=False): 809 | 810 | # # get list of fill id 811 | # fill_id = np.unique(result.flatten()).tolist() 812 | # fill_id.remove(0) 813 | # assert 0 not in fill_id 814 | 815 | # _, result = cv2.connectedComponents(result, connectivity=4) 816 | # # there will left some small regions, we can merge them into region 0 in the following step 817 | 818 | # # result = build_fill_map(result, fill_points) 819 | 820 | # fill_id_new = np.unique(result) 821 | 822 | # generate thershold of merging region 823 | w, h = result.shape 824 | th = int(w*h*0.09) 825 | # get list of fill id 826 | fill_id = np.unique(result.flatten()).tolist() 827 | fill_id.remove(0) 828 | assert 0 not in fill_id 829 | 830 | fill_points = [] 831 | 832 | # get each region ready to be filled 833 | if multi_proc: 834 | n_proc = 8 835 | start = time.process_time() 836 | 837 | fill_points_multi_proc = flood_fill_multi_proc(flood_fill_single_proc, fill_id, result, n_proc) 838 | for f in fill_points_multi_proc: 839 | fill_points += f 840 | 841 | print("Mutiprocessing time: {}secs\n".format((time.process_time()-start))) 842 | 843 | else: 844 | # split each region if it is splited by ink region 845 | start = time.process_time() 846 | for j in tqdm(fill_id): 847 | 848 | # skip strokes 849 | if j == 0: 850 | continue 851 | 852 | # generate fill mask of that region 853 | fill_region = np.full(result.shape, 0, np.uint8) 854 | fill_region[np.where(result == j)] = 255 855 | 856 | # corp to a smaller region that only cover the current filling region to speed up 857 | # todo 858 | 859 | # assign new id to 860 | fills = flood_fill_multi(fill_region, verbo=False) 861 | 862 | merge = [] 863 | merge_idx = [] 864 | for i in range(len(fills)): 865 | if len(fills[i][0]) > th: 866 | merge_idx.append(i) 867 | 868 | for i in range(len(merge_idx)): 869 | merge.append(fills[merge_idx[i]]) 870 | 871 | for i in merge_idx: 872 | fills.pop(i) 873 | 874 | if len(merge) > 0: 875 | region_merged = merge.pop(0) 876 | for p in merge: 877 | region_merged = (np.concatenate((region_merged[0], p[0])), np.concatenate((region_merged[1], p[1]))) 878 | fills.append(region_merged) 879 | 880 | fill_points += fills 881 | print("Single-processing time: {}secs\n".format((time.process_time()-start))) 882 | 883 | result = build_fill_map(result, fill_points) 884 | fill_id_new = np.unique(result) 885 | 886 | return result, fill_id_new 887 | 888 | def find_neighbor(result, fills_graph, max_height, max_width): 889 | 890 | fill_id_new = np.unique(result) 891 | 892 | for j in tqdm(fill_id_new): 893 | 894 | if j == 0: 895 | continue 896 | 897 | fill_region = np.zeros(result.shape, np.uint8) 898 | fill_region[np.where(result == j)] = 255 899 | 900 | # find boundary of each region 901 | # sometimes this function is not multually correct, why? 902 | border_points, _ = get_border_point(fills_graph[j]['point'], fills_graph[j]['rect'], max_height, max_width) 903 | 904 | # construct a graph map of all regions 905 | neighbor_id = np.unique(result[border_points]) 906 | 907 | # record neighbor information 908 | for k in neighbor_id: 909 | if k != 0: 910 | if k not in fills_graph[j]["neighbor"]: 911 | fills_graph[j]["neighbor"].append(k) 912 | if j not in fills_graph[k]["neighbor"]: 913 | fills_graph[k]["neighbor"].append(j) 914 | 915 | return fills_graph 916 | 917 | def find_min_neighbor(fills_graph, idx): 918 | 919 | neighbors = [] 920 | neighbor_sizes = [] 921 | for n in fills_graph[idx]["neighbor"]: 922 | if n != None: 923 | neighbors.append(n) 924 | neighbor_sizes.append(fills_graph[n]['area']) 925 | else: 926 | neighbors.append(n) 927 | neighbor_sizes.append(-1) 928 | 929 | # we need to sort the index 930 | sort_idx = sorted(range(len(neighbors)), key=lambda k: neighbor_sizes[k]) 931 | 932 | # for i in sort_idx: 933 | # print("Log:\tregion %d with size %d"%(neighbors[i], neighbor_sizes[i])) 934 | 935 | return sort_idx, neighbor_sizes 936 | 937 | def merge_all_neighbor(fills_graph, idx, result): 938 | 939 | for n in fills_graph[idx]['neighbor']: 940 | result[fills_graph[n]['point']] = idx 941 | 942 | return result 943 | 944 | def check_all_neighbor(fills_graph, j, low_th, max_th): 945 | 946 | min_neighbors = find_min_neighbor(fills_graph, j) 947 | for k in min_neighbors: 948 | nb = fills_graph[j]['neighbor'][k] 949 | 950 | if nb != None: 951 | if fills_graph[nb] != None and fills_graph[nb]['area'] <= low_th and nb != j: 952 | continue 953 | else: 954 | print("Log:\texclude region %d"%nb) 955 | 956 | def remove_bleeding(fills_graph, fill_id_new, max_iter, result, low_th, max_th): 957 | 958 | count = 0 959 | really_low_th = 100 960 | # max region absorb small neighbors 961 | for i in range(max_iter): 962 | # print('merge 2nd ' + str(i + 1)) 963 | for j in tqdm(fill_id_new): 964 | if j == 0: 965 | continue 966 | if fills_graph[j] == None: # this region has been removed 967 | continue 968 | if fills_graph[j]['area'] < max_th: 969 | continue 970 | 971 | min_neighbors, min_neighbor_sizes = find_min_neighbor(fills_graph, j) 972 | # print("Log:\tfound region %d have %d neighbors"%(j, len(min_neighbors))) 973 | 974 | for k in min_neighbors: 975 | 976 | nb = fills_graph[j]['neighbor'][k] 977 | 978 | if min_neighbor_sizes[k] == -1: 979 | continue 980 | 981 | if nb != None: 982 | if fills_graph[nb] != None and fills_graph[nb]['area'] <= low_th and nb != j: 983 | fills_graph = merge_region(fills_graph, nb, j, result) 984 | count += 1 985 | else: 986 | fills_graph[j]['neighbor'][k] = None 987 | 988 | # small region join its largest neighbor 989 | small_regions = list_region(fills_graph, low_th, False) 990 | first_loop = True 991 | num_samll_before = len(small_regions) 992 | num_samll_after = len(small_regions) 993 | 994 | while first_loop or num_samll_before - num_samll_after > 0: 995 | first_loop = False 996 | for s in small_regions: 997 | 998 | min_neighbors, min_neighbor_sizes = find_min_neighbor(fills_graph, s) 999 | 1000 | if len(min_neighbors) == 0 or min_neighbors == None or min_neighbor_sizes[min_neighbors[-1]] == -1: 1001 | if fills_graph[s]['area'] < really_low_th: 1002 | fills_graph = merge_region(fills_graph, s, 0, result) 1003 | continue 1004 | 1005 | t = fills_graph[s]['neighbor'][min_neighbors[-1]] 1006 | fills_graph = merge_region(fills_graph, s, t, result) 1007 | count += 1 1008 | 1009 | small_regions = list_region(fills_graph, low_th, False) 1010 | 1011 | num_samll_before = num_samll_after 1012 | num_samll_after = len(small_regions) 1013 | 1014 | print("Log:\t %d neighbors merged"%count) 1015 | 1016 | return fills_graph 1017 | 1018 | 1019 | def merger_fill_2nd(fillmap, max_iter=10, low_th=0.001, max_th=0.01, debug=False): 1020 | 1021 | """ 1022 | next step should be using multi threading in each step 1023 | get the function as fast as I can 1024 | """ 1025 | 1026 | max_height, max_width = fillmap.shape[:2] 1027 | result = fillmap.copy() 1028 | low_th = int(max_height*max_width*low_th) 1029 | max_th = int(max_height*max_width*max_th) 1030 | 1031 | # 1. convert filling map to graphs 1032 | # this step take 99% of running time, need optimaization a lot 1033 | if debug: 1034 | print("Log:\tload fill_map.pickle") 1035 | result = load_obj("fill_map.pickle") 1036 | fill_id_new = np.unique(result) 1037 | else: 1038 | print("Log:\tsplit bleeding regions") 1039 | result, fill_id_new = split_region(result) 1040 | 1041 | # initailize the graph of regions 1042 | if debug: 1043 | print("Log:\tload fills_graph.pickle") 1044 | fills_graph_init = load_obj("fills_graph.pickle") 1045 | fills_graph = load_obj("fills_graph.pickle") 1046 | else: 1047 | print("Log:\tinitialize region graph") 1048 | fills_graph = to_graph(result, fill_id_new) 1049 | 1050 | # find neighbor 1051 | if debug: 1052 | print("Log:\tload fills_graph_n.pickle") 1053 | fills_graph = load_obj("fills_graph_n.pickle") 1054 | else: 1055 | print("Log:\tfind region neighbors") 1056 | fills_graph = find_neighbor(result, fills_graph, max_height, max_width) 1057 | 1058 | # self check if the graph is constructed correctly 1059 | graph_self_check(fills_graph) 1060 | 1061 | # 2. merge all small region to its largest neighbor 1062 | # this step seems fast, it only takes around 20s to finish 1063 | print("Log:\tremove leaking color") 1064 | fills_graph = remove_bleeding(fills_graph, fill_id_new, max_iter, result, low_th, max_th) 1065 | 1066 | # 3. show the refined the result 1067 | visualize_graph(fills_graph, result, region=None) 1068 | 1069 | # 4. map region graph back to fillmaps 1070 | result = to_fillmap(result, fills_graph) 1071 | return result, fills_graph 1072 | -------------------------------------------------------------------------------- /src/flatting/unet/__init__.py: -------------------------------------------------------------------------------- 1 | from .unet_model import UNet 2 | -------------------------------------------------------------------------------- /src/flatting/unet/unet_model.py: -------------------------------------------------------------------------------- 1 | """ Full assembly of the parts to form the complete network """ 2 | 3 | import torch.nn.functional as F 4 | 5 | from .unet_parts import * 6 | 7 | 8 | class UNet(nn.Module): 9 | def __init__(self, in_channels, out_channels, bilinear=True): 10 | super(UNet, self).__init__() 11 | self.in_channels = in_channels 12 | self.out_channels = out_channels 13 | self.bilinear = bilinear 14 | 15 | self.inc = DoubleConv(in_channels, 64) 16 | self.down1 = Down(64, 128) 17 | self.down2 = Down(128, 256) 18 | self.down3 = Down(256, 512) 19 | factor = 2 if bilinear else 1 20 | self.down4 = Down(512, 1024 // factor) 21 | self.up1 = Up(1024, 512 // factor, bilinear) 22 | self.up2 = Up(512, 256 // factor, bilinear) 23 | self.up3 = Up(256, 128 // factor, bilinear) 24 | self.up4 = Up(128, 64, bilinear) 25 | self.outc = OutConv(64, out_channels) 26 | 27 | def forward(self, x): 28 | x1 = self.inc(x) 29 | x2 = self.down1(x1) 30 | x3 = self.down2(x2) 31 | x4 = self.down3(x3) 32 | x5 = self.down4(x4) 33 | x = self.up1(x5, x4) 34 | x = self.up2(x, x3) 35 | x = self.up3(x, x2) 36 | x = self.up4(x, x1) 37 | logits = self.outc(x) 38 | return logits 39 | -------------------------------------------------------------------------------- /src/flatting/unet/unet_parts.py: -------------------------------------------------------------------------------- 1 | """ Parts of the U-Net model """ 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class DoubleConv(nn.Module): 9 | """(convolution => [BN] => ReLU) * 2""" 10 | 11 | def __init__(self, in_channels, out_channels, mid_channels=None): 12 | super().__init__() 13 | if not mid_channels: 14 | mid_channels = out_channels 15 | self.double_conv = nn.Sequential( 16 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), 17 | nn.BatchNorm2d(mid_channels), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), 20 | nn.BatchNorm2d(out_channels), 21 | nn.ReLU(inplace=True) 22 | ) 23 | 24 | def forward(self, x): 25 | return self.double_conv(x) 26 | 27 | 28 | class Down(nn.Module): 29 | """Downscaling with maxpool then double conv""" 30 | 31 | def __init__(self, in_channels, out_channels): 32 | super().__init__() 33 | self.maxpool_conv = nn.Sequential( 34 | nn.MaxPool2d(2), 35 | DoubleConv(in_channels, out_channels) 36 | ) 37 | 38 | def forward(self, x): 39 | return self.maxpool_conv(x) 40 | 41 | 42 | class Up(nn.Module): 43 | """Upscaling then double conv""" 44 | 45 | def __init__(self, in_channels, out_channels, bilinear=True): 46 | super().__init__() 47 | 48 | # if bilinear, use the normal convolutions to reduce the number of channels 49 | if bilinear: 50 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 51 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 52 | else: 53 | self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2) 54 | self.conv = DoubleConv(in_channels, out_channels) 55 | 56 | 57 | def forward(self, x1, x2): 58 | x1 = self.up(x1) 59 | # input is CHW 60 | diffY = x2.size()[2] - x1.size()[2] 61 | diffX = x2.size()[3] - x1.size()[3] 62 | 63 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 64 | diffY // 2, diffY - diffY // 2]) 65 | # if you have padding issues, see 66 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 67 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 68 | x = torch.cat([x2, x1], dim=1) 69 | return self.conv(x) 70 | 71 | 72 | class OutConv(nn.Module): 73 | def __init__(self, in_channels, out_channels): 74 | super(OutConv, self).__init__() 75 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 76 | 77 | def forward(self, x): 78 | return self.conv(x) 79 | -------------------------------------------------------------------------------- /src/flatting/utils/add_white_background.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | from os.path import * 5 | from PIL import Image 6 | from tqdm import tqdm 7 | 8 | source = "L:\\2.Research_project\\3.flatting\\Pytorch-UNet\\flatting\\validation" 9 | target = "L:\\2.Research_project\\3.flatting\\Pytorch-UNet\\flatting\\validation" 10 | 11 | for img in tqdm(os.listdir(source)): 12 | 13 | if ".png" not in img: continue 14 | 15 | # open image 16 | img_a = Image.open(join(source, img)) 17 | 18 | # prepare white backgournd 19 | img_w = Image.new("RGBA", img_a.size, "WHITE") 20 | try: 21 | img_w.paste(img_a, None, img_a) 22 | img_w.convert("RGB").save(join(target, img)) 23 | except: 24 | print("Error:\tfailed on %s"%img) 25 | -------------------------------------------------------------------------------- /src/flatting/utils/data_vis.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | 4 | def plot_img_and_mask(img, mask): 5 | classes = mask.shape[2] if len(mask.shape) > 2 else 1 6 | fig, ax = plt.subplots(1, classes + 1) 7 | ax[0].set_title('Input image') 8 | ax[0].imshow(img) 9 | if classes > 1: 10 | for i in range(classes): 11 | ax[i+1].set_title(f'Output mask (class {i+1})') 12 | ax[i+1].imshow(mask[:, :, i]) 13 | else: 14 | ax[1].set_title(f'Output mask') 15 | ax[1].imshow(mask) 16 | plt.xticks([]), plt.yticks([]) 17 | plt.show() 18 | -------------------------------------------------------------------------------- /src/flatting/utils/dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | import logging 5 | import cv2 6 | # import webp 7 | 8 | from os.path import * 9 | from os import listdir 10 | 11 | from PIL import Image 12 | 13 | from torch.utils.data import Dataset 14 | from torchvision import transforms as T 15 | from torch.nn import Threshold 16 | from io import BytesIO 17 | 18 | 19 | class BasicDataset(Dataset): 20 | # let's try to make this all work in memory 21 | ''' 22 | The original version, which read image from disk 23 | uncomment to enable 24 | ''' 25 | # def __init__(self, line_dir, edge_dir, radius = 2, crop_size = 0): 26 | # self.line_dir = line_dir 27 | # self.edge_dir = edge_dir 28 | # self.kernel = self.get_ball_structuring_element(radius) 29 | 30 | # self.crop_size = crop_size if crop_size != 0 else 1024 31 | # assert self.crop_size > 0 32 | 33 | # self.ids = listdir(line_dir) 34 | # self.length = len(self.ids) 35 | # assert self.length == len(listdir(edge_dir)) 36 | 37 | # logging.info(f'Creating dataset with {len(self.ids)} examples') 38 | 39 | ''' 40 | The modified version, read the whole data set in numpy array 41 | ''' 42 | def __init__(self, lines_bytes, edges_bytes, radius = 2, crop_size = 0): 43 | 44 | self.lines_bytes = lines_bytes 45 | self.edges_bytes = edges_bytes 46 | 47 | self.kernel = self.get_ball_structuring_element(radius) 48 | 49 | self.crop_size = crop_size if crop_size != 0 else 1024 50 | assert self.crop_size > 0 51 | 52 | 53 | 54 | self.length = len(lines_bytes) 55 | # self.length = len(self.ids) 56 | 57 | assert self.length == len(edges_bytes) 58 | 59 | logging.info(f'Creating dataset with {self.length} examples') 60 | 61 | def __len__(self): 62 | return self.length 63 | 64 | def get_ball_structuring_element(self, radius): 65 | """Get a ball shape structuring element with specific radius for morphology operation. 66 | The radius of ball usually equals to (leaking_gap_size / 2). 67 | 68 | # Arguments 69 | radius: radius of ball shape. 70 | 71 | # Returns 72 | an array of ball structuring element. 73 | """ 74 | return cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2 * radius + 1, 2 * radius + 1)) 75 | 76 | def __getitem__(self, i): 77 | 78 | 79 | ''' 80 | The original version, uncomment to enable 81 | ''' 82 | # idx = self.ids[i] 83 | # edge_path = join(self.edge_dir, idx.replace("webp", "png")) 84 | # line_path = join(self.line_dir, idx) 85 | 86 | # assert exists(edge_path), \ 87 | # f'No edge map found for the ID {idx}: {edge_path}' 88 | # assert exists(line_path), \ 89 | # f'No line art found for the ID {idx}: {line_path}' 90 | 91 | # remove white 92 | # if ".webp" in line_path: 93 | # line_np = np.array(webp.load_image(line_path, "RGB").convert("L")) 94 | # else: 95 | # line_np = np.array(Image.open(line_path)) 96 | 97 | # if ".webp" in edge_path: 98 | # edge_np = np.array(webp.load_image(edge_path, "RGB")) 99 | # else: 100 | # edge_np = np.array(Image.open(edge_path)) 101 | ''' 102 | the end of orignal version 103 | ''' 104 | 105 | ''' 106 | The modified version 107 | ''' 108 | line_bytes = self.lines_bytes[i] 109 | edge_bytes = self.edges_bytes[i] 110 | 111 | 112 | buffer = BytesIO(line_bytes) 113 | line_np = np.array(Image.open(buffer).convert("L")) 114 | 115 | buffer = BytesIO(edge_bytes) 116 | edge_np = np.array(Image.open(buffer).convert("L")) 117 | ''' 118 | end of modified version 119 | ''' 120 | 121 | ''' 122 | The following part should be fine 123 | ''' 124 | 125 | # crop_bbox = self.find_bbox(self.to_point_list(line_np)) 126 | # line_np = self.crop_img(crop_bbox, line_np) 127 | # edge_np = self.crop_img(crop_bbox, edge_np) 128 | 129 | # line_np, edge_np = self.random_resize([line_np, edge_np]) 130 | 131 | # or threshold by opencv? 132 | _, mask1_np = cv2.threshold(line_np, 125, 255, cv2.THRESH_BINARY) 133 | _, mask2_np = cv2.threshold(edge_np, 125, 255, cv2.THRESH_BINARY) 134 | 135 | # convert to tensor, and the following process should all be done by cuda 136 | line = self.to_tensor(line_np) 137 | edge = self.to_tensor(edge_np) 138 | 139 | mask1 = self.to_tensor(mask1_np, normalize = False) 140 | mask2 = self.to_tensor(mask2_np, normalize = False) 141 | 142 | assert line.shape == line.shape, \ 143 | f'Line art and edge map {i} should be the same size, but are {line.shape} and {edge.shape}' 144 | 145 | 146 | 147 | imgs = self.augment(torch.cat((line, edge, mask1, mask2), dim=0)) 148 | 149 | # it returns tensor at last 150 | return torch.chunk(imgs, 4, dim=0) 151 | 152 | def to_point_list(self, img_np): 153 | p = np.where(img_np < 220) 154 | return p 155 | 156 | def find_bbox(self, p): 157 | t = p[0].min() 158 | l = p[1].min() 159 | b = p[0].max() 160 | r = p[1].max() 161 | return t,l,b,r 162 | 163 | def crop_img(self, bbox, img_np): 164 | t,l,b,r = bbox 165 | return img_np[t:b, l:r] 166 | 167 | # def random_resize(self, img_np_list): 168 | # ''' 169 | # Experiment shows that random resize is not working well, so this function is obsoleted and just be left here 170 | # as a record. 171 | # Don't try random resize in this way, it will not work! 172 | # Much slower converging speed and not obvious better generalizetion ability 173 | # ''' 174 | # size = self.crop_size * (1 + np.random.rand()/5) 175 | 176 | # # if the image is a very long or wide image, then split it before cropping 177 | # img_np_resize_list = [] 178 | # for img_np in img_np_list: 179 | # if len(img_np.shape) == 2: 180 | # h, w = img_np.shape 181 | # else: 182 | # h, w, _ = img_np.shape 183 | 184 | # short_side = w if w < h else h 185 | # r = size / short_side 186 | # target_w = int(w*r+0.5) 187 | # target_h = int(h*r+0.5) 188 | # img_np = cv2.resize(img_np, (target_w, target_h), interpolation=cv2.INTER_AREA) 189 | # img_np_resize_list.append(img_np) 190 | 191 | # return img_np_resize_list 192 | 193 | def to_tensor(self, pil_img, normalize = True): 194 | 195 | # assume the input is always grayscal 196 | if normalize: 197 | transforms = T.Compose( 198 | [ 199 | # to tensor will change the channel order and divide 255 if necessary 200 | T.ToTensor(), 201 | T.Normalize(0.5, 0.5, inplace = True) 202 | ] 203 | ) 204 | else: 205 | transforms = T.Compose( 206 | [ 207 | # to tensor will change the channel order and divide 255 if necessary 208 | T.ToTensor(), 209 | ] 210 | ) 211 | 212 | return transforms(pil_img) 213 | 214 | def augment(self, tensors): 215 | transforms = T.Compose( 216 | [ 217 | T.RandomHorizontalFlip(), 218 | T.RandomVerticalFlip(), 219 | T.RandomCrop(size = self.crop_size) 220 | 221 | ] 222 | ) 223 | return transforms(tensors) 224 | -------------------------------------------------------------------------------- /src/flatting/utils/ground_truth_creation.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import os, sys 4 | sys.path.append("./line") 5 | import torch 6 | 7 | from PIL import Image 8 | from os.path import * 9 | from tqdm import tqdm 10 | from hed.run import estimate 11 | from line.thin import Thinner 12 | from skimage.morphology import skeletonize, thin 13 | 14 | 15 | # let's use a advanced edge detection algorithm 16 | def get_ball_structuring_element(radius): 17 | """Get a ball shape structuring element with specific radius for morphology operation. 18 | The radius of ball usually equals to (leaking_gap_size / 2). 19 | 20 | # Arguments 21 | radius: radius of ball shape. 22 | 23 | # Returns 24 | an array of ball structuring element. 25 | """ 26 | return cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2 * radius + 1, 2 * radius + 1)) 27 | 28 | def to_tensor(path_to_img): 29 | 30 | # img = Image.open(path_to_img) 31 | img_np = cv2.imread(path_to_img, cv2.IMREAD_COLOR) 32 | img_np = np.ascontiguousarray(img_np[:, :, ::-1].transpose(2, 0, 1).astype(np.float32) * (1.0 / 255.0)) 33 | 34 | return torch.FloatTensor(img_np) 35 | 36 | def to_numpy(edge_tensor): 37 | 38 | edge_np = edge_tensor.clamp(0.0, 1.0).numpy().transpose(1, 2, 0)[:, :, 0] * 255.0 39 | edge_np = edge_np.astype(np.uint8) 40 | 41 | return edge_np 42 | 43 | def extract_skeleton(img): 44 | 45 | size = np.size(img) 46 | skel = np.zeros(img.shape,np.uint8) 47 | element = cv2.getStructuringElement(cv2.MORPH_CROSS,(3,3)) 48 | done = False 49 | 50 | while done is False: 51 | eroded = cv2.erode(img,element) 52 | temp = cv2.dilate(eroded,element) 53 | temp = cv2.subtract(img,temp) 54 | skel = cv2.bitwise_or(skel,temp) 55 | img = eroded.copy() 56 | 57 | zeros = size - cv2.countNonZero(img) 58 | if zeros==size: 59 | done = True 60 | 61 | return skel 62 | 63 | def extract_gt_hed(input_line, input_flat, out_path): 64 | 65 | _, line_name = split(input_line) 66 | 67 | # extract edge by HED 68 | tenInput = to_tensor(input_flat) 69 | tenOutput = estimate(tenInput) 70 | 71 | # threshold the output 72 | edge = to_numpy(tenOutput) 73 | edge_thresh = cv2.adaptiveThreshold(255-edge, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 11, 2) 74 | # values = np.unique(edge) 75 | # lower_bound = np.percentile(values, 30) 76 | # _, edge_thresh = cv2.threshold(edge, lower_bound, 255, cv2.THRESH_BINARY) 77 | 78 | # get skeleton 79 | # thin = Thinner() 80 | # edge_thin = thin(Image.fromarray(edge)).detach().cpu().numpy().transpose(1,2,0)*255 81 | # edge_thin = edge_thin.astype(np.uint8).repeat(3, axis=-1) 82 | 83 | # all of these not work 84 | # edge_thin = cv2.ximgproc.thinning(edge) 85 | # edge_thin = extract_skeleton(255 - edge_thresh) 86 | # edge_thin = skeletonize(edge_thresh) 87 | 88 | # Image.fromarray(edge_thresh).save(join(out_path, line_name)) 89 | cv2.imwrite(join(out_path, line_name), edge_thresh) 90 | 91 | 92 | def extract_gt(input_line, input_flat, out_path): 93 | # initialize 94 | 95 | print("Log:\topen %s"%input_flat) 96 | if exists(out_path) is False: 97 | os.makedirs(out_path) 98 | 99 | # canny edge detection 100 | img = cv2.imread(input_flat) 101 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 102 | img = cv2.blur(img,(5,5)) 103 | 104 | # analyaze the gradient of flat image 105 | grad = cv2.Laplacian(img,cv2.CV_64F) 106 | grad = abs(grad).sum(axis = -1) 107 | grad_v, grad_c = np.unique(grad, return_counts=True) 108 | 109 | # remove the majority grad, which is 0 110 | assert np.where(grad_v==0) == np.where(grad_c==grad_c.max()) 111 | grad_v = np.delete(grad_v, np.where(grad_v==0)) 112 | grad_c = np.delete(grad_c, np.where(grad_c==grad_c.max())) 113 | print("Log:\tlen of grad_v %d"%len(grad_v)) 114 | grad_c_cum = np.cumsum(grad_c) 115 | 116 | # if grad number is greater than 100, then this probably means the current 117 | # image exists pretty similar colors, then we should apply 118 | # another set of parameter to detect edge 119 | # this could be better if we can find the realtion between them 120 | if len(grad_v) < 100: 121 | min_val = grad_v[np.where(grad_c_cum<=np.percentile(grad_c_cum, 25))[0].max()] 122 | max_val = grad_v[np.where(grad_c_cum<=np.percentile(grad_c_cum, 40))[0].max()] 123 | else: 124 | min_val = grad_v[np.where(grad_c_cum<=np.percentile(grad_c_cum, 1))[0].max()] 125 | max_val = grad_v[np.where(grad_c_cum<=np.percentile(grad_c_cum, 10))[0].max()] 126 | 127 | edges = cv2.Canny(img, min_val, max_val, L2gradient=True) 128 | 129 | # write result 130 | _, line_name = split(input_line) 131 | cv2.imwrite(join(out_path, line_name.replace("webp", "png")), 255-edges) 132 | 133 | def main(): 134 | 135 | input_line_path = "../flatting/size_org/line/" 136 | input_flat_path = "../flatting/size_org/flat/" 137 | out_path = "../flatting/size_org/line_detection/" 138 | 139 | for img in tqdm(os.listdir(input_line_path)): 140 | input_line = join(input_line_path, img) 141 | input_flat = join(input_flat_path, img.replace("line", "flat")) 142 | 143 | # neural net base edge detection 144 | # extract_gt_hed(input_line, input_flat, out_path) 145 | 146 | # canny edge detection 147 | extract_gt(input_line, input_flat, out_path) 148 | 149 | 150 | 151 | if __name__=="__main__": 152 | main() -------------------------------------------------------------------------------- /src/flatting/utils/move_to_duplicate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from os.path import * 4 | 5 | flat = "L:\\2.Research_project\\3.flatting\\flatting_trapped_ball\\flatting\\size_org\\flat" 6 | line = "L:\\2.Research_project\\3.flatting\\flatting_trapped_ball\\flatting\\size_org\\line" 7 | # flat = "L:\\2.Research_project\\3.flatting\\Pytorch-UNet\\flatting\\size_org\\OneDrive_2021-03-14\\[21-02-05] 318 SETS\\flat" 8 | # line = "L:\\2.Research_project\\3.flatting\\Pytorch-UNet\\flatting\\size_org\\OneDrive_2021-03-14\\[21-02-05] 318 SETS\\line" 9 | duplicate = "L:\\2.Research_project\\3.flatting\\flatting_trapped_ball\\flatting\\size_org\\duplicate" 10 | test = "L:\\2.Research_project\\3.flatting\\flatting_trapped_ball\\flatting\\size_org\\test" 11 | 12 | 13 | line_croped = "L:\\2.Research_project\\3.flatting\\flatting_trapped_ball\\flatting\\size_1024\\line_croped" 14 | line_detection_croped = "L:\\2.Research_project\\3.flatting\\flatting_trapped_ball\\flatting\\size_1024\\line_detection_croped" 15 | # with open("moving log.txt", "r") as f: 16 | # move_list = f.readlines() 17 | 18 | # for img in move_list: 19 | # img = img.replace("\n", "").replace("Log: moving ", "") 20 | # if "flat" in img: 21 | # shutil.move(join(duplicate, img), join(flat, img)) 22 | # if "line" in img: 23 | # shutil.move(join(duplicate, img), join(line, img)) 24 | 25 | flats = os.listdir(flat) 26 | lines = os.listdir(line) 27 | 28 | lines_croped = os.listdir(line_croped) 29 | lines_croped.sort() 30 | lines_detection_croped = os.listdir(line_detection_croped) 31 | lines_detection_croped.sort() 32 | 33 | assert len(lines_croped) == len(lines_detection_croped) 34 | 35 | 36 | ''' 37 | Move to test folders, but I think those are not good for evaluation... 38 | ''' 39 | for img in os.listdir(line): 40 | if img.replace("line", "flat") not in flats: 41 | print("Log:\tmoving %s"%img) 42 | shutil.move(join(line, img), join(test, img.replace(".png", "_line.png"))) 43 | 44 | for img in os.listdir(flat): 45 | if img.replace("flat", "line") not in lines: 46 | print("Log:\tmoving %s"%img) 47 | os.remove(join(flat, img)) 48 | # shutil.move(join(flat, img), join(test, img.replace(".png", "_flat.png"))) 49 | 50 | 51 | 52 | ''' 53 | Re-order all images in resized folder 54 | ''' 55 | # count = 0 56 | # for i in range(len(lines_croped)): 57 | # assert lines_croped[i] == lines_detection_croped[i] 58 | # img = lines_croped[i] 59 | # os.rename(join(line_croped, img), join(line_croped, "%04d.png"%count)) 60 | # os.rename(join(line_detection_croped, img), join(line_detection_croped, "%04d.png"%count)) 61 | # count += 1 62 | -------------------------------------------------------------------------------- /src/flatting/utils/polyvector/run_all_examples.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from os.path import * 3 | import os 4 | import subprocess 5 | import shutil 6 | import time 7 | import pdb 8 | 9 | def parse(): 10 | # path to executable file 11 | p1 = normcase("./") 12 | p2 = 'polyvector_thing.exe' 13 | # path to input files 14 | p3 = normcase("E:\\OneDrive - George Mason University\\00.Projects\\01.Sketch cleanup\\00.Bechmark Dataset\\SSEB Dataset with GT") 15 | # path to output files, if necessary 16 | # p4 = normcase("./results") 17 | p4 = None 18 | parser = argparse.ArgumentParser(description='Batch Run Script') 19 | parser.add_argument('--exe', 20 | help ='path to executable file', 21 | default = join(p1,p2)) 22 | parser.add_argument('--input', 23 | help='path to input files', 24 | default = p3) 25 | parser.add_argument('--result', 26 | help='path where reuslts will be saved, if necessary', 27 | default = p4) 28 | 29 | return parser 30 | 31 | def main(): 32 | args = parse().parse_args() 33 | for img in os.listdir(args.input): 34 | name, extension = splitext(img) 35 | 36 | if extension == '.png': 37 | subprocess.run([args.exe, "-noisy", join(args.input, img)]) 38 | 39 | 40 | # 这个可能用的到,也可能用不到,以后再详细想想怎么写成一个通用的框架 41 | # path_to_result = join(args.result, name) 42 | # time.sleep(1) 43 | # if not exists(args.result): 44 | # os.mkdir(args.result) 45 | # if not exists(path_to_result): 46 | # os.mkdir(path_to_result) 47 | 48 | # for svg in os.listdir(normcase('./')): 49 | # if svg.endswith('.svg'): 50 | # # pdb.set_trace() 51 | # shutil.move(join(normcase('./'), svg), 52 | # join(path_to_result,svg)) 53 | print("Done") 54 | 55 | if __name__ == "__main__": 56 | main() 57 | -------------------------------------------------------------------------------- /src/flatting/utils/preprocessing.py: -------------------------------------------------------------------------------- 1 | # remove all white regions in image, and do the sample crop to groud turth 2 | # then resize all images by calling magick 3 | import os 4 | import cv2 5 | import numpy as np 6 | 7 | from os.path import * 8 | from PIL import Image 9 | from tqdm import tqdm 10 | 11 | def to_np(img_path, th = False): 12 | 13 | if th: 14 | img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) 15 | _, img = cv2.threshold(img,220,255,cv2.THRESH_BINARY) 16 | # img = cv2.adaptiveThreshold(img,255,cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY,11,2) 17 | else: 18 | img = cv2.imread(img_path, cv2.IMREAD_COLOR) 19 | 20 | return img 21 | 22 | def to_point_list(img_np): 23 | p = np.where(img_np < 220) 24 | return p 25 | 26 | def find_bbox(p): 27 | t = p[0].min() 28 | l = p[1].min() 29 | b = p[0].max()+1 30 | r = p[1].max()+1 31 | return t,l,b,r 32 | 33 | def crop_img(bbox, img_np): 34 | t,l,b,r = bbox 35 | return img_np[t:b, l:r] 36 | 37 | def center_crop_resize(img_np, size, crop=False, th=False): 38 | # if the image is a very long or wide image, then split it before cropping 39 | if len(img_np.shape) == 2: 40 | h, w = img_np.shape 41 | else: 42 | h, w, _ = img_np.shape 43 | 44 | short_side = w if w < h else h 45 | r = size / short_side * 1.2 46 | target_w = int(w*r+0.5) 47 | target_h = int(h*r+0.5) 48 | img_np = cv2.resize(img_np, (target_w, target_h), interpolation = cv2.INTER_AREA) 49 | if th: 50 | _, img_np = cv2.threshold(img_np,250,255,cv2.THRESH_BINARY) 51 | # center crop image 52 | if crop: 53 | l = (target_w - size)//2 54 | t = (target_h - size)//2 55 | r = (target_w + size)//2 56 | b = (target_h + size)//2 57 | img_np = img_np[t:b, l:r] 58 | return img_np 59 | 60 | def try_split(img_np): 61 | 62 | img_list = [] 63 | 64 | h, w = img_np.shape[:2] 65 | if h >= 2*w: 66 | splition = h // w 67 | for i in range(0, h-h//splition, h//splition): 68 | img_list.append(img_np[i:i+h//splition]) 69 | elif w >= 2*h: 70 | splition = w // h 71 | for i in range(0, w-w//splition, w//splition): 72 | img_list.append(img_np[:,i:i+w//splition]) 73 | else: 74 | img_list.append(img_np) 75 | 76 | return img_list 77 | 78 | def main(): 79 | path_root = "../flatting" 80 | org = "size_org" 81 | 82 | crop_size = 512 83 | size = "size_%d"%crop_size 84 | 85 | path_to_img = join(path_root, org, "line") 86 | path_to_mask = join(path_root, org, "line_detection") 87 | out_path_img = join(path_root, size, "line_croped") 88 | out_path_mask = join(path_root, size, "line_detection_croped") 89 | 90 | counter = 0 91 | for img_name in tqdm(os.listdir(path_to_img)): 92 | 93 | mask_name = img_name 94 | assert exists(join(path_to_mask, mask_name)) 95 | 96 | img_np_th = to_np(join(path_to_img, img_name), th=True) 97 | img_np = to_np(join(path_to_img, img_name)) 98 | mask_np = to_np(join(path_to_mask, mask_name)) 99 | 100 | # remove addtional blank area 101 | bbox = find_bbox(to_point_list(img_np_th)) 102 | img_crop = crop_img(bbox, img_np) 103 | mask_crop = crop_img(bbox, mask_np) 104 | 105 | # detect if a image need split 106 | if False: 107 | img_crop_list = try_split(img_crop) 108 | mask_crop_list = try_split(mask_crop) 109 | assert len(img_crop_list) == len(mask_crop_list) 110 | else: 111 | img_crop_list = [img_crop] 112 | mask_crop_list = [mask_crop] 113 | 114 | # crop and resize each image 115 | for i in range(len(img_crop_list)): 116 | 117 | img = center_crop_resize(img_crop_list[i], crop_size) 118 | mask = center_crop_resize(mask_crop_list[i], crop_size, th=True) 119 | 120 | assert img.shape[:2] == mask.shape[:2] 121 | 122 | if True: 123 | cv2.imwrite(join(out_path_img, "%05d.png"%counter), img) 124 | cv2.imwrite(join(out_path_mask, "%05d.png"%counter), mask) 125 | else: 126 | cv2.imwrite(join(out_path_img, img_name), img) 127 | cv2.imwrite(join(out_path_mask, img_name), mask) 128 | counter += 1 129 | 130 | 131 | 132 | if __name__=="__main__": 133 | main() 134 | -------------------------------------------------------------------------------- /src/flatting_server.py: -------------------------------------------------------------------------------- 1 | # we need to import these modules in the first level, otherwise pyinstaller will not be able to import them 2 | # import sys, os 3 | # import pathlib 4 | # sys.path.append(pathlib.Path(__file__).parent.absolute()/"flatting") 5 | # sys.path.append(pathlib.Path(__file__).parent.absolute()/"flatting"/"trapped_ball") 6 | # from aiohttp import web 7 | # from PIL import Image 8 | # from io import BytesIO 9 | # import numpy as np 10 | # import flatting_api 11 | # import flatting_api_async 12 | # import base64 13 | # import io 14 | # import json 15 | # import asyncio 16 | # import multiprocessing 17 | # import cv2 18 | # import torch 19 | # from pathlib import Path 20 | # from os.path import * 21 | # from run import region_get_map, merge_to_ref, verify_region 22 | # from thinning import thinning 23 | # from predict import predict_img 24 | # from unet import UNet 25 | # import asyncio 26 | # from concurrent.futures import ProcessPoolExecutor 27 | # import functools 28 | 29 | if __name__ == '__main__': 30 | from flatting import app 31 | ## https://docs.python.org/3/library/multiprocessing.html#multiprocessing.freeze_support 32 | if app.MULTIPROCESS: app.multiprocessing.freeze_support() 33 | app.main() 34 | --------------------------------------------------------------------------------