├── .gitignore
├── LICENSE
├── README.md
├── environment.yml
├── flatting_server.spec
├── pyproject.toml
├── requirements.txt
└── src
├── flatting
├── __init__.py
├── __main__.py
├── app.py
├── client_debug.py
├── demo.py
├── dice_loss.py
├── eval.py
├── flatting_api.py
├── flatting_api_async.py
├── hubconf.py
├── predict.py
├── resources
│ ├── __init__.py
│ ├── flatting.icns
│ ├── flatting.ico
│ └── flatting.png
├── submit.py
├── tkapp.py
├── train.py
├── trapped_ball
│ ├── adjacency_matrix.pyx
│ ├── examples
│ │ ├── 01.png
│ │ ├── 01_sim.png
│ │ ├── 02.png
│ │ ├── tiny.png
│ │ └── tiny_sim.png
│ ├── run.py
│ ├── thinning.py
│ ├── thinning_zhang.py
│ └── trappedball_fill.py
├── unet
│ ├── __init__.py
│ ├── unet_model.py
│ └── unet_parts.py
└── utils
│ ├── add_white_background.py
│ ├── data_vis.py
│ ├── dataset.py
│ ├── ground_truth_creation.py
│ ├── move_to_duplicate.py
│ ├── polyvector
│ └── run_all_examples.py
│ └── preprocessing.py
└── flatting_server.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Specific to this project
2 | /data
3 | /src/flatting/checkpoints
4 | /src/flatting.dist-info
5 |
6 | # Byte-compiled / optimized / DLL files
7 | __pycache__/
8 | *.py[cod]
9 | *$py.class
10 |
11 | # OSX useful to ignore
12 | *.DS_Store
13 | .AppleDouble
14 | .LSOverride
15 |
16 | # Thumbnails
17 | ._*
18 |
19 | # Files that might appear in the root of a volume
20 | .DocumentRevisions-V100
21 | .fseventsd
22 | .Spotlight-V100
23 | .TemporaryItems
24 | .Trashes
25 | .VolumeIcon.icns
26 | .com.apple.timemachine.donotpresent
27 |
28 | # Directories potentially created on remote AFP share
29 | .AppleDB
30 | .AppleDesktop
31 | Network Trash Folder
32 | Temporary Items
33 | .apdisk
34 |
35 | # C extensions
36 | *.so
37 |
38 | # Distribution / packaging
39 | .Python
40 | env/
41 | build/
42 | develop-eggs/
43 | dist/
44 | downloads/
45 | eggs/
46 | .eggs/
47 | lib/
48 | lib64/
49 | parts/
50 | sdist/
51 | var/
52 | *.egg-info/
53 | .installed.cfg
54 | *.egg
55 |
56 | # IntelliJ Idea family of suites
57 | .idea
58 | *.iml
59 | ## File-based project format:
60 | *.ipr
61 | *.iws
62 | ## mpeltonen/sbt-idea plugin
63 | .idea_modules/
64 |
65 | # Briefcase build directories
66 | iOS/
67 | macOS/
68 | windows/
69 | android/
70 | linux/
71 | django/
72 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 |
635 | Copyright (C)
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | Copyright (C)
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
675 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Flatting
2 | This project is based on [U-net](https://github.com/milesial/Pytorch-UNet)
3 |
4 | ## Install
5 | ### 1. Install required packages
6 | To run color filling, you need the following module installed:
7 |
8 | - numpy
9 | - opencv-python
10 | - tqdm
11 | - pillow
12 | - cython
13 | - aiohttp
14 | - scikit-image
15 | - torch
16 | - torchvision
17 |
18 | You can install these dependencies via [Anaconda](https://www.anaconda.com/products/individual) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html).
19 | Miniconda is faster to install. (On Windows, choose the 64-bit Python 3.x version. Launch the Anaconda shell from the Start menu and navigate to this directory.)
20 | Then:
21 |
22 | conda env create -f environment.yml
23 | conda activate flatting
24 |
25 | To update an already created environment if the `environment.yml` file changes or to change environments, activate and then run `conda env update --file environment.yml --prune`.
26 |
27 | ### 2. Download pretrained models
28 | Download the [pretrained network model](https://drive.google.com/file/d/1NLooRQ8uZ3ZwQnAYjQAiGhOJqit5Q2_J/view?usp=sharing) and unzip `checkpoints.zip` into `./src/flatting/`.
29 |
30 | ### 3. Run
31 | You can run our backend directly by:
32 |
33 | cd src
34 | python -m flatting
35 |
36 |
37 |
38 | ### 4. Package
39 | If you just want to run the backend only and don't want to touch the code. We provide a [portable backend (Windows only)](https://drive.google.com/file/d/1s9Z5Qgc9siWMu45iOetEUhuzNfJbjbGw/view?usp=sharing) which packaged by the pyinstaller (see sec 4b.) You can download it and unzip to any place, then run:
40 |
41 | cd flatting_server
42 | flatting_server.exe
43 |
44 | ### 4a. Packaging with Briefcase
45 | **Issues:** Although briefcase can output a cleaner package of our backend but it seems also hide the running log as well, we currently don't have a good solution for this issue yet.
46 |
47 | Use `briefcase` [commands](https://docs.beeware.org/en/latest/tutorial/tutorial-1.html) for packaging. Briefcase can't compile Cython modules, so you must first do that. There is only one. Compile it via `cythonize -i src/flatting/trapped_ball/adjacency_matrix.pyx`.
48 |
49 | To start the process, run:
50 |
51 | briefcase create
52 | briefcase build
53 |
54 | To run the standalone program:
55 |
56 | briefcase run
57 |
58 | To create an installer:
59 |
60 | briefcase package
61 |
62 | To update the standalone program when your code or dependencies change:
63 |
64 | briefcase update -r -d
65 |
66 | You can also simply run `briefcase run -u`.
67 |
68 | To debug this process, you can run your code from the entrypoint briefcase uses:
69 |
70 | briefcase dev
71 |
72 | This reveals some issues important to debug. It doesn't reveal dependency issues, because it's not using briefcase's python installation.
73 |
74 | On my setup, I have to manually edit `edit macOS/app/Flatting/Flatting.app/Contents/Resources/app_packages/torch/distributed/rpc/api.py` to insert a line `if docstring is None: continue` after line 443:
75 |
76 | assert docstring is not None, "RRef user-facing methods should all have docstrings."
77 |
78 | ### 4b. Packaging with pyinstaller
79 |
80 | If briefcase doesn't work, you can use [pyinstaller](https://www.pyinstaller.org/):
81 |
82 | pyinstaller --noconfirm flatting_server.spec
83 |
84 | ### 5. Install Photoshop plugin
85 | Download the [flatting plugin](https://drive.google.com/file/d/1HivdqU2Z2dIL2MvqzEYmCLO2_nDL2Cnk/view?usp=sharing) and unzip it to any place.
86 | Download the backend server by following the instructions inside the "flatting plugin.zip"
87 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: flatting
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - numpy
8 | - tqdm
9 | - pillow
10 | - py-opencv
11 | - aiohttp
12 | - scikit-image
13 | - pytorch
14 | - torchvision
15 | - appdirs
16 | - pyinstaller
17 | - pip
18 | - pip:
19 | - briefcase
20 | prefix: /Users/anonymous/opt/miniconda3/envs/flatting
21 |
--------------------------------------------------------------------------------
/flatting_server.spec:
--------------------------------------------------------------------------------
1 | # -*- mode: python ; coding: utf-8 -*-
2 | import os.path
3 | block_cipher = None
4 | a = Analysis(['src/flatting_server.py'],
5 | ## pyinstaller iheartla.spec must be run from
6 | ## ???
7 | pathex=[os.path.abspath(os.getcwd())],
8 | binaries=[],
9 | datas=[('src/flatting/checkpoints','checkpoints')],
10 | hiddenimports=[],
11 | hookspath=[],
12 | runtime_hooks=[],
13 | excludes=[],
14 | win_no_prefer_redirects=False,
15 | win_private_assemblies=False,
16 | cipher=block_cipher,
17 | noarchive=False)
18 | pyz = PYZ(a.pure, a.zipped_data,
19 | cipher=block_cipher)
20 | exe = EXE(pyz,
21 | a.scripts,
22 | [],
23 | exclude_binaries=True,
24 | name='flatting_server',
25 | debug=False,
26 | bootloader_ignore_signals=False,
27 | strip=False,
28 | upx=True,
29 | console=True,
30 | icon='src/flatting/resources/flatting.ico' )
31 | coll = COLLECT(exe,
32 | a.binaries,
33 | a.zipfiles,
34 | a.datas,
35 | strip=False,
36 | upx=True,
37 | upx_exclude=[],
38 | name='flatting_server')
39 | app = BUNDLE(coll,
40 | name='flatting_server.app',
41 | icon='src/flatting/resources/flatting.icns',
42 | bundle_identifier=None,
43 | info_plist={'NSHighResolutionCapable': 'True'}
44 | )
45 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.briefcase]
2 | project_name = "Flatting"
3 | bundle = "edu.gmu.cs.cragl.flatting"
4 | version = "0.0.1"
5 | url = "https://cragl.cs.gmu.edu/flatting"
6 | license = "Proprietary"
7 | author = 'Chuan Yan'
8 | author_email = "cyan3@gmu.edu"
9 |
10 | [tool.briefcase.app.flatting]
11 | formal_name = "Flatting"
12 | description = "Back-end for Photoshop flatting plugin"
13 | icon = "src/flatting/resources/flatting"
14 | sources = ['src/flatting']
15 | requires = [
16 | "numpy",
17 | "tqdm",
18 | "pillow",
19 | "opencv-python-headless",
20 | "aiohttp",
21 | "scikit-image",
22 | "torch",
23 | "torchvision"
24 | ]
25 |
26 | [tool.briefcase.app.flatting.macOS]
27 | requires = []
28 |
29 | [tool.briefcase.app.flatting.linux]
30 | requires = []
31 | system_requires = []
32 |
33 | [tool.briefcase.app.flatting.windows]
34 | requires = []
35 |
36 | # Mobile deployments
37 | [tool.briefcase.app.flatting.iOS]
38 | requires = []
39 |
40 | [tool.briefcase.app.flatting.android]
41 | requires = []
42 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib
2 | numpy
3 | Pillow
4 | torch
5 | torchvision
6 | tensorboard
7 | future
8 | tqdm
9 |
--------------------------------------------------------------------------------
/src/flatting/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nauhcnay/flat_magic_backend/5344f11c7a50c0a5b0d0876dcf68aa45b5a84687/src/flatting/__init__.py
--------------------------------------------------------------------------------
/src/flatting/__main__.py:
--------------------------------------------------------------------------------
1 | from . import app
2 |
3 | if __name__ == '__main__':
4 | ## https://docs.python.org/3/library/multiprocessing.html#multiprocessing.freeze_support
5 | if app.MULTIPROCESS: app.multiprocessing.freeze_support()
6 | app.main()
7 |
--------------------------------------------------------------------------------
/src/flatting/app.py:
--------------------------------------------------------------------------------
1 | from aiohttp import web
2 | from PIL import Image
3 | from io import BytesIO
4 | from datetime import datetime
5 | from os.path import join, exists
6 |
7 | import appdirs
8 | import numpy as np
9 | from . import flatting_api
10 | import base64
11 | import os
12 | import io
13 | import json
14 | import asyncio
15 | import multiprocessing
16 |
17 | MULTIPROCESS = True
18 | LOG = True
19 |
20 | if MULTIPROCESS:
21 | # Importing this module creates multiprocessing pools, which is problematic
22 | # in Briefcase and PyInstaller on macOS.
23 | from . import flatting_api_async
24 |
25 | routes = web.RouteTableDef()
26 |
27 | @routes.get('/')
28 | # seems the function name is not that important?
29 | async def hello(request):
30 | return web.Response(text="Flatting API server is running")
31 |
32 | ## Add more API entry points
33 | @routes.post('/flatsingle')
34 | async def flatsingle( request ):
35 | data = await request.json()
36 | try:
37 | data = json.loads(data)
38 | except:
39 | print("got dict directly")
40 |
41 | # convert to json
42 | img = to_pil(data['image'])
43 | net = str(data['net'])
44 | radii = int(data['radius'])
45 | resize = data['resize']
46 | if 'userName' in data:
47 | user = data['userName']
48 | img_name = data['fileName']
49 | if resize:
50 | w_new, h_new = data["newSize"]
51 | else:
52 | w_new = None
53 | h_new = None
54 |
55 | if MULTIPROCESS:
56 | flatted = await flatting_api_async.run_single(img, net, radii, resize, w_new, h_new, img_name)
57 | else:
58 | flatted = flatting_api.run_single(img, net, radii, resize, w_new, h_new, img_name)
59 |
60 | result = {}
61 | result['line_artist'] = to_base64(flatted['line_artist'])
62 | result['line_hint'] = to_base64(flatted['line_hint'])
63 | result['line_simplified'] = to_base64(flatted['line_simplified'])
64 | result['image'] = to_base64(flatted['fill_color'])
65 | result['fill_artist'] = to_base64(flatted['components_color'])
66 |
67 | if LOG:
68 | now = datetime.now()
69 | save_to_log(now, flatted['line_artist'], user, img_name, "line_artist", "flat")
70 | save_to_log(now, flatted['line_hint'], user, img_name, "line_hint", "flat")
71 | save_to_log(now, flatted['line_simplified'], user, img_name, "line_simplified", "flat")
72 | save_to_log(now, flatted['fill_color'], user, img_name, "fill_color", "flat")
73 | save_to_log(now, flatted['components_color'], user, img_name, "fill_color_floodfill", "flat")
74 | save_to_log(now, flatted['fill_color_neural'], user, img_name, "fill_color_neural", "flat")
75 | save_to_log(now, flatted['line_neural'], user, img_name, "line_neural", "flat")
76 | print("Log:\tlogs saved")
77 | return web.json_response( result )
78 |
79 | @routes.post('/merge')
80 | async def merge( request ):
81 | data = await request.json()
82 | try:
83 | data = json.loads(data)
84 | except:
85 | print("got dict directly")
86 |
87 | line_artist = to_pil(data['line_artist'])
88 | fill_neural = np.array(to_pil(data['fill_neural']))
89 | fill_artist = np.array(to_pil(data['fill_artist']))
90 | stroke = to_pil(data['stroke'])
91 | if 'userName' in data:
92 | user = data['userName']
93 | img_name = data['fileName']
94 | # palette = np.array(data['palette'])
95 |
96 | if MULTIPROCESS:
97 | merged = await flatting_api_async.merge(fill_neural, fill_artist, stroke, line_artist)
98 | else:
99 | merged = flatting_api.merge(fill_neural, fill_artist, stroke, line_artist)
100 |
101 | result = {}
102 | result['image'] = to_base64(merged['fill_color'])
103 | result['line_simplified'] = to_base64(merged['line_simplified'])
104 | if LOG:
105 | now = datetime.now()
106 | save_to_log(now, merged['line_simplified'], user, img_name, "line_simplified", "merge")
107 | save_to_log(now, merged['fill_color'], user, img_name, "fill_color", "merge")
108 | save_to_log(now, stroke, user, img_name, "merge_stroke", "merge")
109 | save_to_log(now, fill_artist, user, img_name, "fill_color_floodfill", "merge")
110 | print("Log:\tlogs saved")
111 |
112 | return web.json_response(result)
113 |
114 | @routes.post('/splitmanual')
115 | async def split_manual( request ):
116 | data = await request.json()
117 | try:
118 | data = json.loads(data)
119 | except:
120 | print("got dict directly")
121 |
122 | fill_neural = np.array(to_pil(data['fill_neural']))
123 | fill_artist = np.array(to_pil(data['fill_artist']))
124 | stroke = np.array(to_pil(data['stroke']))
125 | line_artist = to_pil(data['line_artist'])
126 | add_only = data['mode']
127 | if 'userName' in data:
128 | user = data['userName']
129 | img_name = data['fileName']
130 |
131 | if MULTIPROCESS:
132 | splited = await flatting_api_async.split_manual(fill_neural, fill_artist, stroke, line_artist, add_only)
133 | else:
134 | splited = flatting_api.split_manual(fill_neural, fill_artist, stroke, line_artist, add_only)
135 |
136 | result = {}
137 | result['line_artist'] = to_base64(splited['line_artist'])
138 | result['line_simplified'] = to_base64(splited['line_neural'])
139 | result['image'] = to_base64(splited['fill_color'])
140 | result['fill_artist'] = to_base64(splited['fill_artist'])
141 | result['line_hint'] = to_base64(splited['line_hint'])
142 |
143 | if LOG:
144 | now = datetime.now()
145 | save_to_log(now, splited['line_neural'], user, img_name, "line_simplified", "split_%s"%str(add_only))
146 | save_to_log(now, splited['line_artist'], user, img_name, "line_artist", "split_%s"%str(add_only))
147 | save_to_log(now, splited['fill_color'], user, img_name, "fill_color", "split_%s"%str(add_only))
148 | save_to_log(now, stroke, user, img_name, "split_stroke", "split_%s"%str(add_only))
149 | save_to_log(now, splited['fill_artist'], user, img_name, "fill_color_floodfill", "split_%s"%str(add_only))
150 | save_to_log(now, splited['line_hint'], user, img_name, "line_hint", "split_%s"%str(add_only))
151 | print("Log:\tlogs saved")
152 | return web.json_response(result)
153 |
154 | @routes.post('/flatlayers')
155 | async def export_fill_to_layers( request ):
156 | data = await request.json()
157 | try:
158 | data = json.loads(data)
159 | except:
160 | print("got dict directly")
161 |
162 | img = np.array(to_pil(data['fill_neural']))
163 | layers = flatting_api.export_layers(img)
164 | layers_base64 = []
165 | for layer in layers["layers"]:
166 | layers_base64.append(to_base64(layer))
167 |
168 | result = {}
169 | result["layersImage"] = layers_base64
170 |
171 | return web.json_response(result)
172 |
173 | def to_base64(array):
174 | '''
175 | A helper function to convert numpy array to png in base64 format
176 | '''
177 | with io.BytesIO() as output:
178 | if type(array) == np.ndarray:
179 | Image.fromarray(array).save(output, format='png')
180 | else:
181 | array.save(output, format='png')
182 | img = output.getvalue()
183 | img = base64.encodebytes(img).decode("utf-8")
184 | return img
185 |
186 | def to_pil(byte):
187 | '''
188 | A helper function to convert byte png to PIL.Image
189 | '''
190 | byte = base64.b64decode(byte)
191 | return Image.open(BytesIO(byte))
192 |
193 | def save_to_log(date, data, user, img_name, save_name, op):
194 | log_dir = appdirs.user_log_dir( "Flatting Server", "CraGL" )
195 | save_folder = "[%s][%s][%s_%s]"%(user, str(date.strftime("%d-%b-%Y %H-%M-%S")), img_name, op)
196 | save_folder = join( log_dir, save_folder)
197 | try:
198 | if exists(save_folder) == False:
199 | os.makedirs(save_folder)
200 | if type(data) == np.ndarray:
201 | Image.fromarray(data).save(join(save_folder, "%s.png"%save_name))
202 | else:
203 | data.save(join(save_folder, "%s.png"%save_name))
204 | except:
205 | print("Warning:\tsave log failed!")
206 |
207 | def main():
208 | '''
209 | import traceback
210 | with open('/Users/yotam/Work/GMU/flatting/code/log.txt','a') as f:
211 | f.write('=================================================================\n')
212 | f.write('__name__: %s\n' % __name__)
213 | traceback.print_stack(file=f)
214 | '''
215 | app = web.Application(client_max_size = 1024 * 1024 ** 2)
216 | app.add_routes(routes)
217 | web.run_app(app)
218 |
219 | ## From JavaScript:
220 | # let result = await fetch( url_of_server.py, { method: 'POST', body: JSON.stringify(data) } ).json();
221 |
222 | if __name__ == '__main__':
223 | main()
224 |
--------------------------------------------------------------------------------
/src/flatting/client_debug.py:
--------------------------------------------------------------------------------
1 | # need to write some test case
2 | import requests
3 | import base64
4 | import io
5 | import json
6 | import os
7 |
8 | from os.path import *
9 | from io import BytesIO
10 | from PIL import Image
11 |
12 | url = "http://jixuanzhi.asuscomm.com:8080/"
13 | image = "./trapped_ball/examples/01.png"
14 | # image = "test1.png"
15 | run_single_test = "flatsingle"
16 | run_multi_test = "flatmultiple"
17 | merge_test = "merge"
18 | split_auto_test = "splitauto"
19 | split_manual_test = "splitmanual"
20 | show_fillmap_test = "showfillmap"
21 |
22 | def png_to_base64(path_to_img):
23 | with open(path_to_img, 'rb') as f:
24 | return base64.encodebytes(f.read()).decode("utf-8")
25 |
26 | def test_case1():
27 | # case for run single test
28 | data = {}
29 | data['image'] = png_to_base64(image)
30 | data['net'] = '1024_base'
31 | data['radius'] = 1
32 | data['preview'] = False
33 |
34 | # convert to json
35 | result = requests.post(url+run_single_test, json = json.dumps(data))
36 | if result.status_code == 200:
37 | result = result.json()
38 | import pdb
39 | pdb.set_trace()
40 | line_sim = to_pil(result['line_artist'])
41 | line_sim.show()
42 | os.system("pause")
43 |
44 | line_sim = to_pil(result['image'])
45 | line_sim.show()
46 | os.system("pause")
47 | line_sim = to_pil(result['image_c'])
48 | line_sim.show()
49 | os.system("pause")
50 | line_sim = to_pil(result['line_simplified'])
51 | line_sim.show()
52 | os.system("pause")
53 |
54 | else:
55 | raise ValueError("Test failed")
56 |
57 | print("Done")
58 |
59 | def to_pil(byte):
60 | '''
61 | A helper function to convert byte png to PIL.Image
62 | '''
63 | byte = base64.b64decode(byte)
64 | return Image.open(BytesIO(byte))
65 |
66 | test_case1()
--------------------------------------------------------------------------------
/src/flatting/demo.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import argparse
3 | from os.path import *
4 | sys.path.append(join(dirname(abspath(__file__)), "trapped_ball"))
5 |
6 |
7 | import gradio as gr
8 | import numpy as np
9 | import torch
10 | import random
11 |
12 | from PIL import Image
13 | from torchvision import transforms as T
14 | from torchvision import utils
15 |
16 | # import model
17 | from unet import UNet
18 | from predict import predict_img
19 |
20 | # import trapped ball filling func
21 | from run import region_get_map
22 |
23 | from functools import partial
24 | from zipfile import ZipFile
25 |
26 | def to_t(array):
27 | return torch.Tensor(array).cuda().unsqueeze(0)
28 |
29 |
30 | def to_tensor(img):
31 |
32 | img_t = (
33 | torch.from_numpy(img).unsqueeze(-1)
34 | .to(torch.float32)
35 | .div(255)
36 | .add_(-0.5)
37 | .mul_(2)
38 | .permute(2, 0, 1)
39 | )
40 | return img_t.unsqueeze(0).cuda()
41 |
42 | # def denormalize(img):
43 | # # denormalize
44 | # inv_normalize = T.Normalize( mean=[-1, -1, -1], std=[2, 2, 2])
45 |
46 | # img_np = inv_normalize(img.repeat(3,1,1))
47 | # img_np = (img_np * 255).clamp(0, 255)
48 |
49 | # # to numpy
50 | # img_np = img_np.cpu().numpy().transpose((1,2,0))
51 |
52 | # return Image.fromarray(img_np.astype(np.uint8))
53 |
54 | def zip_files(files):
55 | with ZipFile("./flatting/gradio/all.zip", 'w') as zipObj:
56 | for f in files:
57 | zipObj.write(f)
58 | return "./flatting/gradio/all.zip"
59 |
60 | def split_to_4(img):
61 |
62 | # now I just write a simple code to split images into 4 evenly
63 | w, h = img.size
64 | h1 = h // 2
65 | w1 = w // 2
66 | img = np.array(img)
67 |
68 | # top left
69 | img1 = Image.fromarray(img[:h1, :w1])
70 |
71 | # top right
72 | img2 = Image.fromarray(img[:h1, w1:])
73 |
74 | # bottom left
75 | img3 = Image.fromarray(img[h1:, :w1])
76 |
77 | # bottom right
78 | img4 = Image.fromarray(img[h1:, w1:])
79 |
80 | return img1, img2, img3, img4
81 |
82 | def merge_to_1(imgs):
83 |
84 | img1, img2, img3, img4 = imgs
85 | img_top = np.concatenate((img1, img2), axis = 1)
86 | img_bottom = np.concatenate((img3, img4), axis = 1)
87 |
88 | return np.concatenate((img_top, img_bottom), axis = 0)
89 |
90 | def pred_and_fill(img, op, radius, patch, nets, outputs="./flatting/gradio"):
91 |
92 | # initail out files
93 | outs = []
94 | outs.append(join(outputs, "%s_input.png"%op))
95 | outs.append(join(outputs, "%s_fill.png"%op))
96 | outs.append(join(outputs, "%s_fill_edge.png"%op))
97 | outs.append(join(outputs, "%s_fill_line.png"%op))
98 | outs.append(join(outputs, "%s_fill_line_full.png"%op))
99 |
100 |
101 | # predict full image
102 | # img = cv2.threshold(img, 240, 255, cv2.THRESH_BINARY)
103 | if patch == "False":
104 | # img_w = Image.new("RGBA", img.size, "WHITE")
105 | # try:
106 | # img_w.paste(img, None, img)
107 | # img = img_w.convert("L")
108 | # except:
109 | # print("Log:\tfailed to add white background")
110 |
111 | edge = predict_img(net=nets[op][0],
112 | full_img=img,
113 | device=nets[op][1],
114 | size = int(op.replace("_rand", "")))
115 | else:
116 | print("Log:\tsplit input into 4 patch with model %s"%(op))
117 | # cut image into non-overlapping patches
118 | imgs = split_to_4(img)
119 |
120 | edges = []
121 | for patch in imgs:
122 | edge = predict_img(net=nets[op][0],
123 | full_img=patch,
124 | device=nets[op][1],
125 | size = int(op))
126 |
127 | edges.append(np.array(edge))
128 |
129 | edge = Image.fromarray(merge_to_1(edges))
130 |
131 | print("Log:\ttrapping ball filling with radius %s"%radius)
132 | fill = region_get_map(edge.convert("L"),
133 | radius_set=[int(radius)], percentiles=[0],
134 | path_to_line_artist=img,
135 | return_numpy=True,
136 | preview = True)
137 |
138 | return edge, fill
139 |
140 | def initial_models(path_to_ckpt):
141 |
142 | # find the lastest model
143 | ckpt_list = []
144 |
145 | if ".pth" not in path_to_ckpt:
146 | for c in os.listdir(path_to_ckpt):
147 | if ".pth" in c:
148 | ckpt_list.append(c)
149 | ckpt_list.sort()
150 | path_to_ckpt = join(path_to_ckpt, ckpt_list[-1])
151 |
152 | assert exists(path_to_ckpt)
153 |
154 | # init model
155 | net = UNet(in_channels=1, out_channels=1, bilinear=True)
156 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
157 | net.to(device=device)
158 |
159 | # load model
160 | print("Log:\tload %s"%path_to_ckpt)
161 | try:
162 | net.load_state_dict(torch.load(path_to_ckpt, map_location=device))
163 | except:
164 | net = torch.nn.DataParallel(net)
165 | net.load_state_dict(torch.load(path_to_ckpt, map_location=device))
166 | net.eval()
167 |
168 | return net, device
169 |
170 | def initial_flatting_input():
171 |
172 | # inputs
173 | img = gr.inputs.Image(image_mode='L',
174 | invert_colors=False, source="upload", label="Input Image",
175 | type = "pil")
176 | # resize = gr.inputs.Radio(choices=["1024", "512", "256"], label="Resize")
177 | model = gr.inputs.Radio(choices=["1024", "1024_rand", "512", "512_rand"], label="Model")
178 | # split = gr.inputs.Radio(choices=["True", "False"], label="Split")
179 | radius = gr.inputs.Slider(minimum=1, maximum=10, step=1, default=7, label="kernel radius")
180 |
181 | # outputs
182 | out1 = gr.outputs.Image(type='pil', label='line prediction')
183 | out2 = gr.outputs.Image(type='pil', label='fill')
184 | # out5 = gr.outputs.File(label="all results")
185 |
186 | return [img, model, radius], [out1, out2]
187 | # return [img, resize], [out1, out2, out3, out4, out5]
188 |
189 | def start_demo(fn, inputs, outputs, examples):
190 | iface = gr.Interface(fn = fn, inputs = inputs, outputs = outputs, examples = examples, layout = "unaligned")
191 | iface.launch()
192 |
193 | def main():
194 |
195 | # get base tcode number
196 | parser = argparse.ArgumentParser()
197 | parser.add_argument("--ckpt-1024", type=str, default = "./checkpoints/base_1024/")
198 | parser.add_argument("--ckpt-512", type=str, default = "./checkpoints/base_512/")
199 | # parser.add_argument("--ckpt-256", type=str, default = "./checkpoints/base_256/")
200 | parser.add_argument("--ckpt-512-rand", type=str, default = "./checkpoints/rand_512/")
201 | # parser.add_argument("--ckpt-256-rand", type=str, default = "./checkpoints/rand_256/")
202 | parser.add_argument("--ckpt-1024-rand", type=str, default = "./checkpoints/rand_1024/")
203 |
204 | args = parser.parse_args()
205 |
206 | # initailize
207 | nets = {}
208 | nets["1024"] = initial_models(args.ckpt_1024)
209 | nets["1024_rand"] = initial_models(args.ckpt_1024_rand)
210 | nets["512"] = initial_models(args.ckpt_512)
211 | nets["512_rand"] = initial_models(args.ckpt_512_rand)
212 | # nets["256"] = initial_models(args.ckpt_256)
213 | # nets["256_rand"] = initial_models(args.ckpt_256_rand)
214 |
215 |
216 | # construct exmaples
217 | example_path = "./flatting/validation"
218 | example_list = os.listdir(example_path)
219 | example_list.sort()
220 |
221 | examples = []
222 |
223 | for file in example_list:
224 | print("find %s"%file)
225 | img = os.path.join(example_path, file)
226 | model = random.choice(["512_rand"])
227 | radius = 2
228 | examples.append([img, model, radius])
229 |
230 | # initial pred func
231 | fn = partial(pred_and_fill, nets=nets, patch="False", outputs="./flatting/gradio")
232 |
233 | # bug fix
234 | fn.__name__ = fn.func.__name__
235 |
236 | # start
237 | inputs, outputs = initial_flatting_input()
238 | start_demo(fn=fn, inputs=inputs, outputs=outputs, examples=examples)
239 |
240 | def debug():
241 | # get base tcode number
242 | parser = argparse.ArgumentParser()
243 | parser.add_argument("--ckpt-1024", type=str, default = "./checkpoints/base_1024/")
244 | parser.add_argument("--ckpt-512", type=str, default = "./checkpoints/base_512/")
245 | # parser.add_argument("--ckpt-256", type=str, default = "./checkpoints/base_256/")
246 | parser.add_argument("--ckpt-512-rand", type=str, default = "./checkpoints/rand_512/")
247 | # parser.add_argument("--ckpt-256-rand", type=str, default = "./checkpoints/rand_256/")
248 | parser.add_argument("--ckpt-1024-rand", type=str, default = "./checkpoints/rand_1024/")
249 | args = parser.parse_args()
250 |
251 | # initailize
252 | nets = {}
253 | nets["1024"] = initial_models(args.ckpt_1024)
254 | nets["1024_rand"] = initial_models(args.ckpt_1024_rand)
255 | nets["512"] = initial_models(args.ckpt_512)
256 | nets["512_rand"] = initial_models(args.ckpt_512_rand)
257 | # nets["256"] = initial_models(args.ckpt_256)
258 | # nets["256_rand"] = initial_models(args.ckpt_256_rand)
259 |
260 |
261 |
262 | img = Image.open("./flatting/validation/train_008.png").convert("L")
263 | pred_and_fill(img, radius=2, op='512_rand', patch="False", nets=nets, outputs="./flatting/gradio")
264 |
265 | if __name__ == '__main__':
266 | main()
267 | # debug()
268 |
--------------------------------------------------------------------------------
/src/flatting/dice_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Function
3 |
4 |
5 | class DiceCoeff(Function):
6 | """Dice coeff for individual examples"""
7 |
8 | def forward(self, input, target):
9 | self.save_for_backward(input, target)
10 | eps = 0.0001
11 | self.inter = torch.dot(input.view(-1), target.view(-1))
12 | self.union = torch.sum(input) + torch.sum(target) + eps
13 |
14 | t = (2 * self.inter.float() + eps) / self.union.float()
15 | return t
16 |
17 | # This function has only a single output, so it gets only one gradient
18 | def backward(self, grad_output):
19 |
20 | input, target = self.saved_variables
21 | grad_input = grad_target = None
22 |
23 | if self.needs_input_grad[0]:
24 | grad_input = grad_output * 2 * (target * self.union - self.inter) \
25 | / (self.union * self.union)
26 | if self.needs_input_grad[1]:
27 | grad_target = None
28 |
29 | return grad_input, grad_target
30 |
31 |
32 | def dice_coeff(input, target):
33 | """Dice coeff for batches"""
34 | if input.is_cuda:
35 | s = torch.FloatTensor(1).cuda().zero_()
36 | else:
37 | s = torch.FloatTensor(1).zero_()
38 |
39 | for i, c in enumerate(zip(input, target)):
40 | s = s + DiceCoeff().forward(c[0], c[1])
41 |
42 | return s / (i + 1)
43 |
--------------------------------------------------------------------------------
/src/flatting/eval.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from tqdm import tqdm
4 |
5 | from dice_loss import dice_coeff
6 |
7 |
8 | # 对,这个函数也要修改
9 |
10 | def eval_net(net, loader, device):
11 | """Evaluation without the densecrf with the dice coefficient"""
12 | net.eval()
13 | mask_type = torch.float32 if net.n_classes == 1 else torch.long
14 | n_val = len(loader) # the number of batch
15 | tot = 0
16 |
17 | with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar:
18 | for batch in loader:
19 | imgs, true_masks = batch['image'], batch['mask']
20 | imgs = imgs.to(device=device, dtype=torch.float32)
21 | true_masks = true_masks.to(device=device, dtype=mask_type)
22 |
23 | with torch.no_grad():
24 | mask_pred = net(imgs)
25 |
26 | if net.n_classes > 1:
27 | tot += F.cross_entropy(mask_pred, true_masks).item()
28 | else:
29 | pred = torch.sigmoid(mask_pred)
30 | pred = (pred > 0.5).float()
31 | tot += dice_coeff(pred, true_masks).item()
32 | pbar.update()
33 |
34 | net.train()
35 | return tot / n_val
36 |
--------------------------------------------------------------------------------
/src/flatting/flatting_api_async.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from concurrent.futures import ProcessPoolExecutor
3 | import functools
4 |
5 | from . import flatting_api
6 |
7 | ## This controls the number of parallel processes.
8 | ## Keep in mind that parallel processes will load duplicate networks
9 | ## and compete for the same RAM, which could lead to thrashing.
10 | ## Pass `max_workers = N` for exactly `N` parallel processes.
11 | # import multiprocessing
12 | # HALF_CORES = max( multiprocessing.cpu_count()//2, 1 ) )
13 | executor_batch = ProcessPoolExecutor(4)
14 | executor_interactive = ProcessPoolExecutor(4)
15 |
16 | async def run_async( executor, f ):
17 | ## We expect this to be called from inside an existing loop.
18 | ## As a result, we call `get_running_loop()` instead of `get_event_loop()` so that
19 | ## it raises an error if our assumption is false, rather than creating a new loop.
20 | loop = asyncio.get_running_loop()
21 | data = await loop.run_in_executor( executor, f )
22 | return data
23 |
24 | async def run_single( *args, **kwargs ):
25 | return await run_async( executor_batch, functools.partial( flatting_api.run_single, *args, **kwargs ) )
26 |
27 | async def merge( *args, **kwargs ):
28 | return await run_async( executor_interactive, functools.partial( flatting_api.merge, *args, **kwargs ) )
29 |
30 | async def split_manual( *args, **kwargs ):
31 | return await run_async( executor_interactive, functools.partial( flatting_api.split_manual, *args, **kwargs ) )
32 |
--------------------------------------------------------------------------------
/src/flatting/hubconf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from unet import UNet as _UNet
3 |
4 | def unet_carvana(pretrained=False):
5 | """
6 | UNet model trained on the Carvana dataset ( https://www.kaggle.com/c/carvana-image-masking-challenge/data ).
7 | Set the scale to 1 (100%) when predicting.
8 | """
9 | net = _UNet(n_channels=3, n_classes=1, bilinear=True)
10 | if pretrained:
11 | checkpoint = 'https://github.com/milesial/Pytorch-UNet/releases/download/v1.0/unet_carvana_scale1_epoch5.pth'
12 | net.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=True))
13 |
14 | return net
15 |
16 |
--------------------------------------------------------------------------------
/src/flatting/predict.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 |
4 | import os
5 | # import os, sys
6 | from os.path import join
7 | # sys.path.append(join(dirname(abspath(__file__)), "trapped_ball"))
8 |
9 | import numpy as np
10 | import torch
11 | import torch.nn.functional as F
12 | import cv2
13 |
14 | from PIL import Image
15 | from torchvision import transforms as T
16 |
17 | from .unet import UNet
18 | from .utils.preprocessing import to_point_list, find_bbox, crop_img
19 | from .trapped_ball.run import region_get_map
20 |
21 |
22 | def to_tensor(img):
23 |
24 | transforms = T.Compose(
25 | [
26 | # to tensor will change the channel order and divide 255 if necessary
27 | T.ToTensor(),
28 | T.Normalize(0.5, 0.5, inplace = True)
29 | ]
30 | )
31 |
32 | return transforms(img)
33 |
34 | def denormalize(img):
35 | # denormalize
36 | inv_normalize = T.Normalize( mean=-1, std=2)
37 |
38 | img_np = inv_normalize(img.repeat(3,1,1)).clamp(0, 1)
39 | img_np = img_np * 255
40 |
41 | # to numpy
42 | img_np = img_np.cpu().numpy().transpose((1,2,0))
43 |
44 | return Image.fromarray(img_np.astype(np.uint8)).convert("L")
45 |
46 | def to_numpy(f, size, bbox=None):
47 |
48 | if type(f) == str:
49 | img = np.array(Image.open(f).convert("L"))
50 |
51 | else:
52 | img = np.array(f.convert("L"))
53 |
54 | if bbox != None:
55 | img = crop_img(bbox, img)
56 |
57 | h, w = img.shape
58 | ratio = size/w if w < h else size/h
59 |
60 | return cv2.resize(img, (int(w*ratio+0.5), int(h*ratio+0.5)), interpolation=cv2.INTER_AREA)
61 |
62 | def predict_img(net,
63 | full_img,
64 | device,
65 | size):
66 | net.eval()
67 |
68 | # corp image
69 | bbox = find_bbox(to_point_list(np.array(full_img)))
70 |
71 | # read image
72 | print("Log:\tpredict image at size %d"%size)
73 | img = to_tensor(to_numpy(full_img, size, bbox = bbox))
74 | img = img.unsqueeze(0)
75 | img = img.to(device=device, dtype=torch.float32)
76 |
77 | with torch.no_grad():
78 | output = net(img)
79 |
80 | output = denormalize(output[0])
81 |
82 | return output, bbox
83 |
84 |
85 | def get_args():
86 | parser = argparse.ArgumentParser(description='Predict edge from line art',
87 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
88 |
89 | parser.add_argument('--model', '-m', default='./checkpoints/exp1/CP_epoch2001.pth',
90 | metavar='FILE',
91 | help="Specify the file in which the model is stored")
92 |
93 | parser.add_argument('--input', '-i', type=str,
94 | help='filename of single input image, include path')
95 |
96 | parser.add_argument('--output', '-o', type=str,
97 | help='filename of single ouput image, include path')
98 |
99 | parser.add_argument('--input-path', type=str, default="./flatting/validation",
100 | help='path to input images')
101 |
102 | parser.add_argument('--output-path', type=str, default="./results/val",
103 | help='path to ouput images')
104 |
105 | return parser.parse_args()
106 |
107 |
108 | if __name__ == "__main__":
109 |
110 | args = get_args()
111 |
112 | in_files = args.input
113 |
114 | net = UNet(in_channels=1, out_channels=1, bilinear=True)
115 |
116 | logging.info("Loading model {}".format(args.model))
117 |
118 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
119 | logging.info(f'Using device {device}')
120 |
121 | net.to(device=device)
122 | net.load_state_dict(torch.load(args.model, map_location=device))
123 |
124 | logging.info("Model loaded !")
125 |
126 | for f in os.listdir(args.input_path):
127 | name, _ = splitext(f)
128 |
129 | logging.info("\nPredicting image {} ...".format(join(args.input_path, f)))
130 |
131 |
132 | # predict edge and save image
133 | edge = predict_img(net=net,
134 | full_img=join(args.input_path, f),
135 | device=device,
136 | size = 1024)
137 |
138 | edge.save(join(args.output_path, name + "_pred.png"))
139 |
140 | # trapped ball fill and save image
141 | region_get_map(join(args.output_path, name + "_pred.png"), args.output_path,
142 | radius_set=[1], percentiles=[0],
143 | path_to_line = join(args.input_path, f),
144 | save_org_size = True)
145 |
--------------------------------------------------------------------------------
/src/flatting/resources/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nauhcnay/flat_magic_backend/5344f11c7a50c0a5b0d0876dcf68aa45b5a84687/src/flatting/resources/__init__.py
--------------------------------------------------------------------------------
/src/flatting/resources/flatting.icns:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nauhcnay/flat_magic_backend/5344f11c7a50c0a5b0d0876dcf68aa45b5a84687/src/flatting/resources/flatting.icns
--------------------------------------------------------------------------------
/src/flatting/resources/flatting.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nauhcnay/flat_magic_backend/5344f11c7a50c0a5b0d0876dcf68aa45b5a84687/src/flatting/resources/flatting.ico
--------------------------------------------------------------------------------
/src/flatting/resources/flatting.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nauhcnay/flat_magic_backend/5344f11c7a50c0a5b0d0876dcf68aa45b5a84687/src/flatting/resources/flatting.png
--------------------------------------------------------------------------------
/src/flatting/submit.py:
--------------------------------------------------------------------------------
1 | """ Submit code specific to the kaggle challenge"""
2 |
3 | import os
4 |
5 | import torch
6 | from PIL import Image
7 | import numpy as np
8 |
9 | from predict import predict_img
10 | from unet import UNet
11 |
12 | # credits to https://stackoverflow.com/users/6076729/manuel-lagunas
13 | def rle_encode(mask_image):
14 | pixels = mask_image.flatten()
15 | # We avoid issues with '1' at the start or end (at the corners of
16 | # the original image) by setting those pixels to '0' explicitly.
17 | # We do not expect these to be non-zero for an accurate mask,
18 | # so this should not harm the score.
19 | pixels[0] = 0
20 | pixels[-1] = 0
21 | runs = np.where(pixels[1:] != pixels[:-1])[0] + 2
22 | runs[1::2] = runs[1::2] - runs[:-1:2]
23 | return runs
24 |
25 |
26 | def submit(net):
27 | """Used for Kaggle submission: predicts and encode all test images"""
28 | dir = 'data/test/'
29 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
30 | N = len(list(os.listdir(dir)))
31 | with open('SUBMISSION.csv', 'a') as f:
32 | f.write('img,rle_mask\n')
33 | for index, i in enumerate(os.listdir(dir)):
34 | print('{}/{}'.format(index, N))
35 |
36 | img = Image.open(dir + i)
37 |
38 | mask = predict_img(net, img, device)
39 | enc = rle_encode(mask)
40 | f.write('{},{}\n'.format(i, ' '.join(map(str, enc))))
41 |
42 |
43 | if __name__ == '__main__':
44 | net = UNet(3, 1).cuda()
45 | net.load_state_dict(torch.load('MODEL.pth'))
46 | submit(net)
47 |
--------------------------------------------------------------------------------
/src/flatting/tkapp.py:
--------------------------------------------------------------------------------
1 | from flatting.app import main as start_server
2 |
3 | ## Briefcase doesn't support tkinter
4 |
5 | import asyncio
6 | import tkinter as tk
7 |
8 | async def start_server_in_thread():
9 | def start_server_wrapper():
10 | asyncio.set_event_loop(asyncio.new_event_loop())
11 | start_server()
12 |
13 | # Run the server in a thread.
14 | import threading
15 | server = threading.Thread( name='flatting_server', target=start_server_wrapper )
16 | server.setDaemon( True )
17 | server.start()
18 |
19 | def start_gui():
20 | import tkinter as tk
21 | root = tk.Tk()
22 |
23 | ## Adapting: https://stackoverflow.com/questions/47895765/use-asyncio-and-tkinter-or-another-gui-lib-together-without-freezing-the-gui
24 | loop = asyncio.get_event_loop()
25 |
26 | INTERVAL = 1/30
27 | async def guiloop():
28 | while True:
29 | root.update()
30 | await asyncio.sleep( INTERVAL )
31 | task = loop.create_task( guiloop() )
32 |
33 | def shutdown():
34 | task.cancel()
35 | loop.stop()
36 | # loop.close()
37 | root.destroy()
38 |
39 | root.title("Flatting Backend")
40 | ## The port changes if we pass a port argument to `web.run_app`.
41 | tk.Label( root, text="Serving at http://127.0.0.1:8080" ).pack()
42 | tk.Button( root, text="Quit", command=shutdown ).pack()
43 |
44 | # tk.mainloop()
45 |
46 |
47 | def main():
48 | start_gui()
49 | start_server()
50 |
51 | if __name__ == '__main__':
52 | main()
53 |
--------------------------------------------------------------------------------
/src/flatting/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import os
4 | import sys
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | from torch import optim
10 | from tqdm import tqdm
11 |
12 | from eval import eval_net
13 | from unet import UNet
14 |
15 | from torch.utils.tensorboard import SummaryWriter
16 | from utils.dataset import BasicDataset
17 | from torch.utils.data import DataLoader, random_split
18 | from torchvision import transforms as T
19 | from torchvision import utils
20 | from PIL import Image
21 | from io import BytesIO
22 |
23 | dir_line = './flatting/size_512/line_croped'
24 | dir_edge = './flatting/size_512/line_detection_croped'
25 | dir_checkpoint = './checkpoints'
26 |
27 | def denormalize(img):
28 | # denormalize
29 | inv_normalize = T.Normalize( mean=[-1], std=[2])
30 |
31 | img_np = inv_normalize(img)
32 | img_np = img_np.clamp(0, 1)
33 | # to numpy
34 | return img_np
35 |
36 | def train_net(net,
37 | device,
38 | epochs=100,
39 | batch_size=1,
40 | lr=0.001,
41 | val_percent=0.1,
42 | save_cp=True,
43 | crop_size = None):
44 |
45 | # dataset = BasicDataset(dir_line, dir_edge, crop_size = crop_size)
46 | logging.info("Loading training set to memory")
47 | lines_bytes, edges_bytes = load_to_ram(dir_line, dir_edge)
48 |
49 | dataset = BasicDataset(lines_bytes, edges_bytes, crop_size = crop_size)
50 |
51 | n_train = len(dataset)
52 |
53 | train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)
54 |
55 | # we don't need valiation currently
56 | # val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True, drop_last=True)
57 |
58 | writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}')
59 |
60 | global_step = 0
61 |
62 | # logging.info(f'''Starting training:
63 | # Epochs: {epochs}
64 | # Batch size: {batch_size}
65 | # Learning rate: {lr}
66 | # Training size: {n_train}
67 | # Validation size: {n_val}
68 | # Checkpoints: {save_cp}
69 | # Device: {device.type}
70 | # Images scaling: {img_scale}
71 | # ''')
72 |
73 | logging.info(f'''Starting training:
74 | Epochs: {epochs}
75 | Batch size: {batch_size}
76 | Learning rate: {lr}
77 | Training size: {n_train}
78 | Checkpoints: {save_cp}
79 | Device: {device.type}
80 | Crop size: {crop_size}
81 | ''')
82 |
83 | #optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
84 | optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=1e-8)
85 |
86 | # how to use scheduler?
87 | # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min' if net.n_classes > 1 else 'max', patience=2)
88 |
89 | # since now we are trying to generate images, so we use l1 loss
90 | # if net.n_classes > 1:
91 | # criterion = nn.CrossEntropyLoss()
92 | # else:
93 | # criterion = nn.BCEWithLogitsLoss()
94 |
95 | criterion = nn.L1Loss()
96 |
97 | for epoch in range(epochs):
98 | net.train()
99 |
100 | epoch_loss = 0
101 | with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
102 | for imgs, gts, mask1, mask2 in train_loader:
103 |
104 | # assert imgs.shape[1] == net.in_channels, \
105 | # f'Network has been defined with {net.in_channels} input channels, ' \
106 | # f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
107 | # 'the images are loaded correctly.'
108 |
109 | imgs = imgs.to(device=device, dtype=torch.float32)
110 | gts = gts.to(device=device, dtype=torch.float32)
111 |
112 | # forward
113 | pred = net(imgs)
114 |
115 | '''
116 | baseline
117 | '''
118 | # loss1 = criterion(pred, gts)
119 |
120 | '''
121 | weighted loss
122 | '''
123 | mask_1 = (1-mask1)
124 | mask_2 = 100 * (1-mask2)
125 | mask_3 = 0.5 * mask2
126 | mask_w = mask_1 + mask_2 + mask_3
127 | mask_w = mask_w.to(device=device, dtype=torch.float32)
128 | loss1 = criterion(pred * mask_w, gts * mask_w)
129 |
130 | '''
131 | point number loss
132 | the point number of the perdiction and gt should close, too
133 |
134 | '''
135 | loss2 = criterion(
136 | ((denormalize(gts)==0).sum()).float(),
137 | ((denormalize(pred)==0).sum()).float()
138 | )
139 |
140 | # total loss
141 | loss = loss1 + 0.5 * torch.log(torch.abs(loss2 + 1))
142 |
143 | epoch_loss += loss.item()
144 | writer.add_scalar('Loss/total', loss.item(), global_step)
145 | writer.add_scalar('Loss/l1', loss1.item(), global_step)
146 | writer.add_scalar('Loss/point', loss2.item(), global_step)
147 |
148 | pbar.set_postfix(**{'loss (batch)': loss.item()})
149 |
150 | # back propagate
151 | optimizer.zero_grad()
152 | loss.backward()
153 | nn.utils.clip_grad_value_(net.parameters(), 0.1)
154 | optimizer.step()
155 |
156 | pbar.update(imgs.shape[0])
157 |
158 | global_step += 1
159 |
160 | # if global_step % (n_train // (10 * batch_size)) == 0:
161 | # for tag, value in net.named_parameters():
162 | # tag = tag.replace('.', '/')
163 | # writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step)
164 | # writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), global_step)
165 | # val_score = eval_net(net, val_loader, device)
166 | # scheduler.step(val_score)
167 | # writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step)
168 |
169 | # if net.n_classes > 1:
170 | # logging.info('Validation cross entropy: {}'.format(val_score))
171 | # writer.add_scalar('Loss/test', val_score, global_step)
172 | # else:
173 | # logging.info('Validation Dice Coeff: {}'.format(val_score))
174 | # writer.add_scalar('Dice/test', val_score, global_step)
175 |
176 | # writer.add_images('images', imgs, global_step)
177 | # if net.n_classes == 1:
178 | # writer.add_images('masks/true', true_masks, global_step)
179 | # writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step)
180 |
181 | if global_step % 1000 == 0:
182 | sample = torch.cat((imgs, pred, gts), dim = 0)
183 | if os.path.exists("./results/train/") is False:
184 | logging.info("Creating ./results/train/")
185 | os.makedirs("./results/train/")
186 |
187 | utils.save_image(
188 | sample,
189 | f"./results/train/{str(global_step).zfill(6)}.png",
190 | nrow=int(batch_size),
191 | # nrow=int(sample.shape[0] ** 0.5),
192 | normalize=True,
193 | range=(-1, 1),
194 | )
195 |
196 | if save_cp and epoch % 100 == 0:
197 | try:
198 | os.mkdir(dir_checkpoint)
199 | logging.info('Created checkpoint directory')
200 | except OSError:
201 | pass
202 | torch.save(net.state_dict(),
203 | dir_checkpoint + f'/CP_epoch{epoch + 1}.pth')
204 | logging.info(f'Checkpoint {epoch + 1} saved !')
205 |
206 | writer.close()
207 |
208 | def save_to_ram(path_to_img):
209 |
210 | img = Image.open(path_to_img).convert("L")
211 | buffer = BytesIO()
212 | img.save(buffer, format='png')
213 |
214 | return buffer.getvalue()
215 |
216 | def load_to_ram(path_to_line, path_to_edge):
217 | lines = os.listdir(path_to_line)
218 | lines.sort()
219 |
220 | edges = os.listdir(path_to_edge)
221 | edges.sort()
222 |
223 | assert len(lines) == len(edges)
224 |
225 | lines_bytes = []
226 | edges_bytes = []
227 |
228 | # read everything into memory
229 | for img in tqdm(lines):
230 | assert img.replace("webp", "png") in edges
231 |
232 | lines_bytes.append(save_to_ram(os.path.join(path_to_line, img)))
233 | edges_bytes.append(save_to_ram(os.path.join(path_to_edge, img.replace("webp", "png"))))
234 |
235 | return lines_bytes, edges_bytes
236 |
237 |
238 | def get_args():
239 | parser = argparse.ArgumentParser(description='Train the UNet on images and target masks',
240 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
241 | parser.add_argument('-e', '--epochs', metavar='E', type=int, default=90000,
242 | help='Number of epochs', dest='epochs')
243 | parser.add_argument('-m', '--multi-gpu', action='store_true')
244 | parser.add_argument('-c', '--crop-size', metavar='C', type=int, default=512,
245 | help='the size of random cropping')
246 | parser.add_argument('-b', '--batch-size', metavar='B', type=int, nargs='?', default=1,
247 | help='Batch size', dest='batchsize')
248 | parser.add_argument('-l', '--learning-rate', metavar='LR', type=float, nargs='?', default=0.0001,
249 | help='Learning rate', dest='lr')
250 | parser.add_argument('-f', '--load', dest='load', type=str, default=False,
251 | help='Load model from a .pth file')
252 |
253 | return parser.parse_args()
254 |
255 |
256 | if __name__ == '__main__':
257 |
258 | __spec__ = None
259 |
260 | logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
261 | args = get_args()
262 |
263 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
264 | logging.info(f'Using device {device}')
265 |
266 | # Change here to adapt to your data
267 | # n_channels=3 for RGB images
268 | # n_classes is the number of probabilities you want to get per pixel
269 | # - For 1 class and background, use n_classes=1
270 | # - For 2 classes, use n_classes=1
271 | # - For N > 2 classes, use n_classes=N
272 |
273 | net = UNet(in_channels=1, out_channels=1, bilinear=True)
274 |
275 | if args.multi_gpu:
276 | logging.info("using data parallel")
277 | net = nn.DataParallel(net).cuda()
278 | else:
279 | net.to(device=device)
280 |
281 | # logging.info(f'Network:\n'
282 | # f'\t{net.in_channels} input channels\n'
283 | # f'\t{net.out_channels} output channels\n'
284 | # f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling'
285 | # )
286 |
287 | if args.load:
288 | net.load_state_dict(
289 | torch.load(args.load, map_location=device)
290 | )
291 | logging.info(f'Model loaded from {args.load}')
292 |
293 |
294 | # faster convolutions, but more memory
295 | # cudnn.benchmark = True
296 |
297 | try:
298 | train_net(net=net,
299 | epochs=args.epochs,
300 | batch_size=args.batchsize,
301 | lr=args.lr,
302 | device=device,
303 | crop_size=args.crop_size)
304 |
305 | # this is interesting, save model when keyborad interrupt
306 | except KeyboardInterrupt:
307 | torch.save(net.state_dict(), './checkpoints/INTERRUPTED.pth')
308 | # logging.info('Saved interrupt')
309 | try:
310 | sys.exit(0)
311 | except SystemExit:
312 | os._exit(0)
313 |
--------------------------------------------------------------------------------
/src/flatting/trapped_ball/adjacency_matrix.pyx:
--------------------------------------------------------------------------------
1 | # cython: language_level=3
2 | # cython: boundscheck=False
3 | # cython: wraparound=False
4 |
5 | #import numpy as np
6 | #cimport numpy as np
7 |
8 | ## TODO: Allocate memory internally. See: https://stackoverflow.com/questions/18462785/what-is-the-recommended-way-of-allocating-memory-for-a-typed-memory-view
9 |
10 | def adjacency_matrix( image, num_regions ):
11 | '''
12 | Given:
13 | image: A 2D image of integer labels in the range [0,num_regions].
14 | num_regions: The number of regions in `image`.
15 | Returns:
16 | A: The adjacency matrix such that A[i,j] is 1 if region i is
17 | connected to region j and 0 otherwise.
18 | '''
19 |
20 | import numpy as np
21 | A = np.zeros( ( num_regions, num_regions ), dtype = int )
22 | adjacency_matrix_internal( image, A )
23 | return A
24 |
25 | cpdef long[:,:] adjacency_matrix_internal( long[:,:] image, long[:,:] A ) nogil:
26 | '''
27 | Given:
28 | image: A 2D image of integer labels in the range [0,num_regions].
29 | Returns:
30 | A: The adjacency matrix such that A[i,j] is 1 if region i is
31 | connected to region j and 0 otherwise.
32 |
33 | Note: `A` is an output parameter. Allocate space and pass it in.
34 | '''
35 |
36 | # A = np.zeros( ( num_regions, num_regions ), dtype = int )
37 | A[:] = 0
38 |
39 | cdef long nrow = image.shape[0]
40 | cdef long ncol = image.shape[1]
41 | cdef long i,j,region0,region1
42 |
43 | ## Sweep with left-right neighbors. Skip the right-most column.
44 | for i in range(nrow):
45 | for j in range(ncol-1):
46 | region0 = image[i,j]
47 | region1 = image[i,j+1]
48 | A[region0,region1] = 1
49 | A[region1,region0] = 1
50 |
51 | ## Sweep with top-bottom neighbors. Skip the bottom-most row.
52 | for i in range(nrow-1):
53 | for j in range(ncol):
54 | region0 = image[i,j]
55 | region1 = image[i+1,j]
56 | A[region0,region1] = 1
57 | A[region1,region0] = 1
58 |
59 | ## Sweep with top-left-to-bottom-right neighbors. Skip the bottom row and right column.
60 | for i in range(nrow-1):
61 | for j in range(ncol-1):
62 | region0 = image[i,j]
63 | region1 = image[i+1,j+1]
64 | A[region0,region1] = 1
65 | A[region1,region0] = 1
66 |
67 | ## Sweep with top-right-to-bottom-left neighbors. Skip the bottom row and left column.
68 | for i in range(nrow-1):
69 | for j in range(1,ncol):
70 | region0 = image[i,j]
71 | region1 = image[i+1,j-1]
72 | A[region0,region1] = 1
73 | A[region1,region0] = 1
74 |
75 | ## region will not connect to itself
76 | for i in range(len(A)):
77 | A[i, i] = 0
78 |
79 | return A
80 |
81 | def region_sizes( image, num_regions ):
82 | '''
83 | Given:
84 | image: A 2D image of integer labels in the range [0,num_regions].
85 | num_regions: The number of regions in `image`.
86 | Returns:
87 | sizes: An array of length `num_regions`. Each element stores the
88 | number of pixels with the corresponding region number.
89 | That is, region i has `region_sizes[i]` pixels.
90 | '''
91 |
92 | import numpy as np
93 | sizes = np.zeros( num_regions, dtype = int )
94 | region_sizes_internal( image, sizes )
95 | return sizes
96 |
97 | cpdef long[:] region_sizes_internal( long[:,:] image, long[:] sizes ) nogil:
98 | '''
99 | Given:
100 | image: A 2D image of integer labels in the range [0,num_regions].
101 | Returns:
102 | sizes: An array of length `num_regions`. Each element stores the
103 | number of pixels with the corresponding region number.
104 | That is, region i has `region_sizes[i]` pixels.
105 |
106 | Note: `sizes` is an output parameter. Allocate space and pass it in.
107 | '''
108 |
109 | # sizes = np.zeros( num_regions, dtype = int )
110 | sizes[:] = 0
111 |
112 | cdef long nrow = image.shape[0]
113 | cdef long ncol = image.shape[1]
114 | cdef long i,j,region
115 |
116 | ## Sweep with left-right neighbors. Skip the right-most column.
117 | for i in range(nrow):
118 | for j in range(ncol):
119 | region = image[i,j]
120 | sizes[region] += 1
121 |
122 | return sizes
123 |
124 | def remap_labels( image, remaps ):
125 | '''
126 | Given:
127 | image: A 2D image of integer labels in the range [0,len(remaps)].
128 | remaps: A 1D array of integer remappings that occurred.
129 | Returns:
130 | image_remapped: An array the same size as `image` except labels have been
131 | remapped according to `remaps`.
132 | '''
133 |
134 | import numpy as np
135 | image_remapped = image.copy().astype( int )
136 | remaps = remaps.astype( int )
137 | remap_labels_internal( image_remapped, remaps )
138 | return image_remapped
139 |
140 | cpdef void remap_labels_internal( long[:,:] image, long[:] remaps ) nogil:
141 | '''
142 | Given:
143 | image: A 2D image of integer labels in the range [0,len(remaps)].
144 | remaps: A 1D array of integer remappings that occurred.
145 | Modifies in-place:
146 | image: `image` with labels remapped according to `remaps`.
147 | '''
148 |
149 | cdef long nrow = image.shape[0]
150 | cdef long ncol = image.shape[1]
151 | cdef long i,j,region
152 |
153 | ## Sweep with left-right neighbors. Skip the right-most column.
154 | for i in range(nrow):
155 | for j in range(ncol):
156 | region = image[i,j]
157 | # I see, the merge operation likes a chain, we need to apply all of them one by one
158 | # or we can apply this on remaps frist, I'm not sure which one will be faster.
159 | while remaps[region] != region:
160 | image[i,j] = remaps[region]
161 | region = remaps[region]
162 |
163 | return
164 |
--------------------------------------------------------------------------------
/src/flatting/trapped_ball/examples/01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nauhcnay/flat_magic_backend/5344f11c7a50c0a5b0d0876dcf68aa45b5a84687/src/flatting/trapped_ball/examples/01.png
--------------------------------------------------------------------------------
/src/flatting/trapped_ball/examples/01_sim.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nauhcnay/flat_magic_backend/5344f11c7a50c0a5b0d0876dcf68aa45b5a84687/src/flatting/trapped_ball/examples/01_sim.png
--------------------------------------------------------------------------------
/src/flatting/trapped_ball/examples/02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nauhcnay/flat_magic_backend/5344f11c7a50c0a5b0d0876dcf68aa45b5a84687/src/flatting/trapped_ball/examples/02.png
--------------------------------------------------------------------------------
/src/flatting/trapped_ball/examples/tiny.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nauhcnay/flat_magic_backend/5344f11c7a50c0a5b0d0876dcf68aa45b5a84687/src/flatting/trapped_ball/examples/tiny.png
--------------------------------------------------------------------------------
/src/flatting/trapped_ball/examples/tiny_sim.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nauhcnay/flat_magic_backend/5344f11c7a50c0a5b0d0876dcf68aa45b5a84687/src/flatting/trapped_ball/examples/tiny_sim.png
--------------------------------------------------------------------------------
/src/flatting/trapped_ball/run.py:
--------------------------------------------------------------------------------
1 |
2 | from .trappedball_fill import trapped_ball_fill_multi, flood_fill_multi, mark_fill, build_fill_map, merge_fill, show_fill_map, merger_fill_2nd
3 | from .trappedball_fill import get_ball_structuring_element, extract_line, to_masked_line
4 | from .thinning import thinning
5 | # from skimage.morphology import skeletonize
6 | from PIL import Image
7 | from tqdm import tqdm
8 |
9 | import argparse
10 | import cv2
11 | # import matplotlib.pyplot as plt
12 | import os
13 | import numpy as np
14 | from os.path import *
15 |
16 | # use cython adjacency matrix
17 | try:
18 | from . import adjacency_matrix
19 | ## If it's not already compiled, compile it.
20 | except:
21 | import pyximport
22 | pyximport.install()
23 | from . import adjacency_matrix
24 |
25 | def extract_skeleton(img):
26 |
27 | size = np.size(img)
28 | skel = np.zeros(img.shape,np.uint8)
29 | element = cv2.getStructuringElement(cv2.MORPH_CROSS,(3,3))
30 | done = False
31 |
32 | while done is False:
33 | eroded = cv2.erode(img,element)
34 | temp = cv2.dilate(eroded,element)
35 | temp = cv2.subtract(img,temp)
36 | skel = cv2.bitwise_or(skel,temp)
37 | img = eroded.copy()
38 |
39 | zeros = size - cv2.countNonZero(img)
40 | if zeros==size:
41 | done = True
42 |
43 | return skel
44 |
45 | def generate_masked_line(line_simplify, line_artist, line_artist_fullsize):
46 | line_masked = to_masked_line(line_simplify, line_artist, rk1=1, rk2=1, tn=1)
47 |
48 | # remove isolate points
49 | # it is not safe to do that at down scaled size
50 | # _, result = cv2.connectedComponents(255 - line_masked, connectivity=8)
51 |
52 | # up scale masked line to full size
53 | line_masked_fullsize_t = cv2.resize(line_masked.astype(np.uint8),
54 | (line_artist_fullsize.shape[1], line_artist_fullsize.shape[0]),
55 | interpolation = cv2.INTER_NEAREST)
56 |
57 | # maske with fullsize artist line again
58 | line_masked_fullsize = to_masked_line(line_masked_fullsize_t, line_artist_fullsize, rk1=7, rk2=1, tn=2)
59 |
60 | # remove isolate points
61 | _, temp = cv2.connectedComponents(255 - line_masked_fullsize, connectivity=4)
62 |
63 | def remove_stray_points(fillmap, drop_thres = 32):
64 | ids = np.unique(fillmap)
65 | result = np.ones(fillmap.shape) * 255
66 |
67 | for i in tqdm(ids):
68 | if i == 0: continue
69 | if len(np.where(fillmap == i)[0]) < drop_thres:
70 | # set them as background
71 | result[fillmap == i] = 255
72 | else:
73 | # set them as line
74 | result[fillmap == i] = 0
75 |
76 | return result
77 |
78 | line_masked_fullsize = remove_stray_points(temp, 16)
79 |
80 | return line_masked_fullsize
81 |
82 |
83 | def region_get_map(path_to_line_sim,
84 | path_to_line_artist=None,
85 | output_path=None,
86 | radius_set=[3,2,1],
87 | percentiles=[90, 0, 0],
88 | visualize_steps=False,
89 | return_numpy=False,
90 | preview=False):
91 | '''
92 | Given:
93 | the path to input png file
94 | Return:
95 | the initial region map as a numpy matrix
96 | '''
97 | def read_png(path_to_png, to_grayscale=True):
98 | '''
99 | Given:
100 | path_to_png, it accept be any type of input, path, numpy array or PIL Image
101 | Return:
102 | the numpy array of a image
103 | '''
104 |
105 | # if it is png file, open it
106 | if isinstance(path_to_png, str):
107 | # get file name
108 | _, file = os.path.split(path_to_png)
109 | name, _ = os.path.splitext(file)
110 |
111 | print("Log:\topen %s"%path_to_png)
112 | img_org = cv2.imread(path_to_png, cv2.IMREAD_COLOR)
113 | if to_grayscale:
114 | img = cv2.cvtColor(img_org, cv2.COLOR_BGR2GRAY)
115 | else:
116 | img = img_org
117 |
118 | elif isinstance(path_to_png, Image.Image):
119 | if to_grayscale:
120 | path_to_png = path_to_png.convert("L")
121 | img = np.array(path_to_png)
122 | name = "result"
123 |
124 | elif isinstance(path_to_png, np.ndarray):
125 | img = path_to_png
126 | if len(img.shape) > 2 and to_grayscale:
127 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
128 | name = "result"
129 |
130 | else:
131 | raise ValueError("The input data type %s is not supported"%str(type(path_to_png)))
132 |
133 | return img, name
134 |
135 | # read files
136 | img, name = read_png(path_to_line_sim)
137 | line_artist_fullsize, _ = read_png(path_to_line_artist)
138 | if len(line_artist_fullsize.shape) == 3:
139 | line_artist_fullsize = cv2.cvtColor(line_artist_fullsize, cv2.COLOR_BGR2GRAY)
140 | # line_artist_fullsize = cv2.adaptiveThreshold(line_artist_fullsize, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY,11,2)
141 |
142 | print("Log:\ttrapped ball filling")
143 | # threshold line arts before filling
144 | _, line_simplify = cv2.threshold(img, 200, 255, cv2.THRESH_BINARY)
145 | _, line_artist_fullsize = cv2.threshold(line_artist_fullsize, 200, 255, cv2.THRESH_BINARY)
146 | fills = []
147 | result = line_simplify # this should be line_simplify numpu array
148 | line = line_artist_fullsize.copy() # change to a shorter name
149 |
150 | # may be resize the original line is not a good idea
151 | if line.shape[:2] != line_simplify.shape[:2]:
152 | line = cv2.resize(line, (line_simplify.shape[1],line_simplify.shape[0]),
153 | interpolation = cv2.INTER_AREA)
154 | _, line_artist = cv2.threshold(line, 200, 255, cv2.THRESH_BINARY)
155 | assert len(radius_set) == len(percentiles)
156 |
157 | # trapped ball fillilng
158 | for i in range(len(radius_set)):
159 | fill = trapped_ball_fill_multi(result, radius_set[i], percentile=percentiles[i])
160 | fills += fill
161 | result = mark_fill(result, fill)
162 | if visualize_steps:
163 | cv2.imwrite("%d.r%d_per%.2f.png"%(i+1, radius_set[i], percentiles[i]),
164 | show_fill_map(build_fill_map(result, fills)))
165 |
166 | # fill up remaining regions if there still have
167 | fill = flood_fill_multi(result)
168 | fills += fill
169 |
170 | # convert fill mask to fill map
171 | fillmap_neural = build_fill_map(result, fills)
172 | if visualize_steps:
173 | i+=1
174 | cv2.imwrite("%d.final_fills.png"%i, show_fill_map(fillmap_neural))
175 |
176 | # final refine, remove tiny regions in the fill map
177 | fillmap_neural = merge_fill(fillmap_neural)
178 | if visualize_steps:
179 | i+=1
180 | cv2.imwrite("%d.merged.png"%i, show_fill_map(fillmap_neural))
181 |
182 | # remove the line art region
183 | fillmap_neural = thinning(fillmap_neural)
184 | if visualize_steps:
185 | i+=1
186 | cv2.imwrite("%d.fills_final.png"%i, show_fill_map(fillmap_neural))
187 |
188 | # upscale neural fill map back to original size
189 | fillmap_neural_fullsize = cv2.resize(fillmap_neural.astype(np.uint8),
190 | (line_artist_fullsize.shape[1], line_artist_fullsize.shape[0]),
191 | interpolation = cv2.INTER_NEAREST)
192 | fillmap_neural_fullsize = fillmap_neural_fullsize.astype(np.int32)
193 |
194 | if preview:
195 | fill_neural_fullsize = show_fill_map(fillmap_neural_fullsize)
196 | fill_neural_fullsize[line_artist_fullsize < 125] = 0
197 | return Image.fromarray(fill_neural_fullsize.astype(np.uint8))
198 |
199 | ## bleeding removal
200 | # prepare nerual map and FF map with full size
201 | fill_neural = show_fill_map(fillmap_neural)
202 | fill_neural_line = fill_neural.copy()
203 | fill_neural_line[line_simplify < 200] = 0
204 | fillmap_artist_fullsize = np.ones(fillmap_neural_fullsize.shape, dtype=np.uint8) * 255
205 | fillmap_artist_fullsize[line_artist_fullsize < 125] = 0
206 | _, fillmap_artist_fullsize_c = cv2.connectedComponents(fillmap_artist_fullsize, connectivity=8)
207 |
208 | print("Log:\tcompute cartesian product")
209 | fillmap_neural_fullsize_c = fillmap_neural_fullsize.copy()
210 |
211 | fillmap_neural_fullsize[line_artist_fullsize < 125] = 0
212 | fillmap_neural_fullsize = verify_region(fillmap_neural_fullsize)
213 |
214 | fillmap_artist_fullsize = fillmap_cartesian_product(fillmap_artist_fullsize_c, fillmap_neural_fullsize)
215 | fillmap_artist_fullsize[line_artist_fullsize < 125] = 0
216 |
217 | # re-order both fillmaps
218 | fillmap_artist_fullsize = verify_region(fillmap_artist_fullsize, True)
219 | # fillmap_neural_fullsize = verify_region(fillmap_neural_fullsize, True)
220 |
221 | fillmap_neural_fullsize = bleeding_removal_yotam(fillmap_neural_fullsize_c, fillmap_artist_fullsize, th=0.0002)
222 | fillmap_neural_fullsize[line_artist_fullsize < 125] = 0
223 | fillmap_neural_fullsize = verify_region(fillmap_neural_fullsize, True)
224 |
225 | # convert final result to graph
226 | # we have adjacency matrix, we have fillmap, do we really need another graph for it?
227 | fillmap_artist_fullsize_c = thinning(fillmap_artist_fullsize_c)
228 | fillmap_neural_fullsize = thinning(fillmap_neural_fullsize)
229 |
230 | fill_artist_fullsize = show_fill_map(fillmap_artist_fullsize_c)
231 | fill_neural_fullsize = show_fill_map(fillmap_neural_fullsize)
232 | # fill_neural_fullsize[line_artist_fullsize < 125] = 0
233 |
234 | # if output_path is not None:
235 |
236 | # print("Log:\tsave final fill at %s"%os.path.join(output_path, str(name)+"_fill.png"))
237 | # cv2.imwrite(os.path.join(output_path, str(name)+"_fill.png"), fill_neural_fullsize)
238 |
239 | # print("Log:\tsave neural fill at %s"%os.path.join(output_path, str(name)+"_neural.png"))
240 | # cv2.imwrite(os.path.join(output_path, str(name)+"_neural.png"), fill_neural)
241 |
242 | # print("Log:\tsave fine fill at %s"%os.path.join(output_path, str(name)+"_fine.png"))
243 | # cv2.imwrite(os.path.join(output_path, str(name)+"_fine.png"),
244 | # show_fill_map(fillmap_artist_fullsize_c))
245 |
246 | print("Log:\tdone")
247 | if return_numpy:
248 | return fill_neural, fill_neural_line, fill_artist_fullsize, fill_neural_fullsize
249 | else:
250 | return fillmap_neural_fullsize, fillmap_artist_fullsize_c,\
251 | fill_neural_fullsize, fill_neural, fill_artist_fullsize
252 |
253 | def fillmap_cartesian_product(fill1, fill2):
254 | '''
255 | Given:
256 | fill1, fillmap 1
257 | fill2, fillmap 2
258 | Return:
259 | A new fillmap based on its cartesian_product
260 | '''
261 | assert fill1.shape == fill2.shape
262 |
263 | if len(fill1.shape)==2:
264 | fill1 = np.expand_dims(fill1, axis=-1)
265 |
266 | if len(fill2.shape)==2:
267 | fill2 = np.expand_dims(fill2, axis=-1)
268 |
269 | # cat along channel
270 | fill_c = np.concatenate((fill1, fill2), axis=-1)
271 |
272 | # regnerate all region labels
273 | labels, inv = np.unique(fill_c.reshape(-1, 2), return_inverse=True, axis=0)
274 | labels = tuple(map(tuple, labels)) # convert array to tuple
275 |
276 | # assign a number lable to each cartesian product tuple
277 | l_to_r = {}
278 | for i in range(len(labels)):
279 | l_to_r[labels[i]] = i+1
280 |
281 | # assign new labels back to fillmap
282 | # https://stackoverflow.com/questions/16992713/translate-every-element-in-numpy-array-according-to-key
283 | fill_c = np.array(list(map(l_to_r.get, labels)))[inv]
284 | fill_c = fill_c.reshape(fill1.shape[0:2])
285 |
286 | return fill_c
287 |
288 |
289 | # verify if there is no isolate sub-region in each region, if yes, split it and assign a new region id
290 | # Yotam: Can this function be replaced with a single call to cv2.connectedComponents()?
291 | # Chuan: I think no, to find the bleeding regions on the bounderay, iteratively flood fill each region is necessary
292 | def verify_region(fillmap, reorder_only=False):
293 | fillmap = fillmap.copy().astype(np.int32)
294 | labels = np.unique(fillmap)
295 | h, w = fillmap.shape
296 | # split region
297 | # is this really necessary?
298 | # yes, without this snippet, the result will be bad at line boundary
299 | # intuitively, this is like an "alingment" of smaller neural fill map to the large original line art
300 | # is it possible to crop the image before connectedComponents filling?
301 | next_label = labels.max() + 1
302 | if reorder_only == False:
303 | print("Log:\tsplit isolate regions in fillmap")
304 | for r in tqdm(labels):
305 | if r == 0: continue
306 | # inital input fill map
307 | region = np.ones(fillmap.shape, dtype=np.uint8)
308 | region[fillmap != r] = 0
309 | '''
310 | seems this get the speed even slower, sad
311 | need to find a better way
312 | '''
313 | # # try to split region
314 | # def find_bounding_box(region):
315 | # # find the pixel coordination of this region
316 | # points = np.array(np.where(region == 1)).T
317 | # t = points[:,0].min() # top
318 | # l = points[:,1].min() # left
319 | # b = points[:,0].max() # bottom
320 | # r = points[:,1].max() # right
321 | # return t, l, b, r
322 | # t, l, b, r = find_bounding_box(region)
323 | # region_cropped = region[t:b+1, l:r+1]
324 | # # fill_map_corpped = fill_map[t:b+1, l:r+1]
325 |
326 | _, region_verify = cv2.connectedComponents(region, connectivity=8)
327 |
328 | '''
329 | seems this get the speed even slower, sad
330 | '''
331 | # padding 0 back to the region
332 | # region_padded = cv2.copyMakeBorder(region_verify, t, h-b-1, l, w-r-1, cv2.BORDER_CONSTANT, 0)
333 | # assert region_padded.shape == fillmap.shape
334 | # region_verify = region_padded
335 |
336 |
337 | # split region if necessary
338 | label_verify = np.unique(region_verify)
339 | if len(label_verify) > 2: # skip 0 and the first region
340 | for j in range(2, len(label_verify)):
341 | fillmap[region_verify == label_verify[j]] = next_label
342 | next_label += 1
343 |
344 | # re-order regions
345 | assert np.unique(fillmap).max() == next_label - 1
346 | old_to_new = [0] * next_label
347 | idx = 1
348 | l = len(old_to_new)
349 | labels = np.unique(fillmap)
350 | for i in range(l):
351 | if i in labels and i != 0:
352 | old_to_new[i] = idx
353 | idx += 1
354 | else:
355 | old_to_new[i] = 0
356 | old_to_new = np.array(old_to_new)
357 | fillmap_out = old_to_new[fillmap]
358 |
359 | # assert np.unique(fillmap_out).max()+1 == len(np.unique(fillmap_out))
360 | return fillmap_out
361 |
362 | def update_adj_matrix(A, source, target):
363 |
364 | # update A, region s and max is not neigbor any more
365 | # assert A[source, target] == 1
366 | A[source, target] = 0
367 |
368 | # assert A[target, source] == 1
369 | A[target, source] = 0
370 |
371 | # neighbors of s should become neighbor of max
372 | s_neighbors_x = np.where(A[source,:] == 1)
373 | s_neighbors_y = np.where(A[:,source] == 1)
374 | A[source, s_neighbors_x] = 0
375 | A[s_neighbors_y, source] = 0
376 |
377 | # neighbor of neighbors of s should use max instead of s
378 | A[s_neighbors_x, target] = 1
379 | A[target, s_neighbors_y] = 1
380 |
381 | return A
382 |
383 | def merge_to_ref(fill_map_ref, fill_map_source, r_idx, result):
384 |
385 | # this could be imporved as well
386 | # r_idx is the region labels
387 | F = {} #mapping of large region to ref region
388 | for i in range(len(r_idx)):
389 | r = r_idx[i]
390 |
391 | if r == 0: continue
392 | label_mask = fill_map_source == r
393 | idx, count = np.unique(fill_map_ref[label_mask], return_counts=True)
394 | most_common = idx[np.argmax(count)]
395 | F[r] = most_common
396 |
397 | for r in r_idx:
398 | if r == 0: continue
399 | label_mask = fill_map_source == r
400 | result[label_mask] = F[r]
401 |
402 | return result
403 |
404 | def merge_small_fast(fill_map_ref, fill_map_source, th):
405 | '''
406 | OK let's understand the improved version
407 |
408 | '''
409 |
410 | fill_map_source = fill_map_source.copy()
411 | fill_map_ref = fill_map_ref.copy()
412 |
413 | num_regions = len(np.unique(fill_map_source))
414 |
415 | # the definition of long int is different on windows and linux
416 | try:
417 | A = adjacency_matrix.adjacency_matrix(fill_map_source.astype(np.int32), num_regions)
418 | except:
419 | A = adjacency_matrix.adjacency_matrix(fill_map_source.astype(np.int64), num_regions)
420 |
421 | r_idx_source, r_count_source = np.unique(fill_map_source, return_counts=True)
422 |
423 |
424 |
425 | ## Labels should be contiguous.
426 | assert len(r_idx_source) == max(r_idx_source)+1
427 | ## A should have the same dimensions as number of labels.
428 | assert A.shape[0] == A.shape[1]
429 | assert A.shape[0] == len( r_idx_source )
430 | ARTIST_LINE_LABEL = 0
431 | def get_small_region(r_idx_source, r_count_source, th):
432 | return set(
433 | # 1. size less that threshold
434 | r_idx_source[ r_count_source < th ]
435 | ) | set(
436 | # 2. not the neighbor of artist line
437 | ## Which of `r_idx_source` have a 0 in the adjacency position for `ARTIST_LINE_LABEL`?
438 | r_idx_source[ A[r_idx_source,ARTIST_LINE_LABEL] == 0 ]
439 | )
440 |
441 | r_idx_source_small = get_small_region(r_idx_source, r_count_source, th)
442 |
443 | stop = False
444 |
445 | while len(r_idx_source_small) > 0 and stop == False:
446 |
447 | stop = True
448 |
449 | for s in r_idx_source_small:
450 | if s == ARTIST_LINE_LABEL: continue
451 |
452 | neighbors = np.where(A[s,:] == 1)[0]
453 |
454 | # remove line regions
455 | neighbors = neighbors[neighbors != ARTIST_LINE_LABEL]
456 |
457 | # skip if this region doesn't have neighbors
458 | if len(neighbors) == 0: continue
459 |
460 | # find region size
461 | # sizes = np.array([get_size(r_idx_source, r_count_source, n) for n in neighbors]).flatten()
462 | sizes = r_count_source[ neighbors ]
463 |
464 | # merge regions if necessary
465 | largest_index = np.argmax(sizes)
466 | if neighbors[largest_index] == ARTIST_LINE_LABEL and len(neighbors) > 1:
467 | # if its largest neighbor is line skip it
468 | del neighbors[ largest_index ]
469 | del sizes[ largest_index ]
470 |
471 | if len(neighbors) >= 1:
472 | label_mask = fill_map_source == s
473 | max_neighbor = neighbors[np.argmax(sizes)]
474 | A = update_adj_matrix(A, s, max_neighbor)
475 | fill_map_source[label_mask] = max_neighbor
476 | stop = False
477 | else:
478 | continue
479 |
480 | r_idx_source, r_count_source = np.unique(fill_map_source, return_counts=True)
481 | r_idx_source_small = get_small_region(r_idx_source, r_count_source, th)
482 |
483 | '''
484 | for debug
485 | after the first for loop, these 3 variable should have exactly same value compare to the merge_small_fast2's result
486 | '''
487 | # return r_idx_source_small, r_idx_source, r_count_source
488 | # return fill_map_source
489 |
490 | return fill_map_source
491 |
492 | def merge_small_fast2(fill_map_ref, fill_map_source, th):
493 | '''
494 |
495 | '''
496 |
497 | fill_map_source = fill_map_source.copy()
498 | fill_map_ref = fill_map_ref.copy()
499 |
500 | num_regions = len(np.unique(fill_map_source))
501 |
502 | # the definition of long int is different on windows and linux
503 | try:
504 | A = adjacency_matrix.adjacency_matrix(fill_map_source.astype(np.int32), num_regions)
505 | except:
506 | A = adjacency_matrix.adjacency_matrix(fill_map_source.astype(np.int64), num_regions)
507 |
508 | r_idx_source, r_count_source = np.unique(fill_map_source, return_counts=True)
509 | ## Convert them to masked arrays
510 | # why?
511 | r_idx_source = np.ma.masked_array( r_idx_source )
512 | r_count_source = np.ma.masked_array( r_count_source )
513 |
514 |
515 | ## Labels should be contiguous.
516 | assert len(r_idx_source) == max(r_idx_source)+1
517 | ## A should have the same dimensions as number of labels.
518 | assert A.shape[0] == A.shape[1]
519 | assert A.shape[0] == len( r_idx_source )
520 | ARTIST_LINE_LABEL = 0
521 | def get_small_region(r_idx_source, r_count_source, th):
522 | return set(
523 | # 1. size less that threshold
524 | r_idx_source[ r_count_source < th ].compressed()
525 | ) | set(
526 | # 2. not the neighbor of artist line
527 | ## Which of `r_idx_source` have a 0 in the adjacency position for `ARTIST_LINE_LABEL`?
528 | r_idx_source[ A[r_idx_source,ARTIST_LINE_LABEL] == 0 ].compressed()
529 | )
530 |
531 | r_idx_source_small = get_small_region(r_idx_source, r_count_source, th)
532 |
533 | # since the region labels are always continous numbers, so it is safe to create a remap array like this
534 | # in other word, r_idx_source.max() + 1 == len(r_idx_source)
535 | remap = np.arange(len(r_idx_source))
536 |
537 | stop = False
538 |
539 | while len(r_idx_source_small) > 0 and stop == False:
540 |
541 | stop = True
542 |
543 | for s in r_idx_source_small:
544 | if s == ARTIST_LINE_LABEL: continue
545 |
546 | neighbors = np.where(A[s,:] == 1)[0]
547 |
548 | # remove line regions
549 | neighbors = neighbors[neighbors != ARTIST_LINE_LABEL]
550 |
551 | # skip if this region doesn't have neighbors
552 | if len(neighbors) == 0: continue
553 |
554 | # find region size
555 | # sizes = np.array([get_size(r_idx_source, r_count_source, n) for n in neighbors]).flatten()
556 | sizes = r_count_source[ neighbors ]
557 |
558 | # merge regions if necessary
559 | largest_index = np.argmax(sizes)
560 | if neighbors[largest_index] == ARTIST_LINE_LABEL and len(neighbors) > 1:
561 | # if its largest neighbor is line skip it
562 | del neighbors[ largest_index ]
563 | del sizes[ largest_index ]
564 |
565 | if len(neighbors) >= 1:
566 | max_neighbor = neighbors[np.argmax(sizes)]
567 | A = update_adj_matrix(A, s, max_neighbor)
568 | # record the operation of merge
569 | remap[s] = max_neighbor
570 | # update the region size
571 | r_count_source[max_neighbor] = r_count_source[max_neighbor] + r_count_source[s]
572 | # remove the merged region, however, we should keep the index unchanged
573 | r_count_source[s] = np.ma.masked
574 | r_idx_source[s] = np.ma.masked
575 | stop = False
576 | else:
577 | continue
578 |
579 | r_idx_source_small = get_small_region(r_idx_source, r_count_source, th)
580 |
581 | '''
582 | for debug
583 | after the first for loop, these 3 variable should have exactly same value compare to the merge_small_fast2's result
584 | '''
585 | # return r_idx_source_small, r_idx_source, r_count_source
586 | # adjacency_matrix.remap_labels( fill_map_source, remap )
587 | # return fill_map_source
588 |
589 | fill_map_source = adjacency_matrix.remap_labels( fill_map_source, remap )
590 |
591 | return fill_map_source
592 |
593 | def merge_small(fill_map_ref, fill_map_source, th):
594 | '''
595 | Given:
596 | fill_map_ref: 2D numpy array as neural fill map on neural line
597 | fill_map_source: Connected commponent fill map on artist line
598 | th: A threshold to identify small regions
599 | Returns:
600 |
601 | '''
602 |
603 | # result_fast1 = merge_small_fast(fill_map_ref, fill_map_source, th)
604 | # result_fast2 = merge_small_fast2(fill_map_ref, fill_map_source, th)
605 | # assert ( result_fast1 == result_fast2 ).all()
606 | # r1, r2, r3 = merge_small_fast(fill_map_ref, fill_map_source, th)
607 | # s1, s2, s3 = merge_small_fast2(fill_map_ref, fill_map_source, th)
608 | # return result_fast1
609 |
610 | # make a copy of input, we don't want to affect the array outside of this function
611 | fill_map_source = fill_map_source.copy()
612 | fill_map_ref = fill_map_ref.copy()
613 |
614 | num_regions = len(np.unique(fill_map_source))
615 |
616 | # the definition of long int is different on windows and linux
617 | try:
618 | A = adjacency_matrix.adjacency_matrix(fill_map_source.astype(np.int32), num_regions)
619 | except:
620 | A = adjacency_matrix.adjacency_matrix(fill_map_source.astype(np.int64), num_regions)
621 |
622 | # find the label and size of each region
623 | r_idx_source, r_count_source = np.unique(fill_map_source, return_counts=True)
624 |
625 | def get_small_region(r_idx_source, r_count_source, th):
626 | '''
627 | Find the 'small' region that need to be merged to its neighbor
628 | '''
629 | r_idx_source_small = []
630 | for i in range(len(r_idx_source)):
631 | # there are two kinds of region should be identified as small region:
632 | # 1. size less the threshold
633 | if r_count_source[i] < th:
634 | r_idx_source_small.append(r_idx_source[i])
635 | # 2. not the neighbor of artist line, this type of region is not adjecent to any stroke lines,
636 | # so it need to be merged to a neighbor which touch the strokes no matter how big it is
637 | n = np.where(A[r_idx_source[i],:] == 1)[0]
638 | if 0 not in n:
639 | r_idx_source_small.append(r_idx_source[i])
640 | return r_idx_source_small
641 |
642 | # find the small regions that need to be merged
643 | r_idx_source_small = get_small_region(r_idx_source, r_count_source, th)
644 |
645 | # early stop sign
646 | stop = False
647 |
648 | # main loop to iteratively merge all small regions into its largest neighbor
649 | while len(r_idx_source_small) > 0 and stop == False:
650 |
651 | stop = True
652 | # each time process small regions in the list sequentially
653 | for s in r_idx_source_small:
654 | if s == 0: continue
655 |
656 | # get the pixel mask of region s
657 | label_mask = fill_map_source == s
658 |
659 | # find all neighbors of region s
660 | neighbors = np.where(A[s,:] == 1)[0]
661 |
662 | # remove line regions
663 | neighbors = neighbors[neighbors != 0]
664 |
665 | # skip if this region doesn't have neighbors
666 | if len(neighbors) == 0: continue
667 |
668 | # find region size of s's neighbors
669 | sizes = np.array([get_size(r_idx_source, r_count_source, n) for n in neighbors]).flatten()
670 |
671 | # merge regions s to its largest neighbor
672 | if neighbors[np.argsort(sizes)[-1]] == 0 and len(neighbors) > 1:
673 | # if its largest neighbor is line skip it
674 | max_neighbor = neighbors[np.argsort(sizes)[-2]]
675 | A = update_adj_matrix(A, s, max_neighbor)
676 | fill_map_source[label_mask] = max_neighbor
677 | stop = False
678 | elif len(neighbors) >= 1:
679 | # esle return its largest nerighbor
680 | max_neighbor = neighbors[np.argsort(sizes)[-1]]
681 | A = update_adj_matrix(A, s, max_neighbor)
682 | fill_map_source[label_mask] = max_neighbor
683 | stop = False
684 | else:
685 | continue
686 |
687 | # re-search the small regions for next loop
688 | r_idx_source, r_count_source = np.unique(fill_map_source, return_counts=True)
689 | r_idx_source_small = get_small_region(r_idx_source, r_count_source, th)
690 |
691 | # assert ( fill_map_source == result_fast2 ).all()
692 | return fill_map_source
693 |
694 | def get_size(idx, count, r):
695 | assert r in idx
696 | assert r != 0
697 |
698 | return count[np.where(idx==r)]
699 |
700 | def bleeding_removal_yotam(fill_map_ref, fill_map_source, th):
701 |
702 | fill_map_ref = fill_map_ref.copy() # connected compoenent fill map
703 | fill_map_source = fill_map_source.copy() # the cartesian product of connected component and neural fill map
704 |
705 | w, h = fill_map_ref.shape
706 | th = int(w * h * th)
707 |
708 | result = np.zeros(fill_map_ref.shape, dtype=np.int32)
709 | # 1. merge small regions which has neighbors
710 | # the int64 means long on linux but long long on windows, sad
711 | print("Log:\tmerge small regions")
712 | fill_map_source = merge_small_fast2(fill_map_ref, fill_map_source, th)
713 |
714 | # 2. merge large regions
715 | # now the fill_map_source is clean, no bleeding. but it still contains many "broken" pieces which
716 | # should belong to the same semantical regions. So, we can merge these "large but still broken" region
717 | # together by the neural fill map.
718 | print("Log:\tmerge large regions")
719 | r_idx_source= np.unique(fill_map_source)
720 | result = merge_to_ref(fill_map_ref, fill_map_source, r_idx_source, result)
721 |
722 | return result
723 |
724 | def sweep_line_merge(fillmap_neural_fullsize, fillmap_artist_fullsize, add_th, keep_th):
725 |
726 | assert fillmap_neural_fullsize.shape == fillmap_artist_fullsize.shape
727 |
728 | result = np.zeros(fillmap_neural_fullsize.shape)
729 |
730 | def to_sweep_list(fillmap):
731 | sweep_dict = {}
732 | sweep_ml = [] # most left position, which is also the sweep line's anchor
733 | sweep_list, sweep_count = np.unique(fillmap, return_counts=True)
734 | for i in range(len(sweep_list)):
735 | idx = sweep_list[i]
736 | if idx == 0: continue
737 | # 1. point sets 2. if have been merged 3. region area
738 | points = np.where(fillmap == idx)
739 | sweep_dict[idx] = [points, False, sweep_count[i]]
740 | sweep_ml.append(points[0].min())
741 |
742 | sweep_list = sweep_list[np.argsort(np.array(sweep_ml))]
743 |
744 | return sweep_list, sweep_dict
745 |
746 | # turn fill map to sweep list
747 | r_idx_ref, r_dict_ref = to_sweep_list(fillmap_neural_fullsize)
748 | r_idx_source, r_dict_artist = to_sweep_list(fillmap_artist_fullsize)
749 |
750 | skip = []
751 | for rn in tqdm(r_idx_ref):
752 |
753 | if rn == 0: continue
754 |
755 | r1 = np.zeros(fillmap_neural_fullsize.shape)
756 | r1[fillmap_neural_fullsize == rn] = 1
757 |
758 | for ra in r_idx_source:
759 | if ra == 0: continue
760 |
761 | # skip if this region has been merged
762 | if r_dict_artist[ra][1]: continue
763 |
764 | # compute iou of this two regions
765 | r2 = np.zeros(r1.shape)
766 | r2[fillmap_artist_fullsize == ra] = 1
767 | iou = (r1 * r2).sum()
768 |
769 | # compute the precentage of iou/region area
770 | c1 = iou/r_dict_ref[rn][2]
771 | c2 = iou/r_dict_artist[ra][2]
772 |
773 | # merge
774 | # r1 and r2 are quite similar, then use r2 instead of r1
775 | if c1 > 0.9 and c2 > 0.9:
776 | result[r_dict_artist[ra][0]] = rn
777 | r_dict_artist[ra][1] = True
778 | continue
779 |
780 | # # r1 is almost contained by r2, the keep r1
781 | # elif c1 > 0.9 and c2 < 0.6:
782 | # result[r_dict_ref[rn][0]] = rn
783 | # # todo:
784 | # # then we need refinement!
785 |
786 | # r2 is almost covered by r1, then merge r2 into r1
787 | elif c1 < 0.6 and c2 > 0.9:
788 | result[r_dict_artist[ra][0]] = rn
789 | r_dict_artist[ra][1] = True
790 |
791 | # r1 and r2 are not close, do nothing then
792 | else:
793 | # we probably could record the c1 and c2, see what the parameter looks like
794 | if c1 != 0 and c2 != 0:
795 | skip.append((c1,c2))
796 |
797 | return result.astype(np.uint8), skip
798 |
799 |
800 | def show_region(region_bit):
801 | plt.imshow(show_fill_map(region_bit))
802 | plt.show()
803 |
804 | def get_figsize(img_num, row_num, img_size, dpi = 100):
805 | # inches = resolution / dpi
806 | # assume all image have the same resolution
807 | width = row_num * (img_size[0] + 200)
808 | height = round(img_num / row_num + 0.5) * (img_size[1] + 300)
809 | return width / dpi, height / dpi
810 |
811 | def visualize(test_folder, row_num = 3):
812 | '''
813 |
814 | '''
815 |
816 | img_list = []
817 | for img in os.listdir(test_folder):
818 | img_list.append(img)
819 |
820 | img = Image.open(os.path.join(test_folder, img_list[0]))
821 |
822 |
823 | # visualize collected centers
824 | plt.rcParams["figure.figsize"] = get_figsize(len(img_list), row_num, img.size)
825 |
826 |
827 | i = 0
828 | for i in range((len(img_list)//row_num + 1 if len(img_list)%row_num != 0 else len(img_list)//row_num)):
829 | for j in range(row_num):
830 | plt.subplot(len(img_list)//row_num + 1 , row_num , i*row_num+j+1)
831 | if i*row_num+j < len(img_list):
832 | # the whole function contains two steps
833 | # 1. get region map
834 | img = region_get_map(os.path.join(test_folder, img_list[i*row_num+j]))
835 | # 2. fill the region map
836 | plt.imshow(show_fill_map(img))
837 | plt.title(img_list[i])
838 | return plt
839 |
840 | def radius_percentile_explor(radius_set, method_set, input, output):
841 | for radius in radius_set:
842 | for method in method_set:
843 | print("Log:\ttrying radius %d with percentile %s"%(radius, method))
844 |
845 | # get file name
846 | _, file = os.path.split(input)
847 | name, _ = os.path.splitext(file)
848 |
849 | # open image
850 | img_org = cv2.imread(input, cv2.IMREAD_COLOR)
851 | img = cv2.cvtColor(img_org, cv2.COLOR_BGR2GRAY)
852 |
853 | ret, binary = cv2.threshold(img, 220, 255, cv2.THRESH_BINARY)
854 | fills = []
855 | result = binary # this should be binary numpu array
856 |
857 | # save result
858 | fill = trapped_ball_fill_multi(result, radius, percentile=method)
859 |
860 | outpath = os.path.join(output, name+"_%d"%radius+"_percentail %s.png"%str(method))
861 | out_map = show_fill_map(build_fill_map(result, fill))
862 | out_map[np.where(binary == 0)]=0
863 | cv2.imwrite(outpath, out_map)
864 |
865 |
866 | # outpath = os.path.join(output, name+"_%d"%radius+"_percentail%s_final.png"%str(method))
867 | # out_map = show_fill_map(thinning(build_fill_map(result, fill)))
868 | # cv2.imwrite(outpath, out_map)
869 | def radius_percentile_explor_repeat(radius_set, input, output, percentile_set = [100], repeat = 20):
870 | for r in radius_set:
871 | for p in percentile_set:
872 | _, file = os.path.split(input)
873 | name, _ = os.path.splitext(file)
874 |
875 | # open image
876 | img_org = cv2.imread(input, cv2.IMREAD_COLOR)
877 | img = cv2.cvtColor(img_org, cv2.COLOR_BGR2GRAY)
878 |
879 | ret, binary = cv2.threshold(img, 220, 255, cv2.THRESH_BINARY)
880 | fills = []
881 | result = binary # this should be binary numpu array
882 |
883 | for i in range(1, repeat+1):
884 | print("Log:\ttrying radius %d with percentile %s, No.%d"%(r, str(p), i))
885 |
886 | # get file name
887 |
888 |
889 | # save result
890 | fill = trapped_ball_fill_multi(result, r, percentile=p)
891 | fills+=fill
892 |
893 | outpath = os.path.join(output, name+"_%d"%r+"_percentail %s_No %d.png"%(str(p), i))
894 | out_map = show_fill_map(build_fill_map(result, fills))
895 | out_map[np.where(binary == 0)]=0
896 |
897 | cv2.imwrite(outpath, out_map)
898 | result = mark_fill(result, fill)
899 |
900 | def trappedball_2pass_exp(path_line, path_line_sim, save_file=False):
901 |
902 | # open image
903 | line = cv2.imread(path_line, cv2.IMREAD_COLOR)
904 | line = cv2.cvtColor(line, cv2.COLOR_BGR2GRAY)
905 |
906 | line_sim = cv2.imread(path_line_sim, cv2.IMREAD_COLOR)
907 | line_sim = cv2.cvtColor(line_sim, cv2.COLOR_BGR2GRAY)
908 |
909 | _, line = cv2.threshold(line, 220, 255, cv2.THRESH_BINARY)
910 | _, binary = cv2.threshold(line_sim, 220, 255, cv2.THRESH_BINARY)
911 |
912 | result = binary
913 | fills = []
914 |
915 | # filling
916 | fill = trapped_ball_fill_multi(result, 1, percentile=0)
917 | fills += fill
918 | result = mark_fill(result, fill)
919 |
920 | # fill rest region
921 | fill = flood_fill_multi(result)
922 | fills += fill
923 |
924 | # merge
925 | fillmap = build_fill_map(result, fills)
926 | fillmap = merge_fill(fillmap)
927 |
928 | # thin
929 | fillmap = thinning(fillmap)
930 |
931 | # let's do 2nd pass merge!
932 | fillmap_full = cv2.resize(fillmap.astype(np.uint8),
933 | (line.shape[1], line.shape[0]),
934 | interpolation = cv2.INTER_NEAREST)
935 |
936 | # construct a full mask
937 | line_sim_scaled = cv2.resize(line_sim.astype(np.uint8),
938 | (line.shape[1], line.shape[0]),
939 | interpolation = cv2.INTER_NEAREST)
940 | # line_full = cv2.bitwise_and(line, line_sim_scaled)
941 | line_full = line
942 |
943 | # fillmap_full[np.where(line_full<220)] = 0
944 | fillmap_full[line_full<220] = 0
945 | fillmap_full = merger_fill_2nd(fillmap_full)[0]
946 | # fillmap_full = thinning(fillmap_full)
947 |
948 | '''
949 | save results
950 | '''
951 | if save_file:
952 | # show fill map
953 | fill_scaled = show_fill_map(fillmap)
954 | fill_scaled_v1 = show_fill_map(fillmap_full)
955 | fill_full = cv2.resize(fill_scaled.astype(np.uint8),
956 | (line.shape[1], line.shape[0]),
957 | interpolation = cv2.INTER_NEAREST)
958 | line_scaled = cv2.resize(line.astype(np.uint8),
959 | (line_sim.shape[1], line_sim.shape[0]),
960 | interpolation = cv2.INTER_NEAREST)
961 |
962 | # overlay strokes
963 | fill_scaled[np.where(line_scaled<220)] = 0
964 | fill_scaled_v1[np.where(line<220)] = 0
965 | fill_full[np.where(line<220)]=0
966 |
967 | # save result
968 | cv2.imwrite("fill_sacled.png", fill_scaled)
969 | cv2.imwrite("fill_scaled_v1.png", fill_scaled_v1)
970 | cv2.imwrite("fill_full.png", fill_full)
971 |
972 | return fillmap_full
973 | if __name__ == '__main__':
974 |
975 | __spec__ = None
976 |
977 | parser = argparse.ArgumentParser()
978 |
979 | parser.add_argument("--single", action = 'store_true', help="process and save a single image to output")
980 | parser.add_argument("--show-intermediate", action = 'store_true', help="save intermediate results")
981 | parser.add_argument("--visualize", action = 'store_true')
982 | parser.add_argument("--exp1", action = 'store_true', help="experiment of exploring the parameter")
983 | parser.add_argument("--exp3", action = 'store_true', help="experiment of exploring the parameter")
984 | parser.add_argument("--exp4", action = 'store_true', help="experiment of exploring the parameter")
985 | parser.add_argument("--exp5", action = 'store_true', help="experiment of exploring the parameter")
986 | parser.add_argument('--input', type=str, default="./flatting/line_white_background/image0001_line.png",
987 | help = "path to input image, support png file only")
988 | parser.add_argument('--output', type=str, default="./exp1",
989 | help = "the path to result saving folder")
990 |
991 | args = parser.parse_args()
992 |
993 | if args.single:
994 | bit_map = region_get_map(args.input, args.output)
995 | if args.visualize:
996 | show_region(bit_map)
997 | elif args.exp1:
998 | # define the range of parameters
999 | radius_set1 = list(range(7, 15))
1000 | method_set = list(range(0, 101, 5)) + ["mean"]
1001 | radius_percentile_explor(radius_set1, method_set, args.input, args.output)
1002 | elif args.exp3:
1003 | radius_set2 = list(range(1, 15))
1004 | radius_percentile_explor_repeat(radius_set2, args.input, "./exp3")
1005 | elif args.exp4:
1006 | # let's test 2 pass merge
1007 | line = "./examples/01.png"
1008 | line_sim = "./examples/01_sim.png"
1009 | # trappedball_2pass_exp(line, line_sim)
1010 | region_get_map(line_sim,
1011 | path_to_line_artist=line,
1012 | output_path='./',
1013 | radius_set=[1],
1014 | percentiles=[0],
1015 | visualize_steps=False,
1016 | return_numpy=False)
1017 | elif args.exp5:
1018 | # let's test 2 pass merge
1019 | line = "./examples/tiny.png"
1020 | line_sim = "./examples/tiny_sim.png"
1021 | # trappedball_2pass_exp(line, line_sim)
1022 | region_get_map(line_sim,
1023 | path_to_line_artist=line,
1024 | output_path='./',
1025 | radius_set=[1],
1026 | percentiles=[0],
1027 | visualize_steps=False,
1028 | return_numpy=False)
1029 | else:
1030 | in_path = "./flatting/size_2048/line_detection_croped"
1031 | out_path = "./exp4"
1032 | for img in os.listdir(in_path):
1033 | region_get_map(join(in_path, img), out_path, radius_set=[1], percentiles=[0])
1034 |
--------------------------------------------------------------------------------
/src/flatting/trapped_ball/thinning.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 |
4 |
5 | def thinning(fillmap, max_iter=100):
6 | """Fill area of line with surrounding fill color.
7 |
8 | # Arguments
9 | fillmap: an image.
10 | max_iter: max iteration number.
11 |
12 | # Returns
13 | an image.
14 | """
15 | line_id = 0
16 | h, w = fillmap.shape[:2]
17 | result = fillmap.copy()
18 |
19 | for iterNum in range(max_iter):
20 | # Get points of line. if there is not point, stop.
21 | line_points = np.where(result == line_id)
22 | if not len(line_points[0]) > 0:
23 | break
24 |
25 | # Get points between lines and fills.
26 | line_mask = np.full((h, w), 255, np.uint8)
27 | line_mask[line_points] = 0
28 | line_border_mask = cv2.morphologyEx(line_mask, cv2.MORPH_DILATE,
29 | cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)), anchor=(-1, -1),
30 | iterations=1) - line_mask
31 | line_border_points = np.where(line_border_mask == 255)
32 |
33 | result_tmp = result.copy()
34 | # Iterate over points, fill each point with nearest fill's id.
35 | for i, _ in enumerate(line_border_points[0]):
36 | x, y = line_border_points[1][i], line_border_points[0][i]
37 |
38 | if x - 1 > 0 and result[y][x - 1] != line_id:
39 | result_tmp[y][x] = result[y][x - 1]
40 | continue
41 |
42 | if x - 1 > 0 and y - 1 > 0 and result[y - 1][x - 1] != line_id:
43 | result_tmp[y][x] = result[y - 1][x - 1]
44 | continue
45 |
46 | if y - 1 > 0 and result[y - 1][x] != line_id:
47 | result_tmp[y][x] = result[y - 1][x]
48 | continue
49 |
50 | if y - 1 > 0 and x + 1 < w and result[y - 1][x + 1] != line_id:
51 | result_tmp[y][x] = result[y - 1][x + 1]
52 | continue
53 |
54 | if x + 1 < w and result[y][x + 1] != line_id:
55 | result_tmp[y][x] = result[y][x + 1]
56 | continue
57 |
58 | if x + 1 < w and y + 1 < h and result[y + 1][x + 1] != line_id:
59 | result_tmp[y][x] = result[y + 1][x + 1]
60 | continue
61 |
62 | if y + 1 < h and result[y + 1][x] != line_id:
63 | result_tmp[y][x] = result[y + 1][x]
64 | continue
65 |
66 | if y + 1 < h and x - 1 > 0 and result[y + 1][x - 1] != line_id:
67 | result_tmp[y][x] = result[y + 1][x - 1]
68 | continue
69 |
70 | result = result_tmp.copy()
71 |
72 | return result
73 |
--------------------------------------------------------------------------------
/src/flatting/trapped_ball/thinning_zhang.py:
--------------------------------------------------------------------------------
1 | """
2 | ===========================
3 | @Author : Linbo
4 | @Version: 1.0 25/10/2014
5 | This is the implementation of the
6 | Zhang-Suen Thinning Algorithm for skeletonization.
7 | ===========================
8 | """
9 |
10 | def neighbours(x,y,image):
11 | "Return 8-neighbours of image point P1(x,y), in a clockwise order"
12 | img = image
13 | x_1, y_1, x1, y1 = x-1, y-1, x+1, y+1
14 | return [ img[x_1][y], img[x_1][y1], img[x][y1], img[x1][y1], # P2,P3,P4,P5
15 | img[x1][y], img[x1][y_1], img[x][y_1], img[x_1][y_1] ] # P6,P7,P8,P9
16 |
17 | def transitions(neighbours):
18 | "No. of 0,1 patterns (transitions from 0 to 1) in the ordered sequence"
19 | n = neighbours + neighbours[0:1] # P2, P3, ... , P8, P9, P2
20 | return sum( (n1, n2) == (0, 1) for n1, n2 in zip(n, n[1:]) ) # (P2,P3), (P3,P4), ... , (P8,P9), (P9,P2)
21 |
22 | def zhangSuen(image):
23 | "the Zhang-Suen Thinning Algorithm"
24 | Image_Thinned = image.copy() # deepcopy to protect the original image
25 | changing1 = changing2 = 1 # the points to be removed (set as 0)
26 | while changing1 or changing2: # iterates until no further changes occur in the image
27 | # Step 1
28 | changing1 = []
29 | rows, columns = Image_Thinned.shape # x for rows, y for columns
30 | for x in range(1, rows - 1): # No. of rows
31 | for y in range(1, columns - 1): # No. of columns
32 | P2,P3,P4,P5,P6,P7,P8,P9 = n = neighbours(x, y, Image_Thinned)
33 | if (Image_Thinned[x][y] == 1 and # Condition 0: Point P1 in the object regions
34 | 2 <= sum(n) <= 6 and # Condition 1: 2<= N(P1) <= 6
35 | transitions(n) == 1 and # Condition 2: S(P1)=1
36 | P2 * P4 * P6 == 0 and # Condition 3
37 | P4 * P6 * P8 == 0): # Condition 4
38 | changing1.append((x,y))
39 | for x, y in changing1:
40 | Image_Thinned[x][y] = 0
41 | # Step 2
42 | changing2 = []
43 | for x in range(1, rows - 1):
44 | for y in range(1, columns - 1):
45 | P2,P3,P4,P5,P6,P7,P8,P9 = n = neighbours(x, y, Image_Thinned)
46 | if (Image_Thinned[x][y] == 1 and # Condition 0
47 | 2 <= sum(n) <= 6 and # Condition 1
48 | transitions(n) == 1 and # Condition 2
49 | P2 * P4 * P8 == 0 and # Condition 3
50 | P2 * P6 * P8 == 0): # Condition 4
51 | changing2.append((x,y))
52 | for x, y in changing2:
53 | Image_Thinned[x][y] = 0
54 | return Image_Thinned
55 |
56 | import matplotlib
57 | import matplotlib.pyplot as plt
58 | import skimage.io as io
59 |
60 |
61 | if __name__ == "__main__":
62 | "load image data"
63 | Img_Original = io.imread( './data/test1.bmp') # Gray image, rgb images need pre-conversion
64 |
65 | "Convert gray images to binary images using Otsu's method"
66 | from skimage import filters
67 | Otsu_Threshold = filters.threshold_otsu(Img_Original)
68 | BW_Original = Img_Original < Otsu_Threshold # must set object region as 1, background region as 0 !
69 |
70 | "Apply the algorithm on images"
71 | BW_Skeleton = zhangSuen(BW_Original)
72 | # BW_Skeleton = BW_Original
73 | "Display the results"
74 | fig, ax = plt.subplots(1, 2)
75 | ax1, ax2 = ax.ravel()
76 | ax1.imshow(BW_Original, cmap=plt.cm.gray)
77 | ax1.set_title('Original binary image')
78 | ax1.axis('off')
79 | ax2.imshow(BW_Skeleton, cmap=plt.cm.gray)
80 | ax2.set_title('Skeleton of the image')
81 | ax2.axis('off')
82 | plt.show()
83 |
--------------------------------------------------------------------------------
/src/flatting/trapped_ball/trappedball_fill.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import pdb
4 | import pickle
5 | import time
6 | # use cython adjacency matrix
7 | try:
8 | from . import adjacency_matrix
9 | ## If it's not already compiled, compile it.
10 | except:
11 | import pyximport
12 | pyximport.install()
13 | from . import adjacency_matrix
14 | # it seemed that multi thread will not help to reduce running time
15 | # https://medium.com/python-experiments/parallelising-in-python-mutithreading-and-mutiprocessing-with-practical-templates-c81d593c1c49
16 | from multiprocessing import Pool
17 | from multiprocessing import freeze_support
18 | from functools import partial
19 | # from skimage.morphology import skeletonize, thin
20 |
21 |
22 |
23 | from tqdm import tqdm
24 | from PIL import Image
25 |
26 | def save_obj(fill_graph, save_path='fill_map.pickle'):
27 |
28 | with open(save_path, 'wb') as f:
29 | pickle.dump(fill_graph, f, protocol=pickle.HIGHEST_PROTOCOL)
30 |
31 | def load_obj(load_path='fill_map.pickle'):
32 |
33 | with open(load_path, 'rb') as f:
34 | fill_graph = pickle.load(f)
35 |
36 | return fill_graph
37 |
38 | def extract_line(fills_result):
39 |
40 | img = cv2.blur(fills_result,(5,5))
41 |
42 | # analyaze the gradient of flat image
43 | grad = cv2.Laplacian(img,cv2.CV_64F)
44 | grad = abs(grad).sum(axis = -1)
45 | grad_v, grad_c = np.unique(grad, return_counts=True)
46 |
47 | # remove the majority grad, which is 0
48 | assert np.where(grad_v==0) == np.where(grad_c==grad_c.max())
49 | grad_v = np.delete(grad_v, np.where(grad_v==0))
50 | grad_c = np.delete(grad_c, np.where(grad_c==grad_c.max()))
51 | print("Log:\tlen of grad_v %d"%len(grad_v))
52 | grad_c_cum = np.cumsum(grad_c)
53 |
54 | # if grad number is greater than 100, then this probably means the current
55 | # image exists pretty similar colors, then we should apply
56 | # another set of parameter to detect edge
57 | # this could be better if we can find the realtion between them
58 | if len(grad_v) < 100:
59 | min_val = grad_v[np.where(grad_c_cum<=np.percentile(grad_c_cum, 25))[0].max()]
60 | max_val = grad_v[np.where(grad_c_cum<=np.percentile(grad_c_cum, 40))[0].max()]
61 | else:
62 | min_val = grad_v[np.where(grad_c_cum<=np.percentile(grad_c_cum, 1))[0].max()]
63 | max_val = grad_v[np.where(grad_c_cum<=np.percentile(grad_c_cum, 10))[0].max()]
64 |
65 | edges = cv2.Canny(img, min_val, max_val, L2gradient=True)
66 | return 255-edges
67 |
68 | def to_masked_line(line_sim, line_artist, rk1=None, rk2=None, ak=None, tn=1):
69 | '''
70 | Given:
71 | line_sim, simplified line, which is also the neural networks output
72 | line_artist, artist line, the original input
73 | rk, remove kernel, thicken kernel. if the neural network ouput too thin line, use this option
74 | ak, add kernel, thinning kernel. if the neural network output too thick line, use this option
75 | Return:
76 | the masked line for filling
77 | '''
78 | # 1. generate lines removed unecessary strokes
79 | if rk1 != None:
80 | kernel_remove1 = get_ball_structuring_element(rk1)
81 | # make the simplified line to cover the artist's line
82 | mask_remove = cv2.morphologyEx(line_sim, cv2.MORPH_ERODE, kernel_remove1)
83 | else:
84 | mask_remove = line_sim
85 |
86 | mask_remove = np.logical_and(line_artist==0, mask_remove==0)
87 |
88 | # 2. generate lines that added by line_sim
89 | if ak != None:
90 | kernel_add = get_ball_structuring_element(ak)
91 | # try to make the artist's line cover the simplified line
92 | mask_add = cv2.morphologyEx(line_sim, cv2.MORPH_DILATE, kernel_add)
93 | else:
94 | mask_add = line_sim
95 |
96 | # may be we don't need that skeleton
97 | # mask_add = 255 - skeletonize((255 - mask_add)/255, method='lee')
98 |
99 | # let's try just thin it
100 | mask_add = 255 - thin(255 - mask_add, max_iter=tn).astype(np.uint8)*255
101 |
102 | if rk2 != None:
103 | kernel_remove2 = get_ball_structuring_element(rk2)
104 | line_artist = cv2.morphologyEx(line_artist, cv2.MORPH_ERODE, kernel_remove2)
105 | mask_add = np.logical_and(mask_add==0, np.logical_xor(mask_add==0, line_artist==0))
106 |
107 | # 3. combine and return the result
108 | mask = np.logical_or(mask_remove, mask_add).astype(np.uint8)*255
109 |
110 | # # 4. connect dot lines if exists
111 | # if connect != None:
112 | # kernel_con = get_ball_structuring_element(1)
113 | # for _ in range(connect):
114 | # mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel_con)
115 |
116 | return 255 - mask
117 |
118 |
119 | def get_ball_structuring_element(radius):
120 | """Get a ball shape structuring element with specific radius for morphology operation.
121 | The radius of ball usually equals to (leaking_gap_size / 2).
122 |
123 | # Arguments
124 | radius: radius of ball shape.
125 |
126 | # Returns
127 | an array of ball structuring element.
128 | """
129 | return cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2 * radius + 1, 2 * radius + 1))
130 |
131 |
132 | def get_unfilled_point(image):
133 | """Get points belong to unfilled(value==255) area.
134 |
135 | # Arguments
136 | image: an image.
137 |
138 | # Returns
139 | an array of points.
140 | """
141 | y, x = np.where(image == 255)
142 |
143 | return np.stack((x.astype(int), y.astype(int)), axis=-1)
144 |
145 |
146 | def exclude_area(image, radius):
147 | """Perform erosion on image to exclude points near the boundary.
148 | We want to pick part using floodfill from the seed point after dilation.
149 | When the seed point is near boundary, it might not stay in the fill, and would
150 | not be a valid point for next floodfill operation. So we ignore these points with erosion.
151 |
152 | # Arguments
153 | image: an image.
154 | radius: radius of ball shape.
155 |
156 | # Returns
157 | an image after dilation.
158 | """
159 | # https://docs.opencv.org/3.4/d4/d86/group__imgproc__filter.html
160 | return cv2.morphologyEx(image, cv2.MORPH_ERODE, get_ball_structuring_element(radius), anchor=(-1, -1), iterations=1)
161 |
162 |
163 | def trapped_ball_fill_single(image, seed_point, radius):
164 | """Perform a single trapped ball fill operation.
165 |
166 | # Arguments
167 | image: an image. the image should consist of white background, black lines and black fills.
168 | the white area is unfilled area, and the black area is filled area.
169 | seed_point: seed point for trapped-ball fill, a tuple (integer, integer).
170 | radius: radius of ball shape.
171 | # Returns
172 | an image after filling.
173 | """
174 |
175 | ball = get_ball_structuring_element(radius)
176 |
177 | pass1 = np.full(image.shape, 255, np.uint8)
178 | pass2 = np.full(image.shape, 255, np.uint8)
179 |
180 | im_inv = cv2.bitwise_not(image) # why inverse image?
181 |
182 | # Floodfill the image
183 | mask1 = cv2.copyMakeBorder(im_inv, 1, 1, 1, 1, cv2.BORDER_CONSTANT, 0)
184 |
185 | # retval, image, mask, rect = cv.floodFill(image, mask, seedPoint, newVal[, loDiff[, upDiff[, flags]]])
186 | # fill back pixles, Flood-filling cannot go across non-zero pixels in the input mask.
187 | _, pass1, _, _ = cv2.floodFill(pass1, mask1, seed_point, 0, 0, 0, 4) #seed point is the first unfilled point
188 |
189 | # Perform dilation on image. The fill areas between gaps became disconnected.
190 | # close any possible gaps that could be coverd by the ball
191 | pass1 = cv2.morphologyEx(pass1, cv2.MORPH_DILATE, ball, anchor=(-1, -1), iterations=1)
192 | mask2 = cv2.copyMakeBorder(pass1, 1, 1, 1, 1, cv2.BORDER_CONSTANT, 0)
193 |
194 | # Floodfill with seed point again to select one fill area.
195 | _, pass2, _, rect = cv2.floodFill(pass2, mask2, seed_point, 0, 0, 0, 4)
196 |
197 | # Perform erosion on the fill result leaking-proof fill.
198 |
199 | pass2 = cv2.morphologyEx(pass2, cv2.MORPH_ERODE, ball, anchor=(-1, -1), iterations=1)
200 |
201 | return pass2
202 |
203 |
204 | def trapped_ball_fill_multi(image, radius, percentile='mean', max_iter=1000, verbo=False):
205 | """Perform multi trapped ball fill operations until all valid areas are filled.
206 |
207 | # Arguments
208 | image: an image. The image should consist of white background, black lines and black fills.
209 | the white area is unfilled area, and the black area is filled area.
210 | radius: radius of ball shape.
211 | method: method for filtering the fills.
212 | 'max' is usually with large radius for select large area such as background.
213 | max_iter: max iteration number.
214 | # Returns
215 | an array of fills' points.
216 | """
217 | if verbo:
218 | print('trapped-ball ' + str(radius))
219 |
220 | unfill_area = image # so unfill_area is the binary numpy array (but contain 0, 255 only), 0 means filled region I guess
221 |
222 | h, w = image.shape
223 |
224 | filled_area, filled_area_size, result = [], [], []
225 |
226 | for _ in range(max_iter):
227 |
228 | # get the point list of unfilled regions
229 | points = get_unfilled_point(exclude_area(unfill_area, radius))
230 | # points = get_unfilled_point(unfill_area)
231 |
232 | # terminate if all points have been filled
233 | if not len(points) > 0:
234 | break
235 |
236 | # perform a single flood fill
237 | fill = trapped_ball_fill_single(unfill_area, (points[0][0], points[0][1]), radius)
238 |
239 | # update filled region
240 | unfill_area = cv2.bitwise_and(unfill_area, fill)
241 |
242 | # record filled region of each iter
243 | filled_area.append(np.where(fill == 0))
244 | filled_area_size.append(len(np.where(fill == 0)[0]))
245 |
246 | filled_area_size = np.asarray(filled_area_size)
247 |
248 | # a filter to remove the "half" filed regions
249 | if percentile == "mean":
250 | area_size_filter = np.mean(filled_area_size)
251 |
252 | elif type(percentile)==int:
253 | assert percentile>=0 and percentile<=100
254 | area_size_filter = np.percentile(filled_area_size, percentile)
255 | else:
256 | print("wrong percentile %s"%percentile)
257 | raise ValueError
258 |
259 | result_idx = np.where(filled_area_size >= area_size_filter)[0]
260 |
261 | # filter out all region that is less than the area_size_filter
262 | for i in result_idx:
263 | result.append(filled_area[i])
264 |
265 | # result is a list of point list for each filled region
266 | return result
267 |
268 |
269 | def flood_fill_single(im, seed_point):
270 | """Perform a single flood fill operation.
271 |
272 | # Arguments
273 | image: an image. the image should consist of white background, black lines and black fills.
274 | the white area is unfilled area, and the black area is filled area.
275 | seed_point: seed point for trapped-ball fill, a tuple (integer, integer).
276 | # Returns
277 | an image after filling.
278 | """
279 | pass1 = np.full(im.shape, 255, np.uint8)
280 |
281 | im_inv = cv2.bitwise_not(im)
282 |
283 | mask1 = cv2.copyMakeBorder(im_inv, 1, 1, 1, 1, cv2.BORDER_CONSTANT, 0)
284 | _, pass1, _, _ = cv2.floodFill(pass1, mask1, seed_point, 0, 0, 0, 4)
285 |
286 | return pass1
287 |
288 |
289 | def flood_fill_multi(image, max_iter=20000, verbo=False):
290 |
291 | """Perform multi flood fill operations until all valid areas are filled.
292 | This operation will fill all rest areas, which may result large amount of fills.
293 |
294 | # Arguments
295 | image: an image. the image should contain white background, black lines and black fills.
296 | the white area is unfilled area, and the black area is filled area.
297 | max_iter: max iteration number.
298 | # Returns
299 | an array of fills' points.
300 | """
301 | if verbo:
302 | print('floodfill')
303 |
304 | unfill_area = image
305 | filled_area = []
306 |
307 | for _ in range(max_iter):
308 | points = get_unfilled_point(unfill_area)
309 |
310 | if not len(points) > 0:
311 | break
312 |
313 | fill = flood_fill_single(unfill_area, (points[0][0], points[0][1]))
314 | unfill_area = cv2.bitwise_and(unfill_area, fill)
315 |
316 | filled_area.append(np.where(fill == 0))
317 |
318 | return filled_area
319 |
320 |
321 | def mark_fill(image, fills):
322 | """Mark filled areas with 0.
323 |
324 | # Arguments
325 | image: an image.
326 | fills: an array of fills' points.
327 | # Returns
328 | an image.
329 | """
330 | result = image.copy()
331 |
332 | for fill in fills:
333 | result[fill] = 0
334 |
335 | return result
336 |
337 |
338 | def build_fill_map(image, fills):
339 | """Make an image(array) with each pixel(element) marked with fills' id. id of line is 0.
340 |
341 | # Arguments
342 | image: an image.
343 | fills: an array of fills' points.
344 | # Returns
345 | an array.
346 | """
347 | result = np.zeros(image.shape[:2], np.int)
348 |
349 | for index, fill in enumerate(fills):
350 | result[fill] = index + 1
351 |
352 | return result
353 |
354 |
355 | def show_fill_map(fillmap):
356 | """Mark filled areas with colors. It is useful for visualization.
357 |
358 | # Arguments
359 | image: an image.
360 | fills: an array of fills' points.
361 | # Returns
362 | an image.
363 | """
364 | # Generate color for each fill randomly.
365 | colors = np.random.randint(0, 255, (np.max(fillmap) + 1, 3), dtype=np.uint8)
366 | # Id of line is 0, and its color is black.
367 | colors[0] = [0, 0, 0]
368 |
369 | return colors[fillmap]
370 |
371 |
372 | def get_bounding_rect(points):
373 | """Get a bounding rect of points.
374 |
375 | # Arguments
376 | points: array of points.
377 | # Returns
378 | rect coord
379 | """
380 | x1, y1, x2, y2 = np.min(points[1]), np.min(points[0]), np.max(points[1]), np.max(points[0])
381 | return x1, y1, x2, y2
382 |
383 |
384 | def get_border_bounding_rect(h, w, p1, p2, r):
385 | """Get a valid bounding rect in the image with border of specific size.
386 |
387 | # Arguments
388 | h: image max height.
389 | w: image max width.
390 | p1: start point of rect.
391 | p2: end point of rect.
392 | r: border radius.
393 | # Returns
394 | rect coord
395 | """
396 | x1, y1, x2, y2 = p1[0], p1[1], p2[0], p2[1]
397 |
398 | x1 = x1 - r if 0 < x1 - r else 0
399 | y1 = y1 - r if 0 < y1 - r else 0
400 | x2 = x2 + r + 1 if x2 + r + 1 < w else w # why here plus 1?
401 | y2 = y2 + r + 1 if y2 + r + 1 < h else h
402 |
403 | return x1, y1, x2, y2
404 |
405 |
406 | def get_border_point(points, rect, max_height, max_width):
407 | """Get border points of a fill area
408 |
409 | # Arguments
410 | points: points of fill .
411 | rect: bounding rect of fill.
412 | max_height: image max height.
413 | max_width: image max width.
414 | # Returns
415 | points , convex shape of points
416 | """
417 |
418 | # Get a local bounding rect.
419 | # what this function used for?
420 | border_rect = get_border_bounding_rect(max_height, max_width, rect[:2], rect[2:], 2)
421 |
422 | # Get fill in rect, all 0s
423 | fill = np.zeros((border_rect[3] - border_rect[1], border_rect[2] - border_rect[0]), np.uint8)
424 |
425 | # Move points to the rect.
426 | # offset points into the fill
427 | fill[(points[0] - border_rect[1], points[1] - border_rect[0])] = 255
428 |
429 | # Get shape.
430 | # pdb.set_trace()
431 | contours, _ = cv2.findContours(fill, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
432 | approx_shape = cv2.approxPolyDP(contours[0], 0.02 * cv2.arcLength(contours[0], True), True)
433 |
434 | # Get border pixel.
435 | # Structuring element in cross shape is used instead of box to get 4-connected border.
436 | '''
437 | # Cross-shaped Kernel
438 | >>> cv2.getStructuringElement(cv2.MORPH_CROSS,(5,5))
439 | array([[0, 0, 1, 0, 0],
440 | [0, 0, 1, 0, 0],
441 | [1, 1, 1, 1, 1],
442 | [0, 0, 1, 0, 0],
443 | [0, 0, 1, 0, 0]], dtype=uint8)
444 | '''
445 | cross = cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)) # this is a ball shape kernel
446 | border_pixel_mask = cv2.morphologyEx(fill, cv2.MORPH_DILATE, cross, anchor=(-1, -1), iterations=1) - fill
447 | border_pixel_points = np.where(border_pixel_mask == 255)
448 |
449 | # Transform points back to fillmap.
450 | border_pixel_points = (border_pixel_points[0] + border_rect[1], border_pixel_points[1] + border_rect[0])
451 |
452 | return border_pixel_points, approx_shape
453 |
454 |
455 | def merge_fill(fillmap, max_iter=10, verbo=False):
456 | """Merge fill areas.
457 |
458 | # Arguments
459 | fillmap: an image.
460 | max_iter: max iteration number.
461 | # Returns
462 | an image.
463 | """
464 | max_height, max_width = fillmap.shape[:2]
465 | result = fillmap.copy()
466 |
467 | for i in range(max_iter):
468 | if verbo:
469 | print('merge ' + str(i + 1))
470 |
471 | # set stroke as black
472 | result[np.where(fillmap == 0)] = 0
473 |
474 | # get list of fill id
475 | fill_id = np.unique(result.flatten())
476 | fills = []
477 |
478 | for j in fill_id:
479 |
480 | # select one region each time
481 | point = np.where(result == j)
482 |
483 | fills.append({
484 | 'id': j,
485 | 'point': point,
486 | 'area': len(point[0]),
487 | 'rect': get_bounding_rect(point)
488 | })
489 |
490 | for j, f in enumerate(fills):
491 |
492 | # ignore lines
493 | if f['id'] == 0:
494 | continue
495 |
496 | # get border shape of a region, but that may contains many nosiy segementation?
497 | border_points, approx_shape = get_border_point(f['point'], f['rect'], max_height, max_width)
498 | border_pixels = result[border_points] # pixel values or seg index of that region
499 | pixel_ids, counts = np.unique(border_pixels, return_counts=True)
500 |
501 | # remove id that equal 0
502 | ids = pixel_ids[np.nonzero(pixel_ids)]
503 | new_id = f['id']
504 | if len(ids) == 0:
505 | # points with lines around color change to line color
506 | # regions surrounded by line remain the same
507 | if f['area'] < 5:
508 | # if f['area'] < 32:
509 | new_id = 0
510 | else:
511 | # region id may be set to region with largest contact
512 | new_id = ids[0]
513 |
514 | # a point, because the convex shape only contains 1 point
515 | if len(approx_shape) == 1 or f['area'] == 1:
516 | result[f['point']] = new_id
517 |
518 | # so this means
519 | if len(approx_shape) in [2, 3, 4, 5] and f['area'] < 500:
520 | # if len(approx_shape) in [2, 3, 4, 5] and f['area'] < 10000:
521 | result[f['point']] = new_id
522 |
523 | if f['area'] < 250 and len(ids) == 1:
524 | # if f['area'] < 5000 and len(ids) == 1:
525 | result[f['point']] = new_id
526 |
527 | if f['area'] < 50:
528 | # if f['area'] < 100:
529 | result[f['point']] = new_id
530 |
531 | # if no merge happen, stop this process
532 | if len(fill_id) == len(np.unique(result.flatten())):
533 | break
534 |
535 | return result
536 |
537 | def search_point(points, point):
538 |
539 | idx = np.where((points == point).all(axis = 1))[0]
540 |
541 | return idx
542 |
543 | def extract_region_obsolete(points, point, width, height):
544 |
545 | # unfortunately, this function is too costly to run
546 |
547 | # get 8-connectivity neighbors
548 | point_list = []
549 |
550 | # search top left
551 | # point[0] is height
552 | # point[1] is width
553 | if point[0] > 0 and point[1] > 0:
554 | tl = np.array([point[0]-1, point[1]-1])
555 | idx = search_point(points, tl)
556 | if len(idx) == 0:
557 | pass
558 | elif len(idx) == 1:
559 | point_list.append(tl)
560 | # pop out current point
561 | points = np.delete(points, idx, axis=0)
562 | point_list += extract_region(points, tl, width, height)
563 | else:
564 | raise ValueError("There should not exist two identical points in the list!")
565 | # search top
566 | if point[0] > 0:
567 | t = np.array([point[0], point[1]-1])
568 | idx = search_point(points, t)
569 | if len(idx) == 0:
570 | pass
571 | elif len(idx) == 1:
572 | point_list.append(t)
573 | # pop out current point
574 | points = np.delete(points, idx, axis=0)
575 | point_list += extract_region(points, t, width, height)
576 | else:
577 | raise ValueError("There should not exist two identical points in the list!")
578 |
579 | # search top right
580 | if point[0] > 0 and point[1] < width:
581 | tr = np.array([point[0]-1, point[1]+1])
582 | idx = search_point(points, tr)
583 | if len(idx) == 0:
584 | pass
585 | elif len(idx) == 1:
586 | point_list.append(tr)
587 | # pop out current point
588 | points = np.delete(points, idx, axis=0)
589 | point_list += extract_region(points, tr, width, height)
590 | else:
591 | raise ValueError("There should not exist two identical points in the list!")
592 |
593 | # search mid left
594 | if point[1] > 0:
595 | ml = np.array([point[0], point[1]-1])
596 | idx = search_point(points, ml)
597 | if len(idx) == 0:
598 | pass
599 | elif len(idx) == 1:
600 | point_list.append(ml)
601 | # pop out current point
602 | points = np.delete(points, idx, axis=0)
603 | point_list += extract_region(points, ml, width, height)
604 | else:
605 | raise ValueError("There should not exist two identical points in the list!")
606 |
607 | # search mid right
608 | if point[1] < width:
609 | mr = np.array([point[0], point[1]+1])
610 | idx = search_point(points, mr)
611 | if len(idx) == 0:
612 | pass
613 | elif len(idx) == 1:
614 | point_list.append(mr)
615 | # pop out current point
616 | points = np.delete(points, idx, axis=0)
617 | point_list += extract_region(points, mr, width, height)
618 | else:
619 | raise ValueError("There should not exist two identical points in the list!")
620 |
621 | # search bottom left
622 | if point[0] < height and point[1] > 0:
623 | bl = np.array([point[0]+1, point[1]-1])
624 | idx = search_point(points, bl)
625 | if len(idx) == 0:
626 | pass
627 | elif len(idx) == 1:
628 | point_list.append(bl)
629 | # pop out current point
630 | points = np.delete(points, idx, axis=0)
631 | point_list += extract_region(points, bl, width, height)
632 | else:
633 | raise ValueError("There should not exist two identical points in the list!")
634 |
635 | # search bottom
636 | if point[0] < height:
637 | b = np.array([point[0]+1, point[1]])
638 | idx = search_point(points, b)
639 | if len(idx) == 0:
640 | pass
641 | elif len(idx) == 1:
642 | point_list.append(b)
643 | # pop out current point
644 | points = np.delete(points, idx, axis=0)
645 | point_list += extract_region(points, b, width, height)
646 | else:
647 | raise ValueError("There should not exist two identical points in the list!")
648 |
649 | # search bottom right
650 | if point[0] < height and point[1] < width:
651 | br = np.array([point[0]+1, point[1]+1])
652 | idx = search_point(points, br)
653 | if len(idx) == 0:
654 | pass
655 | elif len(idx) == 1:
656 | point_list.append(br)
657 | # pop out current point
658 | points = np.delete(points, idx, axis=0)
659 | point_list += extract_region(points, br, width, height)
660 | else:
661 | raise ValueError("There should not exist two identical points in the list!")
662 |
663 | return point_list
664 |
665 | def extract_region():
666 |
667 | # let's try flood fill
668 |
669 | flood_fill_multi(image, max_iter=20000)
670 |
671 | def to_graph(fillmap, fillid):
672 |
673 | # how to speed up this part?
674 | # use another graph data structure
675 | # or maybe use list instead of dict
676 |
677 | fills = {}
678 | for j in tqdm(fillid):
679 |
680 | # select one region each time
681 | point = np.where(fillmap == j)
682 |
683 | fills[j] = {"point":point,
684 | "area": len(point[0]),
685 | "rect": get_bounding_rect(point),
686 | "neighbor":[]}
687 | return fills
688 |
689 | def to_fillmap(fillmap, fills):
690 |
691 | for j in fills:
692 | if fills[j] == None:
693 | continue
694 | fillmap[fills[j]['point']] = j
695 |
696 | return fillmap
697 |
698 | def merge_list_ordered(list1, list2, idx):
699 |
700 | for value in list2:
701 | if value not in list1 and value != idx and value != None:
702 | list1.append(value)
703 |
704 | return list1
705 |
706 | def merge_region(fills, source_idx, target_idx, result):
707 |
708 | # merge from source to target
709 | assert fills[source_idx] != None
710 | assert fills[target_idx] != None
711 |
712 | # update target region
713 | fills[target_idx]['point'] = (np.concatenate((fills[target_idx]['point'][0], fills[source_idx]['point'][0])),
714 | np.concatenate((fills[target_idx]['point'][1], fills[source_idx]['point'][1])))
715 |
716 | fills[target_idx]['area'] += fills[source_idx]['area']
717 | assert len(fills[target_idx]['point'][0]) + len(fills[source_idx]['point'][0]) == fills[target_idx]['area'] + fills[source_idx]['area']
718 |
719 | fills[target_idx]['neighbor'] = merge_list_ordered(fills[target_idx]['neighbor'],
720 | fills[source_idx]['neighbor'], target_idx)
721 |
722 | # update source's neighbor
723 | for n in fills[source_idx]['neighbor']:
724 | if n != None:
725 | if source_idx in fills[n]['neighbor']:
726 | t = fills[n]['neighbor'].index(source_idx)
727 | fills[n]['neighbor'][t] = None
728 | else:
729 | print("find one side neighbor")
730 |
731 | if target_idx not in fills[n]['neighbor'] and target_idx != n:
732 | fills[n]['neighbor'].append(target_idx)
733 |
734 | # remove source region
735 | fills[source_idx] = None
736 |
737 | return fills
738 |
739 | def split():
740 | # this might be a different function
741 | pass
742 |
743 | def list_region(fill_graph, th = None, verbo = True):
744 |
745 | regions = 0
746 | small_regions = []
747 | for key in fill_graph:
748 | if fill_graph[key] != None:
749 |
750 | if verbo:
751 | print("Log:\tregion %d with size %d"%(key, fill_graph[key]['area']))
752 | regions += 1
753 |
754 | if th == None:
755 | continue
756 |
757 | # collect small regions
758 | if fill_graph[key]['area'] < th:
759 | small_regions.append(key)
760 |
761 | if verbo:
762 | print("Log:\ttotal regions %d"%regions)
763 |
764 | return small_regions
765 |
766 | def visualize_graph(fills_graph, result, region=None):
767 | if region == None:
768 | Image.fromarray(show_fill_map(to_fillmap(result, fills_graph)).astype(np.uint8)).show()
769 | else:
770 | assert region in fills_graph
771 | show_map = np.zeros(result.shape, np.uint8)
772 | show_map[fills_graph[region]['point']] = 255
773 | Image.fromarray(show_map).show()
774 |
775 | def visualize_result(result, region=None):
776 | if region == None:
777 | Image.fromarray(show_fill_map(result).astype(np.uint8)).show()
778 | else:
779 | assert region in result
780 | show_map = np.zeros(result.shape, np.uint8)
781 | show_map[np.where(result == region)] = 255
782 | Image.fromarray(show_map).show()
783 |
784 | def graph_self_check(fill_graph):
785 |
786 | for key in fill_graph:
787 | if fill_graph[key] != None:
788 | if len(fill_graph[key]['neighbor']) > 0:
789 | if len(fill_graph[key]['neighbor']) != len(set(fill_graph[key]['neighbor'])):
790 | print("Log:\tfind duplicate neighbor!")
791 | for n in fill_graph[key]['neighbor']:
792 | if key not in fill_graph[n]['neighbor']:
793 | print("Log:\tfind missing neighbor")
794 | # print("Log:\tregion %d has %d points"%(key, fill_graph[key]['area']))
795 |
796 | def flood_fill_single_proc(region_id, img):
797 |
798 | # construct fill region
799 | fill_region = np.full(img.shape, 0, np.uint8)
800 | fill_region[np.where(img == region_id)] = 255
801 | return flood_fill_multi(fill_region, verbo=False)
802 |
803 | def flood_fill_multi_proc(func, fill_id, result, n_proc):
804 | print("Log:\tmulti process spliting bleeding regions")
805 | with Pool(processes=n_proc) as p:
806 | return p.map(partial(func, img=result), fill_id)
807 |
808 | def split_region(result, multi_proc=False):
809 |
810 | # # get list of fill id
811 | # fill_id = np.unique(result.flatten()).tolist()
812 | # fill_id.remove(0)
813 | # assert 0 not in fill_id
814 |
815 | # _, result = cv2.connectedComponents(result, connectivity=4)
816 | # # there will left some small regions, we can merge them into region 0 in the following step
817 |
818 | # # result = build_fill_map(result, fill_points)
819 |
820 | # fill_id_new = np.unique(result)
821 |
822 | # generate thershold of merging region
823 | w, h = result.shape
824 | th = int(w*h*0.09)
825 | # get list of fill id
826 | fill_id = np.unique(result.flatten()).tolist()
827 | fill_id.remove(0)
828 | assert 0 not in fill_id
829 |
830 | fill_points = []
831 |
832 | # get each region ready to be filled
833 | if multi_proc:
834 | n_proc = 8
835 | start = time.process_time()
836 |
837 | fill_points_multi_proc = flood_fill_multi_proc(flood_fill_single_proc, fill_id, result, n_proc)
838 | for f in fill_points_multi_proc:
839 | fill_points += f
840 |
841 | print("Mutiprocessing time: {}secs\n".format((time.process_time()-start)))
842 |
843 | else:
844 | # split each region if it is splited by ink region
845 | start = time.process_time()
846 | for j in tqdm(fill_id):
847 |
848 | # skip strokes
849 | if j == 0:
850 | continue
851 |
852 | # generate fill mask of that region
853 | fill_region = np.full(result.shape, 0, np.uint8)
854 | fill_region[np.where(result == j)] = 255
855 |
856 | # corp to a smaller region that only cover the current filling region to speed up
857 | # todo
858 |
859 | # assign new id to
860 | fills = flood_fill_multi(fill_region, verbo=False)
861 |
862 | merge = []
863 | merge_idx = []
864 | for i in range(len(fills)):
865 | if len(fills[i][0]) > th:
866 | merge_idx.append(i)
867 |
868 | for i in range(len(merge_idx)):
869 | merge.append(fills[merge_idx[i]])
870 |
871 | for i in merge_idx:
872 | fills.pop(i)
873 |
874 | if len(merge) > 0:
875 | region_merged = merge.pop(0)
876 | for p in merge:
877 | region_merged = (np.concatenate((region_merged[0], p[0])), np.concatenate((region_merged[1], p[1])))
878 | fills.append(region_merged)
879 |
880 | fill_points += fills
881 | print("Single-processing time: {}secs\n".format((time.process_time()-start)))
882 |
883 | result = build_fill_map(result, fill_points)
884 | fill_id_new = np.unique(result)
885 |
886 | return result, fill_id_new
887 |
888 | def find_neighbor(result, fills_graph, max_height, max_width):
889 |
890 | fill_id_new = np.unique(result)
891 |
892 | for j in tqdm(fill_id_new):
893 |
894 | if j == 0:
895 | continue
896 |
897 | fill_region = np.zeros(result.shape, np.uint8)
898 | fill_region[np.where(result == j)] = 255
899 |
900 | # find boundary of each region
901 | # sometimes this function is not multually correct, why?
902 | border_points, _ = get_border_point(fills_graph[j]['point'], fills_graph[j]['rect'], max_height, max_width)
903 |
904 | # construct a graph map of all regions
905 | neighbor_id = np.unique(result[border_points])
906 |
907 | # record neighbor information
908 | for k in neighbor_id:
909 | if k != 0:
910 | if k not in fills_graph[j]["neighbor"]:
911 | fills_graph[j]["neighbor"].append(k)
912 | if j not in fills_graph[k]["neighbor"]:
913 | fills_graph[k]["neighbor"].append(j)
914 |
915 | return fills_graph
916 |
917 | def find_min_neighbor(fills_graph, idx):
918 |
919 | neighbors = []
920 | neighbor_sizes = []
921 | for n in fills_graph[idx]["neighbor"]:
922 | if n != None:
923 | neighbors.append(n)
924 | neighbor_sizes.append(fills_graph[n]['area'])
925 | else:
926 | neighbors.append(n)
927 | neighbor_sizes.append(-1)
928 |
929 | # we need to sort the index
930 | sort_idx = sorted(range(len(neighbors)), key=lambda k: neighbor_sizes[k])
931 |
932 | # for i in sort_idx:
933 | # print("Log:\tregion %d with size %d"%(neighbors[i], neighbor_sizes[i]))
934 |
935 | return sort_idx, neighbor_sizes
936 |
937 | def merge_all_neighbor(fills_graph, idx, result):
938 |
939 | for n in fills_graph[idx]['neighbor']:
940 | result[fills_graph[n]['point']] = idx
941 |
942 | return result
943 |
944 | def check_all_neighbor(fills_graph, j, low_th, max_th):
945 |
946 | min_neighbors = find_min_neighbor(fills_graph, j)
947 | for k in min_neighbors:
948 | nb = fills_graph[j]['neighbor'][k]
949 |
950 | if nb != None:
951 | if fills_graph[nb] != None and fills_graph[nb]['area'] <= low_th and nb != j:
952 | continue
953 | else:
954 | print("Log:\texclude region %d"%nb)
955 |
956 | def remove_bleeding(fills_graph, fill_id_new, max_iter, result, low_th, max_th):
957 |
958 | count = 0
959 | really_low_th = 100
960 | # max region absorb small neighbors
961 | for i in range(max_iter):
962 | # print('merge 2nd ' + str(i + 1))
963 | for j in tqdm(fill_id_new):
964 | if j == 0:
965 | continue
966 | if fills_graph[j] == None: # this region has been removed
967 | continue
968 | if fills_graph[j]['area'] < max_th:
969 | continue
970 |
971 | min_neighbors, min_neighbor_sizes = find_min_neighbor(fills_graph, j)
972 | # print("Log:\tfound region %d have %d neighbors"%(j, len(min_neighbors)))
973 |
974 | for k in min_neighbors:
975 |
976 | nb = fills_graph[j]['neighbor'][k]
977 |
978 | if min_neighbor_sizes[k] == -1:
979 | continue
980 |
981 | if nb != None:
982 | if fills_graph[nb] != None and fills_graph[nb]['area'] <= low_th and nb != j:
983 | fills_graph = merge_region(fills_graph, nb, j, result)
984 | count += 1
985 | else:
986 | fills_graph[j]['neighbor'][k] = None
987 |
988 | # small region join its largest neighbor
989 | small_regions = list_region(fills_graph, low_th, False)
990 | first_loop = True
991 | num_samll_before = len(small_regions)
992 | num_samll_after = len(small_regions)
993 |
994 | while first_loop or num_samll_before - num_samll_after > 0:
995 | first_loop = False
996 | for s in small_regions:
997 |
998 | min_neighbors, min_neighbor_sizes = find_min_neighbor(fills_graph, s)
999 |
1000 | if len(min_neighbors) == 0 or min_neighbors == None or min_neighbor_sizes[min_neighbors[-1]] == -1:
1001 | if fills_graph[s]['area'] < really_low_th:
1002 | fills_graph = merge_region(fills_graph, s, 0, result)
1003 | continue
1004 |
1005 | t = fills_graph[s]['neighbor'][min_neighbors[-1]]
1006 | fills_graph = merge_region(fills_graph, s, t, result)
1007 | count += 1
1008 |
1009 | small_regions = list_region(fills_graph, low_th, False)
1010 |
1011 | num_samll_before = num_samll_after
1012 | num_samll_after = len(small_regions)
1013 |
1014 | print("Log:\t %d neighbors merged"%count)
1015 |
1016 | return fills_graph
1017 |
1018 |
1019 | def merger_fill_2nd(fillmap, max_iter=10, low_th=0.001, max_th=0.01, debug=False):
1020 |
1021 | """
1022 | next step should be using multi threading in each step
1023 | get the function as fast as I can
1024 | """
1025 |
1026 | max_height, max_width = fillmap.shape[:2]
1027 | result = fillmap.copy()
1028 | low_th = int(max_height*max_width*low_th)
1029 | max_th = int(max_height*max_width*max_th)
1030 |
1031 | # 1. convert filling map to graphs
1032 | # this step take 99% of running time, need optimaization a lot
1033 | if debug:
1034 | print("Log:\tload fill_map.pickle")
1035 | result = load_obj("fill_map.pickle")
1036 | fill_id_new = np.unique(result)
1037 | else:
1038 | print("Log:\tsplit bleeding regions")
1039 | result, fill_id_new = split_region(result)
1040 |
1041 | # initailize the graph of regions
1042 | if debug:
1043 | print("Log:\tload fills_graph.pickle")
1044 | fills_graph_init = load_obj("fills_graph.pickle")
1045 | fills_graph = load_obj("fills_graph.pickle")
1046 | else:
1047 | print("Log:\tinitialize region graph")
1048 | fills_graph = to_graph(result, fill_id_new)
1049 |
1050 | # find neighbor
1051 | if debug:
1052 | print("Log:\tload fills_graph_n.pickle")
1053 | fills_graph = load_obj("fills_graph_n.pickle")
1054 | else:
1055 | print("Log:\tfind region neighbors")
1056 | fills_graph = find_neighbor(result, fills_graph, max_height, max_width)
1057 |
1058 | # self check if the graph is constructed correctly
1059 | graph_self_check(fills_graph)
1060 |
1061 | # 2. merge all small region to its largest neighbor
1062 | # this step seems fast, it only takes around 20s to finish
1063 | print("Log:\tremove leaking color")
1064 | fills_graph = remove_bleeding(fills_graph, fill_id_new, max_iter, result, low_th, max_th)
1065 |
1066 | # 3. show the refined the result
1067 | visualize_graph(fills_graph, result, region=None)
1068 |
1069 | # 4. map region graph back to fillmaps
1070 | result = to_fillmap(result, fills_graph)
1071 | return result, fills_graph
1072 |
--------------------------------------------------------------------------------
/src/flatting/unet/__init__.py:
--------------------------------------------------------------------------------
1 | from .unet_model import UNet
2 |
--------------------------------------------------------------------------------
/src/flatting/unet/unet_model.py:
--------------------------------------------------------------------------------
1 | """ Full assembly of the parts to form the complete network """
2 |
3 | import torch.nn.functional as F
4 |
5 | from .unet_parts import *
6 |
7 |
8 | class UNet(nn.Module):
9 | def __init__(self, in_channels, out_channels, bilinear=True):
10 | super(UNet, self).__init__()
11 | self.in_channels = in_channels
12 | self.out_channels = out_channels
13 | self.bilinear = bilinear
14 |
15 | self.inc = DoubleConv(in_channels, 64)
16 | self.down1 = Down(64, 128)
17 | self.down2 = Down(128, 256)
18 | self.down3 = Down(256, 512)
19 | factor = 2 if bilinear else 1
20 | self.down4 = Down(512, 1024 // factor)
21 | self.up1 = Up(1024, 512 // factor, bilinear)
22 | self.up2 = Up(512, 256 // factor, bilinear)
23 | self.up3 = Up(256, 128 // factor, bilinear)
24 | self.up4 = Up(128, 64, bilinear)
25 | self.outc = OutConv(64, out_channels)
26 |
27 | def forward(self, x):
28 | x1 = self.inc(x)
29 | x2 = self.down1(x1)
30 | x3 = self.down2(x2)
31 | x4 = self.down3(x3)
32 | x5 = self.down4(x4)
33 | x = self.up1(x5, x4)
34 | x = self.up2(x, x3)
35 | x = self.up3(x, x2)
36 | x = self.up4(x, x1)
37 | logits = self.outc(x)
38 | return logits
39 |
--------------------------------------------------------------------------------
/src/flatting/unet/unet_parts.py:
--------------------------------------------------------------------------------
1 | """ Parts of the U-Net model """
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class DoubleConv(nn.Module):
9 | """(convolution => [BN] => ReLU) * 2"""
10 |
11 | def __init__(self, in_channels, out_channels, mid_channels=None):
12 | super().__init__()
13 | if not mid_channels:
14 | mid_channels = out_channels
15 | self.double_conv = nn.Sequential(
16 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
17 | nn.BatchNorm2d(mid_channels),
18 | nn.ReLU(inplace=True),
19 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
20 | nn.BatchNorm2d(out_channels),
21 | nn.ReLU(inplace=True)
22 | )
23 |
24 | def forward(self, x):
25 | return self.double_conv(x)
26 |
27 |
28 | class Down(nn.Module):
29 | """Downscaling with maxpool then double conv"""
30 |
31 | def __init__(self, in_channels, out_channels):
32 | super().__init__()
33 | self.maxpool_conv = nn.Sequential(
34 | nn.MaxPool2d(2),
35 | DoubleConv(in_channels, out_channels)
36 | )
37 |
38 | def forward(self, x):
39 | return self.maxpool_conv(x)
40 |
41 |
42 | class Up(nn.Module):
43 | """Upscaling then double conv"""
44 |
45 | def __init__(self, in_channels, out_channels, bilinear=True):
46 | super().__init__()
47 |
48 | # if bilinear, use the normal convolutions to reduce the number of channels
49 | if bilinear:
50 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
51 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
52 | else:
53 | self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
54 | self.conv = DoubleConv(in_channels, out_channels)
55 |
56 |
57 | def forward(self, x1, x2):
58 | x1 = self.up(x1)
59 | # input is CHW
60 | diffY = x2.size()[2] - x1.size()[2]
61 | diffX = x2.size()[3] - x1.size()[3]
62 |
63 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
64 | diffY // 2, diffY - diffY // 2])
65 | # if you have padding issues, see
66 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
67 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
68 | x = torch.cat([x2, x1], dim=1)
69 | return self.conv(x)
70 |
71 |
72 | class OutConv(nn.Module):
73 | def __init__(self, in_channels, out_channels):
74 | super(OutConv, self).__init__()
75 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
76 |
77 | def forward(self, x):
78 | return self.conv(x)
79 |
--------------------------------------------------------------------------------
/src/flatting/utils/add_white_background.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 |
4 | from os.path import *
5 | from PIL import Image
6 | from tqdm import tqdm
7 |
8 | source = "L:\\2.Research_project\\3.flatting\\Pytorch-UNet\\flatting\\validation"
9 | target = "L:\\2.Research_project\\3.flatting\\Pytorch-UNet\\flatting\\validation"
10 |
11 | for img in tqdm(os.listdir(source)):
12 |
13 | if ".png" not in img: continue
14 |
15 | # open image
16 | img_a = Image.open(join(source, img))
17 |
18 | # prepare white backgournd
19 | img_w = Image.new("RGBA", img_a.size, "WHITE")
20 | try:
21 | img_w.paste(img_a, None, img_a)
22 | img_w.convert("RGB").save(join(target, img))
23 | except:
24 | print("Error:\tfailed on %s"%img)
25 |
--------------------------------------------------------------------------------
/src/flatting/utils/data_vis.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 |
3 |
4 | def plot_img_and_mask(img, mask):
5 | classes = mask.shape[2] if len(mask.shape) > 2 else 1
6 | fig, ax = plt.subplots(1, classes + 1)
7 | ax[0].set_title('Input image')
8 | ax[0].imshow(img)
9 | if classes > 1:
10 | for i in range(classes):
11 | ax[i+1].set_title(f'Output mask (class {i+1})')
12 | ax[i+1].imshow(mask[:, :, i])
13 | else:
14 | ax[1].set_title(f'Output mask')
15 | ax[1].imshow(mask)
16 | plt.xticks([]), plt.yticks([])
17 | plt.show()
18 |
--------------------------------------------------------------------------------
/src/flatting/utils/dataset.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import torch
4 | import logging
5 | import cv2
6 | # import webp
7 |
8 | from os.path import *
9 | from os import listdir
10 |
11 | from PIL import Image
12 |
13 | from torch.utils.data import Dataset
14 | from torchvision import transforms as T
15 | from torch.nn import Threshold
16 | from io import BytesIO
17 |
18 |
19 | class BasicDataset(Dataset):
20 | # let's try to make this all work in memory
21 | '''
22 | The original version, which read image from disk
23 | uncomment to enable
24 | '''
25 | # def __init__(self, line_dir, edge_dir, radius = 2, crop_size = 0):
26 | # self.line_dir = line_dir
27 | # self.edge_dir = edge_dir
28 | # self.kernel = self.get_ball_structuring_element(radius)
29 |
30 | # self.crop_size = crop_size if crop_size != 0 else 1024
31 | # assert self.crop_size > 0
32 |
33 | # self.ids = listdir(line_dir)
34 | # self.length = len(self.ids)
35 | # assert self.length == len(listdir(edge_dir))
36 |
37 | # logging.info(f'Creating dataset with {len(self.ids)} examples')
38 |
39 | '''
40 | The modified version, read the whole data set in numpy array
41 | '''
42 | def __init__(self, lines_bytes, edges_bytes, radius = 2, crop_size = 0):
43 |
44 | self.lines_bytes = lines_bytes
45 | self.edges_bytes = edges_bytes
46 |
47 | self.kernel = self.get_ball_structuring_element(radius)
48 |
49 | self.crop_size = crop_size if crop_size != 0 else 1024
50 | assert self.crop_size > 0
51 |
52 |
53 |
54 | self.length = len(lines_bytes)
55 | # self.length = len(self.ids)
56 |
57 | assert self.length == len(edges_bytes)
58 |
59 | logging.info(f'Creating dataset with {self.length} examples')
60 |
61 | def __len__(self):
62 | return self.length
63 |
64 | def get_ball_structuring_element(self, radius):
65 | """Get a ball shape structuring element with specific radius for morphology operation.
66 | The radius of ball usually equals to (leaking_gap_size / 2).
67 |
68 | # Arguments
69 | radius: radius of ball shape.
70 |
71 | # Returns
72 | an array of ball structuring element.
73 | """
74 | return cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2 * radius + 1, 2 * radius + 1))
75 |
76 | def __getitem__(self, i):
77 |
78 |
79 | '''
80 | The original version, uncomment to enable
81 | '''
82 | # idx = self.ids[i]
83 | # edge_path = join(self.edge_dir, idx.replace("webp", "png"))
84 | # line_path = join(self.line_dir, idx)
85 |
86 | # assert exists(edge_path), \
87 | # f'No edge map found for the ID {idx}: {edge_path}'
88 | # assert exists(line_path), \
89 | # f'No line art found for the ID {idx}: {line_path}'
90 |
91 | # remove white
92 | # if ".webp" in line_path:
93 | # line_np = np.array(webp.load_image(line_path, "RGB").convert("L"))
94 | # else:
95 | # line_np = np.array(Image.open(line_path))
96 |
97 | # if ".webp" in edge_path:
98 | # edge_np = np.array(webp.load_image(edge_path, "RGB"))
99 | # else:
100 | # edge_np = np.array(Image.open(edge_path))
101 | '''
102 | the end of orignal version
103 | '''
104 |
105 | '''
106 | The modified version
107 | '''
108 | line_bytes = self.lines_bytes[i]
109 | edge_bytes = self.edges_bytes[i]
110 |
111 |
112 | buffer = BytesIO(line_bytes)
113 | line_np = np.array(Image.open(buffer).convert("L"))
114 |
115 | buffer = BytesIO(edge_bytes)
116 | edge_np = np.array(Image.open(buffer).convert("L"))
117 | '''
118 | end of modified version
119 | '''
120 |
121 | '''
122 | The following part should be fine
123 | '''
124 |
125 | # crop_bbox = self.find_bbox(self.to_point_list(line_np))
126 | # line_np = self.crop_img(crop_bbox, line_np)
127 | # edge_np = self.crop_img(crop_bbox, edge_np)
128 |
129 | # line_np, edge_np = self.random_resize([line_np, edge_np])
130 |
131 | # or threshold by opencv?
132 | _, mask1_np = cv2.threshold(line_np, 125, 255, cv2.THRESH_BINARY)
133 | _, mask2_np = cv2.threshold(edge_np, 125, 255, cv2.THRESH_BINARY)
134 |
135 | # convert to tensor, and the following process should all be done by cuda
136 | line = self.to_tensor(line_np)
137 | edge = self.to_tensor(edge_np)
138 |
139 | mask1 = self.to_tensor(mask1_np, normalize = False)
140 | mask2 = self.to_tensor(mask2_np, normalize = False)
141 |
142 | assert line.shape == line.shape, \
143 | f'Line art and edge map {i} should be the same size, but are {line.shape} and {edge.shape}'
144 |
145 |
146 |
147 | imgs = self.augment(torch.cat((line, edge, mask1, mask2), dim=0))
148 |
149 | # it returns tensor at last
150 | return torch.chunk(imgs, 4, dim=0)
151 |
152 | def to_point_list(self, img_np):
153 | p = np.where(img_np < 220)
154 | return p
155 |
156 | def find_bbox(self, p):
157 | t = p[0].min()
158 | l = p[1].min()
159 | b = p[0].max()
160 | r = p[1].max()
161 | return t,l,b,r
162 |
163 | def crop_img(self, bbox, img_np):
164 | t,l,b,r = bbox
165 | return img_np[t:b, l:r]
166 |
167 | # def random_resize(self, img_np_list):
168 | # '''
169 | # Experiment shows that random resize is not working well, so this function is obsoleted and just be left here
170 | # as a record.
171 | # Don't try random resize in this way, it will not work!
172 | # Much slower converging speed and not obvious better generalizetion ability
173 | # '''
174 | # size = self.crop_size * (1 + np.random.rand()/5)
175 |
176 | # # if the image is a very long or wide image, then split it before cropping
177 | # img_np_resize_list = []
178 | # for img_np in img_np_list:
179 | # if len(img_np.shape) == 2:
180 | # h, w = img_np.shape
181 | # else:
182 | # h, w, _ = img_np.shape
183 |
184 | # short_side = w if w < h else h
185 | # r = size / short_side
186 | # target_w = int(w*r+0.5)
187 | # target_h = int(h*r+0.5)
188 | # img_np = cv2.resize(img_np, (target_w, target_h), interpolation=cv2.INTER_AREA)
189 | # img_np_resize_list.append(img_np)
190 |
191 | # return img_np_resize_list
192 |
193 | def to_tensor(self, pil_img, normalize = True):
194 |
195 | # assume the input is always grayscal
196 | if normalize:
197 | transforms = T.Compose(
198 | [
199 | # to tensor will change the channel order and divide 255 if necessary
200 | T.ToTensor(),
201 | T.Normalize(0.5, 0.5, inplace = True)
202 | ]
203 | )
204 | else:
205 | transforms = T.Compose(
206 | [
207 | # to tensor will change the channel order and divide 255 if necessary
208 | T.ToTensor(),
209 | ]
210 | )
211 |
212 | return transforms(pil_img)
213 |
214 | def augment(self, tensors):
215 | transforms = T.Compose(
216 | [
217 | T.RandomHorizontalFlip(),
218 | T.RandomVerticalFlip(),
219 | T.RandomCrop(size = self.crop_size)
220 |
221 | ]
222 | )
223 | return transforms(tensors)
224 |
--------------------------------------------------------------------------------
/src/flatting/utils/ground_truth_creation.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import os, sys
4 | sys.path.append("./line")
5 | import torch
6 |
7 | from PIL import Image
8 | from os.path import *
9 | from tqdm import tqdm
10 | from hed.run import estimate
11 | from line.thin import Thinner
12 | from skimage.morphology import skeletonize, thin
13 |
14 |
15 | # let's use a advanced edge detection algorithm
16 | def get_ball_structuring_element(radius):
17 | """Get a ball shape structuring element with specific radius for morphology operation.
18 | The radius of ball usually equals to (leaking_gap_size / 2).
19 |
20 | # Arguments
21 | radius: radius of ball shape.
22 |
23 | # Returns
24 | an array of ball structuring element.
25 | """
26 | return cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2 * radius + 1, 2 * radius + 1))
27 |
28 | def to_tensor(path_to_img):
29 |
30 | # img = Image.open(path_to_img)
31 | img_np = cv2.imread(path_to_img, cv2.IMREAD_COLOR)
32 | img_np = np.ascontiguousarray(img_np[:, :, ::-1].transpose(2, 0, 1).astype(np.float32) * (1.0 / 255.0))
33 |
34 | return torch.FloatTensor(img_np)
35 |
36 | def to_numpy(edge_tensor):
37 |
38 | edge_np = edge_tensor.clamp(0.0, 1.0).numpy().transpose(1, 2, 0)[:, :, 0] * 255.0
39 | edge_np = edge_np.astype(np.uint8)
40 |
41 | return edge_np
42 |
43 | def extract_skeleton(img):
44 |
45 | size = np.size(img)
46 | skel = np.zeros(img.shape,np.uint8)
47 | element = cv2.getStructuringElement(cv2.MORPH_CROSS,(3,3))
48 | done = False
49 |
50 | while done is False:
51 | eroded = cv2.erode(img,element)
52 | temp = cv2.dilate(eroded,element)
53 | temp = cv2.subtract(img,temp)
54 | skel = cv2.bitwise_or(skel,temp)
55 | img = eroded.copy()
56 |
57 | zeros = size - cv2.countNonZero(img)
58 | if zeros==size:
59 | done = True
60 |
61 | return skel
62 |
63 | def extract_gt_hed(input_line, input_flat, out_path):
64 |
65 | _, line_name = split(input_line)
66 |
67 | # extract edge by HED
68 | tenInput = to_tensor(input_flat)
69 | tenOutput = estimate(tenInput)
70 |
71 | # threshold the output
72 | edge = to_numpy(tenOutput)
73 | edge_thresh = cv2.adaptiveThreshold(255-edge, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 11, 2)
74 | # values = np.unique(edge)
75 | # lower_bound = np.percentile(values, 30)
76 | # _, edge_thresh = cv2.threshold(edge, lower_bound, 255, cv2.THRESH_BINARY)
77 |
78 | # get skeleton
79 | # thin = Thinner()
80 | # edge_thin = thin(Image.fromarray(edge)).detach().cpu().numpy().transpose(1,2,0)*255
81 | # edge_thin = edge_thin.astype(np.uint8).repeat(3, axis=-1)
82 |
83 | # all of these not work
84 | # edge_thin = cv2.ximgproc.thinning(edge)
85 | # edge_thin = extract_skeleton(255 - edge_thresh)
86 | # edge_thin = skeletonize(edge_thresh)
87 |
88 | # Image.fromarray(edge_thresh).save(join(out_path, line_name))
89 | cv2.imwrite(join(out_path, line_name), edge_thresh)
90 |
91 |
92 | def extract_gt(input_line, input_flat, out_path):
93 | # initialize
94 |
95 | print("Log:\topen %s"%input_flat)
96 | if exists(out_path) is False:
97 | os.makedirs(out_path)
98 |
99 | # canny edge detection
100 | img = cv2.imread(input_flat)
101 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
102 | img = cv2.blur(img,(5,5))
103 |
104 | # analyaze the gradient of flat image
105 | grad = cv2.Laplacian(img,cv2.CV_64F)
106 | grad = abs(grad).sum(axis = -1)
107 | grad_v, grad_c = np.unique(grad, return_counts=True)
108 |
109 | # remove the majority grad, which is 0
110 | assert np.where(grad_v==0) == np.where(grad_c==grad_c.max())
111 | grad_v = np.delete(grad_v, np.where(grad_v==0))
112 | grad_c = np.delete(grad_c, np.where(grad_c==grad_c.max()))
113 | print("Log:\tlen of grad_v %d"%len(grad_v))
114 | grad_c_cum = np.cumsum(grad_c)
115 |
116 | # if grad number is greater than 100, then this probably means the current
117 | # image exists pretty similar colors, then we should apply
118 | # another set of parameter to detect edge
119 | # this could be better if we can find the realtion between them
120 | if len(grad_v) < 100:
121 | min_val = grad_v[np.where(grad_c_cum<=np.percentile(grad_c_cum, 25))[0].max()]
122 | max_val = grad_v[np.where(grad_c_cum<=np.percentile(grad_c_cum, 40))[0].max()]
123 | else:
124 | min_val = grad_v[np.where(grad_c_cum<=np.percentile(grad_c_cum, 1))[0].max()]
125 | max_val = grad_v[np.where(grad_c_cum<=np.percentile(grad_c_cum, 10))[0].max()]
126 |
127 | edges = cv2.Canny(img, min_val, max_val, L2gradient=True)
128 |
129 | # write result
130 | _, line_name = split(input_line)
131 | cv2.imwrite(join(out_path, line_name.replace("webp", "png")), 255-edges)
132 |
133 | def main():
134 |
135 | input_line_path = "../flatting/size_org/line/"
136 | input_flat_path = "../flatting/size_org/flat/"
137 | out_path = "../flatting/size_org/line_detection/"
138 |
139 | for img in tqdm(os.listdir(input_line_path)):
140 | input_line = join(input_line_path, img)
141 | input_flat = join(input_flat_path, img.replace("line", "flat"))
142 |
143 | # neural net base edge detection
144 | # extract_gt_hed(input_line, input_flat, out_path)
145 |
146 | # canny edge detection
147 | extract_gt(input_line, input_flat, out_path)
148 |
149 |
150 |
151 | if __name__=="__main__":
152 | main()
--------------------------------------------------------------------------------
/src/flatting/utils/move_to_duplicate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | from os.path import *
4 |
5 | flat = "L:\\2.Research_project\\3.flatting\\flatting_trapped_ball\\flatting\\size_org\\flat"
6 | line = "L:\\2.Research_project\\3.flatting\\flatting_trapped_ball\\flatting\\size_org\\line"
7 | # flat = "L:\\2.Research_project\\3.flatting\\Pytorch-UNet\\flatting\\size_org\\OneDrive_2021-03-14\\[21-02-05] 318 SETS\\flat"
8 | # line = "L:\\2.Research_project\\3.flatting\\Pytorch-UNet\\flatting\\size_org\\OneDrive_2021-03-14\\[21-02-05] 318 SETS\\line"
9 | duplicate = "L:\\2.Research_project\\3.flatting\\flatting_trapped_ball\\flatting\\size_org\\duplicate"
10 | test = "L:\\2.Research_project\\3.flatting\\flatting_trapped_ball\\flatting\\size_org\\test"
11 |
12 |
13 | line_croped = "L:\\2.Research_project\\3.flatting\\flatting_trapped_ball\\flatting\\size_1024\\line_croped"
14 | line_detection_croped = "L:\\2.Research_project\\3.flatting\\flatting_trapped_ball\\flatting\\size_1024\\line_detection_croped"
15 | # with open("moving log.txt", "r") as f:
16 | # move_list = f.readlines()
17 |
18 | # for img in move_list:
19 | # img = img.replace("\n", "").replace("Log: moving ", "")
20 | # if "flat" in img:
21 | # shutil.move(join(duplicate, img), join(flat, img))
22 | # if "line" in img:
23 | # shutil.move(join(duplicate, img), join(line, img))
24 |
25 | flats = os.listdir(flat)
26 | lines = os.listdir(line)
27 |
28 | lines_croped = os.listdir(line_croped)
29 | lines_croped.sort()
30 | lines_detection_croped = os.listdir(line_detection_croped)
31 | lines_detection_croped.sort()
32 |
33 | assert len(lines_croped) == len(lines_detection_croped)
34 |
35 |
36 | '''
37 | Move to test folders, but I think those are not good for evaluation...
38 | '''
39 | for img in os.listdir(line):
40 | if img.replace("line", "flat") not in flats:
41 | print("Log:\tmoving %s"%img)
42 | shutil.move(join(line, img), join(test, img.replace(".png", "_line.png")))
43 |
44 | for img in os.listdir(flat):
45 | if img.replace("flat", "line") not in lines:
46 | print("Log:\tmoving %s"%img)
47 | os.remove(join(flat, img))
48 | # shutil.move(join(flat, img), join(test, img.replace(".png", "_flat.png")))
49 |
50 |
51 |
52 | '''
53 | Re-order all images in resized folder
54 | '''
55 | # count = 0
56 | # for i in range(len(lines_croped)):
57 | # assert lines_croped[i] == lines_detection_croped[i]
58 | # img = lines_croped[i]
59 | # os.rename(join(line_croped, img), join(line_croped, "%04d.png"%count))
60 | # os.rename(join(line_detection_croped, img), join(line_detection_croped, "%04d.png"%count))
61 | # count += 1
62 |
--------------------------------------------------------------------------------
/src/flatting/utils/polyvector/run_all_examples.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from os.path import *
3 | import os
4 | import subprocess
5 | import shutil
6 | import time
7 | import pdb
8 |
9 | def parse():
10 | # path to executable file
11 | p1 = normcase("./")
12 | p2 = 'polyvector_thing.exe'
13 | # path to input files
14 | p3 = normcase("E:\\OneDrive - George Mason University\\00.Projects\\01.Sketch cleanup\\00.Bechmark Dataset\\SSEB Dataset with GT")
15 | # path to output files, if necessary
16 | # p4 = normcase("./results")
17 | p4 = None
18 | parser = argparse.ArgumentParser(description='Batch Run Script')
19 | parser.add_argument('--exe',
20 | help ='path to executable file',
21 | default = join(p1,p2))
22 | parser.add_argument('--input',
23 | help='path to input files',
24 | default = p3)
25 | parser.add_argument('--result',
26 | help='path where reuslts will be saved, if necessary',
27 | default = p4)
28 |
29 | return parser
30 |
31 | def main():
32 | args = parse().parse_args()
33 | for img in os.listdir(args.input):
34 | name, extension = splitext(img)
35 |
36 | if extension == '.png':
37 | subprocess.run([args.exe, "-noisy", join(args.input, img)])
38 |
39 |
40 | # 这个可能用的到,也可能用不到,以后再详细想想怎么写成一个通用的框架
41 | # path_to_result = join(args.result, name)
42 | # time.sleep(1)
43 | # if not exists(args.result):
44 | # os.mkdir(args.result)
45 | # if not exists(path_to_result):
46 | # os.mkdir(path_to_result)
47 |
48 | # for svg in os.listdir(normcase('./')):
49 | # if svg.endswith('.svg'):
50 | # # pdb.set_trace()
51 | # shutil.move(join(normcase('./'), svg),
52 | # join(path_to_result,svg))
53 | print("Done")
54 |
55 | if __name__ == "__main__":
56 | main()
57 |
--------------------------------------------------------------------------------
/src/flatting/utils/preprocessing.py:
--------------------------------------------------------------------------------
1 | # remove all white regions in image, and do the sample crop to groud turth
2 | # then resize all images by calling magick
3 | import os
4 | import cv2
5 | import numpy as np
6 |
7 | from os.path import *
8 | from PIL import Image
9 | from tqdm import tqdm
10 |
11 | def to_np(img_path, th = False):
12 |
13 | if th:
14 | img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
15 | _, img = cv2.threshold(img,220,255,cv2.THRESH_BINARY)
16 | # img = cv2.adaptiveThreshold(img,255,cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY,11,2)
17 | else:
18 | img = cv2.imread(img_path, cv2.IMREAD_COLOR)
19 |
20 | return img
21 |
22 | def to_point_list(img_np):
23 | p = np.where(img_np < 220)
24 | return p
25 |
26 | def find_bbox(p):
27 | t = p[0].min()
28 | l = p[1].min()
29 | b = p[0].max()+1
30 | r = p[1].max()+1
31 | return t,l,b,r
32 |
33 | def crop_img(bbox, img_np):
34 | t,l,b,r = bbox
35 | return img_np[t:b, l:r]
36 |
37 | def center_crop_resize(img_np, size, crop=False, th=False):
38 | # if the image is a very long or wide image, then split it before cropping
39 | if len(img_np.shape) == 2:
40 | h, w = img_np.shape
41 | else:
42 | h, w, _ = img_np.shape
43 |
44 | short_side = w if w < h else h
45 | r = size / short_side * 1.2
46 | target_w = int(w*r+0.5)
47 | target_h = int(h*r+0.5)
48 | img_np = cv2.resize(img_np, (target_w, target_h), interpolation = cv2.INTER_AREA)
49 | if th:
50 | _, img_np = cv2.threshold(img_np,250,255,cv2.THRESH_BINARY)
51 | # center crop image
52 | if crop:
53 | l = (target_w - size)//2
54 | t = (target_h - size)//2
55 | r = (target_w + size)//2
56 | b = (target_h + size)//2
57 | img_np = img_np[t:b, l:r]
58 | return img_np
59 |
60 | def try_split(img_np):
61 |
62 | img_list = []
63 |
64 | h, w = img_np.shape[:2]
65 | if h >= 2*w:
66 | splition = h // w
67 | for i in range(0, h-h//splition, h//splition):
68 | img_list.append(img_np[i:i+h//splition])
69 | elif w >= 2*h:
70 | splition = w // h
71 | for i in range(0, w-w//splition, w//splition):
72 | img_list.append(img_np[:,i:i+w//splition])
73 | else:
74 | img_list.append(img_np)
75 |
76 | return img_list
77 |
78 | def main():
79 | path_root = "../flatting"
80 | org = "size_org"
81 |
82 | crop_size = 512
83 | size = "size_%d"%crop_size
84 |
85 | path_to_img = join(path_root, org, "line")
86 | path_to_mask = join(path_root, org, "line_detection")
87 | out_path_img = join(path_root, size, "line_croped")
88 | out_path_mask = join(path_root, size, "line_detection_croped")
89 |
90 | counter = 0
91 | for img_name in tqdm(os.listdir(path_to_img)):
92 |
93 | mask_name = img_name
94 | assert exists(join(path_to_mask, mask_name))
95 |
96 | img_np_th = to_np(join(path_to_img, img_name), th=True)
97 | img_np = to_np(join(path_to_img, img_name))
98 | mask_np = to_np(join(path_to_mask, mask_name))
99 |
100 | # remove addtional blank area
101 | bbox = find_bbox(to_point_list(img_np_th))
102 | img_crop = crop_img(bbox, img_np)
103 | mask_crop = crop_img(bbox, mask_np)
104 |
105 | # detect if a image need split
106 | if False:
107 | img_crop_list = try_split(img_crop)
108 | mask_crop_list = try_split(mask_crop)
109 | assert len(img_crop_list) == len(mask_crop_list)
110 | else:
111 | img_crop_list = [img_crop]
112 | mask_crop_list = [mask_crop]
113 |
114 | # crop and resize each image
115 | for i in range(len(img_crop_list)):
116 |
117 | img = center_crop_resize(img_crop_list[i], crop_size)
118 | mask = center_crop_resize(mask_crop_list[i], crop_size, th=True)
119 |
120 | assert img.shape[:2] == mask.shape[:2]
121 |
122 | if True:
123 | cv2.imwrite(join(out_path_img, "%05d.png"%counter), img)
124 | cv2.imwrite(join(out_path_mask, "%05d.png"%counter), mask)
125 | else:
126 | cv2.imwrite(join(out_path_img, img_name), img)
127 | cv2.imwrite(join(out_path_mask, img_name), mask)
128 | counter += 1
129 |
130 |
131 |
132 | if __name__=="__main__":
133 | main()
134 |
--------------------------------------------------------------------------------
/src/flatting_server.py:
--------------------------------------------------------------------------------
1 | # we need to import these modules in the first level, otherwise pyinstaller will not be able to import them
2 | # import sys, os
3 | # import pathlib
4 | # sys.path.append(pathlib.Path(__file__).parent.absolute()/"flatting")
5 | # sys.path.append(pathlib.Path(__file__).parent.absolute()/"flatting"/"trapped_ball")
6 | # from aiohttp import web
7 | # from PIL import Image
8 | # from io import BytesIO
9 | # import numpy as np
10 | # import flatting_api
11 | # import flatting_api_async
12 | # import base64
13 | # import io
14 | # import json
15 | # import asyncio
16 | # import multiprocessing
17 | # import cv2
18 | # import torch
19 | # from pathlib import Path
20 | # from os.path import *
21 | # from run import region_get_map, merge_to_ref, verify_region
22 | # from thinning import thinning
23 | # from predict import predict_img
24 | # from unet import UNet
25 | # import asyncio
26 | # from concurrent.futures import ProcessPoolExecutor
27 | # import functools
28 |
29 | if __name__ == '__main__':
30 | from flatting import app
31 | ## https://docs.python.org/3/library/multiprocessing.html#multiprocessing.freeze_support
32 | if app.MULTIPROCESS: app.multiprocessing.freeze_support()
33 | app.main()
34 |
--------------------------------------------------------------------------------