├── .github
└── workflows
│ └── main.yml
├── .gitignore
├── Dockerfile
├── LICENSE
├── README.md
├── data
├── imgs
│ └── .keep
└── masks
│ └── .keep
├── evaluate.py
├── hubconf.py
├── predict.py
├── requirements.txt
├── scripts
├── download_data.bat
└── download_data.sh
├── train.py
├── unet
├── __init__.py
├── unet_model.py
└── unet_parts.py
└── utils
├── __init__.py
├── data_loading.py
├── dice_score.py
└── utils.py
/.github/workflows/main.yml:
--------------------------------------------------------------------------------
1 | name: Publish Docker image
2 |
3 | on:
4 | push:
5 | branches: master
6 |
7 | jobs:
8 | push_to_registry:
9 | name: Push Docker image
10 | runs-on: ubuntu-latest
11 | steps:
12 | - name: Checkout
13 | uses: actions/checkout@v2
14 |
15 | - name: Set up Docker Buildx
16 | uses: docker/setup-buildx-action@v1
17 |
18 | - name: Log in to Docker Hub
19 | uses: docker/login-action@v1
20 | with:
21 | username: milesial
22 | password: ${{ secrets.DOCKER_PASSWORD }}
23 |
24 | - name: Log in to the Container registry
25 | uses: docker/login-action@f054a8b539a109f9f41c372932f1ae047eff08c9
26 | with:
27 | registry: ghcr.io
28 | username: ${{ github.repository_owner }}
29 | password: ${{ secrets.GITHUB_TOKEN }}
30 |
31 | - name: Extract metadata (tags, labels) for Docker
32 | id: meta
33 | uses: docker/metadata-action@v3
34 | with:
35 | images: milesial/unet
36 |
37 | - name: Build and push Docker image
38 | id: docker_build
39 | uses: docker/build-push-action@v2
40 | with:
41 | context: .
42 | push: true
43 | tags: |
44 | milesial/unet:latest
45 | ghcr.io/milesial/pytorch-unet:latest
46 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | data/
3 | __pycache__/
4 | checkpoints/
5 | *.pth
6 | *.jpg
7 | venv/
8 | .idea/
9 | wandb/
10 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvcr.io/nvidia/pytorch:22.11-py3
2 |
3 | RUN rm -rf /workspace/*
4 | WORKDIR /workspace/unet
5 |
6 | ADD requirements.txt .
7 | RUN pip install --no-cache-dir --upgrade --pre pip
8 | RUN pip install --no-cache-dir -r requirements.txt
9 | ADD . .
10 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 |
635 | Copyright (C)
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | Copyright (C)
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
675 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # U-Net: Semantic segmentation with PyTorch
2 |
3 |
4 |
5 |
6 |
7 | 
8 |
9 |
10 | Customized implementation of the [U-Net](https://arxiv.org/abs/1505.04597) in PyTorch for Kaggle's [Carvana Image Masking Challenge](https://www.kaggle.com/c/carvana-image-masking-challenge) from high definition images.
11 |
12 | - [Quick start](#quick-start)
13 | - [Without Docker](#without-docker)
14 | - [With Docker](#with-docker)
15 | - [Description](#description)
16 | - [Usage](#usage)
17 | - [Docker](#docker)
18 | - [Training](#training)
19 | - [Prediction](#prediction)
20 | - [Weights & Biases](#weights--biases)
21 | - [Pretrained model](#pretrained-model)
22 | - [Data](#data)
23 |
24 | ## Quick start
25 |
26 | ### Without Docker
27 |
28 | 1. [Install CUDA](https://developer.nvidia.com/cuda-downloads)
29 |
30 | 2. [Install PyTorch 1.13 or later](https://pytorch.org/get-started/locally/)
31 |
32 | 3. Install dependencies
33 | ```bash
34 | pip install -r requirements.txt
35 | ```
36 |
37 | 4. Download the data and run training:
38 | ```bash
39 | bash scripts/download_data.sh
40 | python train.py --amp
41 | ```
42 |
43 | ### With Docker
44 |
45 | 1. [Install Docker 19.03 or later:](https://docs.docker.com/get-docker/)
46 | ```bash
47 | curl https://get.docker.com | sh && sudo systemctl --now enable docker
48 | ```
49 | 2. [Install the NVIDIA container toolkit:](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html)
50 | ```bash
51 | distribution=$(. /etc/os-release;echo $ID$VERSION_ID) \
52 | && curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add - \
53 | && curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list | sudo tee /etc/apt/sources.list.d/nvidia-docker.list
54 | sudo apt-get update
55 | sudo apt-get install -y nvidia-docker2
56 | sudo systemctl restart docker
57 | ```
58 | 3. [Download and run the image:](https://hub.docker.com/repository/docker/milesial/unet)
59 | ```bash
60 | sudo docker run --rm --shm-size=8g --ulimit memlock=-1 --gpus all -it milesial/unet
61 | ```
62 |
63 | 4. Download the data and run training:
64 | ```bash
65 | bash scripts/download_data.sh
66 | python train.py --amp
67 | ```
68 |
69 | ## Description
70 | This model was trained from scratch with 5k images and scored a [Dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient) of 0.988423 on over 100k test images.
71 |
72 | It can be easily used for multiclass segmentation, portrait segmentation, medical segmentation, ...
73 |
74 |
75 | ## Usage
76 | **Note : Use Python 3.6 or newer**
77 |
78 | ### Docker
79 |
80 | A docker image containing the code and the dependencies is available on [DockerHub](https://hub.docker.com/repository/docker/milesial/unet).
81 | You can download and jump in the container with ([docker >=19.03](https://docs.docker.com/get-docker/)):
82 |
83 | ```console
84 | docker run -it --rm --shm-size=8g --ulimit memlock=-1 --gpus all milesial/unet
85 | ```
86 |
87 |
88 | ### Training
89 |
90 | ```console
91 | > python train.py -h
92 | usage: train.py [-h] [--epochs E] [--batch-size B] [--learning-rate LR]
93 | [--load LOAD] [--scale SCALE] [--validation VAL] [--amp]
94 |
95 | Train the UNet on images and target masks
96 |
97 | optional arguments:
98 | -h, --help show this help message and exit
99 | --epochs E, -e E Number of epochs
100 | --batch-size B, -b B Batch size
101 | --learning-rate LR, -l LR
102 | Learning rate
103 | --load LOAD, -f LOAD Load model from a .pth file
104 | --scale SCALE, -s SCALE
105 | Downscaling factor of the images
106 | --validation VAL, -v VAL
107 | Percent of the data that is used as validation (0-100)
108 | --amp Use mixed precision
109 | ```
110 |
111 | By default, the `scale` is 0.5, so if you wish to obtain better results (but use more memory), set it to 1.
112 |
113 | Automatic mixed precision is also available with the `--amp` flag. [Mixed precision](https://arxiv.org/abs/1710.03740) allows the model to use less memory and to be faster on recent GPUs by using FP16 arithmetic. Enabling AMP is recommended.
114 |
115 |
116 | ### Prediction
117 |
118 | After training your model and saving it to `MODEL.pth`, you can easily test the output masks on your images via the CLI.
119 |
120 | To predict a single image and save it:
121 |
122 | `python predict.py -i image.jpg -o output.jpg`
123 |
124 | To predict a multiple images and show them without saving them:
125 |
126 | `python predict.py -i image1.jpg image2.jpg --viz --no-save`
127 |
128 | ```console
129 | > python predict.py -h
130 | usage: predict.py [-h] [--model FILE] --input INPUT [INPUT ...]
131 | [--output INPUT [INPUT ...]] [--viz] [--no-save]
132 | [--mask-threshold MASK_THRESHOLD] [--scale SCALE]
133 |
134 | Predict masks from input images
135 |
136 | optional arguments:
137 | -h, --help show this help message and exit
138 | --model FILE, -m FILE
139 | Specify the file in which the model is stored
140 | --input INPUT [INPUT ...], -i INPUT [INPUT ...]
141 | Filenames of input images
142 | --output INPUT [INPUT ...], -o INPUT [INPUT ...]
143 | Filenames of output images
144 | --viz, -v Visualize the images as they are processed
145 | --no-save, -n Do not save the output masks
146 | --mask-threshold MASK_THRESHOLD, -t MASK_THRESHOLD
147 | Minimum probability value to consider a mask pixel white
148 | --scale SCALE, -s SCALE
149 | Scale factor for the input images
150 | ```
151 | You can specify which model file to use with `--model MODEL.pth`.
152 |
153 | ## Weights & Biases
154 |
155 | The training progress can be visualized in real-time using [Weights & Biases](https://wandb.ai/). Loss curves, validation curves, weights and gradient histograms, as well as predicted masks are logged to the platform.
156 |
157 | When launching a training, a link will be printed in the console. Click on it to go to your dashboard. If you have an existing W&B account, you can link it
158 | by setting the `WANDB_API_KEY` environment variable. If not, it will create an anonymous run which is automatically deleted after 7 days.
159 |
160 |
161 | ## Pretrained model
162 | A [pretrained model](https://github.com/milesial/Pytorch-UNet/releases/tag/v3.0) is available for the Carvana dataset. It can also be loaded from torch.hub:
163 |
164 | ```python
165 | net = torch.hub.load('milesial/Pytorch-UNet', 'unet_carvana', pretrained=True, scale=0.5)
166 | ```
167 | Available scales are 0.5 and 1.0.
168 |
169 | ## Data
170 | The Carvana data is available on the [Kaggle website](https://www.kaggle.com/c/carvana-image-masking-challenge/data).
171 |
172 | You can also download it using the helper script:
173 |
174 | ```
175 | bash scripts/download_data.sh
176 | ```
177 |
178 | The input images and target masks should be in the `data/imgs` and `data/masks` folders respectively (note that the `imgs` and `masks` folder should not contain any sub-folder or any other files, due to the greedy data-loader). For Carvana, images are RGB and masks are black and white.
179 |
180 | You can use your own dataset as long as you make sure it is loaded properly in `utils/data_loading.py`.
181 |
182 |
183 | ---
184 |
185 | Original paper by Olaf Ronneberger, Philipp Fischer, Thomas Brox:
186 |
187 | [U-Net: Convolutional Networks for Biomedical Image Segmentation](https://arxiv.org/abs/1505.04597)
188 |
189 | 
190 |
--------------------------------------------------------------------------------
/data/imgs/.keep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/milesial/Pytorch-UNet/21d7850f2af30a9695bbeea75f3136aa538cfc4a/data/imgs/.keep
--------------------------------------------------------------------------------
/data/masks/.keep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/milesial/Pytorch-UNet/21d7850f2af30a9695bbeea75f3136aa538cfc4a/data/masks/.keep
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from tqdm import tqdm
4 |
5 | from utils.dice_score import multiclass_dice_coeff, dice_coeff
6 |
7 |
8 | @torch.inference_mode()
9 | def evaluate(net, dataloader, device, amp):
10 | net.eval()
11 | num_val_batches = len(dataloader)
12 | dice_score = 0
13 |
14 | # iterate over the validation set
15 | with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
16 | for batch in tqdm(dataloader, total=num_val_batches, desc='Validation round', unit='batch', leave=False):
17 | image, mask_true = batch['image'], batch['mask']
18 |
19 | # move images and labels to correct device and type
20 | image = image.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
21 | mask_true = mask_true.to(device=device, dtype=torch.long)
22 |
23 | # predict the mask
24 | mask_pred = net(image)
25 |
26 | if net.n_classes == 1:
27 | assert mask_true.min() >= 0 and mask_true.max() <= 1, 'True mask indices should be in [0, 1]'
28 | mask_pred = (F.sigmoid(mask_pred) > 0.5).float()
29 | # compute the Dice score
30 | dice_score += dice_coeff(mask_pred, mask_true, reduce_batch_first=False)
31 | else:
32 | assert mask_true.min() >= 0 and mask_true.max() < net.n_classes, 'True mask indices should be in [0, n_classes['
33 | # convert to one-hot format
34 | mask_true = F.one_hot(mask_true, net.n_classes).permute(0, 3, 1, 2).float()
35 | mask_pred = F.one_hot(mask_pred.argmax(dim=1), net.n_classes).permute(0, 3, 1, 2).float()
36 | # compute the Dice score, ignoring background
37 | dice_score += multiclass_dice_coeff(mask_pred[:, 1:], mask_true[:, 1:], reduce_batch_first=False)
38 |
39 | net.train()
40 | return dice_score / max(num_val_batches, 1)
41 |
--------------------------------------------------------------------------------
/hubconf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from unet import UNet as _UNet
3 |
4 | def unet_carvana(pretrained=False, scale=0.5):
5 | """
6 | UNet model trained on the Carvana dataset ( https://www.kaggle.com/c/carvana-image-masking-challenge/data ).
7 | Set the scale to 0.5 (50%) when predicting.
8 | """
9 | net = _UNet(n_channels=3, n_classes=2, bilinear=False)
10 | if pretrained:
11 | if scale == 0.5:
12 | checkpoint = 'https://github.com/milesial/Pytorch-UNet/releases/download/v3.0/unet_carvana_scale0.5_epoch2.pth'
13 | elif scale == 1.0:
14 | checkpoint = 'https://github.com/milesial/Pytorch-UNet/releases/download/v3.0/unet_carvana_scale1.0_epoch2.pth'
15 | else:
16 | raise RuntimeError('Only 0.5 and 1.0 scales are available')
17 | state_dict = torch.hub.load_state_dict_from_url(checkpoint, progress=True)
18 | if 'mask_values' in state_dict:
19 | state_dict.pop('mask_values')
20 | net.load_state_dict(state_dict)
21 |
22 | return net
23 |
24 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import os
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn.functional as F
8 | from PIL import Image
9 | from torchvision import transforms
10 |
11 | from utils.data_loading import BasicDataset
12 | from unet import UNet
13 | from utils.utils import plot_img_and_mask
14 |
15 | def predict_img(net,
16 | full_img,
17 | device,
18 | scale_factor=1,
19 | out_threshold=0.5):
20 | net.eval()
21 | img = torch.from_numpy(BasicDataset.preprocess(None, full_img, scale_factor, is_mask=False))
22 | img = img.unsqueeze(0)
23 | img = img.to(device=device, dtype=torch.float32)
24 |
25 | with torch.no_grad():
26 | output = net(img).cpu()
27 | output = F.interpolate(output, (full_img.size[1], full_img.size[0]), mode='bilinear')
28 | if net.n_classes > 1:
29 | mask = output.argmax(dim=1)
30 | else:
31 | mask = torch.sigmoid(output) > out_threshold
32 |
33 | return mask[0].long().squeeze().numpy()
34 |
35 |
36 | def get_args():
37 | parser = argparse.ArgumentParser(description='Predict masks from input images')
38 | parser.add_argument('--model', '-m', default='MODEL.pth', metavar='FILE',
39 | help='Specify the file in which the model is stored')
40 | parser.add_argument('--input', '-i', metavar='INPUT', nargs='+', help='Filenames of input images', required=True)
41 | parser.add_argument('--output', '-o', metavar='OUTPUT', nargs='+', help='Filenames of output images')
42 | parser.add_argument('--viz', '-v', action='store_true',
43 | help='Visualize the images as they are processed')
44 | parser.add_argument('--no-save', '-n', action='store_true', help='Do not save the output masks')
45 | parser.add_argument('--mask-threshold', '-t', type=float, default=0.5,
46 | help='Minimum probability value to consider a mask pixel white')
47 | parser.add_argument('--scale', '-s', type=float, default=0.5,
48 | help='Scale factor for the input images')
49 | parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')
50 | parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes')
51 |
52 | return parser.parse_args()
53 |
54 |
55 | def get_output_filenames(args):
56 | def _generate_name(fn):
57 | return f'{os.path.splitext(fn)[0]}_OUT.png'
58 |
59 | return args.output or list(map(_generate_name, args.input))
60 |
61 |
62 | def mask_to_image(mask: np.ndarray, mask_values):
63 | if isinstance(mask_values[0], list):
64 | out = np.zeros((mask.shape[-2], mask.shape[-1], len(mask_values[0])), dtype=np.uint8)
65 | elif mask_values == [0, 1]:
66 | out = np.zeros((mask.shape[-2], mask.shape[-1]), dtype=bool)
67 | else:
68 | out = np.zeros((mask.shape[-2], mask.shape[-1]), dtype=np.uint8)
69 |
70 | if mask.ndim == 3:
71 | mask = np.argmax(mask, axis=0)
72 |
73 | for i, v in enumerate(mask_values):
74 | out[mask == i] = v
75 |
76 | return Image.fromarray(out)
77 |
78 |
79 | if __name__ == '__main__':
80 | args = get_args()
81 | logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
82 |
83 | in_files = args.input
84 | out_files = get_output_filenames(args)
85 |
86 | net = UNet(n_channels=3, n_classes=args.classes, bilinear=args.bilinear)
87 |
88 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
89 | logging.info(f'Loading model {args.model}')
90 | logging.info(f'Using device {device}')
91 |
92 | net.to(device=device)
93 | state_dict = torch.load(args.model, map_location=device)
94 | mask_values = state_dict.pop('mask_values', [0, 1])
95 | net.load_state_dict(state_dict)
96 |
97 | logging.info('Model loaded!')
98 |
99 | for i, filename in enumerate(in_files):
100 | logging.info(f'Predicting image {filename} ...')
101 | img = Image.open(filename)
102 |
103 | mask = predict_img(net=net,
104 | full_img=img,
105 | scale_factor=args.scale,
106 | out_threshold=args.mask_threshold,
107 | device=device)
108 |
109 | if not args.no_save:
110 | out_filename = out_files[i]
111 | result = mask_to_image(mask, mask_values)
112 | result.save(out_filename)
113 | logging.info(f'Mask saved to {out_filename}')
114 |
115 | if args.viz:
116 | logging.info(f'Visualizing results for image {filename}, close to continue...')
117 | plot_img_and_mask(img, mask)
118 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib==3.6.2
2 | numpy==1.23.5
3 | Pillow==9.3.0
4 | tqdm==4.64.1
5 | wandb==0.13.5
6 |
--------------------------------------------------------------------------------
/scripts/download_data.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 | setlocal enabledelayedexpansion
3 |
4 | if not exist "%userprofile%\.kaggle\kaggle.json" (
5 | set /p USERNAME=Kaggle username:
6 | echo.
7 | set /p APIKEY=Kaggle API key:
8 |
9 | mkdir "%userprofile%\.kaggle"
10 | echo {"username":"!USERNAME!","key":"!APIKEY!"} > "%userprofile%\.kaggle\kaggle.json"
11 | attrib +R "%userprofile%\.kaggle\kaggle.json"
12 | )
13 |
14 | pip install kaggle --upgrade
15 |
16 | kaggle competitions download -c carvana-image-masking-challenge -f train_hq.zip
17 | powershell Expand-Archive train_hq.zip -DestinationPath data\imgs
18 | move data\imgs\train_hq\* data\imgs\
19 | rmdir /s /q data\imgs\train_hq
20 | del /q train_hq.zip
21 |
22 | kaggle competitions download -c carvana-image-masking-challenge -f train_masks.zip
23 | powershell Expand-Archive train_masks.zip -DestinationPath data\masks
24 | move data\masks\train_masks\* data\masks\
25 | rmdir /s /q data\masks\train_masks
26 | del /q train_masks.zip
27 |
28 | exit /b 0
29 |
--------------------------------------------------------------------------------
/scripts/download_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [[ ! -f ~/.kaggle/kaggle.json ]]; then
4 | echo -n "Kaggle username: "
5 | read USERNAME
6 | echo
7 | echo -n "Kaggle API key: "
8 | read APIKEY
9 |
10 | mkdir -p ~/.kaggle
11 | echo "{\"username\":\"$USERNAME\",\"key\":\"$APIKEY\"}" > ~/.kaggle/kaggle.json
12 | chmod 600 ~/.kaggle/kaggle.json
13 | fi
14 |
15 | pip install kaggle --upgrade
16 |
17 | kaggle competitions download -c carvana-image-masking-challenge -f train_hq.zip
18 | unzip train_hq.zip
19 | mv train_hq/* data/imgs/
20 | rm -d train_hq
21 | rm train_hq.zip
22 |
23 | kaggle competitions download -c carvana-image-masking-challenge -f train_masks.zip
24 | unzip train_masks.zip
25 | mv train_masks/* data/masks/
26 | rm -d train_masks
27 | rm train_masks.zip
28 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import os
4 | import random
5 | import sys
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import torchvision.transforms as transforms
10 | import torchvision.transforms.functional as TF
11 | from pathlib import Path
12 | from torch import optim
13 | from torch.utils.data import DataLoader, random_split
14 | from tqdm import tqdm
15 |
16 | import wandb
17 | from evaluate import evaluate
18 | from unet import UNet
19 | from utils.data_loading import BasicDataset, CarvanaDataset
20 | from utils.dice_score import dice_loss
21 |
22 | dir_img = Path('./data/imgs/')
23 | dir_mask = Path('./data/masks/')
24 | dir_checkpoint = Path('./checkpoints/')
25 |
26 |
27 | def train_model(
28 | model,
29 | device,
30 | epochs: int = 5,
31 | batch_size: int = 1,
32 | learning_rate: float = 1e-5,
33 | val_percent: float = 0.1,
34 | save_checkpoint: bool = True,
35 | img_scale: float = 0.5,
36 | amp: bool = False,
37 | weight_decay: float = 1e-8,
38 | momentum: float = 0.999,
39 | gradient_clipping: float = 1.0,
40 | ):
41 | # 1. Create dataset
42 | try:
43 | dataset = CarvanaDataset(dir_img, dir_mask, img_scale)
44 | except (AssertionError, RuntimeError, IndexError):
45 | dataset = BasicDataset(dir_img, dir_mask, img_scale)
46 |
47 | # 2. Split into train / validation partitions
48 | n_val = int(len(dataset) * val_percent)
49 | n_train = len(dataset) - n_val
50 | train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))
51 |
52 | # 3. Create data loaders
53 | loader_args = dict(batch_size=batch_size, num_workers=os.cpu_count(), pin_memory=True)
54 | train_loader = DataLoader(train_set, shuffle=True, **loader_args)
55 | val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)
56 |
57 | # (Initialize logging)
58 | experiment = wandb.init(project='U-Net', resume='allow', anonymous='must')
59 | experiment.config.update(
60 | dict(epochs=epochs, batch_size=batch_size, learning_rate=learning_rate,
61 | val_percent=val_percent, save_checkpoint=save_checkpoint, img_scale=img_scale, amp=amp)
62 | )
63 |
64 | logging.info(f'''Starting training:
65 | Epochs: {epochs}
66 | Batch size: {batch_size}
67 | Learning rate: {learning_rate}
68 | Training size: {n_train}
69 | Validation size: {n_val}
70 | Checkpoints: {save_checkpoint}
71 | Device: {device.type}
72 | Images scaling: {img_scale}
73 | Mixed Precision: {amp}
74 | ''')
75 |
76 | # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
77 | optimizer = optim.RMSprop(model.parameters(),
78 | lr=learning_rate, weight_decay=weight_decay, momentum=momentum, foreach=True)
79 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=5) # goal: maximize Dice score
80 | grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
81 | criterion = nn.CrossEntropyLoss() if model.n_classes > 1 else nn.BCEWithLogitsLoss()
82 | global_step = 0
83 |
84 | # 5. Begin training
85 | for epoch in range(1, epochs + 1):
86 | model.train()
87 | epoch_loss = 0
88 | with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar:
89 | for batch in train_loader:
90 | images, true_masks = batch['image'], batch['mask']
91 |
92 | assert images.shape[1] == model.n_channels, \
93 | f'Network has been defined with {model.n_channels} input channels, ' \
94 | f'but loaded images have {images.shape[1]} channels. Please check that ' \
95 | 'the images are loaded correctly.'
96 |
97 | images = images.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
98 | true_masks = true_masks.to(device=device, dtype=torch.long)
99 |
100 | with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
101 | masks_pred = model(images)
102 | if model.n_classes == 1:
103 | loss = criterion(masks_pred.squeeze(1), true_masks.float())
104 | loss += dice_loss(F.sigmoid(masks_pred.squeeze(1)), true_masks.float(), multiclass=False)
105 | else:
106 | loss = criterion(masks_pred, true_masks)
107 | loss += dice_loss(
108 | F.softmax(masks_pred, dim=1).float(),
109 | F.one_hot(true_masks, model.n_classes).permute(0, 3, 1, 2).float(),
110 | multiclass=True
111 | )
112 |
113 | optimizer.zero_grad(set_to_none=True)
114 | grad_scaler.scale(loss).backward()
115 | grad_scaler.unscale_(optimizer)
116 | torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
117 | grad_scaler.step(optimizer)
118 | grad_scaler.update()
119 |
120 | pbar.update(images.shape[0])
121 | global_step += 1
122 | epoch_loss += loss.item()
123 | experiment.log({
124 | 'train loss': loss.item(),
125 | 'step': global_step,
126 | 'epoch': epoch
127 | })
128 | pbar.set_postfix(**{'loss (batch)': loss.item()})
129 |
130 | # Evaluation round
131 | division_step = (n_train // (5 * batch_size))
132 | if division_step > 0:
133 | if global_step % division_step == 0:
134 | histograms = {}
135 | for tag, value in model.named_parameters():
136 | tag = tag.replace('/', '.')
137 | if not (torch.isinf(value) | torch.isnan(value)).any():
138 | histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu())
139 | if not (torch.isinf(value.grad) | torch.isnan(value.grad)).any():
140 | histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu())
141 |
142 | val_score = evaluate(model, val_loader, device, amp)
143 | scheduler.step(val_score)
144 |
145 | logging.info('Validation Dice score: {}'.format(val_score))
146 | try:
147 | experiment.log({
148 | 'learning rate': optimizer.param_groups[0]['lr'],
149 | 'validation Dice': val_score,
150 | 'images': wandb.Image(images[0].cpu()),
151 | 'masks': {
152 | 'true': wandb.Image(true_masks[0].float().cpu()),
153 | 'pred': wandb.Image(masks_pred.argmax(dim=1)[0].float().cpu()),
154 | },
155 | 'step': global_step,
156 | 'epoch': epoch,
157 | **histograms
158 | })
159 | except:
160 | pass
161 |
162 | if save_checkpoint:
163 | Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
164 | state_dict = model.state_dict()
165 | state_dict['mask_values'] = dataset.mask_values
166 | torch.save(state_dict, str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch)))
167 | logging.info(f'Checkpoint {epoch} saved!')
168 |
169 |
170 | def get_args():
171 | parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
172 | parser.add_argument('--epochs', '-e', metavar='E', type=int, default=5, help='Number of epochs')
173 | parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=1, help='Batch size')
174 | parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=1e-5,
175 | help='Learning rate', dest='lr')
176 | parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file')
177 | parser.add_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images')
178 | parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0,
179 | help='Percent of the data that is used as validation (0-100)')
180 | parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')
181 | parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')
182 | parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes')
183 |
184 | return parser.parse_args()
185 |
186 |
187 | if __name__ == '__main__':
188 | args = get_args()
189 |
190 | logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
191 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
192 | logging.info(f'Using device {device}')
193 |
194 | # Change here to adapt to your data
195 | # n_channels=3 for RGB images
196 | # n_classes is the number of probabilities you want to get per pixel
197 | model = UNet(n_channels=3, n_classes=args.classes, bilinear=args.bilinear)
198 | model = model.to(memory_format=torch.channels_last)
199 |
200 | logging.info(f'Network:\n'
201 | f'\t{model.n_channels} input channels\n'
202 | f'\t{model.n_classes} output channels (classes)\n'
203 | f'\t{"Bilinear" if model.bilinear else "Transposed conv"} upscaling')
204 |
205 | if args.load:
206 | state_dict = torch.load(args.load, map_location=device)
207 | del state_dict['mask_values']
208 | model.load_state_dict(state_dict)
209 | logging.info(f'Model loaded from {args.load}')
210 |
211 | model.to(device=device)
212 | try:
213 | train_model(
214 | model=model,
215 | epochs=args.epochs,
216 | batch_size=args.batch_size,
217 | learning_rate=args.lr,
218 | device=device,
219 | img_scale=args.scale,
220 | val_percent=args.val / 100,
221 | amp=args.amp
222 | )
223 | except torch.cuda.OutOfMemoryError:
224 | logging.error('Detected OutOfMemoryError! '
225 | 'Enabling checkpointing to reduce memory usage, but this slows down training. '
226 | 'Consider enabling AMP (--amp) for fast and memory efficient training')
227 | torch.cuda.empty_cache()
228 | model.use_checkpointing()
229 | train_model(
230 | model=model,
231 | epochs=args.epochs,
232 | batch_size=args.batch_size,
233 | learning_rate=args.lr,
234 | device=device,
235 | img_scale=args.scale,
236 | val_percent=args.val / 100,
237 | amp=args.amp
238 | )
239 |
--------------------------------------------------------------------------------
/unet/__init__.py:
--------------------------------------------------------------------------------
1 | from .unet_model import UNet
2 |
--------------------------------------------------------------------------------
/unet/unet_model.py:
--------------------------------------------------------------------------------
1 | """ Full assembly of the parts to form the complete network """
2 |
3 | from .unet_parts import *
4 |
5 |
6 | class UNet(nn.Module):
7 | def __init__(self, n_channels, n_classes, bilinear=False):
8 | super(UNet, self).__init__()
9 | self.n_channels = n_channels
10 | self.n_classes = n_classes
11 | self.bilinear = bilinear
12 |
13 | self.inc = (DoubleConv(n_channels, 64))
14 | self.down1 = (Down(64, 128))
15 | self.down2 = (Down(128, 256))
16 | self.down3 = (Down(256, 512))
17 | factor = 2 if bilinear else 1
18 | self.down4 = (Down(512, 1024 // factor))
19 | self.up1 = (Up(1024, 512 // factor, bilinear))
20 | self.up2 = (Up(512, 256 // factor, bilinear))
21 | self.up3 = (Up(256, 128 // factor, bilinear))
22 | self.up4 = (Up(128, 64, bilinear))
23 | self.outc = (OutConv(64, n_classes))
24 |
25 | def forward(self, x):
26 | x1 = self.inc(x)
27 | x2 = self.down1(x1)
28 | x3 = self.down2(x2)
29 | x4 = self.down3(x3)
30 | x5 = self.down4(x4)
31 | x = self.up1(x5, x4)
32 | x = self.up2(x, x3)
33 | x = self.up3(x, x2)
34 | x = self.up4(x, x1)
35 | logits = self.outc(x)
36 | return logits
37 |
38 | def use_checkpointing(self):
39 | self.inc = torch.utils.checkpoint(self.inc)
40 | self.down1 = torch.utils.checkpoint(self.down1)
41 | self.down2 = torch.utils.checkpoint(self.down2)
42 | self.down3 = torch.utils.checkpoint(self.down3)
43 | self.down4 = torch.utils.checkpoint(self.down4)
44 | self.up1 = torch.utils.checkpoint(self.up1)
45 | self.up2 = torch.utils.checkpoint(self.up2)
46 | self.up3 = torch.utils.checkpoint(self.up3)
47 | self.up4 = torch.utils.checkpoint(self.up4)
48 | self.outc = torch.utils.checkpoint(self.outc)
--------------------------------------------------------------------------------
/unet/unet_parts.py:
--------------------------------------------------------------------------------
1 | """ Parts of the U-Net model """
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class DoubleConv(nn.Module):
9 | """(convolution => [BN] => ReLU) * 2"""
10 |
11 | def __init__(self, in_channels, out_channels, mid_channels=None):
12 | super().__init__()
13 | if not mid_channels:
14 | mid_channels = out_channels
15 | self.double_conv = nn.Sequential(
16 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
17 | nn.BatchNorm2d(mid_channels),
18 | nn.ReLU(inplace=True),
19 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
20 | nn.BatchNorm2d(out_channels),
21 | nn.ReLU(inplace=True)
22 | )
23 |
24 | def forward(self, x):
25 | return self.double_conv(x)
26 |
27 |
28 | class Down(nn.Module):
29 | """Downscaling with maxpool then double conv"""
30 |
31 | def __init__(self, in_channels, out_channels):
32 | super().__init__()
33 | self.maxpool_conv = nn.Sequential(
34 | nn.MaxPool2d(2),
35 | DoubleConv(in_channels, out_channels)
36 | )
37 |
38 | def forward(self, x):
39 | return self.maxpool_conv(x)
40 |
41 |
42 | class Up(nn.Module):
43 | """Upscaling then double conv"""
44 |
45 | def __init__(self, in_channels, out_channels, bilinear=True):
46 | super().__init__()
47 |
48 | # if bilinear, use the normal convolutions to reduce the number of channels
49 | if bilinear:
50 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
51 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
52 | else:
53 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
54 | self.conv = DoubleConv(in_channels, out_channels)
55 |
56 | def forward(self, x1, x2):
57 | x1 = self.up(x1)
58 | # input is CHW
59 | diffY = x2.size()[2] - x1.size()[2]
60 | diffX = x2.size()[3] - x1.size()[3]
61 |
62 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
63 | diffY // 2, diffY - diffY // 2])
64 | # if you have padding issues, see
65 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
66 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
67 | x = torch.cat([x2, x1], dim=1)
68 | return self.conv(x)
69 |
70 |
71 | class OutConv(nn.Module):
72 | def __init__(self, in_channels, out_channels):
73 | super(OutConv, self).__init__()
74 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
75 |
76 | def forward(self, x):
77 | return self.conv(x)
78 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/milesial/Pytorch-UNet/21d7850f2af30a9695bbeea75f3136aa538cfc4a/utils/__init__.py
--------------------------------------------------------------------------------
/utils/data_loading.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import numpy as np
3 | import torch
4 | from PIL import Image
5 | from functools import lru_cache
6 | from functools import partial
7 | from itertools import repeat
8 | from multiprocessing import Pool
9 | from os import listdir
10 | from os.path import splitext, isfile, join
11 | from pathlib import Path
12 | from torch.utils.data import Dataset
13 | from tqdm import tqdm
14 |
15 |
16 | def load_image(filename):
17 | ext = splitext(filename)[1]
18 | if ext == '.npy':
19 | return Image.fromarray(np.load(filename))
20 | elif ext in ['.pt', '.pth']:
21 | return Image.fromarray(torch.load(filename).numpy())
22 | else:
23 | return Image.open(filename)
24 |
25 |
26 | def unique_mask_values(idx, mask_dir, mask_suffix):
27 | mask_file = list(mask_dir.glob(idx + mask_suffix + '.*'))[0]
28 | mask = np.asarray(load_image(mask_file))
29 | if mask.ndim == 2:
30 | return np.unique(mask)
31 | elif mask.ndim == 3:
32 | mask = mask.reshape(-1, mask.shape[-1])
33 | return np.unique(mask, axis=0)
34 | else:
35 | raise ValueError(f'Loaded masks should have 2 or 3 dimensions, found {mask.ndim}')
36 |
37 |
38 | class BasicDataset(Dataset):
39 | def __init__(self, images_dir: str, mask_dir: str, scale: float = 1.0, mask_suffix: str = ''):
40 | self.images_dir = Path(images_dir)
41 | self.mask_dir = Path(mask_dir)
42 | assert 0 < scale <= 1, 'Scale must be between 0 and 1'
43 | self.scale = scale
44 | self.mask_suffix = mask_suffix
45 |
46 | self.ids = [splitext(file)[0] for file in listdir(images_dir) if isfile(join(images_dir, file)) and not file.startswith('.')]
47 | if not self.ids:
48 | raise RuntimeError(f'No input file found in {images_dir}, make sure you put your images there')
49 |
50 | logging.info(f'Creating dataset with {len(self.ids)} examples')
51 | logging.info('Scanning mask files to determine unique values')
52 | with Pool() as p:
53 | unique = list(tqdm(
54 | p.imap(partial(unique_mask_values, mask_dir=self.mask_dir, mask_suffix=self.mask_suffix), self.ids),
55 | total=len(self.ids)
56 | ))
57 |
58 | self.mask_values = list(sorted(np.unique(np.concatenate(unique), axis=0).tolist()))
59 | logging.info(f'Unique mask values: {self.mask_values}')
60 |
61 | def __len__(self):
62 | return len(self.ids)
63 |
64 | @staticmethod
65 | def preprocess(mask_values, pil_img, scale, is_mask):
66 | w, h = pil_img.size
67 | newW, newH = int(scale * w), int(scale * h)
68 | assert newW > 0 and newH > 0, 'Scale is too small, resized images would have no pixel'
69 | pil_img = pil_img.resize((newW, newH), resample=Image.NEAREST if is_mask else Image.BICUBIC)
70 | img = np.asarray(pil_img)
71 |
72 | if is_mask:
73 | mask = np.zeros((newH, newW), dtype=np.int64)
74 | for i, v in enumerate(mask_values):
75 | if img.ndim == 2:
76 | mask[img == v] = i
77 | else:
78 | mask[(img == v).all(-1)] = i
79 |
80 | return mask
81 |
82 | else:
83 | if img.ndim == 2:
84 | img = img[np.newaxis, ...]
85 | else:
86 | img = img.transpose((2, 0, 1))
87 |
88 | if (img > 1).any():
89 | img = img / 255.0
90 |
91 | return img
92 |
93 | def __getitem__(self, idx):
94 | name = self.ids[idx]
95 | mask_file = list(self.mask_dir.glob(name + self.mask_suffix + '.*'))
96 | img_file = list(self.images_dir.glob(name + '.*'))
97 |
98 | assert len(img_file) == 1, f'Either no image or multiple images found for the ID {name}: {img_file}'
99 | assert len(mask_file) == 1, f'Either no mask or multiple masks found for the ID {name}: {mask_file}'
100 | mask = load_image(mask_file[0])
101 | img = load_image(img_file[0])
102 |
103 | assert img.size == mask.size, \
104 | f'Image and mask {name} should be the same size, but are {img.size} and {mask.size}'
105 |
106 | img = self.preprocess(self.mask_values, img, self.scale, is_mask=False)
107 | mask = self.preprocess(self.mask_values, mask, self.scale, is_mask=True)
108 |
109 | return {
110 | 'image': torch.as_tensor(img.copy()).float().contiguous(),
111 | 'mask': torch.as_tensor(mask.copy()).long().contiguous()
112 | }
113 |
114 |
115 | class CarvanaDataset(BasicDataset):
116 | def __init__(self, images_dir, mask_dir, scale=1):
117 | super().__init__(images_dir, mask_dir, scale, mask_suffix='_mask')
118 |
--------------------------------------------------------------------------------
/utils/dice_score.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 |
4 |
5 | def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6):
6 | # Average of Dice coefficient for all batches, or for a single mask
7 | assert input.size() == target.size()
8 | assert input.dim() == 3 or not reduce_batch_first
9 |
10 | sum_dim = (-1, -2) if input.dim() == 2 or not reduce_batch_first else (-1, -2, -3)
11 |
12 | inter = 2 * (input * target).sum(dim=sum_dim)
13 | sets_sum = input.sum(dim=sum_dim) + target.sum(dim=sum_dim)
14 | sets_sum = torch.where(sets_sum == 0, inter, sets_sum)
15 |
16 | dice = (inter + epsilon) / (sets_sum + epsilon)
17 | return dice.mean()
18 |
19 |
20 | def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6):
21 | # Average of Dice coefficient for all classes
22 | return dice_coeff(input.flatten(0, 1), target.flatten(0, 1), reduce_batch_first, epsilon)
23 |
24 |
25 | def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False):
26 | # Dice loss (objective to minimize) between 0 and 1
27 | fn = multiclass_dice_coeff if multiclass else dice_coeff
28 | return 1 - fn(input, target, reduce_batch_first=True)
29 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 |
3 |
4 | def plot_img_and_mask(img, mask):
5 | classes = mask.max() + 1
6 | fig, ax = plt.subplots(1, classes + 1)
7 | ax[0].set_title('Input image')
8 | ax[0].imshow(img)
9 | for i in range(classes):
10 | ax[i + 1].set_title(f'Mask (class {i + 1})')
11 | ax[i + 1].imshow(mask == i)
12 | plt.xticks([]), plt.yticks([])
13 | plt.show()
14 |
--------------------------------------------------------------------------------