├── .github └── workflows │ └── build.yml ├── .gitignore ├── LICENSE ├── README.md ├── environment.yml ├── main.py ├── requirements.txt ├── sample_analysis.ipynb ├── setup.py └── src ├── augmentations.py ├── constants.py ├── data_utils.py ├── inference.py ├── multi_head_unet.py ├── post_process.py ├── post_process_utils.py ├── spatial_augmenter.py └── viz_utils.py /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Build and Check 2 | # This specifies when the workflow should run. It's set to trigger on any push 3 | # and any pull request to the main branch. 4 | on: 5 | push: 6 | branches: [main] 7 | pull_request: 8 | branches: [main] 9 | 10 | # This ensures that only the latest run for a given branch or workflow is active, 11 | # canceling any in-progress runs if a new one is triggered. 12 | concurrency: 13 | group: ${{ github.workflow }}-${{ github.ref }} 14 | cancel-in-progress: true 15 | 16 | # Defines the job called 'build'. 17 | jobs: 18 | build: 19 | # Specifies the type of runner that the job will execute on. 20 | runs-on: ubuntu-latest 21 | 22 | # A matrix to run jobs across multiple versions of Python. 23 | strategy: 24 | matrix: 25 | python-version: [3.8, 3.9, 3.10.13, 3.11, 3.12.2] 26 | 27 | # Steps define a sequence of tasks that will be executed as part of the job. 28 | steps: 29 | # Checks-out repository under $GITHUB_WORKSPACE, so the workflow can access it. 30 | - uses: actions/checkout@v3 31 | 32 | # Sets up a Python environment with the version specified in the matrix, 33 | # allowing the workflow to execute actions with Python. 34 | - name: Set up Python ${{ matrix.python-version }} 35 | uses: actions/setup-python@v4 36 | with: 37 | python-version: ${{ matrix.python-version }} 38 | 39 | # Installs the necessary dependencies to build and check the Python package. 40 | # Includes pip, wheel, twine, and the build module. 41 | - name: Install dependencies 42 | run: python -m pip install --upgrade pip wheel twine build 43 | 44 | # Builds the package using the Python build module, which creates both source 45 | # distribution and wheel distribution files in the dist/ directory. 46 | - name: Build package 47 | run: python -m build 48 | 49 | # Uses Twine to check the built packages (.whl files) in the dist/ directory, 50 | # ensuring compliance with PyPI standards. 51 | - name: Check package 52 | run: twine check --strict dist/*.whl 53 | 54 | # Uploads the built wheel files as artifacts, which can be downloaded 55 | # after the workflow run completes. 56 | - name: Upload artifacts 57 | uses: actions/upload-artifact@v2 58 | with: 59 | name: wheels 60 | path: dist/*.whl 61 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | convnextv2_tiny_focal_fulldata_0/* 162 | *.out 163 | logs/* 164 | convnextv2_large_focal_fulldata_0/* 165 | tmp.ipynb 166 | convnextv2_base_focal_fulldata_0/* 167 | pannuke_convnextv2_tiny_1/* 168 | get_wsi_sizes.ipynb 169 | main_debug.py 170 | figures/* 171 | *.txt 172 | model_weights/* 173 | cpu_pp_pannuke.sh 174 | cpu_pp.sh 175 | debug_inference.sh 176 | dummy.sh 177 | run_eos_inference.sh 178 | run_inference_container_2.sh 179 | run_inference_container_pannuke.sh 180 | run_inference_container.sh 181 | run_wsi_validation_set_lizard_testo.sh 182 | run_wsi_validation_set_lizard.sh 183 | run_wsi_validation_set_pannuke.sh 184 | get_sizes.py 185 | image_loader.ipynb 186 | lizard_convnextv2_base.zip 187 | lizard_convnextv2_large.zip 188 | lizard_convnextv2_tiny.zip 189 | pannuke_convnextv2_tiny_1.zip 190 | pannuke_convnextv2_tiny_2.zip 191 | pannuke_convnextv2_tiny_3.zip 192 | run_inference_container_gpu.sh 193 | run_inference_container_gpu2.sh 194 | sizes.csv 195 | lizard_convnextv2_base/* 196 | lizard_convnextv2_large/* 197 | lizard_convnextv2_tiny/* 198 | pannuke_convnextv2_tiny_2/* 199 | pannuke_convnextv2_tiny_3/* 200 | sample/* 201 | testo.py 202 | test_out.csv 203 | cpu_pp_pannuke_debug.sh 204 | run_inference_container_pannuke_debug.sh 205 | run_mit_inference.sh 206 | sample_cls.bmp 207 | sample_cls.jpg 208 | sample_he.jpg 209 | testo.sh 210 | debug.py 211 | direct_cpu_pp.sh 212 | run_inference_container_jc.sh 213 | sample_analysis.ipynb 214 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HoVer-NeXt Inference 2 | HoVer-NeXt is a fast and efficient nuclei segmentation and classification pipeline. 3 | 4 | Supported are a variety of data formats, including all OpenSlide supported datatypes, `.npy` numpy array dumps, and common image formats such as JPEG and PNG. 5 | If you are having trouble with using this repository, please create an issue and we will be happy to help! 6 | 7 | For training code, please check the [hover-next training repository](https://github.com/digitalpathologybern/hover_next_train) 8 | 9 | Find the Publication here: [https://openreview.net/pdf?id=3vmB43oqIO](https://openreview.net/pdf?id=3vmB43oqIO) 10 | 11 | ## Setup 12 | 13 | Environments for train and inference are the same so if you already have set the environment up for training, you can use it for inference as well. 14 | 15 | Otherwise: 16 | 17 | ```bash 18 | conda env create -f environment.yml 19 | conda activate hovernext 20 | pip install torch==2.1.1 torchvision==0.16.1 --index-url https://download.pytorch.org/whl/cu118 21 | ``` 22 | 23 | or use predefined [docker/singularity container](#docker-and-apptainersingularity-container) 24 | 25 | ## Model Weights 26 | 27 | Weights are hosted on [Zenodo](https://zenodo.org/records/10635618) 28 | By specifying one of the ID's listed, weights are **automatically** downloaded and loaded. 29 | 30 | | Dataset | ID | Weights | 31 | |--------------|--------|-----| 32 | | Lizard-Mitosis | "lizard_convnextv2_large" | [Large](https://zenodo.org/records/10635618/files/lizard_convnextv2_large.zip?download=1) | 33 | | | "lizard_convnextv2_base" |[Base](https://zenodo.org/records/10635618/files/lizard_convnextv2_base.zip?download=1) | 34 | | | "lizard_convnextv2_tiny" |[Tiny](https://zenodo.org/records/10635618/files/lizard_convnextv2_tiny.zip?download=1) | 35 | | PanNuke | "pannuke_convnextv2_tiny_1" | [Tiny Fold 1](https://zenodo.org/records/10635618/files/pannuke_convnextv2_tiny_1.zip?download=1) | 36 | | | "pannuke_convnextv2_tiny_2" | [Tiny Fold 2](https://zenodo.org/records/10635618/files/pannuke_convnextv2_tiny_2.zip?download=1) | 37 | | | "pannuke_convnextv2_tiny_3" | [Tiny Fold 3](https://zenodo.org/records/10635618/files/pannuke_convnextv2_tiny_3.zip?download=1) | 38 | 39 | If you are manually downloading weights, unzip them in the directory, such that the folder (e.g. ```lizard_convnextv2_large```) sits in the same directory as ```main.py```. 40 | 41 | ## WSI Inference 42 | 43 | This pipeline uses OpenSlide to read images, and therefore supports all formats which are supported by OpenSlide. 44 | If you want to run this pipeline on custom ome.tif files, ensure that the necessary metadata such as resolution, downsampling and dimensions are available. 45 | Additionally, czi is is supported via pylibCZIrw. 46 | Before running a slide, choose [appropriate parameters for your machine](#optimizing-inference-for-your-machine) 47 | 48 | To run a single slide: 49 | 50 | ```bash 51 | python3 main.py \ 52 | --input "/path-to-wsi/wsi.svs" \ 53 | --output_root "results/" \ 54 | --cp "lizard_convnextv2_large" \ 55 | --tta 4 \ 56 | --inf_workers 16 \ 57 | --pp_tiling 10 \ 58 | --pp_workers 16 59 | ``` 60 | 61 | To run multiple slides, specify a glob pattern such as `"/path-to-folder/*.mrxs"` or provide a list of paths as a `.txt` file. 62 | 63 | ### Slurm 64 | 65 | if you are running on a slurm cluster you might consider separating pre and post-processing to improve GPU utilization. 66 | Use the `--only_inference` parameter and submit another job on with the same parameters, but removing the `--only_inference`. 67 | 68 | ## NPY / Image inference 69 | 70 | NPY and image inference works the same as WSI inference, however output files are only a ZARR array. 71 | 72 | ```bash 73 | python3 main.py \ 74 | --input "/path-to-file/file.npy" \ 75 | --output_root "/results/" \ 76 | --cp "lizard_convnextv2_large" \ 77 | --tta 4 \ 78 | --inf_workers 16 \ 79 | --pp_tiling 10 \ 80 | --pp_workers 16 81 | ``` 82 | 83 | Support for other datatypes are easy to implement. Check the NPYDataloader for reference. 84 | 85 | ## Optimizing inference for your machine: 86 | 87 | 1. WSI is on the machine or on a fast access network location 88 | 2. If you have multiple machines, e.g. CPU-only machines, you can move post-processing to that machine 89 | 3. '--tta 4' yields robust results with very high speed 90 | 4. '--inf_workers' should be set to the number of available cores 91 | 5. '--pp_workers' should be set to number of available cores -1, with '--pp_tiling' set to a low number where the machine does not run OOM. E.g. on a 16-Core machine, '--pp_workers 16 --pp_tiling 8 is good. If you are running out of memory, increase --pp_tiling. 92 | 93 | ## Using the output files for downstream analysis: 94 | 95 | By default, the pipeline produces an instance-map, a class-lookup with centroids and a number of .tsv files to load in QuPath. 96 | sample_analysis.ipynb shows exemplarily how to use the files. 97 | 98 | ## Docker and Apptainer/Singularity Container: 99 | 100 | Download the singularity image from [Zenodo](https://zenodo.org/records/10649470/files/hover_next.sif) 101 | 102 | ```bash 103 | # don't forget to mount your local directory 104 | export APPTAINER_BINDPATH="/storage" 105 | apptainer exec --nv /path-to-container/hover_next.sif \ 106 | python3 /path-to-repo/main.py \ 107 | --input "/path-to-wsi/*.svs" \ 108 | --output_root "results/" \ 109 | --cp "lizard_convnextv2_large" \ 110 | --tta 4 111 | ``` 112 | # License 113 | 114 | This repository is licensed under GNU General Public License v3.0 (See License Info). 115 | If you are intending to use this repository for commercial usecases, please check the licenses of all python packages referenced in the Setup section / described in the requirements.txt and environment.yml. 116 | 117 | # Citation 118 | 119 | If you are using this code, please cite: 120 | ``` 121 | @inproceedings{baumann2024hover, 122 | title={HoVer-NeXt: A Fast Nuclei Segmentation and Classification Pipeline for Next Generation Histopathology}, 123 | author={Baumann, Elias and Dislich, Bastian and Rumberger, Josef Lorenz and Nagtegaal, Iris D and Martinez, Maria Rodriguez and Zlobec, Inti}, 124 | booktitle={Medical Imaging with Deep Learning}, 125 | year={2024} 126 | } 127 | ``` 128 | and 129 | ``` 130 | @INPROCEEDINGS{rumberger2022panoptic, 131 | author={Rumberger, Josef Lorenz and Baumann, Elias and Hirsch, Peter and Janowczyk, Andrew and Zlobec, Inti and Kainmueller, Dagmar}, 132 | booktitle={2022 IEEE International Symposium on Biomedical Imaging Challenges (ISBIC)}, 133 | title={Panoptic segmentation with highly imbalanced semantic labels}, 134 | year={2022}, 135 | pages={1-4}, 136 | doi={10.1109/ISBIC56247.2022.9854551}} 137 | ``` -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: hovernext 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - python=3.11.5 7 | - openslide 8 | - pip: 9 | - -r file:requirements.txt -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import sys 4 | from timeit import default_timer as timer 5 | from datetime import timedelta 6 | import torch 7 | from glob import glob 8 | from src.inference import inference_main, get_inference_setup 9 | from src.post_process import post_process_main 10 | from src.data_utils import copy_img 11 | 12 | torch.backends.cudnn.benchmark = True 13 | print(torch.cuda.device_count(), " cuda devices") 14 | 15 | 16 | def prepare_input(params): 17 | """ 18 | Check if input is a text file, glob pattern, or a directory, and return a list of input files 19 | 20 | Parameters 21 | ---------- 22 | params: dict 23 | input parameters from argparse 24 | 25 | """ 26 | print("input specified: ", params["input"]) 27 | if params["input"].endswith(".txt"): 28 | if os.path.exists(params["input"]): 29 | with open(params["input"], "r") as f: 30 | input_list = f.read().splitlines() 31 | else: 32 | raise FileNotFoundError("input file not found") 33 | else: 34 | input_list = sorted(glob(params["input"].rstrip())) 35 | return input_list 36 | 37 | 38 | def get_input_type(params): 39 | """ 40 | Check if input is an image, numpy array, or whole slide image, and return the input type 41 | If you are trying to process other images that are supported by opencv (e.g. tiff), you can add the extension to the list 42 | 43 | Parameters 44 | ---------- 45 | params: dict 46 | input parameters from argparse 47 | """ 48 | params["ext"] = os.path.splitext(params["p"])[-1] 49 | if params["ext"] == ".npy": 50 | params["input_type"] = "npy" 51 | elif params["ext"] in [".jpg", ".png", ".jpeg", ".bmp"]: 52 | params["input_type"] = "img" 53 | else: 54 | params["input_type"] = "wsi" 55 | return params 56 | 57 | 58 | def main(params: dict): 59 | """ 60 | Start nuclei segmentation and classification pipeline using specified parameters from argparse 61 | 62 | Parameters 63 | ---------- 64 | params: dict 65 | input parameters from argparse 66 | """ 67 | 68 | if params["metric"] not in ["mpq", "f1", "pannuke"]: 69 | params["metric"] = "f1" 70 | print("invalid metric, falling back to f1") 71 | else: 72 | print("optimizing postprocessing for: ", params["metric"]) 73 | 74 | params["root"] = os.path.dirname(__file__) 75 | params["data_dirs"] = [ 76 | os.path.join(params["root"], c) for c in params["cp"].split(",") 77 | ] 78 | 79 | print("saving results to:", params["output_root"]) 80 | print("loading model from:", params["data_dirs"]) 81 | 82 | # Run per tile inference and store results 83 | params, models, augmenter, color_aug_fn = get_inference_setup(params) 84 | 85 | input_list = prepare_input(params) 86 | print("Running inference on", len(input_list), "file(s)") 87 | 88 | for inp in input_list: 89 | start_time = timer() 90 | params["p"] = inp.rstrip() 91 | params = get_input_type(params) 92 | print("Processing ", params["p"]) 93 | if params["cache"] is not None: 94 | print("Caching input at:") 95 | params["p"] = copy_img(params["p"], params["cache"]) 96 | print(params["p"]) 97 | 98 | params, z = inference_main(params, models, augmenter, color_aug_fn) 99 | print( 100 | "::: finished or skipped inference after", 101 | timedelta(seconds=timer() - start_time), 102 | ) 103 | process_timer = timer() 104 | if params["only_inference"]: 105 | try: 106 | z[0].store.close() 107 | z[1].store.close() 108 | except TypeError: 109 | # if z is None, z cannot be indexed -> throws a TypeError 110 | pass 111 | print("Exiting after inference") 112 | sys.exit(2) 113 | # Stitch tiles together and postprocess to get instance segmentation 114 | if not os.path.exists(os.path.join(params["output_dir"], "pinst_pp.zip")): 115 | print("running post-processing") 116 | 117 | z_pp = post_process_main( 118 | params, 119 | z, 120 | ) 121 | if not params["keep_raw"]: 122 | try: 123 | os.remove(params["model_out_p"] + "_inst.zip") 124 | os.remove(params["model_out_p"] + "_cls.zip") 125 | except FileNotFoundError: 126 | pass 127 | else: 128 | z_pp = None 129 | print( 130 | "::: postprocessing took", 131 | timedelta(seconds=timer() - process_timer), 132 | "total elapsed time", 133 | timedelta(seconds=timer() - start_time), 134 | ) 135 | if z_pp is not None: 136 | z_pp.store.close() 137 | print("done") 138 | sys.exit(0) 139 | 140 | 141 | if __name__ == "__main__": 142 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 143 | print(device) 144 | 145 | parser = argparse.ArgumentParser() 146 | parser.add_argument( 147 | "--input", 148 | type=str, 149 | default=None, 150 | help="path to wsi, glob pattern or text file containing paths", 151 | required=True, 152 | ) 153 | parser.add_argument( 154 | "--output_root", type=str, default=None, help="output directory", required=True 155 | ) 156 | parser.add_argument( 157 | "--cp", 158 | type=str, 159 | default=None, 160 | help="comma separated list of checkpoint folders to consider", 161 | ) 162 | parser.add_argument( 163 | "--only_inference", 164 | action="store_true", 165 | help="split inference to gpu and cpu node/ only run inference", 166 | ) 167 | parser.add_argument( 168 | "--metric", type=str, default="f1", help="metric to optimize for pp" 169 | ) 170 | parser.add_argument("--batch_size", type=int, default=64, help="batch size") 171 | parser.add_argument( 172 | "--tta", 173 | type=int, 174 | default=4, 175 | help="test time augmentations, number of views (4= results from 4 different augmentations are averaged for each sample)", 176 | ) 177 | parser.add_argument( 178 | "--save_polygon", 179 | action="store_true", 180 | help="save output as polygons to load in qupath", 181 | ) 182 | parser.add_argument( 183 | "--tile_size", 184 | type=int, 185 | default=256, 186 | help="tile size, models are trained on 256x256", 187 | ) 188 | parser.add_argument( 189 | "--overlap", 190 | type=float, 191 | default=0.96875, 192 | help="overlap between tiles, at 0.5mpp, 0.96875 is best, for 0.25mpp use 0.9375 for better results", 193 | ) 194 | parser.add_argument( 195 | "--inf_workers", 196 | type=int, 197 | default=4, 198 | help="number of workers for inference dataloader, maximally set this to number of cores", 199 | ) 200 | parser.add_argument( 201 | "--inf_writers", 202 | type=int, 203 | default=2, 204 | help="number of writers for inference dataloader, default 2 should be sufficient" 205 | + ", \ tune based on core availability and delay between final inference step and inference finalization", 206 | ) 207 | parser.add_argument( 208 | "--pp_tiling", 209 | type=int, 210 | default=8, 211 | help="tiling factor for post processing, number of tiles per dimension, 8 = 64 tiles", 212 | ) 213 | parser.add_argument( 214 | "--pp_overlap", 215 | type=int, 216 | default=256, 217 | help="overlap for postprocessing tiles, put to around tile_size", 218 | ) 219 | parser.add_argument( 220 | "--pp_workers", 221 | type=int, 222 | default=16, 223 | help="number of workers for postprocessing, maximally set this to number of cores", 224 | ) 225 | parser.add_argument( 226 | "--keep_raw", 227 | action="store_true", 228 | help="keep raw predictions (can be large files for particularly for pannuke)", 229 | ) 230 | parser.add_argument("--cache", type=str, default=None, help="cache path") 231 | params = vars(parser.parse_args()) 232 | main(params) 233 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | openslide-python 2 | scikit-learn 3 | scikit-image 4 | scipy 5 | opencv-python 6 | pandas 7 | tqdm 8 | itk 9 | matplotlib 10 | mahotas 11 | pandas 12 | jupyterlab 13 | zarr 14 | tifffile 15 | h5py 16 | segmentation-models-pytorch 17 | networkx==2.8.7 18 | libpysal 19 | Pillow 20 | shapely 21 | staintools 22 | albumentations 23 | spams-bin 24 | toml 25 | numcodecs 26 | imagecodecs 27 | timm==0.9.6 28 | geojson 29 | pylibCZIrw -------------------------------------------------------------------------------- /sample_analysis.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import zarr\n", 10 | "import numpy as np\n", 11 | "import json\n", 12 | "from main import main\n", 13 | "\n", 14 | "'''\n", 15 | "Run inference on a small sample image from TCGA\n", 16 | "'''\n", 17 | "params = {\n", 18 | "\n", 19 | "}\n", 20 | "\n", 21 | "main(params)" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "'''\n", 31 | "Instance map: 2D full-size matrix where each pixels value corresponds to the associated instance (value>0) or background (value=0)\n", 32 | "'''\n", 33 | "\n", 34 | "# open: file-like interaction with zarr-array\n", 35 | "instance_map = zarr.open(\"pinst_pp.zip\", mode=\"r\")\n", 36 | "# selecting a ROI will yield a numpy array\n", 37 | "roi = instance_map[10000:20000,10000:20000]\n", 38 | "# or with [:] to load the entire array\n", 39 | "full_instance_map = instance_map[:]\n", 40 | "# alternatively, use load, which will directly create a numpy array:\n", 41 | "full_instance_map = zarr.load(\"pinst_pp.zip\") \n", 42 | "\n", 43 | "'''\n", 44 | "Class dictionary: Lookup for the instance map, also contains centroid coordinates. If only centroid coordinates are of interest, you can skip loading the instance map.\n", 45 | "'''\n", 46 | "\n", 47 | "# load the dictionary\n", 48 | "with open(\"class_inst.json\",\"r\") as f:\n", 49 | " class_info = json.load(f)\n", 50 | "# create a centroid info array\n", 51 | "centroid_array = np.array([[int(k),v[0],*v[1]] for k,v in class_info.items()])\n", 52 | "# [instance_id, class_id, y, x]\n", 53 | "\n", 54 | "# or alternatively create a lookup for the instance map to get a corresponding class map\n", 55 | "pcls_list = np.array([0] + [v[0] for v in class_info.values()])\n", 56 | "pcls_keys = np.array([\"0\"] + list(class_info.keys())).astype(int)\n", 57 | "lookup = np.zeros(pcls_keys.max() + 1,dtype=np.uint8)\n", 58 | "lookup[pcls_keys] = pcls_list\n", 59 | "cls_map = lookup[full_instance_map]" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [] 68 | } 69 | ], 70 | "metadata": { 71 | "language_info": { 72 | "name": "python" 73 | } 74 | }, 75 | "nbformat": 4, 76 | "nbformat_minor": 2 77 | } 78 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | # Read the content of README file 4 | with open('README.md', encoding='utf-8') as f: 5 | long_description = f.read() 6 | 7 | setup( 8 | name="hover_next_inference", 9 | version="0.1", 10 | packages=find_packages(), 11 | long_description=long_description, 12 | long_description_content_type='text/markdown', 13 | ) 14 | -------------------------------------------------------------------------------- /src/augmentations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torchvision.transforms.transforms import ColorJitter, RandomApply, GaussianBlur 4 | 5 | 6 | rgb_from_hed = np.array( 7 | [[0.65, 0.70, 0.29], [0.07, 0.99, 0.11], [0.27, 0.57, 0.78]], dtype=np.float32 8 | ) 9 | hed_from_rgb = np.linalg.inv(rgb_from_hed) 10 | 11 | 12 | def torch_rgb2hed(img: torch.Tensor, hed_t: torch.Tensor, e: float): 13 | """ 14 | convert rgb torch tensor to hed torch tensor (adopted from skimage) 15 | 16 | Parameters 17 | ---------- 18 | img : torch.Tensor 19 | rgb image tensor (B, C, H, W) or (C, H, W) 20 | hed_t : torch.Tensor 21 | hed transform tensor (3, 3) 22 | e : float 23 | epsilon 24 | 25 | Returns 26 | ------- 27 | torch.Tensor 28 | hed image tensor (B, C, H, W) or (C, H, W) 29 | """ 30 | img = img.movedim(-3, -1) 31 | 32 | img = torch.clamp(img, min=e) 33 | img = torch.log(img) / torch.log(e) 34 | img = torch.matmul(img, hed_t) 35 | return img.movedim(-1, -3) 36 | 37 | 38 | def torch_hed2rgb(img: torch.Tensor, rgb_t: torch.Tensor, e: float): 39 | """ 40 | convert rgb torch tensor to hed torch tensor (adopted from skimage) 41 | 42 | Parameters 43 | ---------- 44 | img : torch.Tensor 45 | hed image tensor (B, C, H, W) or (C, H, W) 46 | hed_t : torch.Tensor 47 | hed inverse transform tensor (3, 3) 48 | e : float 49 | epsilon 50 | 51 | Returns 52 | ------- 53 | torch.Tensor 54 | RGB image tensor (B, C, H, W) or (C, H, W) 55 | """ 56 | e = -torch.log(e) 57 | img = img.movedim(-3, -1) 58 | img = torch.matmul(-(img * e), rgb_t) 59 | img = torch.exp(img) 60 | img = torch.clamp(img, 0, 1) 61 | return img.movedim(-1, -3) 62 | 63 | 64 | class Hed2Rgb(torch.nn.Module): 65 | """ 66 | Pytorch module to convert hed image tensors to rgb 67 | """ 68 | 69 | def __init__(self, rank): 70 | super().__init__() 71 | self.e = torch.tensor(1e-6).to(rank) 72 | self.rgb_t = torch.from_numpy(rgb_from_hed).to(rank) 73 | self.rank = rank 74 | 75 | def forward(self, img): 76 | return torch_hed2rgb(img, self.rgb_t, self.e) 77 | 78 | 79 | class Rgb2Hed(torch.nn.Module): 80 | """ 81 | Pytorch module to convert rgb image tensors to hed 82 | """ 83 | 84 | def __init__(self, rank): 85 | super().__init__() 86 | self.e = torch.tensor(1e-6).to(rank) 87 | self.hed_t = torch.from_numpy(hed_from_rgb).to(rank) 88 | self.rank = rank 89 | 90 | def forward(self, img): 91 | return torch_rgb2hed(img, self.hed_t, self.e) 92 | 93 | 94 | class HedNormalizeTorch(torch.nn.Module): 95 | """ 96 | Pytorch augmentation module to apply HED stain augmentation 97 | 98 | Parameters 99 | ---------- 100 | sigma : float 101 | sigma for linear scaling of HED channels 102 | bias : float 103 | bias for additive scaling of HED channels 104 | """ 105 | 106 | def __init__(self, sigma, bias, rank, *args, **kwargs) -> None: 107 | super().__init__(*args, **kwargs) 108 | self.sigma = sigma 109 | self.bias = bias 110 | self.rank = rank 111 | self.rgb2hed = Rgb2Hed(rank=rank) 112 | self.hed2rgb = Hed2Rgb(rank=rank) 113 | 114 | def rng(self, val, batch_size): 115 | return torch.empty(batch_size, 3).uniform_(-val, val).to(self.rank) 116 | 117 | def color_norm_hed(self, img): 118 | B = img.shape[0] 119 | sigmas = self.rng(self.sigma, B) 120 | biases = self.rng(self.bias, B) 121 | return (img * (1 + sigmas.view(*sigmas.shape, 1, 1))) + biases.view( 122 | *biases.shape, 1, 1 123 | ) 124 | 125 | def forward(self, img): 126 | if img.dim() == 3: 127 | img = img.view(1, *img.shape) 128 | hed = self.rgb2hed(img) 129 | hed = self.color_norm_hed(hed) 130 | return self.hed2rgb(hed) 131 | 132 | 133 | class GaussianNoise(torch.nn.Module): 134 | """ 135 | Pytorch augmentation module to apply gaussian noise 136 | 137 | Parameters 138 | ---------- 139 | sigma : float 140 | sigma for uniform distribution to sample from 141 | rank : str or int or torch.device 142 | device to put the module to 143 | """ 144 | 145 | def __init__(self, sigma, rank): 146 | super().__init__() 147 | self.sigma = sigma 148 | self.rank = rank 149 | 150 | def forward(self, img): 151 | noise = torch.empty(img.shape).uniform_(-self.sigma, self.sigma).to(self.rank) 152 | return img + noise 153 | 154 | 155 | def color_augmentations(train, sigma=0.05, bias=0.03, s=0.2, rank=0): 156 | """ 157 | Color augmentation function (in theory can set to train to have more variance 158 | with high test time augmentations) 159 | 160 | Parameters 161 | ---------- 162 | train : bool 163 | during training, the model uses more augmentation than during inference, 164 | set to true for more variance in colors 165 | sigma: float 166 | parameter for hed augmentation 167 | bias: float 168 | parameter for hed augmentation 169 | s: float 170 | parameter for color jitter 171 | rank: int or torch.device or str 172 | device to use for augmentation 173 | 174 | Returns 175 | ------- 176 | torch.nn.Sequential 177 | sequential augmentation module 178 | """ 179 | if train: 180 | color_jitter = ColorJitter( 181 | 0.8 * s, 0.0 * s, 0.8 * s, 0.2 * s 182 | ) # brightness, contrast, saturation, hue 183 | 184 | data_transforms = torch.nn.Sequential( 185 | RandomApply([HedNormalizeTorch(sigma, bias, rank=rank)], p=0.75), 186 | RandomApply([color_jitter], p=0.3), 187 | RandomApply([GaussianNoise(0.02, rank)], p=0.3), 188 | RandomApply([GaussianBlur(kernel_size=15, sigma=(0.1, 0.1))], p=0.3), 189 | ) 190 | else: 191 | data_transforms = torch.nn.Sequential(HedNormalizeTorch(sigma, bias, rank=rank)) 192 | return data_transforms 193 | -------------------------------------------------------------------------------- /src/constants.py: -------------------------------------------------------------------------------- 1 | ### Size thresholds for nuclei (in pixels), pannuke is less conservative 2 | # These have been optimized for the conic challenge, but can be changed 3 | # to get more small nuclei (e.g. by setting all min_threshs to 0) 4 | MIN_THRESHS_LIZARD = [30, 30, 20, 20, 30, 30, 15] 5 | MAX_THRESHS_LIZARD = [5000, 5000, 5000, 5000, 5000, 5000, 5000] 6 | MIN_THRESHS_PANNUKE = [10, 10, 10, 10, 10] 7 | MAX_THRESHS_PANNUKE = [20000, 20000, 20000, 3000, 10000] 8 | 9 | # Maximal size of holes to remove from a nucleus 10 | MAX_HOLE_SIZE = 128 11 | 12 | # Colors for geojson output 13 | COLORS_LIZARD = [ 14 | [0, 255, 0], # neu 15 | [255, 0, 0], # epi 16 | [0, 0, 255], # lym 17 | [0, 128, 0], # pla 18 | [0, 255, 255], # eos 19 | [255, 179, 102], # con 20 | [255, 0, 255], # mitosis 21 | ] 22 | 23 | COLORS_PANNUKE = [ 24 | [255, 0, 0], # neo 25 | [0, 127, 255], # inf 26 | [255, 179, 102], # con 27 | [0, 0, 0], # dead 28 | [0, 255, 0], # epi 29 | ] 30 | 31 | # text labels for lizard 32 | CLASS_LABELS_LIZARD = { 33 | "neutrophil": 1, 34 | "epithelial-cell": 2, 35 | "lymphocyte": 3, 36 | "plasma-cell": 4, 37 | "eosinophil": 5, 38 | "connective-tissue-cell": 6, 39 | "mitosis": 7, 40 | } 41 | 42 | # text labels for pannuke 43 | CLASS_LABELS_PANNUKE = { 44 | "neoplastic": 1, 45 | "inflammatory": 2, 46 | "connective": 3, 47 | "dead": 4, 48 | "epithelial": 5, 49 | } 50 | 51 | # magnifiation and resolutions for WSI dataloader 52 | LUT_MAGNIFICATION_X = [10, 20, 40, 80] 53 | LUT_MAGNIFICATION_MPP = [0.97, 0.485, 0.2425, 0.124] 54 | 55 | CONIC_MPP = 0.5 56 | PANNUKE_MPP = 0.25 57 | 58 | # parameters for test time augmentations, do not change 59 | TTA_AUG_PARAMS = { 60 | "mirror": {"prob_x": 0.5, "prob_y": 0.5, "prob": 0.75}, 61 | "translate": {"max_percent": 0.03, "prob": 0.0}, 62 | "scale": {"min": 0.8, "max": 1.2, "prob": 0.0}, 63 | "zoom": {"min": 0.8, "max": 1.2, "prob": 0.0}, 64 | "rotate": {"rot90": True, "prob": 0.75}, 65 | "shear": {"max_percent": 0.1, "prob": 0.0}, 66 | "elastic": {"alpha": [120, 120], "sigma": 8, "prob": 0.0}, 67 | } 68 | 69 | # current valid pre-trained weights to be automatically downloaded and used in HoVer-NeXt 70 | VALID_WEIGHTS = [ 71 | "lizard_convnextv2_large", 72 | "lizard_convnextv2_base", 73 | "lizard_convnextv2_tiny", 74 | "pannuke_convnextv2_tiny_1", 75 | "pannuke_convnextv2_tiny_2", 76 | "pannuke_convnextv2_tiny_3", 77 | ] -------------------------------------------------------------------------------- /src/data_utils.py: -------------------------------------------------------------------------------- 1 | import openslide 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import Dataset 5 | from typing import Optional, List, Tuple, Callable 6 | from skimage.morphology import remove_small_objects, disk, dilation 7 | import PIL 8 | import pathlib 9 | import cv2 10 | from src.constants import LUT_MAGNIFICATION_MPP, LUT_MAGNIFICATION_X 11 | from shutil import copy2, copytree 12 | import os 13 | from pylibCZIrw import czi as pyczi 14 | 15 | 16 | def copy_img(im_path, cache_dir): 17 | """ 18 | Helper function to copy WSI to cache directory 19 | 20 | Parameters 21 | ---------- 22 | im_path : str 23 | path to the WSI 24 | cache_dir : str 25 | path to the cache directory 26 | 27 | Returns 28 | ------- 29 | str 30 | path to the copied WSI 31 | """ 32 | file, ext = os.path.splitext(im_path) 33 | if ext == ".mrxs": 34 | copy2(im_path, cache_dir) 35 | copytree( 36 | file, os.path.join(cache_dir, os.path.split(file)[-1]), dirs_exist_ok=True 37 | ) 38 | else: 39 | copy2(im_path, cache_dir) 40 | return os.path.join(cache_dir, os.path.split(im_path)[-1]) 41 | 42 | 43 | def normalize_min_max(x: np.ndarray, mi, ma, clip=False, eps=1e-20, dtype=np.float32): 44 | """ 45 | Min max scaling for input array 46 | 47 | Parameters 48 | ---------- 49 | x : np.ndarray 50 | input array 51 | mi : float or int 52 | minimum value 53 | ma : float or int 54 | maximum value 55 | clip : bool, optional 56 | clip values be between 0 and 1, False by default 57 | eps : float 58 | epsilon value to avoid division by zero 59 | dtype : type 60 | data type of the output array 61 | 62 | Returns 63 | ------- 64 | np.ndarray 65 | normalized array 66 | """ 67 | if mi is None: 68 | mi = np.min(x) 69 | if ma is None: 70 | ma = np.max(x) 71 | if dtype is not None: 72 | x = x.astype(dtype, copy=False) 73 | mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype, copy=False) 74 | ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype, copy=False) 75 | eps = dtype(eps) 76 | 77 | x = (x - mi) / (ma - mi + eps) 78 | 79 | if clip: 80 | x = np.clip(x, 0, 1) 81 | return x 82 | 83 | 84 | def center_crop(t, croph, cropw): 85 | """ 86 | Center crop input tensor in last two axes to height and width 87 | """ 88 | h, w = t.shape[-2:] 89 | startw = w // 2 - (cropw // 2) 90 | starth = h // 2 - (croph // 2) 91 | return t[..., starth : starth + croph, startw : startw + cropw] 92 | 93 | 94 | class czi_wrapper: 95 | def __init__(self, path, levels=11, sharpen_img=True): 96 | """ 97 | Wrapper to load czi files without openslide, but with the same endpoints 98 | 99 | Parameters 100 | ------------ 101 | path: str 102 | Path to the wsi (.czi) 103 | levels: int, optional 104 | number of artificially created levels 105 | sharpen_img: bool, optional 106 | whether to sharpen the image (cohort dependent) 107 | 108 | Examples 109 | ----------- 110 | Use as a replacement for openslide.open_slide: 111 | >>> sl = czi_wrapper(path) 112 | >>> sl.read_region(...) 113 | """ 114 | self.path = path 115 | self.levels = levels 116 | self.sharpen_img = sharpen_img 117 | self.level_dimensions = None 118 | self.level_downsamples = None 119 | self.properties = {} 120 | self.associated_images = {} 121 | try: 122 | self._generate_dictionaries() 123 | except: 124 | raise RuntimeError(f"issue with {self.path}") 125 | 126 | @staticmethod 127 | def _convert_rect_to_tuple(rect): 128 | return rect.x, rect.y, rect.w, rect.h 129 | 130 | @staticmethod 131 | def _sharpen(img_o): 132 | img_b = cv2.GaussianBlur(img_o, ksize=[3, 3], sigmaX=1, sigmaY=1) 133 | img_s = cv2.addWeighted(img_o, 3.0, img_b, -2.0, 0) 134 | return img_s 135 | 136 | def _generate_dictionaries(self): 137 | with pyczi.open_czi(self.path) as sl: 138 | total_bounding_rectangle = sl.total_bounding_rectangle 139 | meta = sl.metadata["ImageDocument"]["Metadata"] 140 | 141 | self.associated_images["thumbnail"] = PIL.Image.fromarray( 142 | cv2.cvtColor(sl.read(zoom=0.005), cv2.COLOR_BGR2RGB) 143 | ) 144 | 145 | x, y, w, h = self._convert_rect_to_tuple(total_bounding_rectangle) 146 | self.level_dimensions = tuple( 147 | (int(w / (2**i)), int(h / (2.0**i))) for i in range(self.levels) 148 | ) 149 | self.level_downsamples = tuple(2.0**i for i in range(self.levels)) 150 | mpp = { 151 | m["@Id"]: float(m["Value"]) * 1e6 152 | for m in meta["Scaling"]["Items"]["Distance"] 153 | } 154 | self.properties["openslide.mpp-x"] = mpp["X"] 155 | self.properties["openslide.mpp-y"] = mpp["Y"] 156 | self.tx = x 157 | self.ty = y 158 | 159 | def read_region(self, crds, level, size): 160 | with pyczi.open_czi(self.path) as sl: 161 | img = sl.read( 162 | # plane={"T": 0, "Z": 0, "C": 0}, 163 | zoom=1.0 / (2**level), 164 | roi=( 165 | self.tx + crds[0], 166 | self.ty + crds[1], 167 | size[0] * (2**level), 168 | size[1] * (2**level), 169 | ), 170 | ) 171 | 172 | if self.sharpen_img: 173 | img = self._sharpen(img) 174 | return cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 175 | 176 | 177 | # Adapted from https://github.com/christianabbet/SRA 178 | # Original Author: Christian Abbet 179 | class WholeSlideDataset(Dataset): 180 | def __init__( 181 | self, 182 | path: str, 183 | crop_sizes_px: Optional[List[int]] = None, 184 | crop_magnifications: Optional[List[float]] = None, 185 | transform: Optional[Callable] = None, 186 | padding_factor: Optional[float] = 0.5, 187 | remove_background: Optional[bool] = True, 188 | remove_oob: Optional[bool] = True, 189 | remove_alpha: Optional[bool] = True, 190 | ratio_object_thresh: Optional[float] = 1e-3, 191 | ) -> None: 192 | """ 193 | Load a crop as a dataset format. The object is iterable. 194 | Parameters 195 | ---------- 196 | path: str 197 | Path to the whole slide as a "*.tif, *.svs, *.mrxs format" 198 | crop_sizes_px: list of int, optional 199 | List of crops output size in pixel, default value is [224]. 200 | crop_magnifications: list of float, optional 201 | List of crops magnification level, default value is [20]. 202 | transform: callable, optional 203 | Transformation to apply to crops, default value is None. So far, only one augmentation for all crops 204 | is possible. 205 | padding_factor: float, optional 206 | Padding value when creating reference grid. Distance between two consecutive crops as a proportion of the 207 | first listed crop size. Default value is 0.5. 208 | remove_background: bool, optional 209 | Remove background crops if their average intensity value is below the threshold value (240). Default value 210 | is True. 211 | remove_oob: bool, optional 212 | Remove all crops where its representation at a specific magnification is out of bound (out of the scanned 213 | image). Default value is True. 214 | remove_alpha: bool, optional 215 | Remove alpha channel when extracting patches to create a RGB image (instead of RGBA). More suitable to ML 216 | input transforms. Default value is True. 217 | ratio_object_thresh: float, optional 218 | Size of the object ot remove. THe value isexpressed as a ratio with respect to the area of the whole slide. 219 | Default value is 1e-3 (e.i., 1%). 220 | Raises 221 | ------ 222 | WholeSlideError 223 | If it is not possible to load the WSIs. 224 | Examples 225 | -------- 226 | Load a slide at 40x with a crop size of 256px: 227 | >>> wsi = WholeSlideDataset( 228 | path="/path/to/slide/.mrxs", 229 | crop_sizes_px=[256], 230 | crop_magnifications=[40.], 231 | ) 232 | """ 233 | 234 | extension = pathlib.Path(path).suffix 235 | if ( 236 | extension != ".svs" 237 | and extension != ".mrxs" 238 | and extension != ".tif" 239 | and extension != ".czi" 240 | ): 241 | raise NotImplementedError( 242 | "Only *.svs, *.tif, *.czi, and *.mrxs files supported" 243 | ) 244 | 245 | # Load and create slide and affect default values 246 | self.path = path 247 | self.s = ( 248 | openslide.open_slide(self.path) 249 | if extension != ".czi" 250 | else czi_wrapper(self.path) 251 | ) 252 | self.crop_sizes_px = crop_sizes_px 253 | self.crop_magnifications = crop_magnifications 254 | self.transform = transform 255 | self.padding_factor = padding_factor 256 | self.remove_alpha = remove_alpha 257 | self.mask = None 258 | 259 | if self.crop_sizes_px is None: 260 | self.crop_sizes_px = [224] 261 | 262 | if self.crop_magnifications is None: 263 | self.crop_magnifications = [20] 264 | 265 | # Dimension of the slide at different levels 266 | self.level_dimensions = self.s.level_dimensions 267 | # Down sampling factor at each level 268 | self.level_downsamples = self.s.level_downsamples 269 | # Get average micro meter per pixel (MPP) for the slide 270 | try: 271 | self.mpp = 0.5 * ( 272 | float(self.s.properties[openslide.PROPERTY_NAME_MPP_X]) 273 | + float(self.s.properties[openslide.PROPERTY_NAME_MPP_Y]) 274 | ) 275 | except KeyError: 276 | print("'No resolution found in WSI metadata, using default .2425") 277 | self.mpp = 0.2425 278 | # raise IndexError('No resolution found in WSI metadata. Impossible to build pyramid.') 279 | 280 | # Extract level magnifications 281 | self.level_magnifications = self._get_magnifications( 282 | self.mpp, self.level_downsamples 283 | ) 284 | # Consider reference level as the level with highest resolution 285 | self.crop_reference_level = 0 286 | 287 | # Build reference grid / crop centers 288 | self.crop_reference_cxy = self._build_reference_grid( 289 | crop_size_px=self.crop_sizes_px[0], 290 | crop_magnification=self.crop_magnifications[0], 291 | padding_factor=padding_factor, 292 | level_magnification=self.level_magnifications[self.crop_reference_level], 293 | level_shape=self.level_dimensions[self.crop_reference_level], 294 | ) 295 | 296 | # Assume the whole slide has an associated image 297 | if remove_background and "thumbnail" in self.s.associated_images: 298 | # Extract image thumbnail from slide metadata 299 | img_thumb = self.s.associated_images["thumbnail"] 300 | # Get scale factor compared to reference size 301 | mx = img_thumb.size[0] / self.level_dimensions[self.crop_reference_level][0] 302 | my = img_thumb.size[1] / self.level_dimensions[self.crop_reference_level][1] 303 | # Compute foreground mask 304 | self.mask = self._foreground_mask( 305 | img_thumb, ratio_object_thresh=ratio_object_thresh 306 | ) 307 | # pad with 1, to avoid rounding error: 308 | self.mask = np.pad( 309 | self.mask, ((0, 1), (0, 1)), mode="constant", constant_values=False 310 | ) 311 | # Select subset of point that are part of the foreground 312 | id_valid = self.mask[ 313 | np.round(my * self.crop_reference_cxy[:, 1]).astype(int), 314 | np.round(mx * self.crop_reference_cxy[:, 0]).astype(int), 315 | ] 316 | self.crop_reference_cxy = self.crop_reference_cxy[id_valid] 317 | 318 | # Build grid for all levels 319 | self.crop_metadatas = self._build_crop_metadatas( 320 | self.crop_sizes_px, 321 | self.crop_magnifications, 322 | self.level_magnifications, 323 | self.crop_reference_cxy, 324 | self.crop_reference_level, 325 | ) 326 | 327 | # Remove samples that are oob from sampling 328 | if remove_oob: 329 | # Compute oob sa,ples 330 | oob_id = self._oob_id( 331 | self.crop_metadatas, self.level_dimensions[self.crop_reference_level] 332 | ) 333 | # Select only smaples that are within bounds. 334 | self.crop_reference_cxy = self.crop_reference_cxy[~oob_id] 335 | self.crop_metadatas = self.crop_metadatas[:, ~oob_id] 336 | 337 | @staticmethod 338 | def _pil_rgba2rgb( 339 | image: PIL.Image, default_background: Optional[List[int]] = None 340 | ) -> PIL.Image: 341 | """ 342 | Convert RGBA image to RGB format using default background color. 343 | From https://stackoverflow.com/questions/9166400/convert-rgba-png-to-rgb-with-pil/9459208#9459208 344 | Parameters 345 | ---------- 346 | image: PIL.Image 347 | Input RBA image to convert. 348 | default_background: list of int, optional 349 | Value to us as background hen alpha channel is not 255. Default value is white (255, 255, 255). 350 | Returns 351 | ------- 352 | Image with alpha channel removed. 353 | """ 354 | if default_background is None: 355 | default_background = (255, 255, 255) 356 | if type(image) == np.ndarray: 357 | if image.shape[-1] == 3: 358 | return image 359 | else: 360 | return cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) 361 | else: 362 | image.load() 363 | background = PIL.Image.new("RGB", image.size, default_background) 364 | background.paste(image, mask=image.split()[3]) 365 | return background 366 | 367 | @staticmethod 368 | def _oob_id( 369 | crop_grid: np.ndarray, 370 | level_shape: List[int], 371 | ) -> np.ndarray: 372 | """ 373 | Check is the samples are within bounds. 374 | Parameters 375 | ---------- 376 | crop_grid: array_like 377 | Input crop meta data of C element where C is the number of crops. For each crop 378 | level_shape: list of int 379 | Dimension of the image 380 | Returns 381 | ------- 382 | """ 383 | # Extract top left coordinated 384 | tx, ty = crop_grid[:, :, 2], crop_grid[:, :, 3] 385 | # Extract top right coordinated 386 | bx, by = crop_grid[:, :, 4], crop_grid[:, :, 5] 387 | # Check for boundaries 388 | oob_id = (tx < 0) | (ty < 0) | (bx > level_shape[0]) | (by > level_shape[1]) 389 | return np.any(oob_id, axis=0) 390 | 391 | @staticmethod 392 | def _build_crop_metadatas( 393 | crop_sizes_px: List[int], 394 | crop_magnifications: List[float], 395 | level_magnifications: List[float], 396 | crop_reference_grid: np.ndarray, 397 | crop_reference_level: int, 398 | ) -> np.ndarray: 399 | """ 400 | Build metadata for each crops definitions. 401 | Parameters 402 | ---------- 403 | crop_sizes_px: list of int, optional 404 | List of crops output size in pixel, default value is [224]. 405 | crop_magnifications: list of float, optional 406 | List of crops magnification level, default value is [20]. 407 | level_magnifications: list of float 408 | List of available magnifications (one for each level) 409 | crop_reference_grid: 410 | Reference grid with shape [Nx2] where N is the number of samples. The column represent x and y coordinates 411 | of the center of the crops respectively. 412 | crop_reference_level: int 413 | Reference level used to compute the reference grid. 414 | Returns 415 | ------- 416 | metas: array_like 417 | Meta data were each entry correspond to the metadata a the crop and [mag, level, tx, ty, cx, cy, bx, 418 | by, s_src, s_tar]. With mag = magnification of the crop, level = level at which the crop was extracted, 419 | (tx, ty) = top left coordinate of the crop, (cx, cy) = center coordinate of the crop, (bx, by) = bottom 420 | right coordinates of the crop, s_src = size of the crop at the level, s_tar = siz of the crop after 421 | applying rescaling. 422 | """ 423 | 424 | crop_grids = [] 425 | for t_size, t_mag in zip(crop_sizes_px, crop_magnifications): 426 | # Level that we use to extract current slide region 427 | t_level = WholeSlideDataset._get_optimal_level(t_mag, level_magnifications) 428 | # Scale factor between the reference magnification and the magnification used 429 | t_scale = ( 430 | level_magnifications[t_level] 431 | / level_magnifications[crop_reference_level] 432 | ) 433 | # Final image size at the current level / magnification 434 | t_level_size = t_size / (t_mag / level_magnifications[t_level]) 435 | # Offset to recenter image 436 | t_shift = (t_level_size / t_scale) // 2 437 | # Return grid as format: [level, tx, ty, bx, by, level_size, size] 438 | grid_ = np.concatenate( 439 | ( 440 | t_mag 441 | * np.ones(len(crop_reference_grid))[:, np.newaxis], # Magnification 442 | t_level * np.ones(len(crop_reference_grid))[:, np.newaxis], # Level 443 | crop_reference_grid - t_shift, # (tx, ty) coordinates 444 | crop_reference_grid, # (cx, cy) coordinates values 445 | crop_reference_grid + t_shift, # (bx, by) coordinates 446 | t_level_size 447 | * np.ones(len(crop_reference_grid))[ 448 | :, np.newaxis 449 | ], # original images size 450 | t_size 451 | * np.ones(len(crop_reference_grid))[ 452 | :, np.newaxis 453 | ], # target image size 454 | ), 455 | axis=1, 456 | ) 457 | crop_grids.append(grid_) 458 | return np.array(crop_grids) 459 | 460 | @staticmethod 461 | def _get_optimal_level( 462 | magnification: float, level_magnifications: List[float] 463 | ) -> int: 464 | """ 465 | Estimate the optimal level to extract crop. It the wanted level do nt exist, use a level with higher resolution 466 | (lower level) and resize crop. 467 | Parameters 468 | ---------- 469 | magnification: float 470 | Wanted output magnification 471 | level_magnifications: list of float 472 | List of available magnifications (one for each level) 473 | Returns 474 | ------- 475 | optimal_level: int 476 | Estimated optimal level for crop extraction 477 | """ 478 | 479 | # Get the highest level that is a least as high resolution as the wanted target. 480 | if magnification <= np.max(level_magnifications): 481 | optimal_level = np.nonzero(np.array(level_magnifications) >= magnification)[ 482 | 0 483 | ][-1] 484 | else: 485 | # If no suitable candidates are found, use max resolution 486 | optimal_level = 0 487 | print( 488 | "Slide magnifications {} do not match expected target magnification {}".format( 489 | magnification, level_magnifications 490 | ) 491 | ) 492 | 493 | return optimal_level 494 | 495 | @staticmethod 496 | def _get_magnifications( 497 | mpp: float, 498 | level_downsamples: List[float], 499 | error_max: Optional[float] = 1e-1, 500 | ) -> List[float]: 501 | """ 502 | Compute estimated magnification for each level. The computation rely on the definition of LUT_MAGNIFICATION_X 503 | and LUT_MAGNIFICATION_MPP that are mapped. For example the assumption is 20x -> ~0.5MPP and 40x -> ~0.25MPP. 504 | Parameters 505 | ---------- 506 | mpp: float 507 | Resolution of the slide (and the scanner). 508 | level_downsamples: lost of float 509 | Down sampling factors for each level as a list of floats. 510 | error_max: float, optional 511 | Maximum relative error accepted when trying to match magnification to predefined factors. Default value 512 | is 1e-1. 513 | Returns 514 | ------- 515 | level_magnifications: list of float 516 | Return the estimated magnifications for each level. 517 | """ 518 | 519 | error_mag = np.abs((np.array(LUT_MAGNIFICATION_MPP) - mpp) / mpp) 520 | # if np.min(error_mag) > error_max: 521 | # print('Error too large for mpp matching: mpp={}, error={}'.format(mpp, np.min(error_mag))) 522 | 523 | return LUT_MAGNIFICATION_X[np.argmin(error_mag)] / np.round( 524 | level_downsamples 525 | ).astype(int) 526 | 527 | @staticmethod 528 | def _foreground_mask( 529 | img: PIL.Image.Image, 530 | intensity_thresh: Optional[int] = 240, 531 | ratio_object_thresh: Optional[float] = 1e-4, 532 | ) -> np.ndarray: 533 | """ 534 | Compute foreground mask the slide base on the input image. Usually the embedded thumbnail is used. 535 | Parameters 536 | ---------- 537 | img: PIL.Image.Image 538 | Downscaled version of the slide as a PIL image 539 | intensity_thresh: int 540 | Intensity threshold applied on te grayscale version of the image to distinguish background from foreground. 541 | The default value is 240. 542 | ratio_object_thresh: float 543 | Minimal ratio of the object to consider as a relevant region. Ratio is applied on the area of the object. 544 | Returns 545 | ------- 546 | mask: np.ndarray 547 | Masked version of the input image where '0', '1' indicates regions belonging to background and foreground 548 | respectively. 549 | """ 550 | 551 | # Convert image to grayscale 552 | mask = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2GRAY) 553 | # Blur image to remove hih frequencies 554 | mask = cv2.blur(mask, (5, 5)) 555 | # Apply threshold on background intensity 556 | mask = mask < intensity_thresh 557 | # Remove smallest object as a ratio of original image size 558 | mask = remove_small_objects( 559 | mask, min_size=ratio_object_thresh * np.prod(mask.shape) 560 | ) 561 | # Add final margin to avoid cutting edges 562 | disk_edge = np.ceil(np.max(mask.shape) * ratio_object_thresh).astype(int) 563 | mask = dilation(mask, disk(max(1, disk_edge))) 564 | 565 | return mask 566 | 567 | @staticmethod 568 | def _build_reference_grid( 569 | crop_size_px: int, 570 | crop_magnification: float, 571 | padding_factor: float, 572 | level_magnification: float, 573 | level_shape: List[int], 574 | ) -> np.ndarray: 575 | """ 576 | Build reference grid for cropping location. The grid is usually computed at the lowest magnification. 577 | Parameters 578 | ---------- 579 | crop_size_px: int 580 | Output size in pixel. 581 | crop_magnification: float 582 | Magnification value. 583 | padding_factor: float 584 | Padding factor to use. Define the interval between two consecutive crops. 585 | level_magnification: float 586 | Selected magnification. 587 | level_shape: list of int 588 | Size of the image at the selected level. 589 | Returns 590 | ------- 591 | (cx, cy): list of int 592 | Center coordinate of the crop. 593 | """ 594 | 595 | # Define the size of the crop at the selected level 596 | level_crop_size_px = int( 597 | (level_magnification / crop_magnification) * crop_size_px 598 | ) 599 | 600 | # Compute the number of crops for each dimensions (rows and columns) 601 | n_w = np.floor( 602 | (1 / padding_factor) * (level_shape[0] / level_crop_size_px - 1) 603 | ).astype(int) 604 | n_h = np.floor( 605 | (1 / padding_factor) * (level_shape[1] / level_crop_size_px - 1) 606 | ).astype(int) 607 | 608 | # Compute the residual margin at each side of the image 609 | margin_w = ( 610 | int(level_shape[0] - padding_factor * (n_w - 1) * level_crop_size_px) // 2 611 | ) 612 | margin_h = ( 613 | int(level_shape[1] - padding_factor * (n_h - 1) * level_crop_size_px) // 2 614 | ) 615 | 616 | # Compute the final center for the cropping 617 | c_x = (np.arange(n_w) * level_crop_size_px * padding_factor + margin_w).astype( 618 | int 619 | ) 620 | c_y = (np.arange(n_h) * level_crop_size_px * padding_factor + margin_h).astype( 621 | int 622 | ) 623 | c_x, c_y = np.meshgrid(c_x, c_y) 624 | 625 | return np.array([c_x.flatten(), c_y.flatten()]).T 626 | 627 | def __len__(self) -> int: 628 | return len(self.crop_reference_cxy) 629 | 630 | def __getitem__(self, idx: int) -> Tuple[List[object], List[object]]: 631 | """ 632 | Get slide element as a function of the index idx. 633 | Parameters 634 | ---------- 635 | idx: int 636 | Index of the crop 637 | Returns 638 | ------- 639 | imgs: List of PIL.Image 640 | List of extracted crops for this index. 641 | metas: List of List of float 642 | Meta data were each entry correspond to the metadata a the crop and [mag, level, tx, ty, cx, cy, bx, 643 | by, s_src, s_tar]. With mag = magnification of the crop, level = level at which the crop was extracted, 644 | (tx, ty) = top left coordinate of the crop, (cx, cy) = center coordinate of the crop, (bx, by) = bottom 645 | right coordinates of the crop, s_src = size of the crop at the level, s_tar = siz of the crop after 646 | applying rescaling. 647 | """ 648 | # Extract metadata for crops 649 | mag, level, tx, ty, cx, cy, bx, by, s_src, s_tar = self.crop_metadatas[0][idx] 650 | # Extract crop 651 | img = self.s.read_region( 652 | (int(tx), int(ty)), int(level), size=(int(s_src), int(s_src)) 653 | ) 654 | # If needed, resize crop to match output shape 655 | if s_src != s_tar: 656 | img = img.resize((int(s_tar), int(s_tar))) 657 | # Append images and metadatas 658 | if self.remove_alpha: 659 | img = self._pil_rgba2rgb(img) 660 | if self.transform is not None: 661 | img = self.transform(img) 662 | 663 | img = normalize_min_max(np.array(img), 0, 255) 664 | 665 | return torch.Tensor(np.array(img)), [ 666 | mag, 667 | level, 668 | tx, 669 | ty, 670 | cx, 671 | cy, 672 | bx, 673 | by, 674 | s_src, 675 | s_tar, 676 | ] 677 | 678 | 679 | class NpyDataset(Dataset): 680 | def __init__( 681 | self, 682 | path, 683 | crop_size_px, 684 | padding_factor=0.5, 685 | remove_bg=True, 686 | ratio_object_thresh=5e-1, 687 | min_tiss=0.1, 688 | ): 689 | """ 690 | Torch Dataset to load from NPY files. 691 | 692 | Parameters 693 | ---------- 694 | path : str 695 | Path to the NPY file. 696 | crop_size_px : int 697 | Size of the extracted tiles in pixels. e.g 256 -> 256x256 tiles 698 | padding_factor : float, optional 699 | Padding value when creating reference grid. Distance between two consecutive crops as a proportion of the 700 | first listed crop size. 701 | remove_bg : bool, optional 702 | Remove background crops if their saturation value is above 5. Default value is True. 703 | ratio_object_thresh : float, optional 704 | Objects are removed if they are smaller than ratio*largest object 705 | min_tiss : float, optional 706 | Threshold value to consider a crop as tissue. Default value is 0.1. 707 | """ 708 | self.path = path 709 | self.crop_size_px = crop_size_px 710 | self.padding_factor = padding_factor 711 | self.ratio_object_thresh = ratio_object_thresh 712 | self.min_tiss = min_tiss 713 | self.remove_bg = remove_bg 714 | self.store = np.load(path) 715 | if self.store.ndim == 3: 716 | self.store = self.store[np.newaxis, :] 717 | if self.store.dtype != np.uint8: 718 | print("converting input dtype to uint8") 719 | self.store = self.store.astype(np.uint8) 720 | self.orig_shape = self.store.shape 721 | self.store = np.pad( 722 | self.store, 723 | [ 724 | (0, 0), 725 | (self.crop_size_px, self.crop_size_px), 726 | (self.crop_size_px, self.crop_size_px), 727 | (0, 0), 728 | ], 729 | "constant", 730 | constant_values=255, 731 | ) 732 | self.msks, self.fg_amount = self._foreground_mask() 733 | 734 | self.grid = self._calc_grid() 735 | self.idx = self._create_idx() 736 | 737 | # TODO No idea what kind of exceptions could happen. 738 | # If you are having issues with this dataloader, create an issue. 739 | 740 | def _foreground_mask(self, h_tresh=5): 741 | # print("computing fg masks") 742 | ret = [] 743 | fg_amount = [] 744 | for im in self.store: 745 | msk = ( 746 | cv2.blur(cv2.cvtColor(im, cv2.COLOR_RGB2HSV)[..., 1], (50, 50)) 747 | > h_tresh 748 | ) 749 | comp, labl, size, cent = cv2.connectedComponentsWithStats( 750 | msk.astype(np.uint8) * 255 751 | ) 752 | selec = size[1:, -1] / size[1:, -1].max() > self.ratio_object_thresh 753 | ids = np.arange(1, comp)[selec] 754 | fin_msk = np.isin(labl, ids) 755 | ret.append(fin_msk) 756 | fg_amount.append(np.mean(fin_msk)) 757 | 758 | return ret, fg_amount 759 | 760 | def _calc_grid(self): 761 | _, h, w, _ = self.store.shape 762 | n_w = np.floor( 763 | (w - self.crop_size_px) / (self.crop_size_px * self.padding_factor) 764 | ) 765 | n_h = np.floor( 766 | (h - self.crop_size_px) / (self.crop_size_px * self.padding_factor) 767 | ) 768 | margin_w = ( 769 | int(w - (self.padding_factor * n_w * self.crop_size_px + self.crop_size_px)) 770 | // 2 771 | ) 772 | margin_h = ( 773 | int(h - (self.padding_factor * n_h * self.crop_size_px + self.crop_size_px)) 774 | // 2 775 | ) 776 | c_x = ( 777 | np.arange(n_w + 1) * self.crop_size_px * self.padding_factor + margin_w 778 | ).astype(int) 779 | c_y = ( 780 | np.arange(n_h + 1) * self.crop_size_px * self.padding_factor + margin_h 781 | ).astype(int) 782 | c_x, c_y = np.meshgrid(c_x, c_y) 783 | return np.array([c_y.flatten(), c_x.flatten()]).T 784 | 785 | def _create_idx(self): 786 | crd_list = [] 787 | for i, msk in enumerate(self.msks): 788 | if self.remove_bg: 789 | valid_crd = [ 790 | np.mean( 791 | msk[ 792 | crd[0] : crd[0] + self.crop_size_px, 793 | crd[1] : crd[1] + self.crop_size_px, 794 | ] 795 | ) 796 | > self.min_tiss 797 | for crd in self.grid 798 | ] 799 | crd_subset = self.grid[valid_crd, :] 800 | crd_list.append( 801 | np.concatenate( 802 | [np.repeat(i, crd_subset.shape[0]).reshape(-1, 1), crd_subset], 803 | -1, 804 | ) 805 | ) 806 | else: 807 | crd_list.append( 808 | np.concatenate( 809 | [np.repeat(i, self.grid.shape[0]).reshape(-1, 1), self.grid], -1 810 | ) 811 | ) 812 | return np.vstack(crd_list) 813 | 814 | def __len__(self) -> int: 815 | return self.idx.shape[0] 816 | 817 | def __getitem__(self, idx): 818 | c, x, y = self.idx[idx] 819 | out_img = self.store[c, x : x + self.crop_size_px, y : y + self.crop_size_px] 820 | out_img = normalize_min_max(out_img, 0, 255) 821 | return out_img, (c, x, y) 822 | 823 | 824 | class ImageDataset(NpyDataset): 825 | """ 826 | Torch Dataset to load from NPY files. 827 | 828 | Parameters 829 | ---------- 830 | path : str 831 | Path to the Image, needs to be supported by opencv 832 | crop_size_px : int 833 | Size of the extracted tiles in pixels. e.g 256 -> 256x256 tiles 834 | padding_factor : float, optional 835 | Padding value when creating reference grid. Distance between two consecutive crops as a proportion of the 836 | first listed crop size. 837 | remove_bg : bool, optional 838 | Remove background crops if their saturation value is above 5. Default value is True. 839 | ratio_object_thresh : float, optional 840 | Objects are removed if they are smaller than ratio*largest object 841 | min_tiss : float, optional 842 | Threshold value to consider a crop as tissue. Default value is 0.1. 843 | """ 844 | 845 | def __init__( 846 | self, 847 | path, 848 | crop_size_px, 849 | padding_factor=0.5, 850 | remove_bg=True, 851 | ratio_object_thresh=5e-1, 852 | min_tiss=0.1, 853 | ): 854 | self.path = path 855 | self.crop_size_px = crop_size_px 856 | self.padding_factor = padding_factor 857 | self.ratio_object_thresh = ratio_object_thresh 858 | self.min_tiss = min_tiss 859 | self.remove_bg = remove_bg 860 | self.store = self._load_image() 861 | 862 | self.orig_shape = self.store.shape 863 | self.store = np.pad( 864 | self.store, 865 | [ 866 | (0, 0), 867 | (self.crop_size_px, self.crop_size_px), 868 | (self.crop_size_px, self.crop_size_px), 869 | (0, 0), 870 | ], 871 | "constant", 872 | constant_values=255, 873 | ) 874 | self.msks, self.fg_amount = self._foreground_mask() 875 | self.grid = self._calc_grid() 876 | self.idx = self._create_idx() 877 | 878 | def _load_image(self): 879 | img = cv2.imread(self.path) 880 | if img.shape[-1] == 4: 881 | img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB) 882 | elif img.shape[-1] == 3: 883 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 884 | else: 885 | raise NotImplementedError("Image is neither RGBA nor RGB") 886 | return img[np.newaxis, ...] 887 | -------------------------------------------------------------------------------- /src/inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import toml 4 | import requests 5 | from concurrent.futures import ThreadPoolExecutor 6 | import concurrent.futures 7 | from typing import List, Union, Tuple 8 | import torch 9 | import numpy as np 10 | import zarr 11 | import zipfile 12 | from numcodecs import Blosc 13 | from torch.utils.data import DataLoader 14 | from tqdm.auto import tqdm 15 | from scipy.special import softmax 16 | from src.multi_head_unet import get_model, load_checkpoint 17 | from src.data_utils import WholeSlideDataset, NpyDataset, ImageDataset 18 | from src.augmentations import color_augmentations 19 | from src.spatial_augmenter import SpatialAugmenter 20 | from src.constants import TTA_AUG_PARAMS, VALID_WEIGHTS 21 | 22 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 23 | 24 | def inference_main( 25 | params: dict, 26 | models, 27 | augmenter, 28 | color_aug_fn, 29 | ): 30 | """ 31 | Inference function for a single input file. 32 | 33 | Parameters 34 | ---------- 35 | params: dict 36 | Parameter store, defined in initial main 37 | models: List[torch.nn.Module] 38 | list of models to run inference with, e.g. multiple folds or a single model in a list 39 | augmenter: SpatialAugmenter 40 | Augmentation module for geometric transformations 41 | color_aug_fn: torch.nn.Sequential 42 | Color Augmentation module 43 | 44 | Returns 45 | ---------- 46 | params: dict 47 | Parameter store, defined in initial main and modified by this function 48 | z: Union(Tuple[zarr.ZipStore, zarr.ZipStore], None) 49 | instance and class segmentation results as zarr stores, kept open for further processing. None if inference was skipped. 50 | """ 51 | # print(repr(params["p"])) 52 | fn = params["p"].split(os.sep)[-1].split(params["ext"])[0] 53 | params["output_dir"] = os.path.join(params["output_root"], fn) 54 | if not os.path.isdir(params["output_dir"]): 55 | os.makedirs(params["output_dir"]) 56 | params["model_out_p"] = os.path.join( 57 | params["output_dir"], fn + "_raw_" + str(params["tile_size"]) 58 | ) 59 | prog_path = os.path.join(params["output_dir"], "progress.txt") 60 | 61 | if os.path.exists(os.path.join(params["output_dir"], "pinst_pp.zip")): 62 | print( 63 | "inference and postprocessing already completed, delete output or specify different output path to re-run" 64 | ) 65 | return params, None 66 | 67 | if ( 68 | os.path.exists(params["model_out_p"] + "_inst.zip") 69 | & (os.path.exists(params["model_out_p"] + "_cls.zip")) 70 | & (not os.path.exists(prog_path)) 71 | ): 72 | try: 73 | z_inst = zarr.open(params["model_out_p"] + "_inst.zip", mode="r") 74 | z_cls = zarr.open(params["model_out_p"] + "_cls.zip", mode="r") 75 | print("Inference already completed", z_inst.shape, z_cls.shape) 76 | return params, (z_inst, z_cls) 77 | except (KeyError, zipfile.BadZipFile): 78 | z_inst = None 79 | z_cls = None 80 | print( 81 | "something went wrong with previous output files, rerunning inference" 82 | ) 83 | 84 | z_inst = None 85 | z_cls = None 86 | 87 | if not torch.cuda.is_available(): 88 | print("trying to run inference on CPU, aborting...") 89 | print("if this is intended, remove this check") 90 | raise Exception("No GPU available") 91 | 92 | # create datasets from specified input 93 | 94 | if params["input_type"] == "npy": 95 | dataset = NpyDataset( 96 | params["p"], 97 | params["tile_size"], 98 | padding_factor=params["overlap"], 99 | ratio_object_thresh=0.3, 100 | min_tiss=0.1, 101 | ) 102 | elif params["input_type"] == "img": 103 | dataset = ImageDataset( 104 | params["p"], 105 | params["tile_size"], 106 | padding_factor=params["overlap"], 107 | ratio_object_thresh=0.3, 108 | min_tiss=0.1, 109 | ) 110 | else: 111 | level = 40 if params["pannuke"] else 20 112 | dataset = WholeSlideDataset( 113 | params["p"], 114 | crop_sizes_px=[params["tile_size"]], 115 | crop_magnifications=[level], 116 | padding_factor=params["overlap"], 117 | remove_background=True, 118 | ratio_object_thresh=0.0001, 119 | ) 120 | 121 | # setup output files to write to, also create dummy file to resume inference if interruped 122 | 123 | z_inst = zarr.open( 124 | params["model_out_p"] + "_inst.zip", 125 | mode="w", 126 | shape=(len(dataset), 3, params["tile_size"], params["tile_size"]), 127 | chunks=(params["batch_size"], 3, params["tile_size"], params["tile_size"]), 128 | dtype="f4", 129 | compressor=Blosc(cname="lz4", clevel=3, shuffle=Blosc.SHUFFLE), 130 | ) 131 | z_cls = zarr.open( 132 | params["model_out_p"] + "_cls.zip", 133 | mode="w", 134 | shape=( 135 | len(dataset), 136 | params["out_channels_cls"], 137 | params["tile_size"], 138 | params["tile_size"], 139 | ), 140 | chunks=( 141 | params["batch_size"], 142 | params["out_channels_cls"], 143 | params["tile_size"], 144 | params["tile_size"], 145 | ), 146 | dtype="u1", 147 | compressor=Blosc(cname="lz4", clevel=3, shuffle=Blosc.BITSHUFFLE), 148 | ) 149 | # creating progress file to restart inference if it was interrupted 150 | with open(prog_path, "w") as f: 151 | f.write("0") 152 | inf_start = 0 153 | 154 | dataloader = DataLoader( 155 | dataset, 156 | batch_size=params["batch_size"], 157 | shuffle=False, 158 | num_workers=params["inf_workers"], 159 | pin_memory=True, 160 | ) 161 | 162 | # IO thread to write output in parallel to inference 163 | def dump_results(res, z_cls, z_inst, prog_path): 164 | cls_, inst_, zc_ = res 165 | if cls_ is None: 166 | return 167 | cls_ = (softmax(cls_.astype(np.float32), axis=1) * 255).astype(np.uint8) 168 | z_cls[zc_ : zc_ + cls_.shape[0]] = cls_ 169 | z_inst[zc_ : zc_ + inst_.shape[0]] = inst_.astype(np.float32) 170 | with open(prog_path, "w") as f: 171 | f.write(str(zc_)) 172 | return 173 | 174 | # Separate thread for IO 175 | with ThreadPoolExecutor(max_workers=params["inf_writers"]) as executor: 176 | futures = [] 177 | # run inference 178 | zc = inf_start 179 | for raw, _ in tqdm(dataloader): 180 | raw = raw.to(device, non_blocking=True).float() 181 | raw = raw.permute(0, 3, 1, 2) # BHWC -> BCHW 182 | with torch.inference_mode(): 183 | ct, inst = batch_pseudolabel_ensemb( 184 | raw, models, params["tta"], augmenter, color_aug_fn 185 | ) 186 | futures.append( 187 | executor.submit( 188 | dump_results, 189 | (ct.cpu().detach().numpy(), inst.cpu().detach().numpy(), zc), 190 | z_cls, 191 | z_inst, 192 | prog_path, 193 | ) 194 | ) 195 | 196 | zc += params["batch_size"] 197 | 198 | # Block until all data is written 199 | for _ in concurrent.futures.as_completed(futures): 200 | pass 201 | # clean up 202 | if os.path.exists(prog_path): 203 | os.remove(prog_path) 204 | return params, (z_inst, z_cls) 205 | 206 | 207 | def batch_pseudolabel_ensemb( 208 | raw: torch.Tensor, 209 | models: List[torch.nn.Module], 210 | nviews: int, 211 | aug: SpatialAugmenter, 212 | color_aug_fn: torch.nn.Sequential, 213 | ): 214 | """ 215 | Run inference step on batch of images with test time augmentations 216 | 217 | Parameters 218 | ---------- 219 | 220 | raw: torch.Tensor 221 | batch of input images 222 | models: List[torch.nn.Module] 223 | list of models to run inference with, e.g. multiple folds or a single model in a list 224 | nviews: int 225 | Number of test-time augmentation views to aggregate 226 | aug: SpatialAugmenter 227 | Augmentation module for geometric transformations 228 | color_aug_fn: torch.nn.Sequential 229 | Color Augmentation module 230 | 231 | Returns 232 | ---------- 233 | 234 | ct: torch.Tensor 235 | Per pixel class predictions as a tensor of shape (batch_size, n_classes+1, tilesize, tilesize) 236 | inst: torch.Tensor 237 | Per pixel 3 class prediction map with boundary, background and foreground classes, shape (batch_size, 3, tilesize, tilesize) 238 | """ 239 | tmp_3c_view = [] 240 | tmp_ct_view = [] 241 | # ensure that at least one view is run, even when specifying 1 view with many models 242 | if nviews <= 0: 243 | out_fast = [] 244 | with torch.inference_mode(): 245 | for mod in models: 246 | with torch.autocast(device_type="cuda", dtype=torch.float16): 247 | out_fast.append(mod(raw)) 248 | out_fast = torch.stack(out_fast, axis=0).nanmean(0) 249 | ct = out_fast[:, 5:].softmax(1) 250 | inst = out_fast[:, 2:5].softmax(1) 251 | else: 252 | for _ in range(nviews): 253 | aug.interpolation = "bilinear" 254 | view_aug = aug.forward_transform(raw) 255 | aug.interpolation = "nearest" 256 | view_aug = torch.clamp(color_aug_fn(view_aug), 0, 1) 257 | out_fast = [] 258 | with torch.inference_mode(): 259 | for mod in models: 260 | with torch.autocast(device_type="cuda", dtype=torch.float16): 261 | out_fast.append(aug.inverse_transform(mod(view_aug))) 262 | out_fast = torch.stack(out_fast, axis=0).nanmean(0) 263 | tmp_3c_view.append(out_fast[:, 2:5].softmax(1)) 264 | tmp_ct_view.append(out_fast[:, 5:].softmax(1)) 265 | ct = torch.stack(tmp_ct_view).nanmean(0) 266 | inst = torch.stack(tmp_3c_view).nanmean(0) 267 | return ct, inst 268 | 269 | 270 | def get_inference_setup(params): 271 | """ 272 | get model/ models and load checkpoint, create augmentation functions and set up parameters for inference 273 | """ 274 | models = [] 275 | for pth in params["data_dirs"]: 276 | if not os.path.exists(pth): 277 | pth = download_weights(os.path.split(pth)[-1]) 278 | 279 | checkpoint_path = f"{pth}/train/best_model" 280 | mod_params = toml.load(f"{pth}/params.toml") 281 | params["out_channels_cls"] = mod_params["out_channels_cls"] 282 | params["inst_channels"] = mod_params["inst_channels"] 283 | model = get_model( 284 | enc=mod_params["encoder"], 285 | out_channels_cls=params["out_channels_cls"], 286 | out_channels_inst=params["inst_channels"], 287 | ).to(device) 288 | model = load_checkpoint(model, checkpoint_path, device) 289 | model.eval() 290 | models.append(copy.deepcopy(model)) 291 | # create augmentation functions on device 292 | augmenter = SpatialAugmenter(TTA_AUG_PARAMS).to(device) 293 | color_aug_fn = color_augmentations(False, rank=device) 294 | 295 | if mod_params["dataset"] == "pannuke": 296 | params["pannuke"] = True 297 | else: 298 | params["pannuke"] = False 299 | print( 300 | "processing input using", 301 | "pannuke" if params["pannuke"] else "lizard", 302 | "trained model", 303 | ) 304 | 305 | return params, models, augmenter, color_aug_fn 306 | 307 | def download_weights(model_code): 308 | if model_code in VALID_WEIGHTS: 309 | url = f"https://zenodo.org/records/10635618/files/{model_code}.zip" 310 | print("downloading",model_code,"weights to",os.getcwd()) 311 | try: 312 | response = requests.get(url, stream=True, timeout=15.0) 313 | except requests.exceptions.Timeout: 314 | print("Timeout") 315 | total_size = int(response.headers.get("content-length", 0)) 316 | block_size = 1024 # 1 Kibibyte 317 | with tqdm(total=total_size, unit="iB", unit_scale=True) as t: 318 | with open("cache.zip", "wb") as f: 319 | for data in response.iter_content(block_size): 320 | t.update(len(data)) 321 | f.write(data) 322 | with zipfile.ZipFile("cache.zip", "r") as zip: 323 | zip.extractall("") 324 | os.remove("cache.zip") 325 | return model_code 326 | else: 327 | raise ValueError("Model id not found in valid identifiers, please make select one of", VALID_WEIGHTS) 328 | -------------------------------------------------------------------------------- /src/multi_head_unet.py: -------------------------------------------------------------------------------- 1 | import segmentation_models_pytorch as smp 2 | 3 | # from segmentation_models_pytorch.encoders import get_encoder 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from collections import OrderedDict 8 | from segmentation_models_pytorch.base import modules as md 9 | import segmentation_models_pytorch.base.initialization as init 10 | import timm 11 | 12 | 13 | def load_checkpoint(model, cp_path, device): 14 | """ 15 | load checkpoint and fix DataParallel/DistributedDataParallel 16 | """ 17 | 18 | cp = torch.load(cp_path, map_location=device) 19 | try: 20 | model.load_state_dict(cp["model_state_dict"]) 21 | 22 | print("succesfully loaded model weights") 23 | except: 24 | print("trying secondary checkpoint loading") 25 | state_dict = cp["model_state_dict"] 26 | new_state_dict = OrderedDict() 27 | for k, v in state_dict.items(): 28 | name = k[7:] # remove 'module.' of DataParallel/DistributedDataParallel 29 | new_state_dict[name] = v 30 | 31 | model.load_state_dict(new_state_dict) 32 | print("succesfully loaded model weights") 33 | return model 34 | 35 | 36 | class TimmEncoderFixed(nn.Module): 37 | """ 38 | Modified version of timm encoder. 39 | Original from: https://github.com/huggingface/pytorch-image-models 40 | 41 | """ 42 | 43 | def __init__( 44 | self, 45 | name, 46 | pretrained=True, 47 | in_channels=3, 48 | depth=5, 49 | output_stride=32, 50 | drop_rate=0.5, 51 | drop_path_rate=0.25, 52 | ): 53 | super().__init__() 54 | kwargs = dict( 55 | in_chans=in_channels, 56 | features_only=True, 57 | pretrained=pretrained, 58 | out_indices=tuple(range(depth)), 59 | drop_rate=drop_rate, 60 | drop_path_rate=drop_path_rate, 61 | ) 62 | 63 | self.model = timm.create_model(name, **kwargs) 64 | 65 | self._in_channels = in_channels 66 | self._out_channels = [ 67 | in_channels, 68 | ] + self.model.feature_info.channels() 69 | self._depth = depth 70 | self._output_stride = output_stride 71 | 72 | def forward(self, x): 73 | features = self.model(x) 74 | features = [ 75 | x, 76 | ] + features 77 | return features 78 | 79 | @property 80 | def out_channels(self): 81 | return self._out_channels 82 | 83 | @property 84 | def output_stride(self): 85 | return min(self._output_stride, 2**self._depth) 86 | 87 | 88 | def get_model( 89 | enc="convnextv2_tiny.fcmae_ft_in22k_in1k", 90 | out_channels_cls=8, 91 | out_channels_inst=5, 92 | pretrained=True, 93 | ): 94 | depth = 4 if "next" in enc else 5 95 | encoder = TimmEncoderFixed( 96 | name=enc, 97 | pretrained=pretrained, 98 | in_channels=3, 99 | depth=depth, 100 | output_stride=32, 101 | drop_rate=0.5, 102 | drop_path_rate=0.0, 103 | ) 104 | 105 | decoder_channels = (256, 128, 64, 32, 16)[:depth] 106 | decoder_inst = UnetDecoder( 107 | encoder_channels=encoder.out_channels, 108 | decoder_channels=decoder_channels, 109 | n_blocks=len(decoder_channels), 110 | use_batchnorm=False, 111 | center=False, 112 | attention_type=None, 113 | next="next" in enc, 114 | ) 115 | decoder_ct = UnetDecoder( 116 | encoder_channels=encoder.out_channels, 117 | decoder_channels=decoder_channels, 118 | n_blocks=len(decoder_channels), 119 | use_batchnorm=False, 120 | center=False, 121 | attention_type=None, 122 | next="next" in enc, 123 | ) 124 | head_inst = smp.base.SegmentationHead( 125 | in_channels=decoder_inst.blocks[-1].conv2[0].out_channels, 126 | out_channels=out_channels_inst, # instance channels 127 | activation=None, 128 | kernel_size=1, 129 | ) 130 | head_ct = smp.base.SegmentationHead( 131 | in_channels=decoder_ct.blocks[-1].conv2[0].out_channels, 132 | out_channels=out_channels_cls, 133 | activation=None, 134 | kernel_size=1, 135 | ) 136 | 137 | decoders = [decoder_inst, decoder_ct] 138 | heads = [head_inst, head_ct] 139 | model = MultiHeadModel(encoder, decoders, heads) 140 | return model 141 | 142 | 143 | class Conv2dReLU(nn.Sequential): 144 | def __init__( 145 | self, 146 | in_channels, 147 | out_channels, 148 | kernel_size, 149 | padding=0, 150 | stride=1, 151 | use_batchnorm=True, 152 | ): 153 | conv = nn.Conv2d( 154 | in_channels, 155 | out_channels, 156 | kernel_size, 157 | stride=stride, 158 | padding=padding, 159 | bias=not (use_batchnorm), 160 | ) 161 | relu = nn.ReLU() 162 | 163 | if use_batchnorm: 164 | bn = nn.BatchNorm2d(out_channels) 165 | 166 | else: 167 | bn = nn.Identity() 168 | 169 | super(Conv2dReLU, self).__init__(conv, bn, relu) 170 | 171 | 172 | class DecoderBlock(nn.Module): 173 | def __init__( 174 | self, 175 | in_channels, 176 | skip_channels, 177 | out_channels, 178 | use_batchnorm=True, 179 | attention_type=None, 180 | ): 181 | super().__init__() 182 | self.conv1 = md.Conv2dReLU( 183 | in_channels + skip_channels, 184 | out_channels, 185 | kernel_size=3, 186 | padding=1, 187 | use_batchnorm=use_batchnorm, 188 | ) 189 | self.attention1 = md.Attention( 190 | attention_type, in_channels=in_channels + skip_channels 191 | ) 192 | self.conv2 = md.Conv2dReLU( 193 | out_channels, 194 | out_channels, 195 | kernel_size=3, 196 | padding=1, 197 | use_batchnorm=use_batchnorm, 198 | ) 199 | self.attention2 = md.Attention(attention_type, in_channels=out_channels) 200 | 201 | def forward(self, x, skip=None): 202 | x = F.interpolate(x, scale_factor=2, mode="nearest") 203 | if skip is not None: 204 | x = torch.cat([x, skip], dim=1) 205 | x = self.attention1(x) 206 | x = self.conv1(x) 207 | x = self.conv2(x) 208 | x = self.attention2(x) 209 | return x 210 | 211 | 212 | class CenterBlock(nn.Sequential): 213 | def __init__(self, in_channels, out_channels, use_batchnorm=True): 214 | conv1 = md.Conv2dReLU( 215 | in_channels, 216 | out_channels, 217 | kernel_size=3, 218 | padding=1, 219 | use_batchnorm=use_batchnorm, 220 | ) 221 | conv2 = md.Conv2dReLU( 222 | out_channels, 223 | out_channels, 224 | kernel_size=3, 225 | padding=1, 226 | use_batchnorm=use_batchnorm, 227 | ) 228 | super().__init__(conv1, conv2) 229 | 230 | 231 | class UnetDecoder(nn.Module): 232 | def __init__( 233 | self, 234 | encoder_channels, 235 | decoder_channels, 236 | n_blocks=5, 237 | use_batchnorm=False, 238 | attention_type=None, 239 | center=False, 240 | next=False, 241 | ): 242 | super().__init__() 243 | 244 | if n_blocks != len(decoder_channels): 245 | raise ValueError( 246 | "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format( 247 | n_blocks, len(decoder_channels) 248 | ) 249 | ) 250 | 251 | # remove first skip with same spatial resolution 252 | encoder_channels = encoder_channels[1:] 253 | # reverse channels to start from head of encoder 254 | encoder_channels = encoder_channels[::-1] 255 | 256 | # computing blocks input and output channels 257 | head_channels = encoder_channels[0] 258 | in_channels = [head_channels] + list(decoder_channels[:-1]) 259 | skip_channels = list(encoder_channels[1:]) + [0] 260 | out_channels = decoder_channels 261 | 262 | if center: 263 | self.center = CenterBlock( 264 | head_channels, head_channels, use_batchnorm=use_batchnorm 265 | ) 266 | else: 267 | self.center = nn.Identity() 268 | 269 | # combine decoder keyword arguments 270 | kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type) 271 | blocks = [ 272 | DecoderBlock(in_ch, skip_ch, out_ch, **kwargs) 273 | for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels) 274 | ] 275 | if next: 276 | blocks.append( 277 | DecoderBlock(out_channels[-1], 0, out_channels[-1] // 2, **kwargs) 278 | ) 279 | self.blocks = nn.ModuleList(blocks) 280 | 281 | def forward(self, *features): 282 | features = features[1:] # remove first skip with same spatial resolution 283 | features = features[::-1] # reverse channels to start from head of encoder 284 | 285 | head = features[0] 286 | skips = features[1:] 287 | 288 | x = self.center(head) 289 | for i, decoder_block in enumerate(self.blocks): 290 | skip = skips[i] if i < len(skips) else None 291 | x = decoder_block(x, skip) 292 | 293 | return x 294 | 295 | 296 | class MultiHeadModel(torch.nn.Module): 297 | def __init__(self, encoder, decoder_list, head_list): 298 | super(MultiHeadModel, self).__init__() 299 | self.encoder = nn.ModuleList([encoder])[0] 300 | self.decoders = nn.ModuleList(decoder_list) 301 | self.heads = nn.ModuleList(head_list) 302 | self.initialize() 303 | 304 | def initialize(self): 305 | for decoder in self.decoders: 306 | init.initialize_decoder(decoder) 307 | for head in self.heads: 308 | init.initialize_head(head) 309 | 310 | def check_input_shape(self, x): 311 | h, w = x.shape[-2:] 312 | output_stride = self.encoder.output_stride 313 | if h % output_stride != 0 or w % output_stride != 0: 314 | new_h = ( 315 | (h // output_stride + 1) * output_stride 316 | if h % output_stride != 0 317 | else h 318 | ) 319 | new_w = ( 320 | (w // output_stride + 1) * output_stride 321 | if w % output_stride != 0 322 | else w 323 | ) 324 | raise RuntimeError( 325 | f"Wrong input shape height={h}, width={w}. Expected image height and width " 326 | f"divisible by {output_stride}. Consider pad your images to shape ({new_h}, {new_w})." 327 | ) 328 | 329 | def forward(self, x): 330 | """Sequentially pass `x` trough model`s encoder, decoder and heads""" 331 | 332 | # self.check_input_shape(x) 333 | 334 | features = self.encoder(x) 335 | decoder_outputs = [] 336 | for decoder in self.decoders: 337 | decoder_outputs.append(decoder(*features)) 338 | 339 | masks = [] 340 | for head, decoder_output in zip(self.heads, decoder_outputs): 341 | masks.append(head(decoder_output)) 342 | 343 | return torch.cat(masks, 1) 344 | 345 | @torch.no_grad() 346 | def predict(self, x): 347 | """Inference method. Switch model to `eval` mode, call `.forward(x)` with `torch.no_grad()` 348 | Args: 349 | x: 4D torch tensor with shape (batch_size, channels, height, width) 350 | Return: 351 | prediction: 4D torch tensor with shape (batch_size, classes, height, width) 352 | """ 353 | if self.training: 354 | self.eval() 355 | 356 | x = self.forward(x) 357 | 358 | return x 359 | -------------------------------------------------------------------------------- /src/post_process.py: -------------------------------------------------------------------------------- 1 | from src.post_process_utils import ( 2 | work, 3 | write, 4 | get_pp_params, 5 | get_shapes, 6 | get_tile_coords, 7 | ) 8 | from src.viz_utils import create_tsvs, create_polygon_output 9 | from src.data_utils import NpyDataset, ImageDataset 10 | from typing import List, Tuple 11 | import zarr 12 | from numcodecs import Blosc 13 | from concurrent.futures import ProcessPoolExecutor 14 | import concurrent.futures 15 | import json 16 | import os 17 | from typing import Union 18 | from tqdm.auto import tqdm 19 | from src.viz_utils import create_geojson 20 | from src.constants import ( 21 | CLASS_LABELS_LIZARD, 22 | CLASS_LABELS_PANNUKE, 23 | ) 24 | 25 | 26 | def post_process_main( 27 | params: dict, 28 | z: Union[Tuple[zarr.ZipStore, zarr.ZipStore], None] = None, 29 | ): 30 | """ 31 | Post processing function for inference results. Computes stitched output maps and refines prediction results and produces instance and class maps 32 | 33 | Parameters 34 | ---------- 35 | 36 | params: dict 37 | Parameter store, defined in initial main 38 | 39 | Returns 40 | ---------- 41 | z_pp: zarr.ZipStore 42 | instance segmentation results as zarr store, kept open for further processing 43 | 44 | """ 45 | # get best parameters for respective evaluation metric 46 | 47 | params = get_pp_params(params, True) 48 | params, ds_coord = get_shapes(params, len(params["best_fg_thresh_cl"])) 49 | 50 | tile_crds = get_tile_coords( 51 | params["out_img_shape"], 52 | params["pp_tiling"], 53 | pad_size=params["pp_overlap"], 54 | npy=params["input_type"] != "wsi", 55 | ) 56 | if params["input_type"] == "wsi": 57 | pinst_out = zarr.zeros( 58 | shape=( 59 | params["out_img_shape"][-1], 60 | params["out_img_shape"][-2], 61 | ), 62 | dtype="i4", 63 | compressor=Blosc(cname="lz4", clevel=3, shuffle=Blosc.SHUFFLE), 64 | ) 65 | 66 | else: 67 | pinst_out = zarr.zeros( 68 | shape=(params["orig_shape"][0], *params["orig_shape"][-2:]), 69 | dtype="i4", 70 | compressor=Blosc(cname="lz4", clevel=3, shuffle=Blosc.SHUFFLE), 71 | ) 72 | 73 | executor = ProcessPoolExecutor(max_workers=params["pp_workers"]) 74 | tile_processors = [ 75 | executor.submit(work, tcrd, ds_coord, z, params) for tcrd in tile_crds 76 | ] 77 | pcls_out = {} 78 | running_max = 0 79 | class_labels = [] 80 | res_poly = [] 81 | for future in tqdm( 82 | concurrent.futures.as_completed(tile_processors), total=len(tile_processors) 83 | ): 84 | pinst_out, pcls_out, running_max, class_labels, res_poly = write( 85 | pinst_out, pcls_out, running_max, future.result(), params, class_labels, res_poly 86 | ) 87 | executor.shutdown(wait=False) 88 | 89 | if params["output_dir"] is not None: 90 | print("saving final output") 91 | zarr.save(os.path.join(params["output_dir"], "pinst_pp.zip"), pinst_out) 92 | print("storing class dictionary...") 93 | with open(os.path.join(params["output_dir"], "class_inst.json"), "w") as fp: 94 | json.dump(pcls_out, fp) 95 | 96 | if params["input_type"] == "wsi": 97 | print("saving geojson coordinates for qupath...") 98 | create_tsvs(pcls_out, params) 99 | # TODO this is way to slow for large images 100 | if params["save_polygon"]: 101 | pred_keys = CLASS_LABELS_PANNUKE if params["pannuke"] else CLASS_LABELS_LIZARD 102 | create_geojson( 103 | res_poly, 104 | class_labels, 105 | dict((v, k) for k, v in pred_keys.items()), 106 | params, 107 | ) 108 | 109 | return pinst_out 110 | -------------------------------------------------------------------------------- /src/post_process_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import zarr 4 | import gc 5 | import json 6 | import os 7 | import time 8 | import openslide 9 | from skimage.segmentation import watershed 10 | from scipy.ndimage import find_objects 11 | from numcodecs import Blosc 12 | from src.viz_utils import cont 13 | from skimage.measure import regionprops 14 | from src.constants import ( 15 | MIN_THRESHS_LIZARD, 16 | MIN_THRESHS_PANNUKE, 17 | MAX_THRESHS_LIZARD, 18 | MAX_THRESHS_PANNUKE, 19 | MAX_HOLE_SIZE, 20 | LUT_MAGNIFICATION_MPP, 21 | LUT_MAGNIFICATION_X, 22 | ) 23 | from src.data_utils import center_crop, WholeSlideDataset, NpyDataset, ImageDataset 24 | 25 | 26 | def update_dicts(pinst_, pcls_, pcls_out, t_, old_ids, initial_ids): 27 | props = [(p.label, p.centroid) for p in regionprops(pinst_)] 28 | pcls_new = {} 29 | for id_, cen in props: 30 | try: 31 | pcls_new[str(id_)] = (pcls_[str(id_)], (cen[0] + t_[2], cen[1] + t_[0])) 32 | except KeyError: 33 | pcls_new[str(id_)] = (pcls_out[str(id_)], (cen[0] + t_[2], cen[1] + t_[0])) 34 | 35 | new_ids = [p[0] for p in props] 36 | 37 | for i in np.setdiff1d(old_ids, new_ids): 38 | try: 39 | del pcls_out[str(i)] 40 | except KeyError: 41 | pass 42 | for i in np.setdiff1d(new_ids, initial_ids): 43 | try: 44 | del pcls_new[str(i)] 45 | except KeyError: 46 | pass 47 | return pcls_out | pcls_new 48 | 49 | 50 | def write(pinst_out, pcls_out, running_max, res, params, class_labels, res_poly): 51 | pinst_, pcls_, max_, t_, skip = res 52 | if not skip: 53 | if params["input_type"] != "wsi": 54 | pinst_.vindex[pinst_[:] != 0] += running_max 55 | pcls_ = {str(int(k) + running_max): v for k, v in pcls_.items()} 56 | props = [(p.label, p.centroid) for p in regionprops(pinst_)] 57 | pcls_new = {} 58 | for id_, cen in props: 59 | pcls_new[str(id_)] = (pcls_[str(id_)], (t_[-1], cen[0], cen[1])) 60 | 61 | running_max += max_ 62 | pcls_out |= pcls_new 63 | pinst_out[t_[-1]] = np.asarray(pinst_, dtype=np.int32) 64 | 65 | else: 66 | pinst_ = np.asarray(pinst_, dtype=np.int32) 67 | ov_regions, local_regions, which = get_overlap_regions( 68 | t_, params["pp_overlap"], pinst_out.shape 69 | ) 70 | msk = pinst_ != 0 71 | pinst_[msk] += running_max 72 | pcls_ = {str(int(k) + running_max): v for k, v in pcls_.items()} 73 | running_max += max_ 74 | initial_ids = np.unique(pinst_[msk]) 75 | old_ids = [] 76 | 77 | for reg, loc, whi in zip(ov_regions, local_regions, which): 78 | if reg is None: 79 | continue 80 | 81 | written = np.array( 82 | pinst_out[reg[2] : reg[3], reg[0] : reg[1]], dtype=np.int32 83 | ) 84 | old_ids.append(np.unique(written[written != 0])) 85 | 86 | small, large = get_subregions(whi, written.shape) 87 | subregion = written[ 88 | small[0] : small[1], small[2] : small[3] 89 | ] # 1/4 of the region 90 | larger_subregion = written[ 91 | large[0] : large[1], large[2] : large[3] 92 | ] # 1/2 of the region 93 | keep = np.unique(subregion[subregion != 0]) 94 | if len(keep) == 0: 95 | continue 96 | 97 | keep_objects = find_objects( 98 | larger_subregion, max_label=max(keep) 99 | ) # [keep-1] 100 | pinst_reg = pinst_[loc[2] : loc[3], loc[0] : loc[1]][ 101 | large[0] : large[1], large[2] : large[3] 102 | ] 103 | 104 | for id_ in keep: 105 | obj = keep_objects[id_ - 1] 106 | if obj is None: 107 | continue 108 | written_mask = larger_subregion[obj] == id_ 109 | pinst_reg[obj][written_mask] = id_ 110 | 111 | old_ids = np.concatenate(old_ids) 112 | pcls_out = update_dicts(pinst_, pcls_, pcls_out, t_, old_ids, initial_ids) 113 | pinst_out[t_[2] : t_[3], t_[0] : t_[1]] = pinst_ 114 | if params["save_polygon"]: 115 | props = [(p.label, p.image, p.bbox) for p in regionprops(np.asarray(pinst_))] 116 | class_labels_partial = [pcls_out[str(p[0])] for p in props] 117 | res_poly_partial = [cont(i, [t_[2], t_[0]]) for i in props] 118 | class_labels.extend(class_labels_partial) 119 | res_poly.extend(res_poly_partial) 120 | # res.task_done() 121 | 122 | return pinst_out, pcls_out, running_max, class_labels, res_poly 123 | 124 | 125 | def work(tcrd, ds_coord, z, params): 126 | out_img = gen_tile_map( 127 | tcrd, 128 | ds_coord, 129 | params["ccrop"], 130 | model_out_p=params["model_out_p"], 131 | which="_inst", 132 | dim=params["out_img_shape"][-3], 133 | z=z, 134 | npy=params["input_type"] != "wsi", 135 | ) 136 | out_cls = gen_tile_map( 137 | tcrd, 138 | ds_coord, 139 | params["ccrop"], 140 | model_out_p=params["model_out_p"], 141 | which="_cls", 142 | dim=params["out_cls_shape"][-3], 143 | z=z, 144 | npy=params["input_type"] != "wsi", 145 | ) 146 | if params["input_type"] != "wsi": 147 | out_img = out_img[ 148 | :, 149 | params["tile_size"] : -params["tile_size"], 150 | params["tile_size"] : -params["tile_size"], 151 | ] 152 | out_cls = out_cls[ 153 | :, 154 | params["tile_size"] : -params["tile_size"], 155 | params["tile_size"] : -params["tile_size"], 156 | ] 157 | best_min_threshs = MIN_THRESHS_PANNUKE if params["pannuke"] else MIN_THRESHS_LIZARD 158 | best_max_threshs = MAX_THRESHS_PANNUKE if params["pannuke"] else MAX_THRESHS_LIZARD 159 | 160 | # using apply_func to apply along axis for npy stacks 161 | pred_inst, skip = faster_instance_seg( 162 | out_img, out_cls, params["best_fg_thresh_cl"], params["best_seed_thresh_cl"] 163 | ) 164 | del out_img 165 | gc.collect() 166 | max_hole_size = MAX_HOLE_SIZE if params["pannuke"] else (MAX_HOLE_SIZE // 4) 167 | if skip: 168 | pred_inst = zarr.array( 169 | pred_inst, compressor=Blosc(cname="zstd", clevel=3, shuffle=Blosc.SHUFFLE) 170 | ) 171 | 172 | return (pred_inst, {}, 0, tcrd, skip) 173 | pred_inst = post_proc_inst( 174 | pred_inst, 175 | max_hole_size, 176 | ) 177 | pred_ct = make_ct(out_cls, pred_inst) 178 | del out_cls 179 | gc.collect() 180 | 181 | processed = remove_obj_cls(pred_inst, pred_ct, best_min_threshs, best_max_threshs) 182 | # TODO why is this here? 183 | pred_inst, pred_ct = processed 184 | max_inst = np.max(pred_inst) 185 | pred_inst = zarr.array( 186 | pred_inst.astype(np.int32), 187 | compressor=Blosc(cname="zstd", clevel=3, shuffle=Blosc.SHUFFLE), 188 | ) 189 | return (pred_inst, pred_ct, max_inst, tcrd, skip) 190 | 191 | 192 | def get_overlap_regions(tcrd, pad_size, out_img_shape): 193 | top = [tcrd[0], tcrd[0] + 2 * pad_size, tcrd[2], tcrd[3]] if tcrd[0] != 0 else None 194 | bottom = ( 195 | [tcrd[1] - 2 * pad_size, tcrd[1], tcrd[2], tcrd[3]] 196 | if tcrd[1] != out_img_shape[-2] 197 | else None 198 | ) 199 | left = [tcrd[0], tcrd[1], tcrd[2], tcrd[2] + 2 * pad_size] if tcrd[2] != 0 else None 200 | right = ( 201 | [tcrd[0], tcrd[1], tcrd[3] - 2 * pad_size, tcrd[3]] 202 | if tcrd[3] != out_img_shape[-1] 203 | else None 204 | ) 205 | d_top = [0, 2 * pad_size, 0, tcrd[3] - tcrd[2]] 206 | d_bottom = [ 207 | tcrd[1] - tcrd[0] - 2 * pad_size, 208 | tcrd[1] - tcrd[0], 209 | 0, 210 | tcrd[3] - tcrd[2], 211 | ] 212 | d_left = [0, tcrd[1] - tcrd[0], 0, 2 * pad_size] 213 | d_right = [ 214 | 0, 215 | tcrd[1] - tcrd[0], 216 | tcrd[3] - tcrd[2] - 2 * pad_size, 217 | tcrd[3] - tcrd[2], 218 | ] 219 | return ( 220 | [top, bottom, left, right], 221 | [d_top, d_bottom, d_left, d_right], 222 | ["top", "bottom", "left", "right"], 223 | ) # 224 | 225 | 226 | def get_subregions(which, shape): 227 | """ 228 | Note that the names are incorrect :), inconsistency to be fixed with coordinates and xy swap 229 | """ 230 | if which == "top": 231 | return [0, shape[0], 0, shape[1] // 4], [0, shape[0], 0, shape[1] // 2] 232 | elif which == "bottom": 233 | return [0, shape[0], (shape[1] * 3) // 4, shape[1]], [ 234 | 0, 235 | shape[0], 236 | shape[1] // 2, 237 | shape[1], 238 | ] 239 | elif which == "left": 240 | return [0, shape[0] // 4, 0, shape[1]], [0, shape[0] // 2, 0, shape[1]] 241 | elif which == "right": 242 | return [(shape[0] * 3) // 4, shape[0], 0, shape[1]], [ 243 | shape[0] // 2, 244 | shape[0], 245 | 0, 246 | shape[1], 247 | ] 248 | 249 | else: 250 | raise ValueError("Invalid which") 251 | 252 | 253 | def expand_bbox(bbox, pad_size, img_size): 254 | return [ 255 | max(0, bbox[0] - pad_size), 256 | max(0, bbox[1] - pad_size), 257 | min(img_size[0], bbox[2] + pad_size), 258 | min(img_size[1], bbox[3] + pad_size), 259 | ] 260 | 261 | 262 | def get_tile_coords(shape, splits, pad_size, npy): 263 | if npy: 264 | tile_crds = [[0, shape[-2], 0, shape[-1], i] for i in range(shape[0])] 265 | return tile_crds 266 | 267 | else: 268 | shape = shape[-2:] 269 | tile_crds = [] 270 | ts_1 = np.array_split(np.arange(0, shape[0]), splits) 271 | ts_2 = np.array_split(np.arange(0, shape[1]), splits) 272 | for i in ts_1: 273 | for j in ts_2: 274 | x_start = 0 if i[0] < pad_size else i[0] - pad_size 275 | x_end = shape[0] if i[-1] + pad_size > shape[0] else i[-1] + pad_size 276 | y_start = 0 if j[0] < pad_size else j[0] - pad_size 277 | y_end = shape[1] if j[-1] + pad_size > shape[1] else j[-1] + pad_size 278 | tile_crds.append([x_start, x_end, y_start, y_end]) 279 | return tile_crds 280 | 281 | 282 | def proc_tile(t, ccrop, which="_cls"): 283 | t = center_crop(t, ccrop, ccrop) 284 | if which == "_cls": 285 | t = t[1:] 286 | t = t.reshape(t.shape[0], -1) 287 | out = np.zeros(t.shape, dtype=bool) 288 | out[t.argmax(axis=0), np.arange(t.shape[1])] = 1 289 | t = out.reshape(-1, ccrop, ccrop) 290 | 291 | else: 292 | t = t[:2].astype(np.float16) 293 | return t 294 | 295 | 296 | def gen_tile_map( 297 | tile_crd, 298 | ds_coord, 299 | ccrop, 300 | model_out_p="", 301 | which="_cls", 302 | dim=5, 303 | z=None, 304 | npy=False, 305 | ): 306 | if z is None: 307 | z = zarr.open(model_out_p + f"{which}.zip", mode="r") 308 | else: 309 | if which == "_cls": 310 | z = z[1] 311 | else: 312 | z = z[0] 313 | cadj = (z.shape[-1] - ccrop) // 2 314 | tx, ty, tz = None, None, None 315 | dtype = bool if which == "_cls" else np.float16 316 | 317 | if npy: 318 | # TODO fix npy 319 | coord_filter = ds_coord[:, 0] == tile_crd[-1] 320 | ds_coord_subset = ds_coord[coord_filter] 321 | zero_map = np.zeros( 322 | (dim, tile_crd[1] - tile_crd[0], tile_crd[3] - tile_crd[2]), dtype=dtype 323 | ) 324 | else: 325 | zero_map = np.zeros( 326 | (dim, tile_crd[3] - tile_crd[2], tile_crd[1] - tile_crd[0]), dtype=dtype 327 | ) 328 | coord_filter = ( 329 | ((ds_coord[:, 0]) < tile_crd[1]) 330 | & ((ds_coord[:, 0] + ccrop) > tile_crd[0]) 331 | & ((ds_coord[:, 1]) < tile_crd[3]) 332 | & ((ds_coord[:, 1] + ccrop) > tile_crd[2]) 333 | ) 334 | ds_coord_subset = ds_coord[coord_filter] - np.array([tile_crd[0], tile_crd[2]]) 335 | 336 | z_address = np.arange(ds_coord.shape[0])[coord_filter] 337 | for _, (crd, tile) in enumerate(zip(ds_coord_subset, z[z_address])): 338 | if npy: 339 | tz, ty, tx = crd 340 | else: 341 | tx, ty = crd 342 | tx = tx 343 | ty = ty 344 | p_shift = [abs(i) if i < 0 else 0 for i in [ty, tx]] 345 | n_shift = [ 346 | crd - (i + ccrop) if (i + ccrop) > crd else 0 347 | for i, crd in zip([ty, tx], zero_map.shape[1:3]) 348 | ] 349 | try: 350 | zero_map[ 351 | :, 352 | ty + p_shift[0] : ty + ccrop + n_shift[0], 353 | tx + p_shift[1] : tx + ccrop + n_shift[1], 354 | ] = proc_tile(tile, ccrop, which)[ 355 | ..., 356 | p_shift[0] : ccrop + n_shift[0], 357 | p_shift[1] : ccrop + n_shift[1], 358 | ] 359 | 360 | except: 361 | print(zero_map.shape) 362 | print(tx) 363 | print(ty) 364 | print(ccrop) 365 | print(tile.shape) 366 | raise ValueError 367 | return zero_map 368 | 369 | 370 | def faster_instance_seg(out_img, out_cls, best_fg_thresh_cl, best_seed_thresh_cl): 371 | _, rois = cv2.connectedComponents((out_img[0] > 0).astype(np.uint8), connectivity=8) 372 | bboxes = find_objects(rois) 373 | del rois 374 | gc.collect() 375 | skip = False 376 | labelling = zarr.zeros( 377 | out_cls.shape[1:], 378 | dtype=np.int32, 379 | compressor=Blosc(cname="lz4", clevel=3, shuffle=Blosc.BITSHUFFLE), 380 | ) 381 | if len(bboxes) == 0: 382 | skip = True 383 | return labelling, skip 384 | max_inst = 0 385 | for bb in bboxes: 386 | bg_pred = out_img[(slice(0, 1, None), *bb)].squeeze() 387 | if ( 388 | (np.array(bg_pred.shape[-2:]) <= 2).any() 389 | | (np.array(bg_pred.shape).sum() <= 64) 390 | | (len(bg_pred.shape) < 2) 391 | ): 392 | continue 393 | fg_pred = out_img[(slice(1, 2, None), *bb)].squeeze() 394 | sem = out_cls[(slice(0, len(best_fg_thresh_cl), None), *bb)] 395 | ws_surface = 1.0 - fg_pred # .astype(np.float32) 396 | fg = np.zeros_like(ws_surface, dtype="bool") 397 | seeds = np.zeros_like(ws_surface, dtype="bool") 398 | 399 | for cl, fg_t in enumerate(best_fg_thresh_cl): 400 | mask = sem[cl] 401 | fg[mask] |= (1.0 - bg_pred[mask]) > fg_t 402 | seeds[mask] |= fg_pred[mask] > best_seed_thresh_cl[cl] 403 | 404 | del fg_pred, bg_pred, sem, mask 405 | gc.collect() 406 | _, markers = cv2.connectedComponents((seeds).astype(np.uint8), connectivity=8) 407 | del seeds 408 | gc.collect() 409 | bb_ws = watershed(ws_surface, markers, mask=fg, connectivity=2) 410 | del ws_surface, markers, fg 411 | gc.collect() 412 | bb_ws[bb_ws != 0] += max_inst 413 | labelling[bb] = bb_ws 414 | max_inst = np.max(bb_ws) 415 | del bb_ws 416 | gc.collect() 417 | return labelling, skip 418 | 419 | 420 | def get_wsi(wsi_path, read_ds=32, pannuke=False, tile_size=256, padding_factor=0.96875): 421 | # TODO change this so it works with non-rescaled version as well 422 | ccrop = int(tile_size * padding_factor) 423 | level = 40 if pannuke else 20 424 | crop_adj = int((tile_size - ccrop) // 2) 425 | 426 | ws_ds = WholeSlideDataset( 427 | wsi_path, 428 | crop_sizes_px=[tile_size], 429 | crop_magnifications=[level], 430 | padding_factor=padding_factor, 431 | ratio_object_thresh=0.0001, 432 | ) 433 | sl = ws_ds.s # openslide.open_slide(wsi_path) 434 | sl_info = get_openslide_info(sl) 435 | target_level = np.argwhere(np.isclose(sl_info["level_downsamples"], read_ds)).item() 436 | ds_coord = ws_ds.crop_metadatas[0] 437 | ds_coord[:, 2:4] -= np.array([sl_info["bounds_x"], sl_info["bounds_y"]]) 438 | 439 | ds_coord[:, 2:4] += tile_size - ccrop 440 | w, h = np.max(ds_coord[:, 2:4], axis=0) 441 | 442 | raw = np.asarray( 443 | sl.read_region( 444 | ( 445 | sl_info["bounds_x"] + crop_adj, 446 | sl_info["bounds_y"] + crop_adj, 447 | ), 448 | target_level, 449 | ( 450 | int((w + ccrop) // (sl_info["level_downsamples"][target_level])), 451 | int((h + ccrop) // (sl_info["level_downsamples"][target_level])), 452 | ), 453 | ) 454 | ) 455 | raw = raw[..., :3] 456 | sl.close() 457 | return raw 458 | 459 | 460 | def post_proc_inst( 461 | pred_inst, 462 | hole_size=50, 463 | ): 464 | pshp = pred_inst.shape 465 | pred_inst = np.asarray(pred_inst) 466 | init = find_objects(pred_inst) 467 | init_large = [] 468 | adj = 8 469 | for i, sl in enumerate(init): 470 | if sl: 471 | slx1 = sl[0].start - adj if (sl[0].start - adj) > 0 else 0 472 | slx2 = sl[0].stop + adj if (sl[0].stop + adj) < pshp[0] else pshp[0] 473 | sly1 = sl[1].start - adj if (sl[1].start - adj) > 0 else 0 474 | sly2 = sl[1].stop + adj if (sl[1].stop + adj) < pshp[1] else pshp[1] 475 | init_large.append( 476 | (i + 1, (slice(slx1, slx2, None), slice(sly1, sly2, None))) 477 | ) 478 | out = np.zeros(pshp, dtype=np.int32) 479 | i = 1 480 | for sl in init_large: 481 | rm_small_hole = remove_small_holescv2(pred_inst[sl[1]] == (sl[0]), hole_size) 482 | out[sl[1]][rm_small_hole > 0] = i 483 | i += 1 484 | 485 | del pred_inst 486 | gc.collect() 487 | 488 | after_sh = find_objects(out) 489 | out_ = np.zeros(out.shape, dtype=np.int32) 490 | i_ = 1 491 | for i, sl in enumerate(after_sh): 492 | i += 1 493 | if sl: 494 | nr_objects, relabeled = cv2.connectedComponents( 495 | (out[sl] == i).astype(np.uint8), connectivity=8 496 | ) 497 | for new_lab in range(1, nr_objects): 498 | out_[sl] += (relabeled == new_lab) * i_ 499 | i_ += 1 500 | return out_ 501 | 502 | 503 | def make_ct(pred_class, instance_map): 504 | if type(pred_class) != np.ndarray: 505 | pred_class = pred_class[:] 506 | slices = find_objects(instance_map) 507 | pred_class = np.rollaxis(pred_class, 0, 3) 508 | # pred_class = softmax(pred_class,0) 509 | out = [] 510 | out.append((0, 0)) 511 | for i, sl in enumerate(slices): 512 | i += 1 513 | if sl: 514 | inst = instance_map[sl] == i 515 | i_cls = pred_class[sl][inst] 516 | i_cls = np.sum(i_cls, axis=0).argmax() + 1 517 | out.append((i, i_cls)) 518 | out_ = np.array(out) 519 | pred_ct = {str(k): int(v) for k, v in out_ if v != 0} 520 | return pred_ct 521 | 522 | 523 | def remove_obj_cls(pred_inst, pred_cls_dict, best_min_threshs, best_max_threshs): 524 | out_oi = np.zeros_like(pred_inst, dtype=np.int64) 525 | i_ = 1 526 | out_oc = [] 527 | out_oc.append((0, 0)) 528 | slices = find_objects(pred_inst) 529 | 530 | for i, sl in enumerate(slices): 531 | i += 1 532 | px = np.sum([pred_inst[sl] == i]) 533 | cls_ = pred_cls_dict[str(i)] 534 | if (px > best_min_threshs[cls_ - 1]) & (px < best_max_threshs[cls_ - 1]): 535 | out_oc.append((i_, cls_)) 536 | out_oi[sl][pred_inst[sl] == i] = i_ 537 | i_ += 1 538 | out_oc = np.array(out_oc) 539 | out_dict = {str(k): int(v) for k, v in out_oc if v != 0} 540 | return out_oi, out_dict 541 | 542 | 543 | def remove_small_holescv2(img, sz): 544 | # this is still pretty slow but at least its a bit faster than other approaches? 545 | img = np.logical_not(img).astype(np.uint8) 546 | 547 | nb_blobs, im_with_separated_blobs, stats, _ = cv2.connectedComponentsWithStats(img) 548 | # stats (and the silenced output centroids) gives some information about the blobs. See the docs for more information. 549 | # here, we're interested only in the size of the blobs, contained in the last column of stats. 550 | sizes = stats[1:, -1] 551 | nb_blobs -= 1 552 | im_result = np.zeros((img.shape), dtype=np.uint16) 553 | for blob in range(nb_blobs): 554 | if sizes[blob] >= sz: 555 | im_result[im_with_separated_blobs == blob + 1] = 1 556 | 557 | im_result = np.logical_not(im_result) 558 | return im_result 559 | 560 | 561 | def get_pp_params(params, mit_eval=False): 562 | eval_metric = params["metric"] 563 | fg_threshs = [] 564 | seed_threshs = [] 565 | for exp in params["data_dirs"]: 566 | mod_path = os.path.join(params["root"], exp) 567 | if "pannuke" in exp: 568 | with open( 569 | os.path.join(mod_path, "pannuke_test_param_dict.json"), "r" 570 | ) as js: 571 | dt = json.load(js) 572 | fg_threshs.append(dt[f"best_fg_{eval_metric}"]) 573 | seed_threshs.append(dt[f"best_seed_{eval_metric}"]) 574 | elif mit_eval: 575 | with open(os.path.join(mod_path, "liz_test_param_dict.json"), "r") as js: 576 | dt = json.load(js) 577 | fg_tmp = dt[f"best_fg_{eval_metric}"] 578 | seed_tmp = dt[f"best_seed_{eval_metric}"] 579 | with open(os.path.join(mod_path, "mit_test_param_dict.json"), "r") as js: 580 | dt = json.load(js) 581 | fg_tmp[-1] = dt[f"best_fg_{eval_metric}"][-1] 582 | seed_tmp[-1] = dt[f"best_seed_{eval_metric}"][-1] 583 | fg_threshs.append(fg_tmp) 584 | seed_threshs.append(seed_tmp) 585 | else: 586 | with open(os.path.join(mod_path, "param_dict.json"), "r") as js: 587 | dt = json.load(js) 588 | fg_threshs.append(dt[f"best_fg_{eval_metric}"]) 589 | seed_threshs.append(dt[f"best_seed_{eval_metric}"]) 590 | params["best_fg_thresh_cl"] = np.mean(fg_threshs, axis=0) 591 | params["best_seed_thresh_cl"] = np.mean(seed_threshs, axis=0) 592 | print(params["best_fg_thresh_cl"], params["best_seed_thresh_cl"]) 593 | 594 | return params 595 | 596 | 597 | def get_shapes(params, nclasses): 598 | padding_factor = params["overlap"] 599 | tile_size = params["tile_size"] 600 | ds_factor = 1 601 | if params["input_type"] in ["img", "npy"]: 602 | if params["input_type"] == "npy": 603 | dataset = NpyDataset( 604 | params["p"], 605 | tile_size, 606 | padding_factor=padding_factor, 607 | ratio_object_thresh=0.3, 608 | min_tiss=0.1, 609 | ) 610 | else: 611 | dataset = ImageDataset( 612 | params["p"], 613 | params["tile_size"], 614 | padding_factor=params["overlap"], 615 | ratio_object_thresh=0.3, 616 | min_tiss=0.1, 617 | ) 618 | params["orig_shape"] = dataset.orig_shape[:-1] 619 | ds_coord = np.array(dataset.idx).astype(int) 620 | shp = dataset.store.shape 621 | 622 | ccrop = int(dataset.padding_factor * dataset.crop_size_px) 623 | coord_adj = (dataset.crop_size_px - ccrop) // 2 624 | ds_coord[:, 1:] += coord_adj 625 | out_img_shape = (shp[0], 2, shp[1], shp[2]) 626 | out_cls_shape = (shp[0], nclasses, shp[1], shp[2]) 627 | else: 628 | level = 40 if params["pannuke"] else 20 629 | dataset = WholeSlideDataset( 630 | params["p"], 631 | crop_sizes_px=[tile_size], 632 | crop_magnifications=[level], 633 | padding_factor=padding_factor, 634 | ratio_object_thresh=0.0001, 635 | ) 636 | 637 | print("getting coords:") 638 | ds_coord = dataset.crop_metadatas[0][:, 2:4].copy() 639 | try: 640 | sl = dataset.s 641 | bounds_x = int(sl.properties["openslide.bounds-x"]) # 158208 642 | bounds_y = int(sl.properties["openslide.bounds-y"]) # 28672 643 | except KeyError: 644 | bounds_x = 0 645 | bounds_y = 0 646 | 647 | ds_coord -= np.array([bounds_x, bounds_y]) 648 | 649 | ccrop = int(tile_size * padding_factor) 650 | rel_res = np.isclose(dataset.mpp, LUT_MAGNIFICATION_MPP, rtol=0.2) 651 | if sum(rel_res) != 1: 652 | raise NotImplementedError( 653 | "Currently no support for images with this resolution. Check src.constants in LUT_MAGNIFICATION_MPP and LUT_MAGNIFICATION_X to add the resultion - downsampling pair" 654 | ) 655 | else: 656 | ds_factor = LUT_MAGNIFICATION_X[rel_res.argmax()] / level 657 | # if ds_factor < 1: 658 | # raise NotImplementedError( 659 | # "The specified model does not support images at this resolution. Consider supplying a higher resolution image" 660 | # ) 661 | ds_coord /= ds_factor 662 | 663 | ds_coord += (tile_size - ccrop) // 2 664 | ds_coord = ds_coord.astype(int) 665 | h, w = np.max(ds_coord, axis=0) 666 | out_img_shape = (2, int(h + ccrop), int(w + ccrop)) 667 | out_cls_shape = (nclasses, int(h + ccrop), int(w + ccrop)) 668 | params["ds_factor"] = ds_factor 669 | params["out_img_shape"] = out_img_shape 670 | params["out_cls_shape"] = out_cls_shape 671 | params["ccrop"] = ccrop 672 | 673 | return params, ds_coord 674 | 675 | 676 | def get_openslide_info(sl: openslide.OpenSlide): 677 | level_count = len(sl.level_downsamples) 678 | try: 679 | mpp_x = float(sl.properties[openslide.PROPERTY_NAME_MPP_X]) 680 | mpp_y = float(sl.properties[openslide.PROPERTY_NAME_MPP_Y]) 681 | except KeyError: 682 | print("'No resolution found in WSI metadata, using default .2425") 683 | mpp_x = 0.2425 684 | mpp_y = 0.2425 685 | try: 686 | bounds_x, bounds_y = ( 687 | int(sl.properties["openslide.bounds-x"]), 688 | int(sl.properties["openslide.bounds-y"]), 689 | ) 690 | except KeyError: 691 | bounds_x = 0 692 | bounds_y = 0 693 | level_downsamples = sl.level_downsamples 694 | 695 | level_mpp_x = [mpp_x * i for i in level_downsamples] 696 | level_mpp_y = [mpp_y * i for i in level_downsamples] 697 | return { 698 | "level_count": level_count, 699 | "mpp_x": mpp_x, 700 | "mpp_y": mpp_y, 701 | "bounds_x": bounds_x, 702 | "bounds_y": bounds_y, 703 | "level_downsamples": level_downsamples, 704 | "level_mpp_x": level_mpp_x, 705 | "level_mpp_y": level_mpp_y, 706 | } 707 | -------------------------------------------------------------------------------- /src/spatial_augmenter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from torchvision.transforms.transforms import GaussianBlur 5 | import math 6 | 7 | 8 | class SpatialAugmenter( 9 | torch.nn.Module, 10 | ): 11 | 12 | def __init__(self, params, interpolation="bilinear", padding_mode="zeros"): 13 | """ 14 | params= { 15 | 'mirror': {'prob': float [0,1], 'prob_x': float [0,1],'prob_y': float [0,1]}, 16 | 'translate': {'max_percent':float [0,1], 'prob': float [0,1]}, 17 | 'scale': {'min': float, 'max':float, 'prob': float [0,1]}, 18 | 'zoom': {'min': float, 'max':float, 'prob': float [0,1]}, 19 | 'rotate': {'rot90': bool, 'max_degree': int [0,360], 'prob': float [0,1]}, 20 | 'shear': {'max_percent': float [0,1], 'prob': float [0,1]}, 21 | 'elastic': {'alpha': list[float|int], 'sigma': float|int, 'prob': float [0,1]}} 22 | """ 23 | super(SpatialAugmenter, self).__init__() 24 | self.params = params 25 | self.mode = "forward" 26 | self.random_state = {} 27 | # fill dict so that augmentation functions can be tested 28 | for key in self.params.keys(): 29 | self.random_state[key] = {} 30 | self.interpolation = interpolation 31 | self.padding_mode = padding_mode 32 | 33 | def forward_transform(self, img, label=None, random_state=None): 34 | self.mode = "forward" 35 | self.device = img.device 36 | if random_state: 37 | self.random_state = random_state 38 | else: 39 | for key in self.params.keys(): 40 | self.random_state[key] = { 41 | "prob": bool(np.random.binomial(1, self.params[key]["prob"])) 42 | } 43 | for key in list(self.params.keys()): 44 | if self.random_state[key]["prob"]: 45 | # print('Do transform: ', key) 46 | func = getattr(self, key) 47 | img, label = func(img, label=label, random_state=random_state) 48 | if label is not None: 49 | return img, label 50 | else: 51 | return img 52 | 53 | def inverse_transform(self, img, label=None, random_state=None): 54 | self.mode = "inverse" 55 | self.device = img.device 56 | keylist = list(self.params.keys()) 57 | keylist.reverse() 58 | if random_state: 59 | self.random_state = random_state 60 | for key in keylist: 61 | if self.random_state[key]["prob"]: 62 | # print('Do inverse transform: ', key) 63 | func = getattr(self, key) 64 | img, label = func(img, label=label) 65 | if label is not None: 66 | return img, label 67 | else: 68 | return img 69 | 70 | def mirror(self, img, label, random_state=None): 71 | if self.mode == "forward" and not random_state: 72 | self.random_state["mirror"]["x"] = bool( 73 | np.random.binomial(1, self.params["mirror"]["prob_x"]) 74 | ) 75 | self.random_state["mirror"]["y"] = bool( 76 | np.random.binomial(1, self.params["mirror"]["prob_y"]) 77 | ) 78 | # 79 | x = self.random_state["mirror"]["x"] 80 | y = self.random_state["mirror"]["y"] 81 | if x: 82 | x = -1 83 | else: 84 | x = 1 85 | if y: 86 | y = -1 87 | else: 88 | y = 1 89 | theta = torch.tensor( 90 | [[[x, 0.0, 0.0], [0.0, y, 0.0]]], device=self.device, dtype=img.dtype 91 | ) 92 | grid = F.affine_grid( 93 | theta.repeat(img.size()[0], 1, 1), img.size(), align_corners=False 94 | ) 95 | if label is not None: 96 | return F.grid_sample( 97 | img, 98 | grid, 99 | mode=self.interpolation, 100 | padding_mode=self.padding_mode, 101 | align_corners=False, 102 | ), F.grid_sample( 103 | label, 104 | grid, 105 | mode="nearest", 106 | padding_mode=self.padding_mode, 107 | align_corners=False, 108 | ) 109 | else: 110 | return ( 111 | F.grid_sample( 112 | img, 113 | grid, 114 | mode=self.interpolation, 115 | padding_mode=self.padding_mode, 116 | align_corners=False, 117 | ), 118 | None, 119 | ) 120 | 121 | def translate(self, img, label, random_state=None): 122 | if self.mode == "forward" and not random_state: 123 | x = np.random.uniform( 124 | -self.params["translate"]["max_percent"], 125 | self.params["translate"]["max_percent"], 126 | ) 127 | y = np.random.uniform( 128 | -self.params["translate"]["max_percent"], 129 | self.params["translate"]["max_percent"], 130 | ) 131 | self.random_state["translate"]["x"] = x 132 | self.random_state["translate"]["y"] = y 133 | elif self.mode == "inverse": 134 | x = -1 * self.random_state["translate"]["x"] 135 | y = -1 * self.random_state["translate"]["y"] 136 | else: 137 | x = self.random_state["translate"]["x"] 138 | y = self.random_state["translate"]["y"] 139 | theta = torch.tensor( 140 | [[[1.0, 0.0, x], [0.0, 1.0, y]]], device=self.device, dtype=img.dtype 141 | ) 142 | grid = F.affine_grid( 143 | theta.repeat(img.size()[0], 1, 1), img.size(), align_corners=False 144 | ) 145 | if label is not None: 146 | return F.grid_sample( 147 | img, 148 | grid, 149 | mode=self.interpolation, 150 | padding_mode=self.padding_mode, 151 | align_corners=False, 152 | ), F.grid_sample( 153 | label, 154 | grid, 155 | mode="nearest", 156 | padding_mode=self.padding_mode, 157 | align_corners=False, 158 | ) 159 | else: 160 | return ( 161 | F.grid_sample( 162 | img, 163 | grid, 164 | mode=self.interpolation, 165 | padding_mode=self.padding_mode, 166 | align_corners=False, 167 | ), 168 | None, 169 | ) 170 | 171 | def zoom(self, img, label, random_state=None): 172 | if self.mode == "forward" and not random_state: 173 | zoom_factor = np.random.uniform( 174 | self.params["scale"]["min"], self.params["scale"]["max"] 175 | ) 176 | self.random_state["zoom"]["factor"] = zoom_factor 177 | elif self.mode == "inverse": 178 | zoom_factor = 1 / self.random_state["zoom"]["factor"] 179 | else: 180 | zoom_factor = self.random_state["zoom"]["factor"] 181 | theta = torch.tensor( 182 | [[[zoom_factor, 0.0, 0.0], [0.0, zoom_factor, 0.0]]], 183 | device=self.device, 184 | dtype=img.dtype, 185 | ) 186 | grid = F.affine_grid( 187 | theta.repeat(img.size()[0], 1, 1), img.size(), align_corners=False 188 | ) 189 | if label is not None: 190 | return F.grid_sample( 191 | img, 192 | grid, 193 | mode=self.interpolation, 194 | padding_mode=self.padding_mode, 195 | align_corners=False, 196 | ), F.grid_sample( 197 | label, 198 | grid, 199 | mode="nearest", 200 | padding_mode=self.padding_mode, 201 | align_corners=False, 202 | ) 203 | else: 204 | return ( 205 | F.grid_sample( 206 | img, 207 | grid, 208 | mode=self.interpolation, 209 | padding_mode=self.padding_mode, 210 | align_corners=False, 211 | ), 212 | None, 213 | ) 214 | 215 | def scale(self, img, label, random_state=None): 216 | if self.mode == "forward" and not random_state: 217 | x = np.random.uniform( 218 | self.params["scale"]["min"], self.params["scale"]["max"] 219 | ) 220 | y = np.random.uniform( 221 | self.params["scale"]["min"], self.params["scale"]["max"] 222 | ) 223 | self.random_state["scale"]["x"] = x 224 | self.random_state["scale"]["y"] = y 225 | elif self.mode == "inverse": 226 | x = 1 / self.random_state["scale"]["x"] 227 | y = 1 / self.random_state["scale"]["y"] 228 | else: 229 | x = self.random_state["scale"]["x"] 230 | y = self.random_state["scale"]["y"] 231 | theta = torch.tensor( 232 | [[[x, 0.0, 0.0], [0.0, y, 0.0]]], device=self.device, dtype=img.dtype 233 | ) 234 | grid = F.affine_grid( 235 | theta.repeat(img.size()[0], 1, 1), img.size(), align_corners=False 236 | ) 237 | if label is not None: 238 | return F.grid_sample( 239 | img, 240 | grid, 241 | mode=self.interpolation, 242 | padding_mode=self.padding_mode, 243 | align_corners=False, 244 | ), F.grid_sample( 245 | label, 246 | grid, 247 | mode="nearest", 248 | padding_mode=self.padding_mode, 249 | align_corners=False, 250 | ) 251 | else: 252 | return ( 253 | F.grid_sample( 254 | img, 255 | grid, 256 | mode=self.interpolation, 257 | padding_mode=self.padding_mode, 258 | align_corners=False, 259 | ), 260 | None, 261 | ) 262 | 263 | def rotate(self, img, label, random_state=None): 264 | if self.mode == "forward" and not random_state: 265 | if ( 266 | "rot90" in self.params["rotate"].keys() 267 | and self.params["rotate"]["rot90"] 268 | ): 269 | degree = np.random.choice([-270, -180, -90, 90, 180, 270]) 270 | else: 271 | degree = np.random.uniform( 272 | -self.params["rotate"]["max_degree"], 273 | self.params["rotate"]["max_degree"], 274 | ) 275 | self.random_state["rotate"]["degree"] = degree 276 | elif self.mode == "inverse": 277 | degree = -1 * self.random_state["rotate"]["degree"] 278 | else: 279 | degree = self.random_state["rotate"]["degree"] 280 | rad = math.radians(degree) 281 | theta = torch.tensor( 282 | [ 283 | [ 284 | [math.cos(rad), -math.sin(rad), 0.0], 285 | [math.sin(rad), math.cos(rad), 0.0], 286 | ] 287 | ], 288 | device=self.device, 289 | dtype=img.dtype, 290 | ) 291 | grid = F.affine_grid( 292 | theta.repeat(img.size()[0], 1, 1), img.size(), align_corners=False 293 | ) 294 | if label is not None: 295 | return F.grid_sample( 296 | img, 297 | grid, 298 | mode=self.interpolation, 299 | padding_mode=self.padding_mode, 300 | align_corners=False, 301 | ), F.grid_sample( 302 | label, 303 | grid, 304 | mode="nearest", 305 | padding_mode=self.padding_mode, 306 | align_corners=False, 307 | ) 308 | else: 309 | return ( 310 | F.grid_sample( 311 | img, 312 | grid, 313 | mode=self.interpolation, 314 | padding_mode=self.padding_mode, 315 | align_corners=False, 316 | ), 317 | None, 318 | ) 319 | 320 | def shear(self, img, label, random_state=None): 321 | if self.mode == "forward" and not random_state: 322 | x = np.random.uniform( 323 | -self.params["shear"]["max_percent"], 324 | self.params["shear"]["max_percent"], 325 | ) 326 | y = np.random.uniform( 327 | -self.params["shear"]["max_percent"], 328 | self.params["shear"]["max_percent"], 329 | ) 330 | self.random_state["shear"]["x"] = x 331 | self.random_state["shear"]["y"] = y 332 | elif self.mode == "inverse": 333 | x = -self.random_state["shear"]["x"] 334 | y = -self.random_state["shear"]["y"] 335 | else: 336 | x = self.random_state["shear"]["x"] 337 | y = self.random_state["shear"]["y"] 338 | theta = torch.tensor( 339 | [[[1.0, x, 0.0], [y, 1.0, 0.0]]], device=self.device, dtype=img.dtype 340 | ) 341 | grid = F.affine_grid( 342 | theta.repeat(img.size()[0], 1, 1), img.size(), align_corners=False 343 | ) 344 | if label is not None: 345 | return F.grid_sample( 346 | img, 347 | grid, 348 | mode=self.interpolation, 349 | padding_mode=self.padding_mode, 350 | align_corners=False, 351 | ), F.grid_sample( 352 | label, 353 | grid, 354 | mode="nearest", 355 | padding_mode=self.padding_mode, 356 | align_corners=False, 357 | ) 358 | else: 359 | return ( 360 | F.grid_sample( 361 | img, 362 | grid, 363 | mode=self.interpolation, 364 | padding_mode=self.padding_mode, 365 | align_corners=False, 366 | ), 367 | None, 368 | ) 369 | 370 | def identity_grid(self, img): 371 | theta = torch.tensor( 372 | [[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]], device=self.device, dtype=img.dtype 373 | ) 374 | return F.affine_grid( 375 | theta.repeat(img.size()[0], 1, 1), img.size(), align_corners=False 376 | ) 377 | 378 | def elastic(self, img, label, random_state=None): 379 | if self.mode == "forward" and not random_state: 380 | displacement = self.create_elastic_transformation( 381 | shape=list(img.shape[-2:]), 382 | alpha=self.params["elastic"]["alpha"], 383 | sigma=self.params["elastic"]["sigma"], 384 | ) 385 | self.random_state["elastic"]["displacement"] = displacement 386 | elif self.mode == "inverse": 387 | displacement = -1 * self.random_state["elastic"]["displacement"] 388 | else: 389 | displacement = self.random_state["elastic"]["displacement"] 390 | identity_grid = self.identity_grid(img) 391 | grid = identity_grid + displacement 392 | if label is not None: 393 | return F.grid_sample( 394 | img, 395 | grid, 396 | mode=self.interpolation, 397 | padding_mode=self.padding_mode, 398 | align_corners=False, 399 | ), F.grid_sample( 400 | label, 401 | grid, 402 | mode="nearest", 403 | padding_mode=self.padding_mode, 404 | align_corners=False, 405 | ) 406 | else: 407 | return ( 408 | F.grid_sample( 409 | img, 410 | grid, 411 | mode=self.interpolation, 412 | padding_mode=self.padding_mode, 413 | align_corners=False, 414 | ), 415 | None, 416 | ) 417 | 418 | def create_elastic_transformation(self, shape, alpha=[80, 80], sigma=8): 419 | 420 | blur = GaussianBlur(kernel_size=int(8 * sigma + 1), sigma=sigma) 421 | dx = ( 422 | blur( 423 | torch.rand(*shape, device=self.device).unsqueeze(0).unsqueeze(0) * 2 - 1 424 | ) 425 | * alpha[0] 426 | / shape[0] 427 | ) 428 | dy = ( 429 | blur( 430 | torch.rand(*shape, device=self.device).unsqueeze(0).unsqueeze(0) * 2 - 1 431 | ) 432 | * alpha[1] 433 | / shape[1] 434 | ) 435 | 436 | displacement = torch.concat([dx, dy], 1).permute([0, 2, 3, 1]) 437 | return displacement 438 | -------------------------------------------------------------------------------- /src/viz_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import geojson 4 | import openslide 5 | import cv2 6 | from skimage.measure import regionprops 7 | from src.constants import ( 8 | CLASS_LABELS_LIZARD, 9 | CLASS_LABELS_PANNUKE, 10 | COLORS_LIZARD, 11 | COLORS_PANNUKE, 12 | CONIC_MPP, 13 | PANNUKE_MPP, 14 | ) 15 | 16 | 17 | def create_geojson(polygons, classids, lookup, params): 18 | features = [] 19 | colors = COLORS_PANNUKE if params["pannuke"] else COLORS_LIZARD 20 | if isinstance(classids[0], (list, tuple)): 21 | classids = [cid[0] for cid in classids] 22 | for i, (poly, cid) in enumerate(zip(polygons, classids)): 23 | poly = np.array(poly) 24 | poly = poly[:, [1, 0]] * params["ds_factor"] 25 | poly = poly.tolist() 26 | 27 | geom = geojson.Polygon([poly], precision=2) 28 | if not geom.is_valid: 29 | print(f"Polygon {i}:{[poly]} is not valid, skipping...") 30 | continue 31 | # poly.append(poly[0]) 32 | measurements = {classifications: 0 for classifications in lookup.values()} 33 | measurements[lookup[cid]] = 1 34 | feature = geojson.Feature( 35 | geometry=geojson.Polygon([poly], precision=2), 36 | properties={ 37 | "Name": f"Nuc {i}", 38 | "Type": "Polygon", 39 | "color": colors[cid - 1], 40 | "classification": lookup[cid], 41 | "measurements": measurements, 42 | "objectType": "tile" 43 | }, 44 | ) 45 | features.append(feature) 46 | feature_collection = geojson.FeatureCollection(features) 47 | with open(params["output_dir"] + "/poly.geojson", "w") as outfile: 48 | geojson.dump(feature_collection, outfile) 49 | 50 | 51 | def create_tsvs(pcls_out, params): 52 | pred_keys = CLASS_LABELS_PANNUKE if params["pannuke"] else CLASS_LABELS_LIZARD 53 | 54 | coord_array = np.array([[i[0], *i[1]] for i in pcls_out.values()]) 55 | classes = list(pred_keys.keys()) 56 | colors = ["-256", "-65536"] 57 | i = 0 58 | for pt in classes: 59 | file = os.path.join(params["output_dir"], "pred_" + pt + ".tsv") 60 | textfile = open(file, "w") 61 | 62 | textfile.write("x" + "\t" + "y" + "\t" + "name" + "\t" + "color" + "\n") 63 | textfile.writelines( 64 | [ 65 | str(element[2] * params["ds_factor"]) 66 | + "\t" 67 | + str(element[1] * params["ds_factor"]) 68 | + "\t" 69 | + pt 70 | + "\t" 71 | + colors[0] 72 | + "\n" 73 | for element in coord_array[coord_array[:, 0] == pred_keys[pt]] 74 | ] 75 | ) 76 | 77 | textfile.close() 78 | i += 1 79 | 80 | 81 | def cont(x, offset=None): 82 | _, im, bb = x 83 | im = np.pad(im.astype(np.uint8), 1, mode="constant", constant_values=0) 84 | 85 | # initial contour finding 86 | cont = cv2.findContours( 87 | im, 88 | mode=cv2.RETR_EXTERNAL, 89 | method=cv2.CHAIN_APPROX_TC89_KCOS, 90 | )[0][0].reshape(-1, 2)[:, [1, 0]] 91 | # since opencv does not do "pixel" contours, we artificially do this for single pixel detections (if they exist) 92 | if cont.shape[0] <= 1: 93 | im = cv2.resize(im, None, fx=2.0, fy=2.0) 94 | cont = ( 95 | cv2.findContours( 96 | im, 97 | mode=cv2.RETR_EXTERNAL, 98 | method=cv2.CHAIN_APPROX_TC89_KCOS, 99 | )[0][0].reshape(-1, 2)[:, [1, 0]] 100 | / 2.0 101 | ) 102 | if offset is not None: 103 | cont = (cont + offset + bb[0:2] - 1).tolist() 104 | else: 105 | cont = (cont + bb[0:2] - 1).tolist() 106 | # close polygon: 107 | if cont[0] != cont[-1]: 108 | cont.append(cont[0]) 109 | return cont 110 | 111 | 112 | def create_polygon_output(pinst, pcls_out, params): 113 | # polygon output is slow and unwieldy, TODO 114 | pred_keys = CLASS_LABELS_PANNUKE if params["pannuke"] else CLASS_LABELS_LIZARD 115 | # whole slide regionprops could be avoided to speed up this process... 116 | print("getting all detections...") 117 | props = [(p.label, p.image, p.bbox) for p in regionprops(np.asarray(pinst))] 118 | class_labels = [pcls_out[str(p[0])] for p in props] 119 | print("generating contours...") 120 | res_poly = [cont(i) for i in props] 121 | print("creating output...") 122 | create_geojson( 123 | res_poly, 124 | class_labels, 125 | dict((v, k) for k, v in pred_keys.items()), 126 | params, 127 | ) 128 | --------------------------------------------------------------------------------