├── .github └── ISSUE_TEMPLATE │ └── bug_report.md ├── .gitignore ├── AUTHORS.rst ├── CONTRIBUTING.rst ├── HISTORY.rst ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.rst ├── demo ├── .gitignore ├── README.rst ├── demo_radio_data.ipynb └── demo_toy_problem.ipynb ├── docs ├── Makefile ├── authors.rst ├── check_sphinx.py ├── conf.py ├── contributing.rst ├── galaxies.png ├── history.rst ├── index.rst ├── installation.rst ├── modules.rst ├── rfi.png ├── stats.png ├── tf_unet.rst ├── toy_problem.png └── usage.rst ├── postBuild ├── requirements.txt ├── scripts ├── .gitignore ├── __init__.py ├── launcher.py ├── rfi_launcher.py ├── ufig_launcher.py └── ultrasound_launcher.py ├── setup.py └── tf_unet ├── __init__.py ├── image_gen.py ├── image_util.py ├── layers.py ├── unet.py └── util.py /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. 16 | 17 | **Expected behavior** 18 | A clear and concise description of what you expected to happen. 19 | 20 | **Screenshots** 21 | If applicable, add screenshots to help explain your problem. 22 | 23 | **Additional context** 24 | Add any other context about the problem here. 25 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.py[cod] 2 | 3 | # C extensions 4 | *.so 5 | 6 | # Packages 7 | *.egg 8 | *.egg-info 9 | dist 10 | build 11 | eggs 12 | parts 13 | bin 14 | var 15 | sdist 16 | develop-eggs 17 | .installed.cfg 18 | lib 19 | lib64 20 | .DS_Store 21 | 22 | # Installer logs 23 | pip-log.txt 24 | 25 | # Unit test / coverage reports 26 | .coverage 27 | .tox 28 | nosetests.xml 29 | *junit-* 30 | htmlcov 31 | 32 | # Translations 33 | *.mo 34 | 35 | # Mr Developer 36 | .mr.developer.cfg 37 | .project 38 | .pydevproject 39 | 40 | # Complexity 41 | output/*.html 42 | output/*/index.html 43 | 44 | # Sphinx 45 | docs/_build 46 | /unet_trained/ 47 | /prediction/ 48 | /bleien_data/ 49 | /daint_unet_trained_rfi/ 50 | /daint_unet_trained_rfi_bleien/ 51 | /daint_unet_trained_rfi_bleien4/ 52 | /sim_data/ 53 | /ufig_images/ 54 | /.hope/ 55 | /.ipynb_checkpoints/ 56 | /.settings/ 57 | -------------------------------------------------------------------------------- /AUTHORS.rst: -------------------------------------------------------------------------------- 1 | ======= 2 | Credits 3 | ======= 4 | 5 | Development Lead 6 | ---------------- 7 | 8 | * `@jakeret `_ 9 | 10 | Contributors 11 | ------------ 12 | 13 | * `@FelixGruen `_ 14 | * `@ameya005 `_ 15 | * `@agrafix `_ 16 | * `@AlessioM `_ 17 | * `@FiLeonard `_ 18 | * `@nikkou `_ 19 | * `@wkeithvan `_ 20 | * `@samsammurphy `_ 21 | * `@siavashk `_ 22 | 23 | Citations 24 | --------- 25 | 26 | As you use **tf_unet** for your exciting discoveries, please cite the paper that describes the package: 27 | 28 | `J. Akeret, C. Chang, A. Lucchi, A. Refregier, Published in Astronomy and Computing (2017) `_ 29 | -------------------------------------------------------------------------------- /CONTRIBUTING.rst: -------------------------------------------------------------------------------- 1 | ============ 2 | Contributing 3 | ============ 4 | 5 | Contributions are welcome, and they are greatly appreciated! Every 6 | little bit helps, and credit will always be given. 7 | 8 | You can contribute in many ways: 9 | 10 | Types of Contributions 11 | ---------------------- 12 | 13 | Report Bugs 14 | ~~~~~~~~~~~ 15 | 16 | If you are reporting a bug, please include: 17 | 18 | * Your operating system name and version. 19 | * Any details about your local setup that might be helpful in troubleshooting. 20 | * Detailed steps to reproduce the bug. 21 | 22 | Fix Bugs 23 | ~~~~~~~~ 24 | 25 | Implement Features 26 | ~~~~~~~~~~~~~~~~~~ 27 | 28 | Write Documentation 29 | ~~~~~~~~~~~~~~~~~~~ 30 | 31 | Tensorflow Unet could always use more documentation, whether as part of the 32 | official Tensorflow Unet docs, in docstrings, or even on the web in blog posts, 33 | articles, and such. 34 | 35 | Submit Feedback 36 | ~~~~~~~~~~~~~~~ 37 | 38 | If you are proposing a feature: 39 | 40 | * Explain in detail how it would work. 41 | * Keep the scope as narrow as possible, to make it easier to implement. 42 | * Remember that this is a volunteer-driven project, and that contributions 43 | are welcome :) 44 | 45 | Pull Request Guidelines 46 | ----------------------- 47 | 48 | Before you submit a pull request, check that it meets these guidelines: 49 | 50 | 1. The pull request should include tests. 51 | 2. If the pull request adds functionality, the docs should be updated. Put 52 | your new functionality into a function with a docstring, and add the 53 | feature to the list in README.rst. 54 | 3. The pull request should work for Python 2.6, 2.7, and 3.3, and for PyPy. 55 | make sure that the tests pass for all supported Python versions. 56 | 57 | 58 | Tips 59 | ---- 60 | 61 | To run a subset of tests:: 62 | 63 | $ py.test test/test_tf_unet.py -------------------------------------------------------------------------------- /HISTORY.rst: -------------------------------------------------------------------------------- 1 | .. :changelog: 2 | 3 | History 4 | ------- 5 | 6 | 0.1.2 (2018-01-08) 7 | ++++++++++++++++++ 8 | 9 | * Namescopes to improve TensorBoard layout 10 | * Move bias addition before dropout 11 | * numerically stable cross entropy computation 12 | * parametrized verification batch size 13 | * bugfix if all pixel values are 0 14 | * cleaned examples 15 | 16 | 0.1.1 (2017-12-29) 17 | ++++++++++++++++++ 18 | 19 | * Support for Tensorflow > 1.0.0 20 | * Clean package structure 21 | * Integration into mybinder 22 | 23 | 0.1.0 (2016-08-18) 24 | ++++++++++++++++++ 25 | 26 | * First release to GitHub. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include AUTHORS.rst 2 | include CONTRIBUTING.rst 3 | include HISTORY.rst 4 | include LICENSE 5 | include README.rst 6 | include requirements.txt 7 | include Makefile 8 | include docs/* 9 | include test/* -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: help clean clean-pyc clean-build list test test-all coverage docs release sdist 2 | 3 | help: 4 | @echo "clean-build - remove build artifacts" 5 | @echo "clean-pyc - remove Python file artifacts" 6 | @echo "lint - check style with flake8" 7 | @echo "test - run tests quickly with the default Python" 8 | @echo "test-all - run tests on every Python version with tox" 9 | @echo "coverage - check code coverage quickly with the default Python" 10 | @echo "docs - generate Sphinx HTML documentation, including API docs" 11 | @echo "sdist - package" 12 | 13 | clean: clean-build clean-pyc 14 | 15 | clean-build: 16 | rm -fr build/ 17 | rm -fr dist/ 18 | rm -fr *.egg-info 19 | 20 | clean-pyc: 21 | find . -name '*.pyc' -exec rm -f {} + 22 | find . -name '*.pyo' -exec rm -f {} + 23 | find . -name '*~' -exec rm -f {} + 24 | find . -name '__pycache__' -exec rm -rf {} + 25 | 26 | lint: 27 | flake8 tf_unet test 28 | 29 | test: 30 | py.test 31 | 32 | test-all: 33 | tox 34 | 35 | coverage: 36 | coverage run --source tf_unet setup.py test 37 | coverage report -m 38 | coverage html 39 | open htmlcov/index.html 40 | 41 | docs: 42 | sphinx-apidoc -o docs/ tf_unet 43 | $(MAKE) -C docs clean 44 | $(MAKE) -C docs html 45 | open docs/_build/html/index.html 46 | 47 | sdist: clean 48 | pip freeze > requirements.rst 49 | python setup.py sdist 50 | ls -l dist -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | ============================= 2 | Tensorflow Unet 3 | ============================= 4 | 5 | .. image:: https://readthedocs.org/projects/tf-unet/badge/?version=latest 6 | :target: http://tf-unet.readthedocs.io/en/latest/?badge=latest 7 | :alt: Documentation Status 8 | 9 | .. image:: http://img.shields.io/badge/arXiv-1609.09077-orange.svg?style=flat 10 | :target: http://arxiv.org/abs/1609.09077 11 | 12 | .. image:: https://img.shields.io/badge/ascl-1611.002-blue.svg?colorB=262255 13 | :target: http://ascl.net/1611.002 14 | 15 | .. image:: https://mybinder.org/badge.svg 16 | :target: https://mybinder.org/v2/gh/jakeret/tf_unet/master?filepath=demo%2Fdemo_toy_problem.ipynb 17 | 18 | 19 | .. warning:: 20 | 21 | This project is discontinued in favour of a Tensorflow 2 compatible reimplementation of this project found under https://github.com/jakeret/unet 22 | 23 | 24 | This is a generic **U-Net** implementation as proposed by `Ronneberger et al. `_ developed with **Tensorflow**. The code has been developed and used for `Radio Frequency Interference mitigation using deep convolutional neural networks `_ . 25 | 26 | The network can be trained to perform image segmentation on arbitrary imaging data. Checkout the `Usage `_ section or the included Jupyter notebooks for a `toy problem `_ or the `Radio Frequency Interference mitigation `_ discussed in our paper. 27 | 28 | The code is not tied to a specific segmentation such that it can be used in a toy problem to detect circles in a noisy image. 29 | 30 | .. image:: https://raw.githubusercontent.com/jakeret/tf_unet/master/docs/toy_problem.png 31 | :alt: Segmentation of a toy problem. 32 | :align: center 33 | 34 | To more complex application such as the detection of radio frequency interference (RFI) in radio astronomy. 35 | 36 | .. image:: https://raw.githubusercontent.com/jakeret/tf_unet/master/docs/rfi.png 37 | :alt: Segmentation of RFI in radio data. 38 | :align: center 39 | 40 | Or to detect galaxies and star in wide field imaging data. 41 | 42 | .. image:: https://raw.githubusercontent.com/jakeret/tf_unet/master/docs/galaxies.png 43 | :alt: Segmentation of a galaxies. 44 | :align: center 45 | 46 | 47 | As you use **tf_unet** for your exciting discoveries, please cite the paper that describes the package:: 48 | 49 | 50 | @article{akeret2017radio, 51 | title={Radio frequency interference mitigation using deep convolutional neural networks}, 52 | author={Akeret, Joel and Chang, Chihway and Lucchi, Aurelien and Refregier, Alexandre}, 53 | journal={Astronomy and Computing}, 54 | volume={18}, 55 | pages={35--39}, 56 | year={2017}, 57 | publisher={Elsevier} 58 | } 59 | -------------------------------------------------------------------------------- /demo/.gitignore: -------------------------------------------------------------------------------- 1 | /.ipynb_checkpoints/ 2 | -------------------------------------------------------------------------------- /demo/README.rst: -------------------------------------------------------------------------------- 1 | ============================= 2 | Tensorflow Unet demos 3 | ============================= 4 | 5 | .. image:: https://mybinder.org/badge.svg 6 | :target: https://mybinder.org/v2/gh/jakeret/tf_unet/master?filepath=demo%2Fdemo_toy_problem.ipynb 7 | 8 | 9 | 10 | Here you can find demo Jupyter notebooks for a `toy problem `_ or the `Radio Frequency Interference mitigation `_ discussed in our paper. 11 | 12 | 13 | You can use the binder badge to start an interactive version of the toy problem without having to install anything on your machine. 14 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = _build 9 | 10 | # User-friendly check for sphinx-build 11 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) 12 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) 13 | endif 14 | 15 | # Internal variables. 16 | PAPEROPT_a4 = -D latex_paper_size=a4 17 | PAPEROPT_letter = -D latex_paper_size=letter 18 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 19 | # the i18n builder cannot share the environment and doctrees with the others 20 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 21 | 22 | .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext 23 | 24 | help: 25 | @echo "Please use \`make ' where is one of" 26 | @echo " html to make standalone HTML files" 27 | @echo " dirhtml to make HTML files named index.html in directories" 28 | @echo " singlehtml to make a single large HTML file" 29 | @echo " pickle to make pickle files" 30 | @echo " json to make JSON files" 31 | @echo " htmlhelp to make HTML files and a HTML help project" 32 | @echo " qthelp to make HTML files and a qthelp project" 33 | @echo " devhelp to make HTML files and a Devhelp project" 34 | @echo " epub to make an epub" 35 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 36 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 37 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" 38 | @echo " text to make text files" 39 | @echo " man to make manual pages" 40 | @echo " texinfo to make Texinfo files" 41 | @echo " info to make Texinfo files and run them through makeinfo" 42 | @echo " gettext to make PO message catalogs" 43 | @echo " changes to make an overview of all changed/added/deprecated items" 44 | @echo " xml to make Docutils-native XML files" 45 | @echo " pseudoxml to make pseudoxml-XML files for display purposes" 46 | @echo " linkcheck to check all external links for integrity" 47 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 48 | 49 | clean: 50 | rm -rf $(BUILDDIR)/* 51 | 52 | html: 53 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 54 | @echo 55 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 56 | 57 | dirhtml: 58 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 59 | @echo 60 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 61 | 62 | singlehtml: 63 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 64 | @echo 65 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 66 | 67 | pickle: 68 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 69 | @echo 70 | @echo "Build finished; now you can process the pickle files." 71 | 72 | json: 73 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json 74 | @echo 75 | @echo "Build finished; now you can process the JSON files." 76 | 77 | htmlhelp: 78 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 79 | @echo 80 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 81 | ".hhp project file in $(BUILDDIR)/htmlhelp." 82 | 83 | qthelp: 84 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 85 | @echo 86 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 87 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 88 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/complexity.qhcp" 89 | @echo "To view the help file:" 90 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/complexity.qhc" 91 | 92 | devhelp: 93 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp 94 | @echo 95 | @echo "Build finished." 96 | @echo "To view the help file:" 97 | @echo "# mkdir -p $$HOME/.local/share/devhelp/complexity" 98 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/complexity" 99 | @echo "# devhelp" 100 | 101 | epub: 102 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub 103 | @echo 104 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub." 105 | 106 | latex: 107 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 108 | @echo 109 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 110 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 111 | "(use \`make latexpdf' here to do that automatically)." 112 | 113 | latexpdf: 114 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 115 | @echo "Running LaTeX files through pdflatex..." 116 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 117 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 118 | 119 | latexpdfja: 120 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 121 | @echo "Running LaTeX files through platex and dvipdfmx..." 122 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja 123 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 124 | 125 | text: 126 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text 127 | @echo 128 | @echo "Build finished. The text files are in $(BUILDDIR)/text." 129 | 130 | man: 131 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man 132 | @echo 133 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man." 134 | 135 | texinfo: 136 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 137 | @echo 138 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." 139 | @echo "Run \`make' in that directory to run these through makeinfo" \ 140 | "(use \`make info' here to do that automatically)." 141 | 142 | info: 143 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 144 | @echo "Running Texinfo files through makeinfo..." 145 | make -C $(BUILDDIR)/texinfo info 146 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." 147 | 148 | gettext: 149 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale 150 | @echo 151 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." 152 | 153 | changes: 154 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 155 | @echo 156 | @echo "The overview file is in $(BUILDDIR)/changes." 157 | 158 | linkcheck: 159 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 160 | @echo 161 | @echo "Link check complete; look for any errors in the above output " \ 162 | "or in $(BUILDDIR)/linkcheck/output.txt." 163 | 164 | doctest: 165 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 166 | @echo "Testing of doctests in the sources finished, look at the " \ 167 | "results in $(BUILDDIR)/doctest/output.txt." 168 | 169 | xml: 170 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml 171 | @echo 172 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml." 173 | 174 | pseudoxml: 175 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml 176 | @echo 177 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." -------------------------------------------------------------------------------- /docs/authors.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../AUTHORS.rst -------------------------------------------------------------------------------- /docs/check_sphinx.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Dec 2, 2013 3 | 4 | @author: jakeret 5 | ''' 6 | import py 7 | import subprocess 8 | def test_linkcheck(tmpdir): 9 | doctrees = tmpdir.join("doctrees") 10 | htmldir = tmpdir.join("html") 11 | subprocess.check_call( 12 | ["sphinx-build", "-blinkcheck", 13 | "-d", str(doctrees), ".", str(htmldir)]) 14 | 15 | def test_build_docs(tmpdir): 16 | doctrees = tmpdir.join("doctrees") 17 | htmldir = tmpdir.join("html") 18 | subprocess.check_call([ 19 | "sphinx-build", "-bhtml", 20 | "-d", str(doctrees), ".", str(htmldir)]) -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # complexity documentation build configuration file, created by 4 | # sphinx-quickstart on Tue Jul 9 22:26:36 2013. 5 | # 6 | # This file is execfile()d with the current directory set to its containing dir. 7 | # 8 | # Note that not all possible configuration values are present in this 9 | # autogenerated file. 10 | # 11 | # All configuration values have a default; values that are commented out 12 | # serve to show the default. 13 | 14 | import sys, os 15 | 16 | # If extensions (or modules to document with autodoc) are in another directory, 17 | # add these directories to sys.path here. If the directory is relative to the 18 | # documentation root, use os.path.abspath to make it absolute, like shown here. 19 | #sys.path.insert(0, os.path.abspath('.')) 20 | 21 | cwd = os.getcwd() 22 | parent = os.path.dirname(cwd) 23 | sys.path.insert(0, parent) 24 | 25 | import mock 26 | 27 | MOCK_MODULES = ['numpy', 'tensorflow'] 28 | for mod_name in MOCK_MODULES: 29 | sys.modules[mod_name] = mock.Mock() 30 | 31 | 32 | import tf_unet 33 | 34 | # -- General configuration ----------------------------------------------------- 35 | 36 | # If your documentation needs a minimal Sphinx version, state it here. 37 | #needs_sphinx = '1.0' 38 | 39 | # Add any Sphinx extension module names here, as strings. They can be extensions 40 | # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. 41 | extensions = ['sphinx.ext.autodoc', 'sphinx.ext.coverage', 'sphinx.ext.mathjax', 'sphinx.ext.viewcode'] 42 | 43 | # Add any paths that contain templates here, relative to this directory. 44 | templates_path = ['_templates'] 45 | 46 | # The suffix of source filenames. 47 | source_suffix = '.rst' 48 | 49 | # The encoding of source files. 50 | #source_encoding = 'utf-8-sig' 51 | 52 | # The master toctree document. 53 | master_doc = 'index' 54 | 55 | # General information about the project. 56 | project = u'Tensorflow Unet' 57 | copyright = u'2016, ETH Zurich, Institute for Astronomy' 58 | 59 | # The version info for the project you're documenting, acts as replacement for 60 | # |version| and |release|, also used in various other places throughout the 61 | # built documents. 62 | # 63 | # The short X.Y version. 64 | version = tf_unet.__version__ 65 | # The full version, including alpha/beta/rc tags. 66 | release = tf_unet.__version__ 67 | 68 | # The language for content autogenerated by Sphinx. Refer to documentation 69 | # for a list of supported languages. 70 | #language = None 71 | 72 | # There are two options for replacing |today|: either, you set today to some 73 | # non-false value, then it is used: 74 | #today = '' 75 | # Else, today_fmt is used as the format for a strftime call. 76 | #today_fmt = '%B %d, %Y' 77 | 78 | # List of patterns, relative to source directory, that match files and 79 | # directories to ignore when looking for source files. 80 | exclude_patterns = ['_build'] 81 | 82 | # The reST default role (used for this markup: `text`) to use for all documents. 83 | #default_role = None 84 | 85 | # If true, '()' will be appended to :func: etc. cross-reference text. 86 | #add_function_parentheses = True 87 | 88 | # If true, the current module name will be prepended to all description 89 | # unit titles (such as .. function::). 90 | #add_module_names = True 91 | 92 | # If true, sectionauthor and moduleauthor directives will be shown in the 93 | # output. They are ignored by default. 94 | #show_authors = False 95 | 96 | # The name of the Pygments (syntax highlighting) style to use. 97 | pygments_style = 'sphinx' 98 | 99 | # A list of ignored prefixes for module index sorting. 100 | #modindex_common_prefix = [] 101 | 102 | # If true, keep warnings as "system message" paragraphs in the built documents. 103 | #keep_warnings = False 104 | 105 | 106 | # -- Options for HTML output --------------------------------------------------- 107 | 108 | # The theme to use for HTML and HTML Help pages. See the documentation for 109 | # a list of builtin themes. 110 | # html_theme = 'default' 111 | 112 | # Theme options are theme-specific and customize the look and feel of a theme 113 | # further. For a list of options available for each theme, see the 114 | # documentation. 115 | #html_theme_options = {} 116 | 117 | # Add any paths that contain custom themes here, relative to this directory. 118 | #html_theme_path = [] 119 | 120 | # The name for this set of Sphinx documents. If None, it defaults to 121 | # " v documentation". 122 | #html_title = None 123 | 124 | # A shorter title for the navigation bar. Default is the same as html_title. 125 | #html_short_title = None 126 | 127 | # The name of an image file (relative to this directory) to place at the top 128 | # of the sidebar. 129 | #html_logo = None 130 | 131 | # The name of an image file (within the static path) to use as favicon of the 132 | # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 133 | # pixels large. 134 | #html_favicon = None 135 | 136 | # Add any paths that contain custom static files (such as style sheets) here, 137 | # relative to this directory. They are copied after the builtin static files, 138 | # so a file named "default.css" will overwrite the builtin "default.css". 139 | html_static_path = ['_static'] 140 | 141 | # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, 142 | # using the given strftime format. 143 | #html_last_updated_fmt = '%b %d, %Y' 144 | 145 | # If true, SmartyPants will be used to convert quotes and dashes to 146 | # typographically correct entities. 147 | #html_use_smartypants = True 148 | 149 | # Custom sidebar templates, maps document names to template names. 150 | #html_sidebars = {} 151 | 152 | # Additional templates that should be rendered to pages, maps page names to 153 | # template names. 154 | #html_additional_pages = {} 155 | 156 | # If false, no module index is generated. 157 | #html_domain_indices = True 158 | 159 | # If false, no index is generated. 160 | #html_use_index = True 161 | 162 | # If true, the index is split into individual pages for each letter. 163 | #html_split_index = False 164 | 165 | # If true, links to the reST sources are added to the pages. 166 | #html_show_sourcelink = True 167 | 168 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. 169 | #html_show_sphinx = True 170 | 171 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. 172 | #html_show_copyright = True 173 | 174 | # If true, an OpenSearch description file will be output, and all pages will 175 | # contain a tag referring to it. The value of this option must be the 176 | # base URL from which the finished HTML is served. 177 | #html_use_opensearch = '' 178 | 179 | # This is the file name suffix for HTML files (e.g. ".xhtml"). 180 | #html_file_suffix = None 181 | 182 | # Output file base name for HTML help builder. 183 | htmlhelp_basename = 'tf_unetdoc' 184 | 185 | 186 | # -- Options for LaTeX output -------------------------------------------------- 187 | 188 | latex_elements = { 189 | # The paper size ('letterpaper' or 'a4paper'). 190 | #'papersize': 'letterpaper', 191 | 192 | # The font size ('10pt', '11pt' or '12pt'). 193 | #'pointsize': '10pt', 194 | 195 | # Additional stuff for the LaTeX preamble. 196 | #'preamble': '', 197 | } 198 | 199 | # Grouping the document tree into LaTeX files. List of tuples 200 | # (source start file, target name, title, author, documentclass [howto/manual]). 201 | latex_documents = [ 202 | ('index', 'tf_unet.tex', u'Tensorflow Unet Documentation', 203 | u'Joel Akeret', 'manual'), 204 | ] 205 | 206 | # The name of an image file (relative to this directory) to place at the top of 207 | # the title page. 208 | #latex_logo = None 209 | 210 | # For "manual" documents, if this is true, then toplevel headings are parts, 211 | # not chapters. 212 | #latex_use_parts = False 213 | 214 | # If true, show page references after internal links. 215 | #latex_show_pagerefs = False 216 | 217 | # If true, show URL addresses after external links. 218 | #latex_show_urls = False 219 | 220 | # Documents to append as an appendix to all manuals. 221 | #latex_appendices = [] 222 | 223 | # If false, no module index is generated. 224 | #latex_domain_indices = True 225 | 226 | 227 | # -- Options for manual page output -------------------------------------------- 228 | 229 | # One entry per manual page. List of tuples 230 | # (source start file, name, description, authors, manual section). 231 | man_pages = [ 232 | ('index', 'tf_unet', u'Tensorflow Unet Documentation', 233 | [u'Joel Akeret'], 1) 234 | ] 235 | 236 | # If true, show URL addresses after external links. 237 | #man_show_urls = False 238 | 239 | 240 | # -- Options for Texinfo output ------------------------------------------------ 241 | 242 | # Grouping the document tree into Texinfo files. List of tuples 243 | # (source start file, target name, title, author, 244 | # dir menu entry, description, category) 245 | texinfo_documents = [ 246 | ('index', 'tf_unet', u'Tensorflow Unet Documentation', 247 | u'Joel Akeret', 'tf_unet', 'One line description of project.', 248 | 'Miscellaneous'), 249 | ] 250 | 251 | # Documents to append as an appendix to all manuals. 252 | #texinfo_appendices = [] 253 | 254 | # If false, no module index is generated. 255 | #texinfo_domain_indices = True 256 | 257 | # How to display URL addresses: 'footnote', 'no', or 'inline'. 258 | #texinfo_show_urls = 'footnote' 259 | 260 | # If true, do not generate a @detailmenu in the "Top" node's menu. 261 | #texinfo_no_detailmenu = False 262 | 263 | try: 264 | import sphinx_eth_theme 265 | html_theme = "sphinx_eth_theme" 266 | html_theme_path = [sphinx_eth_theme.get_html_theme_path()] 267 | except ImportError: 268 | html_theme = 'default' -------------------------------------------------------------------------------- /docs/contributing.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../CONTRIBUTING.rst -------------------------------------------------------------------------------- /docs/galaxies.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeret/tf_unet/0dcdf2ff1ebcc2ee59997d127a0c0be847168884/docs/galaxies.png -------------------------------------------------------------------------------- /docs/history.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../HISTORY.rst -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. complexity documentation master file, created by 2 | sphinx-quickstart on Tue Jul 9 22:26:36 2013. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | .. include:: ../README.rst 7 | 8 | Contents: 9 | ========= 10 | 11 | .. toctree:: 12 | :maxdepth: 2 13 | 14 | installation 15 | usage 16 | tf_unet 17 | contributing 18 | authors 19 | history 20 | 21 | Feedback 22 | ======== 23 | 24 | If you have any suggestions or questions about **Tensorflow Unet** feel free to contact me 25 | on `GitHub `_. 26 | 27 | If you encounter any errors or problems with **Tensorflow Unet**, please let me know! -------------------------------------------------------------------------------- /docs/installation.rst: -------------------------------------------------------------------------------- 1 | ============ 2 | Installation 3 | ============ 4 | 5 | The project is hosted on GitHub. Get a copy by running:: 6 | 7 | $ git clone https://github.com/jakeret/tf_unet.git 8 | 9 | 10 | Install the package like this:: 11 | 12 | $ cd tf_unet 13 | $ pip install -r requirements.txt 14 | $ python setup.py install --user 15 | 16 | Alternatively, if you want to develop new features:: 17 | 18 | $ cd tf_unet 19 | $ python setup.py develop --user 20 | 21 | Make sure `TensorFlow` is installed on your system. Installation instruction can be found `here `_ -------------------------------------------------------------------------------- /docs/modules.rst: -------------------------------------------------------------------------------- 1 | tf_unet 2 | ======= 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | tf_unet 8 | -------------------------------------------------------------------------------- /docs/rfi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeret/tf_unet/0dcdf2ff1ebcc2ee59997d127a0c0be847168884/docs/rfi.png -------------------------------------------------------------------------------- /docs/stats.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeret/tf_unet/0dcdf2ff1ebcc2ee59997d127a0c0be847168884/docs/stats.png -------------------------------------------------------------------------------- /docs/tf_unet.rst: -------------------------------------------------------------------------------- 1 | tf_unet Package 2 | =============== 3 | 4 | :mod:`unet` Module 5 | ------------------ 6 | 7 | .. automodule:: tf_unet.unet 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | 12 | :mod:`image_util` Module 13 | ------------------------ 14 | 15 | .. automodule:: tf_unet.image_util 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | 20 | :mod:`util` Module 21 | ------------------ 22 | 23 | .. automodule:: tf_unet.util 24 | :members: 25 | :undoc-members: 26 | :show-inheritance: 27 | 28 | :mod:`layers` Module 29 | -------------------- 30 | 31 | .. automodule:: tf_unet.layers 32 | :members: 33 | :undoc-members: 34 | :show-inheritance: 35 | 36 | -------------------------------------------------------------------------------- /docs/toy_problem.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeret/tf_unet/0dcdf2ff1ebcc2ee59997d127a0c0be847168884/docs/toy_problem.png -------------------------------------------------------------------------------- /docs/usage.rst: -------------------------------------------------------------------------------- 1 | ======== 2 | Usage 3 | ======== 4 | 5 | To use Tensorflow Unet in a project:: 6 | 7 | from tf_unet import unet, util, image_util 8 | 9 | #preparing data loading 10 | data_provider = image_util.ImageDataProvider("fishes/train/*.tif") 11 | 12 | #setup & training 13 | net = unet.Unet(layers=3, features_root=64, channels=1, n_class=2) 14 | trainer = unet.Trainer(net) 15 | path = trainer.train(data_provider, output_path, training_iters=32, epochs=100) 16 | 17 | #verification 18 | ... 19 | 20 | prediction = net.predict(path, data) 21 | 22 | unet.error_rate(prediction, util.crop_to_shape(label, prediction.shape)) 23 | 24 | img = util.combine_img_prediction(data, label, prediction) 25 | util.save_image(img, "prediction.jpg") 26 | 27 | Keep track of the learning progress using *Tensorboard*. **tf_unet** automatically outputs relevant summaries. 28 | 29 | .. image:: https://raw.githubusercontent.com/jakeret/tf_unet/master/docs/stats.png 30 | :alt: Segmentation of a toy problem. 31 | :align: center 32 | 33 | 34 | More examples can be found in the Jupyter notebooks for a `toy problem `_ or for a `RFI problem `_. 35 | Further code is stored in the `scripts `_ folder. 36 | -------------------------------------------------------------------------------- /postBuild: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # NOTE: this file needs to be executable in order to work. 4 | python setup.py install 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | click 2 | numpy 3 | Pillow 4 | tensorflow>=1.0.0 5 | matplotlib -------------------------------------------------------------------------------- /scripts/.gitignore: -------------------------------------------------------------------------------- 1 | /prediction/ 2 | /unet_trained* 3 | /daint_unet_trained_rfi_bleien/ 4 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeret/tf_unet/0dcdf2ff1ebcc2ee59997d127a0c0be847168884/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/launcher.py: -------------------------------------------------------------------------------- 1 | # tf_unet is free software: you can redistribute it and/or modify 2 | # it under the terms of the GNU General Public License as published by 3 | # the Free Software Foundation, either version 3 of the License, or 4 | # (at your option) any later version. 5 | # 6 | # tf_unet is distributed in the hope that it will be useful, 7 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 8 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 9 | # GNU General Public License for more details. 10 | # 11 | # You should have received a copy of the GNU General Public License 12 | # along with tf_unet. If not, see . 13 | 14 | 15 | ''' 16 | Created on Jul 28, 2016 17 | 18 | author: jakeret 19 | 20 | Trains a tf_unet network to segment circles in noisy images. 21 | ''' 22 | 23 | from __future__ import print_function, division, absolute_import, unicode_literals 24 | import numpy as np 25 | from tf_unet import image_gen 26 | from tf_unet import unet 27 | from tf_unet import util 28 | 29 | 30 | if __name__ == '__main__': 31 | np.random.seed(98765) 32 | 33 | generator = image_gen.GrayScaleDataProvider(nx=572, ny=572, cnt=20, rectangles=False) 34 | 35 | net = unet.Unet(channels=generator.channels, 36 | n_class=generator.n_class, 37 | layers=3, 38 | features_root=16) 39 | 40 | trainer = unet.Trainer(net, optimizer="momentum", opt_kwargs=dict(momentum=0.2)) 41 | path = trainer.train(generator, "./unet_trained", 42 | training_iters=32, 43 | epochs=5, 44 | dropout=0.75,# probability to keep units 45 | display_step=2) 46 | 47 | x_test, y_test = generator(4) 48 | prediction = net.predict(path, x_test) 49 | 50 | print("Testing error rate: {:.2f}%".format(unet.error_rate(prediction, 51 | util.crop_to_shape(y_test, prediction.shape)))) 52 | -------------------------------------------------------------------------------- /scripts/rfi_launcher.py: -------------------------------------------------------------------------------- 1 | # tf_unet is free software: you can redistribute it and/or modify 2 | # it under the terms of the GNU General Public License as published by 3 | # the Free Software Foundation, either version 3 of the License, or 4 | # (at your option) any later version. 5 | # 6 | # tf_unet is distributed in the hope that it will be useful, 7 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 8 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 9 | # GNU General Public License for more details. 10 | # 11 | # You should have received a copy of the GNU General Public License 12 | # along with tf_unet. If not, see . 13 | 14 | 15 | ''' 16 | Created on Jul 28, 2016 17 | 18 | author: jakeret 19 | 20 | Trains a tf_unet network to segment radio frequency interference pattern. 21 | Requires data from the Bleien Observatory or a HIDE&SEEK simulation. 22 | ''' 23 | 24 | from __future__ import print_function, division, absolute_import, unicode_literals 25 | import glob 26 | import click 27 | import h5py 28 | import numpy as np 29 | 30 | from tf_unet import unet 31 | from tf_unet import util 32 | from tf_unet.image_util import BaseDataProvider 33 | 34 | 35 | @click.command() 36 | @click.option('--data_root', default="./bleien_data") 37 | @click.option('--output_path', default="./daint_unet_trained_rfi_bleien") 38 | @click.option('--training_iters', default=32) 39 | @click.option('--epochs', default=100) 40 | @click.option('--restore', default=False) 41 | @click.option('--layers', default=5) 42 | @click.option('--features_root', default=64) 43 | def launch(data_root, output_path, training_iters, epochs, restore, layers, features_root): 44 | print("Using data from: %s"%data_root) 45 | data_provider = DataProvider(600, glob.glob(data_root+"/*")) 46 | 47 | net = unet.Unet(channels=data_provider.channels, 48 | n_class=data_provider.n_class, 49 | layers=layers, 50 | features_root=features_root, 51 | cost_kwargs=dict(regularizer=0.001), 52 | ) 53 | 54 | path = output_path if restore else util.create_training_path(output_path) 55 | trainer = unet.Trainer(net, optimizer="momentum", opt_kwargs=dict(momentum=0.2)) 56 | path = trainer.train(data_provider, path, 57 | training_iters=training_iters, 58 | epochs=epochs, 59 | dropout=0.5, 60 | display_step=2, 61 | restore=restore) 62 | 63 | x_test, y_test = data_provider(1) 64 | prediction = net.predict(path, x_test) 65 | 66 | print("Testing error rate: {:.2f}%".format(unet.error_rate(prediction, util.crop_to_shape(y_test, prediction.shape)))) 67 | 68 | 69 | class DataProvider(BaseDataProvider): 70 | """ 71 | Extends the BaseDataProvider to randomly select the next 72 | data chunk 73 | """ 74 | 75 | channels = 1 76 | n_class = 2 77 | 78 | def __init__(self, nx, files, a_min=30, a_max=210): 79 | super(DataProvider, self).__init__(a_min, a_max) 80 | self.nx = nx 81 | self.files = files 82 | 83 | assert len(files) > 0, "No training files" 84 | print("Number of files used: %s"%len(files)) 85 | self._cylce_file() 86 | 87 | def _read_chunck(self): 88 | with h5py.File(self.files[self.file_idx], "r") as fp: 89 | nx = fp["data"].shape[1] 90 | idx = np.random.randint(0, nx - self.nx) 91 | 92 | sl = slice(idx, (idx+self.nx)) 93 | data = fp["data"][:, sl] 94 | rfi = fp["mask"][:, sl] 95 | return data, rfi 96 | 97 | def _next_data(self): 98 | data, rfi = self._read_chunck() 99 | nx = data.shape[1] 100 | while nx < self.nx: 101 | self._cylce_file() 102 | data, rfi = self._read_chunck() 103 | nx = data.shape[1] 104 | 105 | return data, rfi 106 | 107 | def _cylce_file(self): 108 | self.file_idx = np.random.choice(len(self.files)) 109 | 110 | 111 | if __name__ == '__main__': 112 | launch() 113 | -------------------------------------------------------------------------------- /scripts/ufig_launcher.py: -------------------------------------------------------------------------------- 1 | # tf_unet is free software: you can redistribute it and/or modify 2 | # it under the terms of the GNU General Public License as published by 3 | # the Free Software Foundation, either version 3 of the License, or 4 | # (at your option) any later version. 5 | # 6 | # tf_unet is distributed in the hope that it will be useful, 7 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 8 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 9 | # GNU General Public License for more details. 10 | # 11 | # You should have received a copy of the GNU General Public License 12 | # along with tf_unet. If not, see . 13 | 14 | 15 | ''' 16 | Created on Jul 28, 2016 17 | 18 | author: jakeret 19 | 20 | Trains a tf_unet network to segment stars and galaxies in a wide field image. 21 | Requires data from a UFIG simulation. 22 | ''' 23 | 24 | from __future__ import print_function, division, absolute_import, unicode_literals 25 | import click 26 | import numpy as np 27 | 28 | from scipy.ndimage import gaussian_filter 29 | import h5py 30 | 31 | from tf_unet import unet 32 | from tf_unet import util 33 | from tf_unet.image_util import BaseDataProvider 34 | 35 | 36 | @click.command() 37 | @click.option('--data_root', default="./ufig_images/1.h5") 38 | @click.option('--output_path', default="./unet_trained_ufig") 39 | @click.option('--training_iters', default=20) 40 | @click.option('--epochs', default=10) 41 | @click.option('--restore', default=False) 42 | @click.option('--layers', default=3) 43 | @click.option('--features_root', default=16) 44 | def launch(data_root, output_path, training_iters, epochs, restore, layers, features_root): 45 | data_provider = DataProvider(572, data_root) 46 | 47 | data, label = data_provider(1) 48 | weights = None#(1/3) / (label.sum(axis=2).sum(axis=1).sum(axis=0) / data.size) 49 | 50 | net = unet.Unet(channels=data_provider.channels, 51 | n_class=data_provider.n_class, 52 | layers=layers, 53 | features_root=features_root, 54 | cost_kwargs=dict(regularizer=0.001, 55 | class_weights=weights), 56 | ) 57 | 58 | path = output_path if restore else util.create_training_path(output_path) 59 | 60 | trainer = unet.Trainer(net, optimizer="adam", opt_kwargs=dict(beta1=0.91)) 61 | path = trainer.train(data_provider, path, 62 | training_iters=training_iters, 63 | epochs=epochs, 64 | dropout=0.5, 65 | display_step=2, 66 | restore=restore) 67 | 68 | prediction = net.predict(path, data) 69 | 70 | print("Testing error rate: {:.2f}%".format(unet.error_rate(prediction, util.crop_to_shape(label, prediction.shape)))) 71 | 72 | 73 | class DataProvider(BaseDataProvider): 74 | """ 75 | Extends the BaseDataProvider to randomly select the next 76 | chunk of the image and randomly applies transformations to the data 77 | """ 78 | 79 | channels = 1 80 | n_class = 3 81 | 82 | def __init__(self, nx, path, a_min=0, a_max=20, sigma=1): 83 | super(DataProvider, self).__init__(a_min, a_max) 84 | self.nx = nx 85 | self.path = path 86 | self.sigma = sigma 87 | 88 | self._load_data() 89 | 90 | def _load_data(self): 91 | with h5py.File(self.path, "r") as fp: 92 | self.image = gaussian_filter(fp["image"].value, self.sigma) 93 | self.gal_map = fp["segmaps/galaxy"].value 94 | self.star_map = fp["segmaps/star"].value 95 | 96 | def _transpose_3d(self, a): 97 | return np.stack([a[..., i].T for i in range(a.shape[2])], axis=2) 98 | 99 | def _post_process(self, data, labels): 100 | op = np.random.randint(0, 4) 101 | if op == 0: 102 | if np.random.randint(0, 2) == 0: 103 | data, labels = self._transpose_3d(data[:,:,np.newaxis]), self._transpose_3d(labels) 104 | else: 105 | data, labels = np.rot90(data, op), np.rot90(labels, op) 106 | 107 | return data, labels 108 | 109 | def _next_data(self): 110 | ix = np.random.randint(0, self.image.shape[0] - self.nx) 111 | iy = np.random.randint(0, self.image.shape[1] - self.nx) 112 | 113 | slx = slice(ix, ix+self.nx) 114 | sly = slice(iy, iy+self.nx) 115 | 116 | data = self.image[slx, sly] 117 | gal_seg = self.gal_map[slx, sly] 118 | star_seg = self.star_map[slx, sly] 119 | 120 | labels = np.zeros((self.nx, self.nx, self.n_class), dtype=np.float32) 121 | labels[..., 1] = np.clip(gal_seg, 0, 1) 122 | labels[..., 2] = np.clip(star_seg, 0, 1) 123 | labels[..., 0] = (1+np.clip(labels[...,1] + labels[...,2], 0, 1))%2 124 | 125 | return data, labels 126 | 127 | 128 | if __name__ == '__main__': 129 | launch() 130 | -------------------------------------------------------------------------------- /scripts/ultrasound_launcher.py: -------------------------------------------------------------------------------- 1 | # tf_unet is free software: you can redistribute it and/or modify 2 | # it under the terms of the GNU General Public License as published by 3 | # the Free Software Foundation, either version 3 of the License, or 4 | # (at your option) any later version. 5 | # 6 | # tf_unet is distributed in the hope that it will be useful, 7 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 8 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 9 | # GNU General Public License for more details. 10 | # 11 | # You should have received a copy of the GNU General Public License 12 | # along with tf_unet. If not, see . 13 | 14 | 15 | ''' 16 | Created on Jul 28, 2016 17 | 18 | author: jakeret 19 | 20 | Trains a tf_unet network to segment nerves in the Ultrasound Kaggle Dataset. 21 | Requires the Kaggle dataset. 22 | ''' 23 | 24 | from __future__ import print_function, division, absolute_import, unicode_literals 25 | import os 26 | import click 27 | import numpy as np 28 | from PIL import Image 29 | 30 | 31 | from tf_unet import unet 32 | from tf_unet import util 33 | from tf_unet.image_util import ImageDataProvider 34 | 35 | IMG_SIZE = (290, 210) 36 | 37 | 38 | @click.command() 39 | @click.option('--data_root', default="../../ultrasound/train") 40 | @click.option('--output_path', default="./unet_trained_ultrasound") 41 | @click.option('--training_iters', default=20) 42 | @click.option('--epochs', default=100) 43 | @click.option('--restore', default=False) 44 | @click.option('--layers', default=3) 45 | @click.option('--features_root', default=32) 46 | def launch(data_root, output_path, training_iters, epochs, restore, layers, features_root): 47 | print("Using data from: %s"%data_root) 48 | 49 | if not os.path.exists(data_root): 50 | raise IOError("Kaggle Ultrasound Dataset not found") 51 | 52 | data_provider = DataProvider(search_path=data_root + "/*.tif", 53 | mean=100, 54 | std=56) 55 | 56 | net = unet.Unet(channels=data_provider.channels, 57 | n_class=data_provider.n_class, 58 | layers=layers, 59 | features_root=features_root, 60 | #cost="dice_coefficient", 61 | ) 62 | 63 | path = output_path if restore else util.create_training_path(output_path) 64 | 65 | trainer = unet.Trainer(net, batch_size=1, norm_grads=False, optimizer="adam") 66 | path = trainer.train(data_provider, path, 67 | training_iters=training_iters, 68 | epochs=epochs, 69 | dropout=0.5, 70 | display_step=2, 71 | restore=restore) 72 | 73 | x_test, y_test = data_provider(1) 74 | prediction = net.predict(path, x_test) 75 | 76 | print("Testing error rate: {:.2f}%".format(unet.error_rate(prediction, util.crop_to_shape(y_test, prediction.shape)))) 77 | 78 | 79 | class DataProvider(ImageDataProvider): 80 | """ 81 | Extends the default ImageDataProvider to randomly select the next 82 | image and ensures that only data sets are used where the mask is not empty. 83 | The data then gets mean and std adjusted 84 | """ 85 | 86 | def __init__(self, mean, std, *args, **kwargs): 87 | super(DataProvider, self).__init__(*args, **kwargs) 88 | self.mean = mean 89 | self.std = std 90 | 91 | def _next_data(self): 92 | data, mask = super(DataProvider, self)._next_data() 93 | while mask.sum() == 0: 94 | self._cylce_file() 95 | data, mask = super(DataProvider, self)._next_data() 96 | 97 | return data, mask 98 | 99 | def _process_data(self, data): 100 | data -= self.mean 101 | data /= self.std 102 | 103 | return data 104 | 105 | def _load_file(self, path, dtype=np.float32): 106 | image = Image.open(path) 107 | return np.array(image.resize(IMG_SIZE), dtype) 108 | 109 | def _cylce_file(self): 110 | self.file_idx = np.random.choice(len(self.data_files)) 111 | 112 | 113 | if __name__ == '__main__': 114 | launch() 115 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | from setuptools import find_packages 5 | 6 | try: 7 | from setuptools import setup 8 | except ImportError: 9 | from distutils.core import setup 10 | 11 | 12 | readme = open('README.rst').read() 13 | history = open('HISTORY.rst').read().replace('.. :changelog:', '') 14 | 15 | requires = [] #during runtime 16 | tests_require=['pytest>=2.3'] #for testing 17 | 18 | PACKAGE_PATH = os.path.abspath(os.path.join(__file__, os.pardir)) 19 | 20 | setup( 21 | name='tf_unet', 22 | version='0.1.2', 23 | description='Unet TensorFlow implementation', 24 | long_description=readme + '\n\n' + history, 25 | author='Joel Akeret', 26 | url='https://github.com/jakeret/tf_unet', 27 | packages=find_packages(PACKAGE_PATH, "test"), 28 | package_dir={'tf_unet': 'tf_unet'}, 29 | include_package_data=True, 30 | install_requires=requires, 31 | license='GPLv3', 32 | zip_safe=False, 33 | keywords='tf_unet', 34 | classifiers=[ 35 | 'Development Status :: 2 - Pre-Alpha', 36 | "Intended Audience :: Science/Research", 37 | 'Intended Audience :: Developers', 38 | "License :: OSI Approved :: GNU General Public License v3 or later (GPLv3+)", 39 | 'Natural Language :: English', 40 | 'Programming Language :: Python :: 2.7', 41 | 'Programming Language :: Python :: 3', 42 | ], 43 | tests_require=tests_require, 44 | ) 45 | -------------------------------------------------------------------------------- /tf_unet/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'Joel Akeret' 2 | __version__ = '0.1.2' 3 | __credits__ = 'ETH Zurich, Institute for Astronomy' 4 | -------------------------------------------------------------------------------- /tf_unet/image_gen.py: -------------------------------------------------------------------------------- 1 | # tf_unet is free software: you can redistribute it and/or modify 2 | # it under the terms of the GNU General Public License as published by 3 | # the Free Software Foundation, either version 3 of the License, or 4 | # (at your option) any later version. 5 | # 6 | # tf_unet is distributed in the hope that it will be useful, 7 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 8 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 9 | # GNU General Public License for more details. 10 | # 11 | # You should have received a copy of the GNU General Public License 12 | # along with tf_unet. If not, see . 13 | 14 | 15 | ''' 16 | Toy example, generates images at random that can be used for training 17 | 18 | Created on Jul 28, 2016 19 | 20 | author: jakeret 21 | ''' 22 | from __future__ import print_function, division, absolute_import, unicode_literals 23 | 24 | import numpy as np 25 | from tf_unet.image_util import BaseDataProvider 26 | 27 | class GrayScaleDataProvider(BaseDataProvider): 28 | channels = 1 29 | n_class = 2 30 | 31 | def __init__(self, nx, ny, **kwargs): 32 | super(GrayScaleDataProvider, self).__init__() 33 | self.nx = nx 34 | self.ny = ny 35 | self.kwargs = kwargs 36 | rect = kwargs.get("rectangles", False) 37 | if rect: 38 | self.n_class=3 39 | 40 | def _next_data(self): 41 | return create_image_and_label(self.nx, self.ny, **self.kwargs) 42 | 43 | class RgbDataProvider(BaseDataProvider): 44 | channels = 3 45 | n_class = 2 46 | 47 | def __init__(self, nx, ny, **kwargs): 48 | super(RgbDataProvider, self).__init__() 49 | self.nx = nx 50 | self.ny = ny 51 | self.kwargs = kwargs 52 | rect = kwargs.get("rectangles", False) 53 | if rect: 54 | self.n_class=3 55 | 56 | 57 | def _next_data(self): 58 | data, label = create_image_and_label(self.nx, self.ny, **self.kwargs) 59 | return to_rgb(data), label 60 | 61 | def create_image_and_label(nx,ny, cnt = 10, r_min = 5, r_max = 50, border = 92, sigma = 20, rectangles=False): 62 | 63 | 64 | image = np.ones((nx, ny, 1)) 65 | label = np.zeros((nx, ny, 3), dtype=np.bool) 66 | mask = np.zeros((nx, ny), dtype=np.bool) 67 | for _ in range(cnt): 68 | a = np.random.randint(border, nx-border) 69 | b = np.random.randint(border, ny-border) 70 | r = np.random.randint(r_min, r_max) 71 | h = np.random.randint(1,255) 72 | 73 | y,x = np.ogrid[-a:nx-a, -b:ny-b] 74 | m = x*x + y*y <= r*r 75 | mask = np.logical_or(mask, m) 76 | 77 | image[m] = h 78 | 79 | label[mask, 1] = 1 80 | 81 | if rectangles: 82 | mask = np.zeros((nx, ny), dtype=np.bool) 83 | for _ in range(cnt//2): 84 | a = np.random.randint(nx) 85 | b = np.random.randint(ny) 86 | r = np.random.randint(r_min, r_max) 87 | h = np.random.randint(1,255) 88 | 89 | m = np.zeros((nx, ny), dtype=np.bool) 90 | m[a:a+r, b:b+r] = True 91 | mask = np.logical_or(mask, m) 92 | image[m] = h 93 | 94 | label[mask, 2] = 1 95 | 96 | label[..., 0] = ~(np.logical_or(label[...,1], label[...,2])) 97 | 98 | image += np.random.normal(scale=sigma, size=image.shape) 99 | image -= np.amin(image) 100 | image /= np.amax(image) 101 | 102 | if rectangles: 103 | return image, label 104 | else: 105 | return image, label[..., 1] 106 | 107 | 108 | 109 | 110 | def to_rgb(img): 111 | img = img.reshape(img.shape[0], img.shape[1]) 112 | img[np.isnan(img)] = 0 113 | img -= np.amin(img) 114 | img /= np.amax(img) 115 | blue = np.clip(4*(0.75-img), 0, 1) 116 | red = np.clip(4*(img-0.25), 0, 1) 117 | green= np.clip(44*np.fabs(img-0.5)-1., 0, 1) 118 | rgb = np.stack((red, green, blue), axis=2) 119 | return rgb 120 | 121 | -------------------------------------------------------------------------------- /tf_unet/image_util.py: -------------------------------------------------------------------------------- 1 | # tf_unet is free software: you can redistribute it and/or modify 2 | # it under the terms of the GNU General Public License as published by 3 | # the Free Software Foundation, either version 3 of the License, or 4 | # (at your option) any later version. 5 | # 6 | # tf_unet is distributed in the hope that it will be useful, 7 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 8 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 9 | # GNU General Public License for more details. 10 | # 11 | # You should have received a copy of the GNU General Public License 12 | # along with tf_unet. If not, see . 13 | 14 | ''' 15 | author: jakeret 16 | ''' 17 | from __future__ import print_function, division, absolute_import, unicode_literals 18 | 19 | import glob 20 | import numpy as np 21 | from PIL import Image 22 | 23 | 24 | class BaseDataProvider(object): 25 | """ 26 | Abstract base class for DataProvider implementation. Subclasses have to 27 | overwrite the `_next_data` method that load the next data and label array. 28 | This implementation automatically clips the data with the given min/max and 29 | normalizes the values to (0,1]. To change this behavoir the `_process_data` 30 | method can be overwritten. To enable some post processing such as data 31 | augmentation the `_post_process` method can be overwritten. 32 | 33 | :param a_min: (optional) min value used for clipping 34 | :param a_max: (optional) max value used for clipping 35 | 36 | """ 37 | 38 | channels = 1 39 | n_class = 2 40 | 41 | def __init__(self, a_min=None, a_max=None): 42 | self.a_min = a_min if a_min is not None else -np.inf 43 | self.a_max = a_max if a_min is not None else np.inf 44 | 45 | def _load_data_and_label(self): 46 | data, label = self._next_data() 47 | 48 | train_data = self._process_data(data) 49 | labels = self._process_labels(label) 50 | 51 | train_data, labels = self._post_process(train_data, labels) 52 | 53 | nx = train_data.shape[1] 54 | ny = train_data.shape[0] 55 | 56 | return train_data.reshape(1, ny, nx, self.channels), labels.reshape(1, ny, nx, self.n_class), 57 | 58 | def _process_labels(self, label): 59 | if self.n_class == 2: 60 | nx = label.shape[1] 61 | ny = label.shape[0] 62 | labels = np.zeros((ny, nx, self.n_class), dtype=np.float32) 63 | 64 | # It is the responsibility of the child class to make sure that the label 65 | # is a boolean array, but we a chech here just in case. 66 | if label.dtype != 'bool': 67 | label = label.astype(np.bool) 68 | 69 | labels[..., 1] = label 70 | labels[..., 0] = ~label 71 | return labels 72 | 73 | return label 74 | 75 | def _process_data(self, data): 76 | # normalization 77 | data = np.clip(np.fabs(data), self.a_min, self.a_max) 78 | data -= np.amin(data) 79 | 80 | if np.amax(data) != 0: 81 | data /= np.amax(data) 82 | 83 | return data 84 | 85 | def _post_process(self, data, labels): 86 | """ 87 | Post processing hook that can be used for data augmentation 88 | 89 | :param data: the data array 90 | :param labels: the label array 91 | """ 92 | return data, labels 93 | 94 | def __call__(self, n): 95 | train_data, labels = self._load_data_and_label() 96 | nx = train_data.shape[1] 97 | ny = train_data.shape[2] 98 | 99 | X = np.zeros((n, nx, ny, self.channels)) 100 | Y = np.zeros((n, nx, ny, self.n_class)) 101 | 102 | X[0] = train_data 103 | Y[0] = labels 104 | for i in range(1, n): 105 | train_data, labels = self._load_data_and_label() 106 | X[i] = train_data 107 | Y[i] = labels 108 | 109 | return X, Y 110 | 111 | 112 | class SimpleDataProvider(BaseDataProvider): 113 | """ 114 | A simple data provider for numpy arrays. 115 | Assumes that the data and label are numpy array with the dimensions 116 | data `[n, X, Y, channels]`, label `[n, X, Y, classes]`. Where 117 | `n` is the number of images, `X`, `Y` the size of the image. 118 | 119 | :param data: data numpy array. Shape=[n, X, Y, channels] 120 | :param label: label numpy array. Shape=[n, X, Y, classes] 121 | :param a_min: (optional) min value used for clipping 122 | :param a_max: (optional) max value used for clipping 123 | 124 | """ 125 | 126 | def __init__(self, data, label, a_min=None, a_max=None): 127 | super(SimpleDataProvider, self).__init__(a_min, a_max) 128 | self.data = data 129 | self.label = label 130 | self.file_count = data.shape[0] 131 | self.n_class = label.shape[-1] 132 | self.channels = data.shape[-1] 133 | 134 | def _next_data(self): 135 | idx = np.random.choice(self.file_count) 136 | return self.data[idx], self.label[idx] 137 | 138 | 139 | class ImageDataProvider(BaseDataProvider): 140 | """ 141 | Generic data provider for images, supports gray scale and colored images. 142 | Assumes that the data images and label images are stored in the same folder 143 | and that the labels have a different file suffix 144 | e.g. 'train/fish_1.tif' and 'train/fish_1_mask.tif' 145 | Number of pixels in x and y of the images and masks should be even. 146 | 147 | Usage: 148 | data_provider = ImageDataProvider("..fishes/train/*.tif") 149 | 150 | :param search_path: a glob search pattern to find all data and label images 151 | :param a_min: (optional) min value used for clipping 152 | :param a_max: (optional) max value used for clipping 153 | :param data_suffix: suffix pattern for the data images. Default '.tif' 154 | :param mask_suffix: suffix pattern for the label images. Default '_mask.tif' 155 | :param shuffle_data: if the order of the loaded file path should be randomized. Default 'True' 156 | 157 | """ 158 | 159 | def __init__(self, search_path, a_min=None, a_max=None, data_suffix=".tif", mask_suffix='_mask.tif', shuffle_data=True): 160 | super(ImageDataProvider, self).__init__(a_min, a_max) 161 | self.data_suffix = data_suffix 162 | self.mask_suffix = mask_suffix 163 | self.file_idx = -1 164 | self.shuffle_data = shuffle_data 165 | 166 | self.data_files = self._find_data_files(search_path) 167 | 168 | if self.shuffle_data: 169 | np.random.shuffle(self.data_files) 170 | 171 | assert len(self.data_files) > 0, "No training files" 172 | print("Number of files used: %s" % len(self.data_files)) 173 | 174 | image_path = self.data_files[0] 175 | label_path = image_path.replace(self.data_suffix, self.mask_suffix) 176 | img = self._load_file(image_path) 177 | mask = self._load_file(label_path) 178 | self.channels = 1 if len(img.shape) == 2 else img.shape[-1] 179 | self.n_class = 2 if len(mask.shape) == 2 else mask.shape[-1] 180 | 181 | print("Number of channels: %s"%self.channels) 182 | print("Number of classes: %s"%self.n_class) 183 | 184 | def _find_data_files(self, search_path): 185 | all_files = glob.glob(search_path) 186 | return [name for name in all_files if self.data_suffix in name and not self.mask_suffix in name] 187 | 188 | def _load_file(self, path, dtype=np.float32): 189 | return np.array(Image.open(path), dtype) 190 | 191 | def _cylce_file(self): 192 | self.file_idx += 1 193 | if self.file_idx >= len(self.data_files): 194 | self.file_idx = 0 195 | if self.shuffle_data: 196 | np.random.shuffle(self.data_files) 197 | 198 | def _next_data(self): 199 | self._cylce_file() 200 | image_name = self.data_files[self.file_idx] 201 | label_name = image_name.replace(self.data_suffix, self.mask_suffix) 202 | 203 | img = self._load_file(image_name, np.float32) 204 | label = self._load_file(label_name, np.bool) 205 | 206 | return img,label 207 | -------------------------------------------------------------------------------- /tf_unet/layers.py: -------------------------------------------------------------------------------- 1 | # tf_unet is free software: you can redistribute it and/or modify 2 | # it under the terms of the GNU General Public License as published by 3 | # the Free Software Foundation, either version 3 of the License, or 4 | # (at your option) any later version. 5 | # 6 | # tf_unet is distributed in the hope that it will be useful, 7 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 8 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 9 | # GNU General Public License for more details. 10 | # 11 | # You should have received a copy of the GNU General Public License 12 | # along with tf_unet. If not, see . 13 | 14 | 15 | ''' 16 | Created on Aug 19, 2016 17 | 18 | author: jakeret 19 | ''' 20 | from __future__ import print_function, division, absolute_import, unicode_literals 21 | 22 | import tensorflow as tf 23 | 24 | def weight_variable(shape, stddev=0.1, name="weight"): 25 | initial = tf.truncated_normal(shape, stddev=stddev) 26 | return tf.Variable(initial, name=name) 27 | 28 | def weight_variable_devonc(shape, stddev=0.1, name="weight_devonc"): 29 | return tf.Variable(tf.truncated_normal(shape, stddev=stddev), name=name) 30 | 31 | def bias_variable(shape, name="bias"): 32 | initial = tf.constant(0.1, shape=shape) 33 | return tf.Variable(initial, name=name) 34 | 35 | def conv2d(x, W, b, keep_prob_): 36 | with tf.name_scope("conv2d"): 37 | conv_2d = tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='VALID') 38 | conv_2d_b = tf.nn.bias_add(conv_2d, b) 39 | return tf.nn.dropout(conv_2d_b, keep_prob_) 40 | 41 | def deconv2d(x, W,stride): 42 | with tf.name_scope("deconv2d"): 43 | x_shape = tf.shape(x) 44 | output_shape = tf.stack([x_shape[0], x_shape[1]*2, x_shape[2]*2, x_shape[3]//2]) 45 | return tf.nn.conv2d_transpose(x, W, output_shape, strides=[1, stride, stride, 1], padding='VALID', name="conv2d_transpose") 46 | 47 | def max_pool(x,n): 48 | return tf.nn.max_pool(x, ksize=[1, n, n, 1], strides=[1, n, n, 1], padding='VALID') 49 | 50 | def crop_and_concat(x1,x2): 51 | with tf.name_scope("crop_and_concat"): 52 | x1_shape = tf.shape(x1) 53 | x2_shape = tf.shape(x2) 54 | # offsets for the top left corner of the crop 55 | offsets = [0, (x1_shape[1] - x2_shape[1]) // 2, (x1_shape[2] - x2_shape[2]) // 2, 0] 56 | size = [-1, x2_shape[1], x2_shape[2], -1] 57 | x1_crop = tf.slice(x1, offsets, size) 58 | return tf.concat([x1_crop, x2], 3) 59 | 60 | def pixel_wise_softmax(output_map): 61 | with tf.name_scope("pixel_wise_softmax"): 62 | max_axis = tf.reduce_max(output_map, axis=3, keepdims=True) 63 | exponential_map = tf.exp(output_map - max_axis) 64 | normalize = tf.reduce_sum(exponential_map, axis=3, keepdims=True) 65 | return exponential_map / normalize 66 | 67 | def cross_entropy(y_,output_map): 68 | return -tf.reduce_mean(y_*tf.log(tf.clip_by_value(output_map,1e-10,1.0)), name="cross_entropy") 69 | -------------------------------------------------------------------------------- /tf_unet/unet.py: -------------------------------------------------------------------------------- 1 | # tf_unet is free software: you can redistribute it and/or modify 2 | # it under the terms of the GNU General Public License as published by 3 | # the Free Software Foundation, either version 3 of the License, or 4 | # (at your option) any later version. 5 | # 6 | # tf_unet is distributed in the hope that it will be useful, 7 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 8 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 9 | # GNU General Public License for more details. 10 | # 11 | # You should have received a copy of the GNU General Public License 12 | # along with tf_unet. If not, see . 13 | 14 | 15 | ''' 16 | Created on Jul 28, 2016 17 | 18 | author: jakeret 19 | ''' 20 | from __future__ import print_function, division, absolute_import, unicode_literals 21 | 22 | import os 23 | import shutil 24 | import numpy as np 25 | from collections import OrderedDict 26 | import logging 27 | 28 | import tensorflow as tf 29 | 30 | from tf_unet import util 31 | from tf_unet.layers import (weight_variable, weight_variable_devonc, bias_variable, 32 | conv2d, deconv2d, max_pool, crop_and_concat, pixel_wise_softmax, 33 | cross_entropy) 34 | 35 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') 36 | 37 | 38 | def create_conv_net(x, keep_prob, channels, n_class, layers=3, features_root=16, filter_size=3, pool_size=2, 39 | summaries=True): 40 | """ 41 | Creates a new convolutional unet for the given parametrization. 42 | 43 | :param x: input tensor, shape [?,nx,ny,channels] 44 | :param keep_prob: dropout probability tensor 45 | :param channels: number of channels in the input image 46 | :param n_class: number of output labels 47 | :param layers: number of layers in the net 48 | :param features_root: number of features in the first layer 49 | :param filter_size: size of the convolution filter 50 | :param pool_size: size of the max pooling operation 51 | :param summaries: Flag if summaries should be created 52 | """ 53 | 54 | logging.info( 55 | "Layers {layers}, features {features}, filter size {filter_size}x{filter_size}, pool size: {pool_size}x{pool_size}".format( 56 | layers=layers, 57 | features=features_root, 58 | filter_size=filter_size, 59 | pool_size=pool_size)) 60 | 61 | # Placeholder for the input image 62 | with tf.name_scope("preprocessing"): 63 | nx = tf.shape(x)[1] 64 | ny = tf.shape(x)[2] 65 | x_image = tf.reshape(x, tf.stack([-1, nx, ny, channels])) 66 | in_node = x_image 67 | batch_size = tf.shape(x_image)[0] 68 | 69 | weights = [] 70 | biases = [] 71 | convs = [] 72 | pools = OrderedDict() 73 | deconv = OrderedDict() 74 | dw_h_convs = OrderedDict() 75 | up_h_convs = OrderedDict() 76 | 77 | in_size = 1000 78 | size = in_size 79 | # down layers 80 | for layer in range(0, layers): 81 | with tf.name_scope("down_conv_{}".format(str(layer))): 82 | features = 2 ** layer * features_root 83 | stddev = np.sqrt(2 / (filter_size ** 2 * features)) 84 | if layer == 0: 85 | w1 = weight_variable([filter_size, filter_size, channels, features], stddev, name="w1") 86 | else: 87 | w1 = weight_variable([filter_size, filter_size, features // 2, features], stddev, name="w1") 88 | 89 | w2 = weight_variable([filter_size, filter_size, features, features], stddev, name="w2") 90 | b1 = bias_variable([features], name="b1") 91 | b2 = bias_variable([features], name="b2") 92 | 93 | conv1 = conv2d(in_node, w1, b1, keep_prob) 94 | tmp_h_conv = tf.nn.relu(conv1) 95 | conv2 = conv2d(tmp_h_conv, w2, b2, keep_prob) 96 | dw_h_convs[layer] = tf.nn.relu(conv2) 97 | 98 | weights.append((w1, w2)) 99 | biases.append((b1, b2)) 100 | convs.append((conv1, conv2)) 101 | 102 | size -= 2 * 2 * (filter_size // 2) # valid conv 103 | if layer < layers - 1: 104 | pools[layer] = max_pool(dw_h_convs[layer], pool_size) 105 | in_node = pools[layer] 106 | size /= pool_size 107 | 108 | in_node = dw_h_convs[layers - 1] 109 | 110 | # up layers 111 | for layer in range(layers - 2, -1, -1): 112 | with tf.name_scope("up_conv_{}".format(str(layer))): 113 | features = 2 ** (layer + 1) * features_root 114 | stddev = np.sqrt(2 / (filter_size ** 2 * features)) 115 | 116 | wd = weight_variable_devonc([pool_size, pool_size, features // 2, features], stddev, name="wd") 117 | bd = bias_variable([features // 2], name="bd") 118 | h_deconv = tf.nn.relu(deconv2d(in_node, wd, pool_size) + bd) 119 | h_deconv_concat = crop_and_concat(dw_h_convs[layer], h_deconv) 120 | deconv[layer] = h_deconv_concat 121 | 122 | w1 = weight_variable([filter_size, filter_size, features, features // 2], stddev, name="w1") 123 | w2 = weight_variable([filter_size, filter_size, features // 2, features // 2], stddev, name="w2") 124 | b1 = bias_variable([features // 2], name="b1") 125 | b2 = bias_variable([features // 2], name="b2") 126 | 127 | conv1 = conv2d(h_deconv_concat, w1, b1, keep_prob) 128 | h_conv = tf.nn.relu(conv1) 129 | conv2 = conv2d(h_conv, w2, b2, keep_prob) 130 | in_node = tf.nn.relu(conv2) 131 | up_h_convs[layer] = in_node 132 | 133 | weights.append((w1, w2)) 134 | biases.append((b1, b2)) 135 | convs.append((conv1, conv2)) 136 | 137 | size *= pool_size 138 | size -= 2 * 2 * (filter_size // 2) # valid conv 139 | 140 | # Output Map 141 | with tf.name_scope("output_map"): 142 | weight = weight_variable([1, 1, features_root, n_class], stddev) 143 | bias = bias_variable([n_class], name="bias") 144 | conv = conv2d(in_node, weight, bias, tf.constant(1.0)) 145 | output_map = tf.nn.relu(conv) 146 | up_h_convs["out"] = output_map 147 | 148 | if summaries: 149 | with tf.name_scope("summaries"): 150 | for i, (c1, c2) in enumerate(convs): 151 | tf.summary.image('summary_conv_%02d_01' % i, get_image_summary(c1)) 152 | tf.summary.image('summary_conv_%02d_02' % i, get_image_summary(c2)) 153 | 154 | for k in pools.keys(): 155 | tf.summary.image('summary_pool_%02d' % k, get_image_summary(pools[k])) 156 | 157 | for k in deconv.keys(): 158 | tf.summary.image('summary_deconv_concat_%02d' % k, get_image_summary(deconv[k])) 159 | 160 | for k in dw_h_convs.keys(): 161 | tf.summary.histogram("dw_convolution_%02d" % k + '/activations', dw_h_convs[k]) 162 | 163 | for k in up_h_convs.keys(): 164 | tf.summary.histogram("up_convolution_%s" % k + '/activations', up_h_convs[k]) 165 | 166 | variables = [] 167 | for w1, w2 in weights: 168 | variables.append(w1) 169 | variables.append(w2) 170 | 171 | for b1, b2 in biases: 172 | variables.append(b1) 173 | variables.append(b2) 174 | 175 | return output_map, variables, int(in_size - size) 176 | 177 | 178 | class Unet(object): 179 | """ 180 | A unet implementation 181 | 182 | :param channels: number of channels in the input image 183 | :param n_class: number of output labels 184 | :param cost: (optional) name of the cost function. Default is 'cross_entropy' 185 | :param cost_kwargs: (optional) kwargs passed to the cost function. See Unet._get_cost for more options 186 | """ 187 | 188 | def __init__(self, channels, n_class, cost="cross_entropy", cost_kwargs={}, **kwargs): 189 | tf.reset_default_graph() 190 | 191 | self.n_class = n_class 192 | self.summaries = kwargs.get("summaries", True) 193 | 194 | self.x = tf.placeholder("float", shape=[None, None, None, channels], name="x") 195 | self.y = tf.placeholder("float", shape=[None, None, None, n_class], name="y") 196 | self.keep_prob = tf.placeholder(tf.float32, name="dropout_probability") # dropout (keep probability) 197 | 198 | logits, self.variables, self.offset = create_conv_net(self.x, self.keep_prob, channels, n_class, **kwargs) 199 | 200 | self.cost = self._get_cost(logits, cost, cost_kwargs) 201 | 202 | self.gradients_node = tf.gradients(self.cost, self.variables) 203 | 204 | with tf.name_scope("cross_entropy"): 205 | self.cross_entropy = cross_entropy(tf.reshape(self.y, [-1, n_class]), 206 | tf.reshape(pixel_wise_softmax(logits), [-1, n_class])) 207 | 208 | with tf.name_scope("results"): 209 | self.predicter = pixel_wise_softmax(logits) 210 | self.correct_pred = tf.equal(tf.argmax(self.predicter, 3), tf.argmax(self.y, 3)) 211 | self.accuracy = tf.reduce_mean(tf.cast(self.correct_pred, tf.float32)) 212 | 213 | def _get_cost(self, logits, cost_name, cost_kwargs): 214 | """ 215 | Constructs the cost function, either cross_entropy, weighted cross_entropy or dice_coefficient. 216 | Optional arguments are: 217 | class_weights: weights for the different classes in case of multi-class imbalance 218 | regularizer: power of the L2 regularizers added to the loss function 219 | """ 220 | 221 | with tf.name_scope("cost"): 222 | flat_logits = tf.reshape(logits, [-1, self.n_class]) 223 | flat_labels = tf.reshape(self.y, [-1, self.n_class]) 224 | if cost_name == "cross_entropy": 225 | class_weights = cost_kwargs.pop("class_weights", None) 226 | 227 | if class_weights is not None: 228 | class_weights = tf.constant(np.array(class_weights, dtype=np.float32)) 229 | 230 | weight_map = tf.multiply(flat_labels, class_weights) 231 | weight_map = tf.reduce_sum(weight_map, axis=1) 232 | 233 | loss_map = tf.nn.softmax_cross_entropy_with_logits_v2(logits=flat_logits, 234 | labels=flat_labels) 235 | weighted_loss = tf.multiply(loss_map, weight_map) 236 | 237 | loss = tf.reduce_mean(weighted_loss) 238 | 239 | else: 240 | loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=flat_logits, 241 | labels=flat_labels)) 242 | elif cost_name == "dice_coefficient": 243 | eps = 1e-5 244 | prediction = pixel_wise_softmax(logits) 245 | intersection = tf.reduce_sum(prediction * self.y) 246 | union = eps + tf.reduce_sum(prediction) + tf.reduce_sum(self.y) 247 | loss = -(2 * intersection / (union)) 248 | 249 | else: 250 | raise ValueError("Unknown cost function: " % cost_name) 251 | 252 | regularizer = cost_kwargs.pop("regularizer", None) 253 | if regularizer is not None: 254 | regularizers = sum([tf.nn.l2_loss(variable) for variable in self.variables]) 255 | loss += (regularizer * regularizers) 256 | 257 | return loss 258 | 259 | def predict(self, model_path, x_test): 260 | """ 261 | Uses the model to create a prediction for the given data 262 | 263 | :param model_path: path to the model checkpoint to restore 264 | :param x_test: Data to predict on. Shape [n, nx, ny, channels] 265 | :returns prediction: The unet prediction Shape [n, px, py, labels] (px=nx-self.offset/2) 266 | """ 267 | 268 | init = tf.global_variables_initializer() 269 | with tf.Session() as sess: 270 | # Initialize variables 271 | sess.run(init) 272 | 273 | # Restore model weights from previously saved model 274 | self.restore(sess, model_path) 275 | 276 | y_dummy = np.empty((x_test.shape[0], x_test.shape[1], x_test.shape[2], self.n_class)) 277 | prediction = sess.run(self.predicter, feed_dict={self.x: x_test, self.y: y_dummy, self.keep_prob: 1.}) 278 | 279 | return prediction 280 | 281 | def save(self, sess, model_path): 282 | """ 283 | Saves the current session to a checkpoint 284 | 285 | :param sess: current session 286 | :param model_path: path to file system location 287 | """ 288 | 289 | saver = tf.train.Saver() 290 | save_path = saver.save(sess, model_path) 291 | return save_path 292 | 293 | def restore(self, sess, model_path): 294 | """ 295 | Restores a session from a checkpoint 296 | 297 | :param sess: current session instance 298 | :param model_path: path to file system checkpoint location 299 | """ 300 | 301 | saver = tf.train.Saver() 302 | saver.restore(sess, model_path) 303 | logging.info("Model restored from file: %s" % model_path) 304 | 305 | 306 | class Trainer(object): 307 | """ 308 | Trains a unet instance 309 | 310 | :param net: the unet instance to train 311 | :param batch_size: size of training batch 312 | :param verification_batch_size: size of verification batch 313 | :param norm_grads: (optional) true if normalized gradients should be added to the summaries 314 | :param optimizer: (optional) name of the optimizer to use (momentum or adam) 315 | :param opt_kwargs: (optional) kwargs passed to the learning rate (momentum opt) and to the optimizer 316 | 317 | """ 318 | 319 | def __init__(self, net, batch_size=1, verification_batch_size = 4, norm_grads=False, optimizer="momentum", opt_kwargs={}): 320 | self.net = net 321 | self.batch_size = batch_size 322 | self.verification_batch_size = verification_batch_size 323 | self.norm_grads = norm_grads 324 | self.optimizer = optimizer 325 | self.opt_kwargs = opt_kwargs 326 | 327 | def _get_optimizer(self, training_iters, global_step): 328 | if self.optimizer == "momentum": 329 | learning_rate = self.opt_kwargs.pop("learning_rate", 0.2) 330 | decay_rate = self.opt_kwargs.pop("decay_rate", 0.95) 331 | momentum = self.opt_kwargs.pop("momentum", 0.2) 332 | 333 | self.learning_rate_node = tf.train.exponential_decay(learning_rate=learning_rate, 334 | global_step=global_step, 335 | decay_steps=training_iters, 336 | decay_rate=decay_rate, 337 | staircase=True) 338 | 339 | optimizer = tf.train.MomentumOptimizer(learning_rate=self.learning_rate_node, momentum=momentum, 340 | **self.opt_kwargs).minimize(self.net.cost, 341 | global_step=global_step) 342 | elif self.optimizer == "adam": 343 | learning_rate = self.opt_kwargs.pop("learning_rate", 0.001) 344 | self.learning_rate_node = tf.Variable(learning_rate, name="learning_rate") 345 | 346 | optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate_node, 347 | **self.opt_kwargs).minimize(self.net.cost, 348 | global_step=global_step) 349 | 350 | return optimizer 351 | 352 | def _initialize(self, training_iters, output_path, restore, prediction_path): 353 | global_step = tf.Variable(0, name="global_step") 354 | 355 | self.norm_gradients_node = tf.Variable(tf.constant(0.0, shape=[len(self.net.gradients_node)]), name="norm_gradients") 356 | 357 | if self.net.summaries and self.norm_grads: 358 | tf.summary.histogram('norm_grads', self.norm_gradients_node) 359 | 360 | tf.summary.scalar('loss', self.net.cost) 361 | tf.summary.scalar('cross_entropy', self.net.cross_entropy) 362 | tf.summary.scalar('accuracy', self.net.accuracy) 363 | 364 | self.optimizer = self._get_optimizer(training_iters, global_step) 365 | tf.summary.scalar('learning_rate', self.learning_rate_node) 366 | 367 | self.summary_op = tf.summary.merge_all() 368 | init = tf.global_variables_initializer() 369 | 370 | self.prediction_path = prediction_path 371 | abs_prediction_path = os.path.abspath(self.prediction_path) 372 | output_path = os.path.abspath(output_path) 373 | 374 | if not restore: 375 | logging.info("Removing '{:}'".format(abs_prediction_path)) 376 | shutil.rmtree(abs_prediction_path, ignore_errors=True) 377 | logging.info("Removing '{:}'".format(output_path)) 378 | shutil.rmtree(output_path, ignore_errors=True) 379 | 380 | if not os.path.exists(abs_prediction_path): 381 | logging.info("Allocating '{:}'".format(abs_prediction_path)) 382 | os.makedirs(abs_prediction_path) 383 | 384 | if not os.path.exists(output_path): 385 | logging.info("Allocating '{:}'".format(output_path)) 386 | os.makedirs(output_path) 387 | 388 | return init 389 | 390 | def train(self, data_provider, output_path, training_iters=10, epochs=100, dropout=0.75, display_step=1, 391 | restore=False, write_graph=False, prediction_path='prediction'): 392 | """ 393 | Lauches the training process 394 | 395 | :param data_provider: callable returning training and verification data 396 | :param output_path: path where to store checkpoints 397 | :param training_iters: number of training mini batch iteration 398 | :param epochs: number of epochs 399 | :param dropout: dropout probability 400 | :param display_step: number of steps till outputting stats 401 | :param restore: Flag if previous model should be restored 402 | :param write_graph: Flag if the computation graph should be written as protobuf file to the output path 403 | :param prediction_path: path where to save predictions on each epoch 404 | """ 405 | save_path = os.path.join(output_path, "model.ckpt") 406 | if epochs == 0: 407 | return save_path 408 | 409 | init = self._initialize(training_iters, output_path, restore, prediction_path) 410 | 411 | with tf.Session() as sess: 412 | if write_graph: 413 | tf.train.write_graph(sess.graph_def, output_path, "graph.pb", False) 414 | 415 | sess.run(init) 416 | 417 | if restore: 418 | ckpt = tf.train.get_checkpoint_state(output_path) 419 | if ckpt and ckpt.model_checkpoint_path: 420 | self.net.restore(sess, ckpt.model_checkpoint_path) 421 | 422 | test_x, test_y = data_provider(self.verification_batch_size) 423 | pred_shape = self.store_prediction(sess, test_x, test_y, "_init") 424 | 425 | summary_writer = tf.summary.FileWriter(output_path, graph=sess.graph) 426 | logging.info("Start optimization") 427 | 428 | avg_gradients = None 429 | for epoch in range(epochs): 430 | total_loss = 0 431 | for step in range((epoch * training_iters), ((epoch + 1) * training_iters)): 432 | batch_x, batch_y = data_provider(self.batch_size) 433 | 434 | # Run optimization op (backprop) 435 | _, loss, lr, gradients = sess.run( 436 | (self.optimizer, self.net.cost, self.learning_rate_node, self.net.gradients_node), 437 | feed_dict={self.net.x: batch_x, 438 | self.net.y: util.crop_to_shape(batch_y, pred_shape), 439 | self.net.keep_prob: dropout}) 440 | 441 | if self.net.summaries and self.norm_grads: 442 | avg_gradients = _update_avg_gradients(avg_gradients, gradients, step) 443 | norm_gradients = [np.linalg.norm(gradient) for gradient in avg_gradients] 444 | self.norm_gradients_node.assign(norm_gradients).eval() 445 | 446 | if step % display_step == 0: 447 | self.output_minibatch_stats(sess, summary_writer, step, batch_x, 448 | util.crop_to_shape(batch_y, pred_shape)) 449 | 450 | total_loss += loss 451 | 452 | self.output_epoch_stats(epoch, total_loss, training_iters, lr) 453 | self.store_prediction(sess, test_x, test_y, "epoch_%s" % epoch) 454 | 455 | save_path = self.net.save(sess, save_path) 456 | logging.info("Optimization Finished!") 457 | 458 | return save_path 459 | 460 | def store_prediction(self, sess, batch_x, batch_y, name): 461 | prediction = sess.run(self.net.predicter, feed_dict={self.net.x: batch_x, 462 | self.net.y: batch_y, 463 | self.net.keep_prob: 1.}) 464 | pred_shape = prediction.shape 465 | 466 | loss = sess.run(self.net.cost, feed_dict={self.net.x: batch_x, 467 | self.net.y: util.crop_to_shape(batch_y, pred_shape), 468 | self.net.keep_prob: 1.}) 469 | 470 | logging.info("Verification error= {:.1f}%, loss= {:.4f}".format(error_rate(prediction, 471 | util.crop_to_shape(batch_y, 472 | prediction.shape)), 473 | loss)) 474 | 475 | img = util.combine_img_prediction(batch_x, batch_y, prediction) 476 | util.save_image(img, "%s/%s.jpg" % (self.prediction_path, name)) 477 | 478 | return pred_shape 479 | 480 | def output_epoch_stats(self, epoch, total_loss, training_iters, lr): 481 | logging.info( 482 | "Epoch {:}, Average loss: {:.4f}, learning rate: {:.4f}".format(epoch, (total_loss / training_iters), lr)) 483 | 484 | def output_minibatch_stats(self, sess, summary_writer, step, batch_x, batch_y): 485 | # Calculate batch loss and accuracy 486 | summary_str, loss, acc, predictions = sess.run([self.summary_op, 487 | self.net.cost, 488 | self.net.accuracy, 489 | self.net.predicter], 490 | feed_dict={self.net.x: batch_x, 491 | self.net.y: batch_y, 492 | self.net.keep_prob: 1.}) 493 | summary_writer.add_summary(summary_str, step) 494 | summary_writer.flush() 495 | logging.info( 496 | "Iter {:}, Minibatch Loss= {:.4f}, Training Accuracy= {:.4f}, Minibatch error= {:.1f}%".format(step, 497 | loss, 498 | acc, 499 | error_rate( 500 | predictions, 501 | batch_y))) 502 | 503 | 504 | def _update_avg_gradients(avg_gradients, gradients, step): 505 | if avg_gradients is None: 506 | avg_gradients = [np.zeros_like(gradient) for gradient in gradients] 507 | for i in range(len(gradients)): 508 | avg_gradients[i] = (avg_gradients[i] * (1.0 - (1.0 / (step + 1)))) + (gradients[i] / (step + 1)) 509 | 510 | return avg_gradients 511 | 512 | 513 | def error_rate(predictions, labels): 514 | """ 515 | Return the error rate based on dense predictions and 1-hot labels. 516 | """ 517 | 518 | return 100.0 - ( 519 | 100.0 * 520 | np.sum(np.argmax(predictions, 3) == np.argmax(labels, 3)) / 521 | (predictions.shape[0] * predictions.shape[1] * predictions.shape[2])) 522 | 523 | 524 | def get_image_summary(img, idx=0): 525 | """ 526 | Make an image summary for 4d tensor image with index idx 527 | """ 528 | 529 | V = tf.slice(img, (0, 0, 0, idx), (1, -1, -1, 1)) 530 | V -= tf.reduce_min(V) 531 | V /= tf.reduce_max(V) 532 | V *= 255 533 | 534 | img_w = tf.shape(img)[1] 535 | img_h = tf.shape(img)[2] 536 | V = tf.reshape(V, tf.stack((img_w, img_h, 1))) 537 | V = tf.transpose(V, (2, 0, 1)) 538 | V = tf.reshape(V, tf.stack((-1, img_w, img_h, 1))) 539 | return V 540 | -------------------------------------------------------------------------------- /tf_unet/util.py: -------------------------------------------------------------------------------- 1 | # tf_unet is free software: you can redistribute it and/or modify 2 | # it under the terms of the GNU General Public License as published by 3 | # the Free Software Foundation, either version 3 of the License, or 4 | # (at your option) any later version. 5 | # 6 | # tf_unet is distributed in the hope that it will be useful, 7 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 8 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 9 | # GNU General Public License for more details. 10 | # 11 | # You should have received a copy of the GNU General Public License 12 | # along with tf_unet. If not, see . 13 | 14 | 15 | ''' 16 | Created on Aug 10, 2016 17 | 18 | author: jakeret 19 | ''' 20 | from __future__ import print_function, division, absolute_import, unicode_literals 21 | 22 | import os 23 | 24 | import numpy as np 25 | from PIL import Image 26 | 27 | def plot_prediction(x_test, y_test, prediction, save=False): 28 | import matplotlib 29 | import matplotlib.pyplot as plt 30 | 31 | test_size = x_test.shape[0] 32 | fig, ax = plt.subplots(test_size, 3, figsize=(12,12), sharey=True, sharex=True) 33 | 34 | x_test = crop_to_shape(x_test, prediction.shape) 35 | y_test = crop_to_shape(y_test, prediction.shape) 36 | 37 | ax = np.atleast_2d(ax) 38 | for i in range(test_size): 39 | cax = ax[i, 0].imshow(x_test[i]) 40 | plt.colorbar(cax, ax=ax[i,0]) 41 | cax = ax[i, 1].imshow(y_test[i, ..., 1]) 42 | plt.colorbar(cax, ax=ax[i,1]) 43 | pred = prediction[i, ..., 1] 44 | pred -= np.amin(pred) 45 | pred /= np.amax(pred) 46 | cax = ax[i, 2].imshow(pred) 47 | plt.colorbar(cax, ax=ax[i,2]) 48 | if i==0: 49 | ax[i, 0].set_title("x") 50 | ax[i, 1].set_title("y") 51 | ax[i, 2].set_title("pred") 52 | fig.tight_layout() 53 | 54 | if save: 55 | fig.savefig(save) 56 | else: 57 | fig.show() 58 | plt.show() 59 | 60 | def to_rgb(img): 61 | """ 62 | Converts the given array into a RGB image. If the number of channels is not 63 | 3 the array is tiled such that it has 3 channels. Finally, the values are 64 | rescaled to [0,255) 65 | 66 | :param img: the array to convert [nx, ny, channels] 67 | 68 | :returns img: the rgb image [nx, ny, 3] 69 | """ 70 | img = np.atleast_3d(img) 71 | channels = img.shape[2] 72 | if channels < 3: 73 | img = np.tile(img, 3) 74 | 75 | img[np.isnan(img)] = 0 76 | img -= np.amin(img) 77 | if np.amax(img) != 0: 78 | img /= np.amax(img) 79 | 80 | img *= 255 81 | return img 82 | 83 | def crop_to_shape(data, shape): 84 | """ 85 | Crops the array to the given image shape by removing the border (expects a tensor of shape [batches, nx, ny, channels]. 86 | 87 | :param data: the array to crop 88 | :param shape: the target shape 89 | """ 90 | diff_nx = (data.shape[1] - shape[1]) 91 | diff_ny = (data.shape[2] - shape[2]) 92 | 93 | offset_nx_left = diff_nx // 2 94 | offset_nx_right = diff_nx - offset_nx_left 95 | offset_ny_left = diff_ny // 2 96 | offset_ny_right = diff_ny - offset_ny_left 97 | 98 | cropped = data[:, offset_nx_left:(-offset_nx_right), offset_ny_left:(-offset_ny_right)] 99 | 100 | assert cropped.shape[1] == shape[1] 101 | assert cropped.shape[2] == shape[2] 102 | return cropped 103 | 104 | def expand_to_shape(data, shape, border=0): 105 | """ 106 | Expands the array to the given image shape by padding it with a border (expects a tensor of shape [batches, nx, ny, channels]. 107 | 108 | :param data: the array to expand 109 | :param shape: the target shape 110 | """ 111 | diff_nx = shape[1] - data.shape[1] 112 | diff_ny = shape[2] - data.shape[2] 113 | 114 | offset_nx_left = diff_nx // 2 115 | offset_nx_right = diff_nx - offset_nx_left 116 | offset_ny_left = diff_ny // 2 117 | offset_ny_right = diff_ny - offset_ny_left 118 | 119 | expanded = np.full(shape, border, dtype=np.float32) 120 | expanded[:, offset_nx_left:(-offset_nx_right), offset_ny_left:(-offset_ny_right)] = data 121 | 122 | return expanded 123 | 124 | def combine_img_prediction(data, gt, pred): 125 | """ 126 | Combines the data, grouth thruth and the prediction into one rgb image 127 | 128 | :param data: the data tensor 129 | :param gt: the ground thruth tensor 130 | :param pred: the prediction tensor 131 | 132 | :returns img: the concatenated rgb image 133 | """ 134 | ny = pred.shape[2] 135 | ch = data.shape[3] 136 | img = np.concatenate((to_rgb(crop_to_shape(data, pred.shape).reshape(-1, ny, ch)), 137 | to_rgb(crop_to_shape(gt[..., 1], pred.shape).reshape(-1, ny, 1)), 138 | to_rgb(pred[..., 1].reshape(-1, ny, 1))), axis=1) 139 | return img 140 | 141 | def save_image(img, path): 142 | """ 143 | Writes the image to disk 144 | 145 | :param img: the rgb image to save 146 | :param path: the target path 147 | """ 148 | Image.fromarray(img.round().astype(np.uint8)).save(path, 'JPEG', dpi=[300,300], quality=90) 149 | 150 | 151 | def create_training_path(output_path, prefix="run_"): 152 | """ 153 | Enumerates a new path using the prefix under the given output_path 154 | :param output_path: the root path 155 | :param prefix: (optional) defaults to `run_` 156 | :return: the generated path as string in form `output_path`/`prefix_` + `` 157 | """ 158 | idx = 0 159 | path = os.path.join(output_path, "{:}{:03d}".format(prefix, idx)) 160 | while os.path.exists(path): 161 | idx += 1 162 | path = os.path.join(output_path, "{:}{:03d}".format(prefix, idx)) 163 | return path 164 | --------------------------------------------------------------------------------