├── .github
└── workflows
│ └── build.yml
├── .gitignore
├── LICENSE
├── README.md
├── environment.yml
├── main.py
├── requirements.txt
├── sample_analysis.ipynb
├── setup.py
└── src
├── augmentations.py
├── constants.py
├── data_utils.py
├── inference.py
├── multi_head_unet.py
├── post_process.py
├── post_process_utils.py
├── spatial_augmenter.py
└── viz_utils.py
/.github/workflows/build.yml:
--------------------------------------------------------------------------------
1 | name: Build and Check
2 | # This specifies when the workflow should run. It's set to trigger on any push
3 | # and any pull request to the main branch.
4 | on:
5 | push:
6 | branches: [main]
7 | pull_request:
8 | branches: [main]
9 |
10 | # This ensures that only the latest run for a given branch or workflow is active,
11 | # canceling any in-progress runs if a new one is triggered.
12 | concurrency:
13 | group: ${{ github.workflow }}-${{ github.ref }}
14 | cancel-in-progress: true
15 |
16 | # Defines the job called 'build'.
17 | jobs:
18 | build:
19 | # Specifies the type of runner that the job will execute on.
20 | runs-on: ubuntu-latest
21 |
22 | # A matrix to run jobs across multiple versions of Python.
23 | strategy:
24 | matrix:
25 | python-version: [3.8, 3.9, 3.10.13, 3.11, 3.12.2]
26 |
27 | # Steps define a sequence of tasks that will be executed as part of the job.
28 | steps:
29 | # Checks-out repository under $GITHUB_WORKSPACE, so the workflow can access it.
30 | - uses: actions/checkout@v3
31 |
32 | # Sets up a Python environment with the version specified in the matrix,
33 | # allowing the workflow to execute actions with Python.
34 | - name: Set up Python ${{ matrix.python-version }}
35 | uses: actions/setup-python@v4
36 | with:
37 | python-version: ${{ matrix.python-version }}
38 |
39 | # Installs the necessary dependencies to build and check the Python package.
40 | # Includes pip, wheel, twine, and the build module.
41 | - name: Install dependencies
42 | run: python -m pip install --upgrade pip wheel twine build
43 |
44 | # Builds the package using the Python build module, which creates both source
45 | # distribution and wheel distribution files in the dist/ directory.
46 | - name: Build package
47 | run: python -m build
48 |
49 | # Uses Twine to check the built packages (.whl files) in the dist/ directory,
50 | # ensuring compliance with PyPI standards.
51 | - name: Check package
52 | run: twine check --strict dist/*.whl
53 |
54 | # Uploads the built wheel files as artifacts, which can be downloaded
55 | # after the workflow run completes.
56 | - name: Upload artifacts
57 | uses: actions/upload-artifact@v2
58 | with:
59 | name: wheels
60 | path: dist/*.whl
61 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 | convnextv2_tiny_focal_fulldata_0/*
162 | *.out
163 | logs/*
164 | convnextv2_large_focal_fulldata_0/*
165 | tmp.ipynb
166 | convnextv2_base_focal_fulldata_0/*
167 | pannuke_convnextv2_tiny_1/*
168 | get_wsi_sizes.ipynb
169 | main_debug.py
170 | figures/*
171 | *.txt
172 | model_weights/*
173 | cpu_pp_pannuke.sh
174 | cpu_pp.sh
175 | debug_inference.sh
176 | dummy.sh
177 | run_eos_inference.sh
178 | run_inference_container_2.sh
179 | run_inference_container_pannuke.sh
180 | run_inference_container.sh
181 | run_wsi_validation_set_lizard_testo.sh
182 | run_wsi_validation_set_lizard.sh
183 | run_wsi_validation_set_pannuke.sh
184 | get_sizes.py
185 | image_loader.ipynb
186 | lizard_convnextv2_base.zip
187 | lizard_convnextv2_large.zip
188 | lizard_convnextv2_tiny.zip
189 | pannuke_convnextv2_tiny_1.zip
190 | pannuke_convnextv2_tiny_2.zip
191 | pannuke_convnextv2_tiny_3.zip
192 | run_inference_container_gpu.sh
193 | run_inference_container_gpu2.sh
194 | sizes.csv
195 | lizard_convnextv2_base/*
196 | lizard_convnextv2_large/*
197 | lizard_convnextv2_tiny/*
198 | pannuke_convnextv2_tiny_2/*
199 | pannuke_convnextv2_tiny_3/*
200 | sample/*
201 | testo.py
202 | test_out.csv
203 | cpu_pp_pannuke_debug.sh
204 | run_inference_container_pannuke_debug.sh
205 | run_mit_inference.sh
206 | sample_cls.bmp
207 | sample_cls.jpg
208 | sample_he.jpg
209 | testo.sh
210 | debug.py
211 | direct_cpu_pp.sh
212 | run_inference_container_jc.sh
213 | sample_analysis.ipynb
214 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 |
635 | Copyright (C)
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | Copyright (C)
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
675 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # HoVer-NeXt Inference
2 | HoVer-NeXt is a fast and efficient nuclei segmentation and classification pipeline.
3 |
4 | Supported are a variety of data formats, including all OpenSlide supported datatypes, `.npy` numpy array dumps, and common image formats such as JPEG and PNG.
5 | If you are having trouble with using this repository, please create an issue and we will be happy to help!
6 |
7 | For training code, please check the [hover-next training repository](https://github.com/digitalpathologybern/hover_next_train)
8 |
9 | Find the Publication here: [https://openreview.net/pdf?id=3vmB43oqIO](https://openreview.net/pdf?id=3vmB43oqIO)
10 |
11 | ## Setup
12 |
13 | Environments for train and inference are the same so if you already have set the environment up for training, you can use it for inference as well.
14 |
15 | Otherwise:
16 |
17 | ```bash
18 | conda env create -f environment.yml
19 | conda activate hovernext
20 | pip install torch==2.1.1 torchvision==0.16.1 --index-url https://download.pytorch.org/whl/cu118
21 | ```
22 |
23 | or use predefined [docker/singularity container](#docker-and-apptainersingularity-container)
24 |
25 | ## Model Weights
26 |
27 | Weights are hosted on [Zenodo](https://zenodo.org/records/10635618)
28 | By specifying one of the ID's listed, weights are **automatically** downloaded and loaded.
29 |
30 | | Dataset | ID | Weights |
31 | |--------------|--------|-----|
32 | | Lizard-Mitosis | "lizard_convnextv2_large" | [Large](https://zenodo.org/records/10635618/files/lizard_convnextv2_large.zip?download=1) |
33 | | | "lizard_convnextv2_base" |[Base](https://zenodo.org/records/10635618/files/lizard_convnextv2_base.zip?download=1) |
34 | | | "lizard_convnextv2_tiny" |[Tiny](https://zenodo.org/records/10635618/files/lizard_convnextv2_tiny.zip?download=1) |
35 | | PanNuke | "pannuke_convnextv2_tiny_1" | [Tiny Fold 1](https://zenodo.org/records/10635618/files/pannuke_convnextv2_tiny_1.zip?download=1) |
36 | | | "pannuke_convnextv2_tiny_2" | [Tiny Fold 2](https://zenodo.org/records/10635618/files/pannuke_convnextv2_tiny_2.zip?download=1) |
37 | | | "pannuke_convnextv2_tiny_3" | [Tiny Fold 3](https://zenodo.org/records/10635618/files/pannuke_convnextv2_tiny_3.zip?download=1) |
38 |
39 | If you are manually downloading weights, unzip them in the directory, such that the folder (e.g. ```lizard_convnextv2_large```) sits in the same directory as ```main.py```.
40 |
41 | ## WSI Inference
42 |
43 | This pipeline uses OpenSlide to read images, and therefore supports all formats which are supported by OpenSlide.
44 | If you want to run this pipeline on custom ome.tif files, ensure that the necessary metadata such as resolution, downsampling and dimensions are available.
45 | Additionally, czi is is supported via pylibCZIrw.
46 | Before running a slide, choose [appropriate parameters for your machine](#optimizing-inference-for-your-machine)
47 |
48 | To run a single slide:
49 |
50 | ```bash
51 | python3 main.py \
52 | --input "/path-to-wsi/wsi.svs" \
53 | --output_root "results/" \
54 | --cp "lizard_convnextv2_large" \
55 | --tta 4 \
56 | --inf_workers 16 \
57 | --pp_tiling 10 \
58 | --pp_workers 16
59 | ```
60 |
61 | To run multiple slides, specify a glob pattern such as `"/path-to-folder/*.mrxs"` or provide a list of paths as a `.txt` file.
62 |
63 | ### Slurm
64 |
65 | if you are running on a slurm cluster you might consider separating pre and post-processing to improve GPU utilization.
66 | Use the `--only_inference` parameter and submit another job on with the same parameters, but removing the `--only_inference`.
67 |
68 | ## NPY / Image inference
69 |
70 | NPY and image inference works the same as WSI inference, however output files are only a ZARR array.
71 |
72 | ```bash
73 | python3 main.py \
74 | --input "/path-to-file/file.npy" \
75 | --output_root "/results/" \
76 | --cp "lizard_convnextv2_large" \
77 | --tta 4 \
78 | --inf_workers 16 \
79 | --pp_tiling 10 \
80 | --pp_workers 16
81 | ```
82 |
83 | Support for other datatypes are easy to implement. Check the NPYDataloader for reference.
84 |
85 | ## Optimizing inference for your machine:
86 |
87 | 1. WSI is on the machine or on a fast access network location
88 | 2. If you have multiple machines, e.g. CPU-only machines, you can move post-processing to that machine
89 | 3. '--tta 4' yields robust results with very high speed
90 | 4. '--inf_workers' should be set to the number of available cores
91 | 5. '--pp_workers' should be set to number of available cores -1, with '--pp_tiling' set to a low number where the machine does not run OOM. E.g. on a 16-Core machine, '--pp_workers 16 --pp_tiling 8 is good. If you are running out of memory, increase --pp_tiling.
92 |
93 | ## Using the output files for downstream analysis:
94 |
95 | By default, the pipeline produces an instance-map, a class-lookup with centroids and a number of .tsv files to load in QuPath.
96 | sample_analysis.ipynb shows exemplarily how to use the files.
97 |
98 | ## Docker and Apptainer/Singularity Container:
99 |
100 | Download the singularity image from [Zenodo](https://zenodo.org/records/10649470/files/hover_next.sif)
101 |
102 | ```bash
103 | # don't forget to mount your local directory
104 | export APPTAINER_BINDPATH="/storage"
105 | apptainer exec --nv /path-to-container/hover_next.sif \
106 | python3 /path-to-repo/main.py \
107 | --input "/path-to-wsi/*.svs" \
108 | --output_root "results/" \
109 | --cp "lizard_convnextv2_large" \
110 | --tta 4
111 | ```
112 | # License
113 |
114 | This repository is licensed under GNU General Public License v3.0 (See License Info).
115 | If you are intending to use this repository for commercial usecases, please check the licenses of all python packages referenced in the Setup section / described in the requirements.txt and environment.yml.
116 |
117 | # Citation
118 |
119 | If you are using this code, please cite:
120 | ```
121 | @inproceedings{baumann2024hover,
122 | title={HoVer-NeXt: A Fast Nuclei Segmentation and Classification Pipeline for Next Generation Histopathology},
123 | author={Baumann, Elias and Dislich, Bastian and Rumberger, Josef Lorenz and Nagtegaal, Iris D and Martinez, Maria Rodriguez and Zlobec, Inti},
124 | booktitle={Medical Imaging with Deep Learning},
125 | year={2024}
126 | }
127 | ```
128 | and
129 | ```
130 | @INPROCEEDINGS{rumberger2022panoptic,
131 | author={Rumberger, Josef Lorenz and Baumann, Elias and Hirsch, Peter and Janowczyk, Andrew and Zlobec, Inti and Kainmueller, Dagmar},
132 | booktitle={2022 IEEE International Symposium on Biomedical Imaging Challenges (ISBIC)},
133 | title={Panoptic segmentation with highly imbalanced semantic labels},
134 | year={2022},
135 | pages={1-4},
136 | doi={10.1109/ISBIC56247.2022.9854551}}
137 | ```
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: hovernext
2 | channels:
3 | - conda-forge
4 | - defaults
5 | dependencies:
6 | - python=3.11.5
7 | - openslide
8 | - pip:
9 | - -r file:requirements.txt
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import sys
4 | from timeit import default_timer as timer
5 | from datetime import timedelta
6 | import torch
7 | from glob import glob
8 | from src.inference import inference_main, get_inference_setup
9 | from src.post_process import post_process_main
10 | from src.data_utils import copy_img
11 |
12 | torch.backends.cudnn.benchmark = True
13 | print(torch.cuda.device_count(), " cuda devices")
14 |
15 |
16 | def prepare_input(params):
17 | """
18 | Check if input is a text file, glob pattern, or a directory, and return a list of input files
19 |
20 | Parameters
21 | ----------
22 | params: dict
23 | input parameters from argparse
24 |
25 | """
26 | print("input specified: ", params["input"])
27 | if params["input"].endswith(".txt"):
28 | if os.path.exists(params["input"]):
29 | with open(params["input"], "r") as f:
30 | input_list = f.read().splitlines()
31 | else:
32 | raise FileNotFoundError("input file not found")
33 | else:
34 | input_list = sorted(glob(params["input"].rstrip()))
35 | return input_list
36 |
37 |
38 | def get_input_type(params):
39 | """
40 | Check if input is an image, numpy array, or whole slide image, and return the input type
41 | If you are trying to process other images that are supported by opencv (e.g. tiff), you can add the extension to the list
42 |
43 | Parameters
44 | ----------
45 | params: dict
46 | input parameters from argparse
47 | """
48 | params["ext"] = os.path.splitext(params["p"])[-1]
49 | if params["ext"] == ".npy":
50 | params["input_type"] = "npy"
51 | elif params["ext"] in [".jpg", ".png", ".jpeg", ".bmp"]:
52 | params["input_type"] = "img"
53 | else:
54 | params["input_type"] = "wsi"
55 | return params
56 |
57 |
58 | def main(params: dict):
59 | """
60 | Start nuclei segmentation and classification pipeline using specified parameters from argparse
61 |
62 | Parameters
63 | ----------
64 | params: dict
65 | input parameters from argparse
66 | """
67 |
68 | if params["metric"] not in ["mpq", "f1", "pannuke"]:
69 | params["metric"] = "f1"
70 | print("invalid metric, falling back to f1")
71 | else:
72 | print("optimizing postprocessing for: ", params["metric"])
73 |
74 | params["root"] = os.path.dirname(__file__)
75 | params["data_dirs"] = [
76 | os.path.join(params["root"], c) for c in params["cp"].split(",")
77 | ]
78 |
79 | print("saving results to:", params["output_root"])
80 | print("loading model from:", params["data_dirs"])
81 |
82 | # Run per tile inference and store results
83 | params, models, augmenter, color_aug_fn = get_inference_setup(params)
84 |
85 | input_list = prepare_input(params)
86 | print("Running inference on", len(input_list), "file(s)")
87 |
88 | for inp in input_list:
89 | start_time = timer()
90 | params["p"] = inp.rstrip()
91 | params = get_input_type(params)
92 | print("Processing ", params["p"])
93 | if params["cache"] is not None:
94 | print("Caching input at:")
95 | params["p"] = copy_img(params["p"], params["cache"])
96 | print(params["p"])
97 |
98 | params, z = inference_main(params, models, augmenter, color_aug_fn)
99 | print(
100 | "::: finished or skipped inference after",
101 | timedelta(seconds=timer() - start_time),
102 | )
103 | process_timer = timer()
104 | if params["only_inference"]:
105 | try:
106 | z[0].store.close()
107 | z[1].store.close()
108 | except TypeError:
109 | # if z is None, z cannot be indexed -> throws a TypeError
110 | pass
111 | print("Exiting after inference")
112 | sys.exit(2)
113 | # Stitch tiles together and postprocess to get instance segmentation
114 | if not os.path.exists(os.path.join(params["output_dir"], "pinst_pp.zip")):
115 | print("running post-processing")
116 |
117 | z_pp = post_process_main(
118 | params,
119 | z,
120 | )
121 | if not params["keep_raw"]:
122 | try:
123 | os.remove(params["model_out_p"] + "_inst.zip")
124 | os.remove(params["model_out_p"] + "_cls.zip")
125 | except FileNotFoundError:
126 | pass
127 | else:
128 | z_pp = None
129 | print(
130 | "::: postprocessing took",
131 | timedelta(seconds=timer() - process_timer),
132 | "total elapsed time",
133 | timedelta(seconds=timer() - start_time),
134 | )
135 | if z_pp is not None:
136 | z_pp.store.close()
137 | print("done")
138 | sys.exit(0)
139 |
140 |
141 | if __name__ == "__main__":
142 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
143 | print(device)
144 |
145 | parser = argparse.ArgumentParser()
146 | parser.add_argument(
147 | "--input",
148 | type=str,
149 | default=None,
150 | help="path to wsi, glob pattern or text file containing paths",
151 | required=True,
152 | )
153 | parser.add_argument(
154 | "--output_root", type=str, default=None, help="output directory", required=True
155 | )
156 | parser.add_argument(
157 | "--cp",
158 | type=str,
159 | default=None,
160 | help="comma separated list of checkpoint folders to consider",
161 | )
162 | parser.add_argument(
163 | "--only_inference",
164 | action="store_true",
165 | help="split inference to gpu and cpu node/ only run inference",
166 | )
167 | parser.add_argument(
168 | "--metric", type=str, default="f1", help="metric to optimize for pp"
169 | )
170 | parser.add_argument("--batch_size", type=int, default=64, help="batch size")
171 | parser.add_argument(
172 | "--tta",
173 | type=int,
174 | default=4,
175 | help="test time augmentations, number of views (4= results from 4 different augmentations are averaged for each sample)",
176 | )
177 | parser.add_argument(
178 | "--save_polygon",
179 | action="store_true",
180 | help="save output as polygons to load in qupath",
181 | )
182 | parser.add_argument(
183 | "--tile_size",
184 | type=int,
185 | default=256,
186 | help="tile size, models are trained on 256x256",
187 | )
188 | parser.add_argument(
189 | "--overlap",
190 | type=float,
191 | default=0.96875,
192 | help="overlap between tiles, at 0.5mpp, 0.96875 is best, for 0.25mpp use 0.9375 for better results",
193 | )
194 | parser.add_argument(
195 | "--inf_workers",
196 | type=int,
197 | default=4,
198 | help="number of workers for inference dataloader, maximally set this to number of cores",
199 | )
200 | parser.add_argument(
201 | "--inf_writers",
202 | type=int,
203 | default=2,
204 | help="number of writers for inference dataloader, default 2 should be sufficient"
205 | + ", \ tune based on core availability and delay between final inference step and inference finalization",
206 | )
207 | parser.add_argument(
208 | "--pp_tiling",
209 | type=int,
210 | default=8,
211 | help="tiling factor for post processing, number of tiles per dimension, 8 = 64 tiles",
212 | )
213 | parser.add_argument(
214 | "--pp_overlap",
215 | type=int,
216 | default=256,
217 | help="overlap for postprocessing tiles, put to around tile_size",
218 | )
219 | parser.add_argument(
220 | "--pp_workers",
221 | type=int,
222 | default=16,
223 | help="number of workers for postprocessing, maximally set this to number of cores",
224 | )
225 | parser.add_argument(
226 | "--keep_raw",
227 | action="store_true",
228 | help="keep raw predictions (can be large files for particularly for pannuke)",
229 | )
230 | parser.add_argument("--cache", type=str, default=None, help="cache path")
231 | params = vars(parser.parse_args())
232 | main(params)
233 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | openslide-python
2 | scikit-learn
3 | scikit-image
4 | scipy
5 | opencv-python
6 | pandas
7 | tqdm
8 | itk
9 | matplotlib
10 | mahotas
11 | pandas
12 | jupyterlab
13 | zarr
14 | tifffile
15 | h5py
16 | segmentation-models-pytorch
17 | networkx==2.8.7
18 | libpysal
19 | Pillow
20 | shapely
21 | staintools
22 | albumentations
23 | spams-bin
24 | toml
25 | numcodecs
26 | imagecodecs
27 | timm==0.9.6
28 | geojson
29 | pylibCZIrw
--------------------------------------------------------------------------------
/sample_analysis.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import zarr\n",
10 | "import numpy as np\n",
11 | "import json\n",
12 | "from main import main\n",
13 | "\n",
14 | "'''\n",
15 | "Run inference on a small sample image from TCGA\n",
16 | "'''\n",
17 | "params = {\n",
18 | "\n",
19 | "}\n",
20 | "\n",
21 | "main(params)"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": null,
27 | "metadata": {},
28 | "outputs": [],
29 | "source": [
30 | "'''\n",
31 | "Instance map: 2D full-size matrix where each pixels value corresponds to the associated instance (value>0) or background (value=0)\n",
32 | "'''\n",
33 | "\n",
34 | "# open: file-like interaction with zarr-array\n",
35 | "instance_map = zarr.open(\"pinst_pp.zip\", mode=\"r\")\n",
36 | "# selecting a ROI will yield a numpy array\n",
37 | "roi = instance_map[10000:20000,10000:20000]\n",
38 | "# or with [:] to load the entire array\n",
39 | "full_instance_map = instance_map[:]\n",
40 | "# alternatively, use load, which will directly create a numpy array:\n",
41 | "full_instance_map = zarr.load(\"pinst_pp.zip\") \n",
42 | "\n",
43 | "'''\n",
44 | "Class dictionary: Lookup for the instance map, also contains centroid coordinates. If only centroid coordinates are of interest, you can skip loading the instance map.\n",
45 | "'''\n",
46 | "\n",
47 | "# load the dictionary\n",
48 | "with open(\"class_inst.json\",\"r\") as f:\n",
49 | " class_info = json.load(f)\n",
50 | "# create a centroid info array\n",
51 | "centroid_array = np.array([[int(k),v[0],*v[1]] for k,v in class_info.items()])\n",
52 | "# [instance_id, class_id, y, x]\n",
53 | "\n",
54 | "# or alternatively create a lookup for the instance map to get a corresponding class map\n",
55 | "pcls_list = np.array([0] + [v[0] for v in class_info.values()])\n",
56 | "pcls_keys = np.array([\"0\"] + list(class_info.keys())).astype(int)\n",
57 | "lookup = np.zeros(pcls_keys.max() + 1,dtype=np.uint8)\n",
58 | "lookup[pcls_keys] = pcls_list\n",
59 | "cls_map = lookup[full_instance_map]"
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": null,
65 | "metadata": {},
66 | "outputs": [],
67 | "source": []
68 | }
69 | ],
70 | "metadata": {
71 | "language_info": {
72 | "name": "python"
73 | }
74 | },
75 | "nbformat": 4,
76 | "nbformat_minor": 2
77 | }
78 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | # Read the content of README file
4 | with open('README.md', encoding='utf-8') as f:
5 | long_description = f.read()
6 |
7 | setup(
8 | name="hover_next_inference",
9 | version="0.1",
10 | packages=find_packages(),
11 | long_description=long_description,
12 | long_description_content_type='text/markdown',
13 | )
14 |
--------------------------------------------------------------------------------
/src/augmentations.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torchvision.transforms.transforms import ColorJitter, RandomApply, GaussianBlur
4 |
5 |
6 | rgb_from_hed = np.array(
7 | [[0.65, 0.70, 0.29], [0.07, 0.99, 0.11], [0.27, 0.57, 0.78]], dtype=np.float32
8 | )
9 | hed_from_rgb = np.linalg.inv(rgb_from_hed)
10 |
11 |
12 | def torch_rgb2hed(img: torch.Tensor, hed_t: torch.Tensor, e: float):
13 | """
14 | convert rgb torch tensor to hed torch tensor (adopted from skimage)
15 |
16 | Parameters
17 | ----------
18 | img : torch.Tensor
19 | rgb image tensor (B, C, H, W) or (C, H, W)
20 | hed_t : torch.Tensor
21 | hed transform tensor (3, 3)
22 | e : float
23 | epsilon
24 |
25 | Returns
26 | -------
27 | torch.Tensor
28 | hed image tensor (B, C, H, W) or (C, H, W)
29 | """
30 | img = img.movedim(-3, -1)
31 |
32 | img = torch.clamp(img, min=e)
33 | img = torch.log(img) / torch.log(e)
34 | img = torch.matmul(img, hed_t)
35 | return img.movedim(-1, -3)
36 |
37 |
38 | def torch_hed2rgb(img: torch.Tensor, rgb_t: torch.Tensor, e: float):
39 | """
40 | convert rgb torch tensor to hed torch tensor (adopted from skimage)
41 |
42 | Parameters
43 | ----------
44 | img : torch.Tensor
45 | hed image tensor (B, C, H, W) or (C, H, W)
46 | hed_t : torch.Tensor
47 | hed inverse transform tensor (3, 3)
48 | e : float
49 | epsilon
50 |
51 | Returns
52 | -------
53 | torch.Tensor
54 | RGB image tensor (B, C, H, W) or (C, H, W)
55 | """
56 | e = -torch.log(e)
57 | img = img.movedim(-3, -1)
58 | img = torch.matmul(-(img * e), rgb_t)
59 | img = torch.exp(img)
60 | img = torch.clamp(img, 0, 1)
61 | return img.movedim(-1, -3)
62 |
63 |
64 | class Hed2Rgb(torch.nn.Module):
65 | """
66 | Pytorch module to convert hed image tensors to rgb
67 | """
68 |
69 | def __init__(self, rank):
70 | super().__init__()
71 | self.e = torch.tensor(1e-6).to(rank)
72 | self.rgb_t = torch.from_numpy(rgb_from_hed).to(rank)
73 | self.rank = rank
74 |
75 | def forward(self, img):
76 | return torch_hed2rgb(img, self.rgb_t, self.e)
77 |
78 |
79 | class Rgb2Hed(torch.nn.Module):
80 | """
81 | Pytorch module to convert rgb image tensors to hed
82 | """
83 |
84 | def __init__(self, rank):
85 | super().__init__()
86 | self.e = torch.tensor(1e-6).to(rank)
87 | self.hed_t = torch.from_numpy(hed_from_rgb).to(rank)
88 | self.rank = rank
89 |
90 | def forward(self, img):
91 | return torch_rgb2hed(img, self.hed_t, self.e)
92 |
93 |
94 | class HedNormalizeTorch(torch.nn.Module):
95 | """
96 | Pytorch augmentation module to apply HED stain augmentation
97 |
98 | Parameters
99 | ----------
100 | sigma : float
101 | sigma for linear scaling of HED channels
102 | bias : float
103 | bias for additive scaling of HED channels
104 | """
105 |
106 | def __init__(self, sigma, bias, rank, *args, **kwargs) -> None:
107 | super().__init__(*args, **kwargs)
108 | self.sigma = sigma
109 | self.bias = bias
110 | self.rank = rank
111 | self.rgb2hed = Rgb2Hed(rank=rank)
112 | self.hed2rgb = Hed2Rgb(rank=rank)
113 |
114 | def rng(self, val, batch_size):
115 | return torch.empty(batch_size, 3).uniform_(-val, val).to(self.rank)
116 |
117 | def color_norm_hed(self, img):
118 | B = img.shape[0]
119 | sigmas = self.rng(self.sigma, B)
120 | biases = self.rng(self.bias, B)
121 | return (img * (1 + sigmas.view(*sigmas.shape, 1, 1))) + biases.view(
122 | *biases.shape, 1, 1
123 | )
124 |
125 | def forward(self, img):
126 | if img.dim() == 3:
127 | img = img.view(1, *img.shape)
128 | hed = self.rgb2hed(img)
129 | hed = self.color_norm_hed(hed)
130 | return self.hed2rgb(hed)
131 |
132 |
133 | class GaussianNoise(torch.nn.Module):
134 | """
135 | Pytorch augmentation module to apply gaussian noise
136 |
137 | Parameters
138 | ----------
139 | sigma : float
140 | sigma for uniform distribution to sample from
141 | rank : str or int or torch.device
142 | device to put the module to
143 | """
144 |
145 | def __init__(self, sigma, rank):
146 | super().__init__()
147 | self.sigma = sigma
148 | self.rank = rank
149 |
150 | def forward(self, img):
151 | noise = torch.empty(img.shape).uniform_(-self.sigma, self.sigma).to(self.rank)
152 | return img + noise
153 |
154 |
155 | def color_augmentations(train, sigma=0.05, bias=0.03, s=0.2, rank=0):
156 | """
157 | Color augmentation function (in theory can set to train to have more variance
158 | with high test time augmentations)
159 |
160 | Parameters
161 | ----------
162 | train : bool
163 | during training, the model uses more augmentation than during inference,
164 | set to true for more variance in colors
165 | sigma: float
166 | parameter for hed augmentation
167 | bias: float
168 | parameter for hed augmentation
169 | s: float
170 | parameter for color jitter
171 | rank: int or torch.device or str
172 | device to use for augmentation
173 |
174 | Returns
175 | -------
176 | torch.nn.Sequential
177 | sequential augmentation module
178 | """
179 | if train:
180 | color_jitter = ColorJitter(
181 | 0.8 * s, 0.0 * s, 0.8 * s, 0.2 * s
182 | ) # brightness, contrast, saturation, hue
183 |
184 | data_transforms = torch.nn.Sequential(
185 | RandomApply([HedNormalizeTorch(sigma, bias, rank=rank)], p=0.75),
186 | RandomApply([color_jitter], p=0.3),
187 | RandomApply([GaussianNoise(0.02, rank)], p=0.3),
188 | RandomApply([GaussianBlur(kernel_size=15, sigma=(0.1, 0.1))], p=0.3),
189 | )
190 | else:
191 | data_transforms = torch.nn.Sequential(HedNormalizeTorch(sigma, bias, rank=rank))
192 | return data_transforms
193 |
--------------------------------------------------------------------------------
/src/constants.py:
--------------------------------------------------------------------------------
1 | ### Size thresholds for nuclei (in pixels), pannuke is less conservative
2 | # These have been optimized for the conic challenge, but can be changed
3 | # to get more small nuclei (e.g. by setting all min_threshs to 0)
4 | MIN_THRESHS_LIZARD = [30, 30, 20, 20, 30, 30, 15]
5 | MAX_THRESHS_LIZARD = [5000, 5000, 5000, 5000, 5000, 5000, 5000]
6 | MIN_THRESHS_PANNUKE = [10, 10, 10, 10, 10]
7 | MAX_THRESHS_PANNUKE = [20000, 20000, 20000, 3000, 10000]
8 |
9 | # Maximal size of holes to remove from a nucleus
10 | MAX_HOLE_SIZE = 128
11 |
12 | # Colors for geojson output
13 | COLORS_LIZARD = [
14 | [0, 255, 0], # neu
15 | [255, 0, 0], # epi
16 | [0, 0, 255], # lym
17 | [0, 128, 0], # pla
18 | [0, 255, 255], # eos
19 | [255, 179, 102], # con
20 | [255, 0, 255], # mitosis
21 | ]
22 |
23 | COLORS_PANNUKE = [
24 | [255, 0, 0], # neo
25 | [0, 127, 255], # inf
26 | [255, 179, 102], # con
27 | [0, 0, 0], # dead
28 | [0, 255, 0], # epi
29 | ]
30 |
31 | # text labels for lizard
32 | CLASS_LABELS_LIZARD = {
33 | "neutrophil": 1,
34 | "epithelial-cell": 2,
35 | "lymphocyte": 3,
36 | "plasma-cell": 4,
37 | "eosinophil": 5,
38 | "connective-tissue-cell": 6,
39 | "mitosis": 7,
40 | }
41 |
42 | # text labels for pannuke
43 | CLASS_LABELS_PANNUKE = {
44 | "neoplastic": 1,
45 | "inflammatory": 2,
46 | "connective": 3,
47 | "dead": 4,
48 | "epithelial": 5,
49 | }
50 |
51 | # magnifiation and resolutions for WSI dataloader
52 | LUT_MAGNIFICATION_X = [10, 20, 40, 80]
53 | LUT_MAGNIFICATION_MPP = [0.97, 0.485, 0.2425, 0.124]
54 |
55 | CONIC_MPP = 0.5
56 | PANNUKE_MPP = 0.25
57 |
58 | # parameters for test time augmentations, do not change
59 | TTA_AUG_PARAMS = {
60 | "mirror": {"prob_x": 0.5, "prob_y": 0.5, "prob": 0.75},
61 | "translate": {"max_percent": 0.03, "prob": 0.0},
62 | "scale": {"min": 0.8, "max": 1.2, "prob": 0.0},
63 | "zoom": {"min": 0.8, "max": 1.2, "prob": 0.0},
64 | "rotate": {"rot90": True, "prob": 0.75},
65 | "shear": {"max_percent": 0.1, "prob": 0.0},
66 | "elastic": {"alpha": [120, 120], "sigma": 8, "prob": 0.0},
67 | }
68 |
69 | # current valid pre-trained weights to be automatically downloaded and used in HoVer-NeXt
70 | VALID_WEIGHTS = [
71 | "lizard_convnextv2_large",
72 | "lizard_convnextv2_base",
73 | "lizard_convnextv2_tiny",
74 | "pannuke_convnextv2_tiny_1",
75 | "pannuke_convnextv2_tiny_2",
76 | "pannuke_convnextv2_tiny_3",
77 | ]
--------------------------------------------------------------------------------
/src/data_utils.py:
--------------------------------------------------------------------------------
1 | import openslide
2 | import numpy as np
3 | import torch
4 | from torch.utils.data import Dataset
5 | from typing import Optional, List, Tuple, Callable
6 | from skimage.morphology import remove_small_objects, disk, dilation
7 | import PIL
8 | import pathlib
9 | import cv2
10 | from src.constants import LUT_MAGNIFICATION_MPP, LUT_MAGNIFICATION_X
11 | from shutil import copy2, copytree
12 | import os
13 | from pylibCZIrw import czi as pyczi
14 |
15 |
16 | def copy_img(im_path, cache_dir):
17 | """
18 | Helper function to copy WSI to cache directory
19 |
20 | Parameters
21 | ----------
22 | im_path : str
23 | path to the WSI
24 | cache_dir : str
25 | path to the cache directory
26 |
27 | Returns
28 | -------
29 | str
30 | path to the copied WSI
31 | """
32 | file, ext = os.path.splitext(im_path)
33 | if ext == ".mrxs":
34 | copy2(im_path, cache_dir)
35 | copytree(
36 | file, os.path.join(cache_dir, os.path.split(file)[-1]), dirs_exist_ok=True
37 | )
38 | else:
39 | copy2(im_path, cache_dir)
40 | return os.path.join(cache_dir, os.path.split(im_path)[-1])
41 |
42 |
43 | def normalize_min_max(x: np.ndarray, mi, ma, clip=False, eps=1e-20, dtype=np.float32):
44 | """
45 | Min max scaling for input array
46 |
47 | Parameters
48 | ----------
49 | x : np.ndarray
50 | input array
51 | mi : float or int
52 | minimum value
53 | ma : float or int
54 | maximum value
55 | clip : bool, optional
56 | clip values be between 0 and 1, False by default
57 | eps : float
58 | epsilon value to avoid division by zero
59 | dtype : type
60 | data type of the output array
61 |
62 | Returns
63 | -------
64 | np.ndarray
65 | normalized array
66 | """
67 | if mi is None:
68 | mi = np.min(x)
69 | if ma is None:
70 | ma = np.max(x)
71 | if dtype is not None:
72 | x = x.astype(dtype, copy=False)
73 | mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype, copy=False)
74 | ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype, copy=False)
75 | eps = dtype(eps)
76 |
77 | x = (x - mi) / (ma - mi + eps)
78 |
79 | if clip:
80 | x = np.clip(x, 0, 1)
81 | return x
82 |
83 |
84 | def center_crop(t, croph, cropw):
85 | """
86 | Center crop input tensor in last two axes to height and width
87 | """
88 | h, w = t.shape[-2:]
89 | startw = w // 2 - (cropw // 2)
90 | starth = h // 2 - (croph // 2)
91 | return t[..., starth : starth + croph, startw : startw + cropw]
92 |
93 |
94 | class czi_wrapper:
95 | def __init__(self, path, levels=11, sharpen_img=True):
96 | """
97 | Wrapper to load czi files without openslide, but with the same endpoints
98 |
99 | Parameters
100 | ------------
101 | path: str
102 | Path to the wsi (.czi)
103 | levels: int, optional
104 | number of artificially created levels
105 | sharpen_img: bool, optional
106 | whether to sharpen the image (cohort dependent)
107 |
108 | Examples
109 | -----------
110 | Use as a replacement for openslide.open_slide:
111 | >>> sl = czi_wrapper(path)
112 | >>> sl.read_region(...)
113 | """
114 | self.path = path
115 | self.levels = levels
116 | self.sharpen_img = sharpen_img
117 | self.level_dimensions = None
118 | self.level_downsamples = None
119 | self.properties = {}
120 | self.associated_images = {}
121 | try:
122 | self._generate_dictionaries()
123 | except:
124 | raise RuntimeError(f"issue with {self.path}")
125 |
126 | @staticmethod
127 | def _convert_rect_to_tuple(rect):
128 | return rect.x, rect.y, rect.w, rect.h
129 |
130 | @staticmethod
131 | def _sharpen(img_o):
132 | img_b = cv2.GaussianBlur(img_o, ksize=[3, 3], sigmaX=1, sigmaY=1)
133 | img_s = cv2.addWeighted(img_o, 3.0, img_b, -2.0, 0)
134 | return img_s
135 |
136 | def _generate_dictionaries(self):
137 | with pyczi.open_czi(self.path) as sl:
138 | total_bounding_rectangle = sl.total_bounding_rectangle
139 | meta = sl.metadata["ImageDocument"]["Metadata"]
140 |
141 | self.associated_images["thumbnail"] = PIL.Image.fromarray(
142 | cv2.cvtColor(sl.read(zoom=0.005), cv2.COLOR_BGR2RGB)
143 | )
144 |
145 | x, y, w, h = self._convert_rect_to_tuple(total_bounding_rectangle)
146 | self.level_dimensions = tuple(
147 | (int(w / (2**i)), int(h / (2.0**i))) for i in range(self.levels)
148 | )
149 | self.level_downsamples = tuple(2.0**i for i in range(self.levels))
150 | mpp = {
151 | m["@Id"]: float(m["Value"]) * 1e6
152 | for m in meta["Scaling"]["Items"]["Distance"]
153 | }
154 | self.properties["openslide.mpp-x"] = mpp["X"]
155 | self.properties["openslide.mpp-y"] = mpp["Y"]
156 | self.tx = x
157 | self.ty = y
158 |
159 | def read_region(self, crds, level, size):
160 | with pyczi.open_czi(self.path) as sl:
161 | img = sl.read(
162 | # plane={"T": 0, "Z": 0, "C": 0},
163 | zoom=1.0 / (2**level),
164 | roi=(
165 | self.tx + crds[0],
166 | self.ty + crds[1],
167 | size[0] * (2**level),
168 | size[1] * (2**level),
169 | ),
170 | )
171 |
172 | if self.sharpen_img:
173 | img = self._sharpen(img)
174 | return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
175 |
176 |
177 | # Adapted from https://github.com/christianabbet/SRA
178 | # Original Author: Christian Abbet
179 | class WholeSlideDataset(Dataset):
180 | def __init__(
181 | self,
182 | path: str,
183 | crop_sizes_px: Optional[List[int]] = None,
184 | crop_magnifications: Optional[List[float]] = None,
185 | transform: Optional[Callable] = None,
186 | padding_factor: Optional[float] = 0.5,
187 | remove_background: Optional[bool] = True,
188 | remove_oob: Optional[bool] = True,
189 | remove_alpha: Optional[bool] = True,
190 | ratio_object_thresh: Optional[float] = 1e-3,
191 | ) -> None:
192 | """
193 | Load a crop as a dataset format. The object is iterable.
194 | Parameters
195 | ----------
196 | path: str
197 | Path to the whole slide as a "*.tif, *.svs, *.mrxs format"
198 | crop_sizes_px: list of int, optional
199 | List of crops output size in pixel, default value is [224].
200 | crop_magnifications: list of float, optional
201 | List of crops magnification level, default value is [20].
202 | transform: callable, optional
203 | Transformation to apply to crops, default value is None. So far, only one augmentation for all crops
204 | is possible.
205 | padding_factor: float, optional
206 | Padding value when creating reference grid. Distance between two consecutive crops as a proportion of the
207 | first listed crop size. Default value is 0.5.
208 | remove_background: bool, optional
209 | Remove background crops if their average intensity value is below the threshold value (240). Default value
210 | is True.
211 | remove_oob: bool, optional
212 | Remove all crops where its representation at a specific magnification is out of bound (out of the scanned
213 | image). Default value is True.
214 | remove_alpha: bool, optional
215 | Remove alpha channel when extracting patches to create a RGB image (instead of RGBA). More suitable to ML
216 | input transforms. Default value is True.
217 | ratio_object_thresh: float, optional
218 | Size of the object ot remove. THe value isexpressed as a ratio with respect to the area of the whole slide.
219 | Default value is 1e-3 (e.i., 1%).
220 | Raises
221 | ------
222 | WholeSlideError
223 | If it is not possible to load the WSIs.
224 | Examples
225 | --------
226 | Load a slide at 40x with a crop size of 256px:
227 | >>> wsi = WholeSlideDataset(
228 | path="/path/to/slide/.mrxs",
229 | crop_sizes_px=[256],
230 | crop_magnifications=[40.],
231 | )
232 | """
233 |
234 | extension = pathlib.Path(path).suffix
235 | if (
236 | extension != ".svs"
237 | and extension != ".mrxs"
238 | and extension != ".tif"
239 | and extension != ".czi"
240 | ):
241 | raise NotImplementedError(
242 | "Only *.svs, *.tif, *.czi, and *.mrxs files supported"
243 | )
244 |
245 | # Load and create slide and affect default values
246 | self.path = path
247 | self.s = (
248 | openslide.open_slide(self.path)
249 | if extension != ".czi"
250 | else czi_wrapper(self.path)
251 | )
252 | self.crop_sizes_px = crop_sizes_px
253 | self.crop_magnifications = crop_magnifications
254 | self.transform = transform
255 | self.padding_factor = padding_factor
256 | self.remove_alpha = remove_alpha
257 | self.mask = None
258 |
259 | if self.crop_sizes_px is None:
260 | self.crop_sizes_px = [224]
261 |
262 | if self.crop_magnifications is None:
263 | self.crop_magnifications = [20]
264 |
265 | # Dimension of the slide at different levels
266 | self.level_dimensions = self.s.level_dimensions
267 | # Down sampling factor at each level
268 | self.level_downsamples = self.s.level_downsamples
269 | # Get average micro meter per pixel (MPP) for the slide
270 | try:
271 | self.mpp = 0.5 * (
272 | float(self.s.properties[openslide.PROPERTY_NAME_MPP_X])
273 | + float(self.s.properties[openslide.PROPERTY_NAME_MPP_Y])
274 | )
275 | except KeyError:
276 | print("'No resolution found in WSI metadata, using default .2425")
277 | self.mpp = 0.2425
278 | # raise IndexError('No resolution found in WSI metadata. Impossible to build pyramid.')
279 |
280 | # Extract level magnifications
281 | self.level_magnifications = self._get_magnifications(
282 | self.mpp, self.level_downsamples
283 | )
284 | # Consider reference level as the level with highest resolution
285 | self.crop_reference_level = 0
286 |
287 | # Build reference grid / crop centers
288 | self.crop_reference_cxy = self._build_reference_grid(
289 | crop_size_px=self.crop_sizes_px[0],
290 | crop_magnification=self.crop_magnifications[0],
291 | padding_factor=padding_factor,
292 | level_magnification=self.level_magnifications[self.crop_reference_level],
293 | level_shape=self.level_dimensions[self.crop_reference_level],
294 | )
295 |
296 | # Assume the whole slide has an associated image
297 | if remove_background and "thumbnail" in self.s.associated_images:
298 | # Extract image thumbnail from slide metadata
299 | img_thumb = self.s.associated_images["thumbnail"]
300 | # Get scale factor compared to reference size
301 | mx = img_thumb.size[0] / self.level_dimensions[self.crop_reference_level][0]
302 | my = img_thumb.size[1] / self.level_dimensions[self.crop_reference_level][1]
303 | # Compute foreground mask
304 | self.mask = self._foreground_mask(
305 | img_thumb, ratio_object_thresh=ratio_object_thresh
306 | )
307 | # pad with 1, to avoid rounding error:
308 | self.mask = np.pad(
309 | self.mask, ((0, 1), (0, 1)), mode="constant", constant_values=False
310 | )
311 | # Select subset of point that are part of the foreground
312 | id_valid = self.mask[
313 | np.round(my * self.crop_reference_cxy[:, 1]).astype(int),
314 | np.round(mx * self.crop_reference_cxy[:, 0]).astype(int),
315 | ]
316 | self.crop_reference_cxy = self.crop_reference_cxy[id_valid]
317 |
318 | # Build grid for all levels
319 | self.crop_metadatas = self._build_crop_metadatas(
320 | self.crop_sizes_px,
321 | self.crop_magnifications,
322 | self.level_magnifications,
323 | self.crop_reference_cxy,
324 | self.crop_reference_level,
325 | )
326 |
327 | # Remove samples that are oob from sampling
328 | if remove_oob:
329 | # Compute oob sa,ples
330 | oob_id = self._oob_id(
331 | self.crop_metadatas, self.level_dimensions[self.crop_reference_level]
332 | )
333 | # Select only smaples that are within bounds.
334 | self.crop_reference_cxy = self.crop_reference_cxy[~oob_id]
335 | self.crop_metadatas = self.crop_metadatas[:, ~oob_id]
336 |
337 | @staticmethod
338 | def _pil_rgba2rgb(
339 | image: PIL.Image, default_background: Optional[List[int]] = None
340 | ) -> PIL.Image:
341 | """
342 | Convert RGBA image to RGB format using default background color.
343 | From https://stackoverflow.com/questions/9166400/convert-rgba-png-to-rgb-with-pil/9459208#9459208
344 | Parameters
345 | ----------
346 | image: PIL.Image
347 | Input RBA image to convert.
348 | default_background: list of int, optional
349 | Value to us as background hen alpha channel is not 255. Default value is white (255, 255, 255).
350 | Returns
351 | -------
352 | Image with alpha channel removed.
353 | """
354 | if default_background is None:
355 | default_background = (255, 255, 255)
356 | if type(image) == np.ndarray:
357 | if image.shape[-1] == 3:
358 | return image
359 | else:
360 | return cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
361 | else:
362 | image.load()
363 | background = PIL.Image.new("RGB", image.size, default_background)
364 | background.paste(image, mask=image.split()[3])
365 | return background
366 |
367 | @staticmethod
368 | def _oob_id(
369 | crop_grid: np.ndarray,
370 | level_shape: List[int],
371 | ) -> np.ndarray:
372 | """
373 | Check is the samples are within bounds.
374 | Parameters
375 | ----------
376 | crop_grid: array_like
377 | Input crop meta data of C element where C is the number of crops. For each crop
378 | level_shape: list of int
379 | Dimension of the image
380 | Returns
381 | -------
382 | """
383 | # Extract top left coordinated
384 | tx, ty = crop_grid[:, :, 2], crop_grid[:, :, 3]
385 | # Extract top right coordinated
386 | bx, by = crop_grid[:, :, 4], crop_grid[:, :, 5]
387 | # Check for boundaries
388 | oob_id = (tx < 0) | (ty < 0) | (bx > level_shape[0]) | (by > level_shape[1])
389 | return np.any(oob_id, axis=0)
390 |
391 | @staticmethod
392 | def _build_crop_metadatas(
393 | crop_sizes_px: List[int],
394 | crop_magnifications: List[float],
395 | level_magnifications: List[float],
396 | crop_reference_grid: np.ndarray,
397 | crop_reference_level: int,
398 | ) -> np.ndarray:
399 | """
400 | Build metadata for each crops definitions.
401 | Parameters
402 | ----------
403 | crop_sizes_px: list of int, optional
404 | List of crops output size in pixel, default value is [224].
405 | crop_magnifications: list of float, optional
406 | List of crops magnification level, default value is [20].
407 | level_magnifications: list of float
408 | List of available magnifications (one for each level)
409 | crop_reference_grid:
410 | Reference grid with shape [Nx2] where N is the number of samples. The column represent x and y coordinates
411 | of the center of the crops respectively.
412 | crop_reference_level: int
413 | Reference level used to compute the reference grid.
414 | Returns
415 | -------
416 | metas: array_like
417 | Meta data were each entry correspond to the metadata a the crop and [mag, level, tx, ty, cx, cy, bx,
418 | by, s_src, s_tar]. With mag = magnification of the crop, level = level at which the crop was extracted,
419 | (tx, ty) = top left coordinate of the crop, (cx, cy) = center coordinate of the crop, (bx, by) = bottom
420 | right coordinates of the crop, s_src = size of the crop at the level, s_tar = siz of the crop after
421 | applying rescaling.
422 | """
423 |
424 | crop_grids = []
425 | for t_size, t_mag in zip(crop_sizes_px, crop_magnifications):
426 | # Level that we use to extract current slide region
427 | t_level = WholeSlideDataset._get_optimal_level(t_mag, level_magnifications)
428 | # Scale factor between the reference magnification and the magnification used
429 | t_scale = (
430 | level_magnifications[t_level]
431 | / level_magnifications[crop_reference_level]
432 | )
433 | # Final image size at the current level / magnification
434 | t_level_size = t_size / (t_mag / level_magnifications[t_level])
435 | # Offset to recenter image
436 | t_shift = (t_level_size / t_scale) // 2
437 | # Return grid as format: [level, tx, ty, bx, by, level_size, size]
438 | grid_ = np.concatenate(
439 | (
440 | t_mag
441 | * np.ones(len(crop_reference_grid))[:, np.newaxis], # Magnification
442 | t_level * np.ones(len(crop_reference_grid))[:, np.newaxis], # Level
443 | crop_reference_grid - t_shift, # (tx, ty) coordinates
444 | crop_reference_grid, # (cx, cy) coordinates values
445 | crop_reference_grid + t_shift, # (bx, by) coordinates
446 | t_level_size
447 | * np.ones(len(crop_reference_grid))[
448 | :, np.newaxis
449 | ], # original images size
450 | t_size
451 | * np.ones(len(crop_reference_grid))[
452 | :, np.newaxis
453 | ], # target image size
454 | ),
455 | axis=1,
456 | )
457 | crop_grids.append(grid_)
458 | return np.array(crop_grids)
459 |
460 | @staticmethod
461 | def _get_optimal_level(
462 | magnification: float, level_magnifications: List[float]
463 | ) -> int:
464 | """
465 | Estimate the optimal level to extract crop. It the wanted level do nt exist, use a level with higher resolution
466 | (lower level) and resize crop.
467 | Parameters
468 | ----------
469 | magnification: float
470 | Wanted output magnification
471 | level_magnifications: list of float
472 | List of available magnifications (one for each level)
473 | Returns
474 | -------
475 | optimal_level: int
476 | Estimated optimal level for crop extraction
477 | """
478 |
479 | # Get the highest level that is a least as high resolution as the wanted target.
480 | if magnification <= np.max(level_magnifications):
481 | optimal_level = np.nonzero(np.array(level_magnifications) >= magnification)[
482 | 0
483 | ][-1]
484 | else:
485 | # If no suitable candidates are found, use max resolution
486 | optimal_level = 0
487 | print(
488 | "Slide magnifications {} do not match expected target magnification {}".format(
489 | magnification, level_magnifications
490 | )
491 | )
492 |
493 | return optimal_level
494 |
495 | @staticmethod
496 | def _get_magnifications(
497 | mpp: float,
498 | level_downsamples: List[float],
499 | error_max: Optional[float] = 1e-1,
500 | ) -> List[float]:
501 | """
502 | Compute estimated magnification for each level. The computation rely on the definition of LUT_MAGNIFICATION_X
503 | and LUT_MAGNIFICATION_MPP that are mapped. For example the assumption is 20x -> ~0.5MPP and 40x -> ~0.25MPP.
504 | Parameters
505 | ----------
506 | mpp: float
507 | Resolution of the slide (and the scanner).
508 | level_downsamples: lost of float
509 | Down sampling factors for each level as a list of floats.
510 | error_max: float, optional
511 | Maximum relative error accepted when trying to match magnification to predefined factors. Default value
512 | is 1e-1.
513 | Returns
514 | -------
515 | level_magnifications: list of float
516 | Return the estimated magnifications for each level.
517 | """
518 |
519 | error_mag = np.abs((np.array(LUT_MAGNIFICATION_MPP) - mpp) / mpp)
520 | # if np.min(error_mag) > error_max:
521 | # print('Error too large for mpp matching: mpp={}, error={}'.format(mpp, np.min(error_mag)))
522 |
523 | return LUT_MAGNIFICATION_X[np.argmin(error_mag)] / np.round(
524 | level_downsamples
525 | ).astype(int)
526 |
527 | @staticmethod
528 | def _foreground_mask(
529 | img: PIL.Image.Image,
530 | intensity_thresh: Optional[int] = 240,
531 | ratio_object_thresh: Optional[float] = 1e-4,
532 | ) -> np.ndarray:
533 | """
534 | Compute foreground mask the slide base on the input image. Usually the embedded thumbnail is used.
535 | Parameters
536 | ----------
537 | img: PIL.Image.Image
538 | Downscaled version of the slide as a PIL image
539 | intensity_thresh: int
540 | Intensity threshold applied on te grayscale version of the image to distinguish background from foreground.
541 | The default value is 240.
542 | ratio_object_thresh: float
543 | Minimal ratio of the object to consider as a relevant region. Ratio is applied on the area of the object.
544 | Returns
545 | -------
546 | mask: np.ndarray
547 | Masked version of the input image where '0', '1' indicates regions belonging to background and foreground
548 | respectively.
549 | """
550 |
551 | # Convert image to grayscale
552 | mask = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2GRAY)
553 | # Blur image to remove hih frequencies
554 | mask = cv2.blur(mask, (5, 5))
555 | # Apply threshold on background intensity
556 | mask = mask < intensity_thresh
557 | # Remove smallest object as a ratio of original image size
558 | mask = remove_small_objects(
559 | mask, min_size=ratio_object_thresh * np.prod(mask.shape)
560 | )
561 | # Add final margin to avoid cutting edges
562 | disk_edge = np.ceil(np.max(mask.shape) * ratio_object_thresh).astype(int)
563 | mask = dilation(mask, disk(max(1, disk_edge)))
564 |
565 | return mask
566 |
567 | @staticmethod
568 | def _build_reference_grid(
569 | crop_size_px: int,
570 | crop_magnification: float,
571 | padding_factor: float,
572 | level_magnification: float,
573 | level_shape: List[int],
574 | ) -> np.ndarray:
575 | """
576 | Build reference grid for cropping location. The grid is usually computed at the lowest magnification.
577 | Parameters
578 | ----------
579 | crop_size_px: int
580 | Output size in pixel.
581 | crop_magnification: float
582 | Magnification value.
583 | padding_factor: float
584 | Padding factor to use. Define the interval between two consecutive crops.
585 | level_magnification: float
586 | Selected magnification.
587 | level_shape: list of int
588 | Size of the image at the selected level.
589 | Returns
590 | -------
591 | (cx, cy): list of int
592 | Center coordinate of the crop.
593 | """
594 |
595 | # Define the size of the crop at the selected level
596 | level_crop_size_px = int(
597 | (level_magnification / crop_magnification) * crop_size_px
598 | )
599 |
600 | # Compute the number of crops for each dimensions (rows and columns)
601 | n_w = np.floor(
602 | (1 / padding_factor) * (level_shape[0] / level_crop_size_px - 1)
603 | ).astype(int)
604 | n_h = np.floor(
605 | (1 / padding_factor) * (level_shape[1] / level_crop_size_px - 1)
606 | ).astype(int)
607 |
608 | # Compute the residual margin at each side of the image
609 | margin_w = (
610 | int(level_shape[0] - padding_factor * (n_w - 1) * level_crop_size_px) // 2
611 | )
612 | margin_h = (
613 | int(level_shape[1] - padding_factor * (n_h - 1) * level_crop_size_px) // 2
614 | )
615 |
616 | # Compute the final center for the cropping
617 | c_x = (np.arange(n_w) * level_crop_size_px * padding_factor + margin_w).astype(
618 | int
619 | )
620 | c_y = (np.arange(n_h) * level_crop_size_px * padding_factor + margin_h).astype(
621 | int
622 | )
623 | c_x, c_y = np.meshgrid(c_x, c_y)
624 |
625 | return np.array([c_x.flatten(), c_y.flatten()]).T
626 |
627 | def __len__(self) -> int:
628 | return len(self.crop_reference_cxy)
629 |
630 | def __getitem__(self, idx: int) -> Tuple[List[object], List[object]]:
631 | """
632 | Get slide element as a function of the index idx.
633 | Parameters
634 | ----------
635 | idx: int
636 | Index of the crop
637 | Returns
638 | -------
639 | imgs: List of PIL.Image
640 | List of extracted crops for this index.
641 | metas: List of List of float
642 | Meta data were each entry correspond to the metadata a the crop and [mag, level, tx, ty, cx, cy, bx,
643 | by, s_src, s_tar]. With mag = magnification of the crop, level = level at which the crop was extracted,
644 | (tx, ty) = top left coordinate of the crop, (cx, cy) = center coordinate of the crop, (bx, by) = bottom
645 | right coordinates of the crop, s_src = size of the crop at the level, s_tar = siz of the crop after
646 | applying rescaling.
647 | """
648 | # Extract metadata for crops
649 | mag, level, tx, ty, cx, cy, bx, by, s_src, s_tar = self.crop_metadatas[0][idx]
650 | # Extract crop
651 | img = self.s.read_region(
652 | (int(tx), int(ty)), int(level), size=(int(s_src), int(s_src))
653 | )
654 | # If needed, resize crop to match output shape
655 | if s_src != s_tar:
656 | img = img.resize((int(s_tar), int(s_tar)))
657 | # Append images and metadatas
658 | if self.remove_alpha:
659 | img = self._pil_rgba2rgb(img)
660 | if self.transform is not None:
661 | img = self.transform(img)
662 |
663 | img = normalize_min_max(np.array(img), 0, 255)
664 |
665 | return torch.Tensor(np.array(img)), [
666 | mag,
667 | level,
668 | tx,
669 | ty,
670 | cx,
671 | cy,
672 | bx,
673 | by,
674 | s_src,
675 | s_tar,
676 | ]
677 |
678 |
679 | class NpyDataset(Dataset):
680 | def __init__(
681 | self,
682 | path,
683 | crop_size_px,
684 | padding_factor=0.5,
685 | remove_bg=True,
686 | ratio_object_thresh=5e-1,
687 | min_tiss=0.1,
688 | ):
689 | """
690 | Torch Dataset to load from NPY files.
691 |
692 | Parameters
693 | ----------
694 | path : str
695 | Path to the NPY file.
696 | crop_size_px : int
697 | Size of the extracted tiles in pixels. e.g 256 -> 256x256 tiles
698 | padding_factor : float, optional
699 | Padding value when creating reference grid. Distance between two consecutive crops as a proportion of the
700 | first listed crop size.
701 | remove_bg : bool, optional
702 | Remove background crops if their saturation value is above 5. Default value is True.
703 | ratio_object_thresh : float, optional
704 | Objects are removed if they are smaller than ratio*largest object
705 | min_tiss : float, optional
706 | Threshold value to consider a crop as tissue. Default value is 0.1.
707 | """
708 | self.path = path
709 | self.crop_size_px = crop_size_px
710 | self.padding_factor = padding_factor
711 | self.ratio_object_thresh = ratio_object_thresh
712 | self.min_tiss = min_tiss
713 | self.remove_bg = remove_bg
714 | self.store = np.load(path)
715 | if self.store.ndim == 3:
716 | self.store = self.store[np.newaxis, :]
717 | if self.store.dtype != np.uint8:
718 | print("converting input dtype to uint8")
719 | self.store = self.store.astype(np.uint8)
720 | self.orig_shape = self.store.shape
721 | self.store = np.pad(
722 | self.store,
723 | [
724 | (0, 0),
725 | (self.crop_size_px, self.crop_size_px),
726 | (self.crop_size_px, self.crop_size_px),
727 | (0, 0),
728 | ],
729 | "constant",
730 | constant_values=255,
731 | )
732 | self.msks, self.fg_amount = self._foreground_mask()
733 |
734 | self.grid = self._calc_grid()
735 | self.idx = self._create_idx()
736 |
737 | # TODO No idea what kind of exceptions could happen.
738 | # If you are having issues with this dataloader, create an issue.
739 |
740 | def _foreground_mask(self, h_tresh=5):
741 | # print("computing fg masks")
742 | ret = []
743 | fg_amount = []
744 | for im in self.store:
745 | msk = (
746 | cv2.blur(cv2.cvtColor(im, cv2.COLOR_RGB2HSV)[..., 1], (50, 50))
747 | > h_tresh
748 | )
749 | comp, labl, size, cent = cv2.connectedComponentsWithStats(
750 | msk.astype(np.uint8) * 255
751 | )
752 | selec = size[1:, -1] / size[1:, -1].max() > self.ratio_object_thresh
753 | ids = np.arange(1, comp)[selec]
754 | fin_msk = np.isin(labl, ids)
755 | ret.append(fin_msk)
756 | fg_amount.append(np.mean(fin_msk))
757 |
758 | return ret, fg_amount
759 |
760 | def _calc_grid(self):
761 | _, h, w, _ = self.store.shape
762 | n_w = np.floor(
763 | (w - self.crop_size_px) / (self.crop_size_px * self.padding_factor)
764 | )
765 | n_h = np.floor(
766 | (h - self.crop_size_px) / (self.crop_size_px * self.padding_factor)
767 | )
768 | margin_w = (
769 | int(w - (self.padding_factor * n_w * self.crop_size_px + self.crop_size_px))
770 | // 2
771 | )
772 | margin_h = (
773 | int(h - (self.padding_factor * n_h * self.crop_size_px + self.crop_size_px))
774 | // 2
775 | )
776 | c_x = (
777 | np.arange(n_w + 1) * self.crop_size_px * self.padding_factor + margin_w
778 | ).astype(int)
779 | c_y = (
780 | np.arange(n_h + 1) * self.crop_size_px * self.padding_factor + margin_h
781 | ).astype(int)
782 | c_x, c_y = np.meshgrid(c_x, c_y)
783 | return np.array([c_y.flatten(), c_x.flatten()]).T
784 |
785 | def _create_idx(self):
786 | crd_list = []
787 | for i, msk in enumerate(self.msks):
788 | if self.remove_bg:
789 | valid_crd = [
790 | np.mean(
791 | msk[
792 | crd[0] : crd[0] + self.crop_size_px,
793 | crd[1] : crd[1] + self.crop_size_px,
794 | ]
795 | )
796 | > self.min_tiss
797 | for crd in self.grid
798 | ]
799 | crd_subset = self.grid[valid_crd, :]
800 | crd_list.append(
801 | np.concatenate(
802 | [np.repeat(i, crd_subset.shape[0]).reshape(-1, 1), crd_subset],
803 | -1,
804 | )
805 | )
806 | else:
807 | crd_list.append(
808 | np.concatenate(
809 | [np.repeat(i, self.grid.shape[0]).reshape(-1, 1), self.grid], -1
810 | )
811 | )
812 | return np.vstack(crd_list)
813 |
814 | def __len__(self) -> int:
815 | return self.idx.shape[0]
816 |
817 | def __getitem__(self, idx):
818 | c, x, y = self.idx[idx]
819 | out_img = self.store[c, x : x + self.crop_size_px, y : y + self.crop_size_px]
820 | out_img = normalize_min_max(out_img, 0, 255)
821 | return out_img, (c, x, y)
822 |
823 |
824 | class ImageDataset(NpyDataset):
825 | """
826 | Torch Dataset to load from NPY files.
827 |
828 | Parameters
829 | ----------
830 | path : str
831 | Path to the Image, needs to be supported by opencv
832 | crop_size_px : int
833 | Size of the extracted tiles in pixels. e.g 256 -> 256x256 tiles
834 | padding_factor : float, optional
835 | Padding value when creating reference grid. Distance between two consecutive crops as a proportion of the
836 | first listed crop size.
837 | remove_bg : bool, optional
838 | Remove background crops if their saturation value is above 5. Default value is True.
839 | ratio_object_thresh : float, optional
840 | Objects are removed if they are smaller than ratio*largest object
841 | min_tiss : float, optional
842 | Threshold value to consider a crop as tissue. Default value is 0.1.
843 | """
844 |
845 | def __init__(
846 | self,
847 | path,
848 | crop_size_px,
849 | padding_factor=0.5,
850 | remove_bg=True,
851 | ratio_object_thresh=5e-1,
852 | min_tiss=0.1,
853 | ):
854 | self.path = path
855 | self.crop_size_px = crop_size_px
856 | self.padding_factor = padding_factor
857 | self.ratio_object_thresh = ratio_object_thresh
858 | self.min_tiss = min_tiss
859 | self.remove_bg = remove_bg
860 | self.store = self._load_image()
861 |
862 | self.orig_shape = self.store.shape
863 | self.store = np.pad(
864 | self.store,
865 | [
866 | (0, 0),
867 | (self.crop_size_px, self.crop_size_px),
868 | (self.crop_size_px, self.crop_size_px),
869 | (0, 0),
870 | ],
871 | "constant",
872 | constant_values=255,
873 | )
874 | self.msks, self.fg_amount = self._foreground_mask()
875 | self.grid = self._calc_grid()
876 | self.idx = self._create_idx()
877 |
878 | def _load_image(self):
879 | img = cv2.imread(self.path)
880 | if img.shape[-1] == 4:
881 | img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
882 | elif img.shape[-1] == 3:
883 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
884 | else:
885 | raise NotImplementedError("Image is neither RGBA nor RGB")
886 | return img[np.newaxis, ...]
887 |
--------------------------------------------------------------------------------
/src/inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 | import toml
4 | import requests
5 | from concurrent.futures import ThreadPoolExecutor
6 | import concurrent.futures
7 | from typing import List, Union, Tuple
8 | import torch
9 | import numpy as np
10 | import zarr
11 | import zipfile
12 | from numcodecs import Blosc
13 | from torch.utils.data import DataLoader
14 | from tqdm.auto import tqdm
15 | from scipy.special import softmax
16 | from src.multi_head_unet import get_model, load_checkpoint
17 | from src.data_utils import WholeSlideDataset, NpyDataset, ImageDataset
18 | from src.augmentations import color_augmentations
19 | from src.spatial_augmenter import SpatialAugmenter
20 | from src.constants import TTA_AUG_PARAMS, VALID_WEIGHTS
21 |
22 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
23 |
24 | def inference_main(
25 | params: dict,
26 | models,
27 | augmenter,
28 | color_aug_fn,
29 | ):
30 | """
31 | Inference function for a single input file.
32 |
33 | Parameters
34 | ----------
35 | params: dict
36 | Parameter store, defined in initial main
37 | models: List[torch.nn.Module]
38 | list of models to run inference with, e.g. multiple folds or a single model in a list
39 | augmenter: SpatialAugmenter
40 | Augmentation module for geometric transformations
41 | color_aug_fn: torch.nn.Sequential
42 | Color Augmentation module
43 |
44 | Returns
45 | ----------
46 | params: dict
47 | Parameter store, defined in initial main and modified by this function
48 | z: Union(Tuple[zarr.ZipStore, zarr.ZipStore], None)
49 | instance and class segmentation results as zarr stores, kept open for further processing. None if inference was skipped.
50 | """
51 | # print(repr(params["p"]))
52 | fn = params["p"].split(os.sep)[-1].split(params["ext"])[0]
53 | params["output_dir"] = os.path.join(params["output_root"], fn)
54 | if not os.path.isdir(params["output_dir"]):
55 | os.makedirs(params["output_dir"])
56 | params["model_out_p"] = os.path.join(
57 | params["output_dir"], fn + "_raw_" + str(params["tile_size"])
58 | )
59 | prog_path = os.path.join(params["output_dir"], "progress.txt")
60 |
61 | if os.path.exists(os.path.join(params["output_dir"], "pinst_pp.zip")):
62 | print(
63 | "inference and postprocessing already completed, delete output or specify different output path to re-run"
64 | )
65 | return params, None
66 |
67 | if (
68 | os.path.exists(params["model_out_p"] + "_inst.zip")
69 | & (os.path.exists(params["model_out_p"] + "_cls.zip"))
70 | & (not os.path.exists(prog_path))
71 | ):
72 | try:
73 | z_inst = zarr.open(params["model_out_p"] + "_inst.zip", mode="r")
74 | z_cls = zarr.open(params["model_out_p"] + "_cls.zip", mode="r")
75 | print("Inference already completed", z_inst.shape, z_cls.shape)
76 | return params, (z_inst, z_cls)
77 | except (KeyError, zipfile.BadZipFile):
78 | z_inst = None
79 | z_cls = None
80 | print(
81 | "something went wrong with previous output files, rerunning inference"
82 | )
83 |
84 | z_inst = None
85 | z_cls = None
86 |
87 | if not torch.cuda.is_available():
88 | print("trying to run inference on CPU, aborting...")
89 | print("if this is intended, remove this check")
90 | raise Exception("No GPU available")
91 |
92 | # create datasets from specified input
93 |
94 | if params["input_type"] == "npy":
95 | dataset = NpyDataset(
96 | params["p"],
97 | params["tile_size"],
98 | padding_factor=params["overlap"],
99 | ratio_object_thresh=0.3,
100 | min_tiss=0.1,
101 | )
102 | elif params["input_type"] == "img":
103 | dataset = ImageDataset(
104 | params["p"],
105 | params["tile_size"],
106 | padding_factor=params["overlap"],
107 | ratio_object_thresh=0.3,
108 | min_tiss=0.1,
109 | )
110 | else:
111 | level = 40 if params["pannuke"] else 20
112 | dataset = WholeSlideDataset(
113 | params["p"],
114 | crop_sizes_px=[params["tile_size"]],
115 | crop_magnifications=[level],
116 | padding_factor=params["overlap"],
117 | remove_background=True,
118 | ratio_object_thresh=0.0001,
119 | )
120 |
121 | # setup output files to write to, also create dummy file to resume inference if interruped
122 |
123 | z_inst = zarr.open(
124 | params["model_out_p"] + "_inst.zip",
125 | mode="w",
126 | shape=(len(dataset), 3, params["tile_size"], params["tile_size"]),
127 | chunks=(params["batch_size"], 3, params["tile_size"], params["tile_size"]),
128 | dtype="f4",
129 | compressor=Blosc(cname="lz4", clevel=3, shuffle=Blosc.SHUFFLE),
130 | )
131 | z_cls = zarr.open(
132 | params["model_out_p"] + "_cls.zip",
133 | mode="w",
134 | shape=(
135 | len(dataset),
136 | params["out_channels_cls"],
137 | params["tile_size"],
138 | params["tile_size"],
139 | ),
140 | chunks=(
141 | params["batch_size"],
142 | params["out_channels_cls"],
143 | params["tile_size"],
144 | params["tile_size"],
145 | ),
146 | dtype="u1",
147 | compressor=Blosc(cname="lz4", clevel=3, shuffle=Blosc.BITSHUFFLE),
148 | )
149 | # creating progress file to restart inference if it was interrupted
150 | with open(prog_path, "w") as f:
151 | f.write("0")
152 | inf_start = 0
153 |
154 | dataloader = DataLoader(
155 | dataset,
156 | batch_size=params["batch_size"],
157 | shuffle=False,
158 | num_workers=params["inf_workers"],
159 | pin_memory=True,
160 | )
161 |
162 | # IO thread to write output in parallel to inference
163 | def dump_results(res, z_cls, z_inst, prog_path):
164 | cls_, inst_, zc_ = res
165 | if cls_ is None:
166 | return
167 | cls_ = (softmax(cls_.astype(np.float32), axis=1) * 255).astype(np.uint8)
168 | z_cls[zc_ : zc_ + cls_.shape[0]] = cls_
169 | z_inst[zc_ : zc_ + inst_.shape[0]] = inst_.astype(np.float32)
170 | with open(prog_path, "w") as f:
171 | f.write(str(zc_))
172 | return
173 |
174 | # Separate thread for IO
175 | with ThreadPoolExecutor(max_workers=params["inf_writers"]) as executor:
176 | futures = []
177 | # run inference
178 | zc = inf_start
179 | for raw, _ in tqdm(dataloader):
180 | raw = raw.to(device, non_blocking=True).float()
181 | raw = raw.permute(0, 3, 1, 2) # BHWC -> BCHW
182 | with torch.inference_mode():
183 | ct, inst = batch_pseudolabel_ensemb(
184 | raw, models, params["tta"], augmenter, color_aug_fn
185 | )
186 | futures.append(
187 | executor.submit(
188 | dump_results,
189 | (ct.cpu().detach().numpy(), inst.cpu().detach().numpy(), zc),
190 | z_cls,
191 | z_inst,
192 | prog_path,
193 | )
194 | )
195 |
196 | zc += params["batch_size"]
197 |
198 | # Block until all data is written
199 | for _ in concurrent.futures.as_completed(futures):
200 | pass
201 | # clean up
202 | if os.path.exists(prog_path):
203 | os.remove(prog_path)
204 | return params, (z_inst, z_cls)
205 |
206 |
207 | def batch_pseudolabel_ensemb(
208 | raw: torch.Tensor,
209 | models: List[torch.nn.Module],
210 | nviews: int,
211 | aug: SpatialAugmenter,
212 | color_aug_fn: torch.nn.Sequential,
213 | ):
214 | """
215 | Run inference step on batch of images with test time augmentations
216 |
217 | Parameters
218 | ----------
219 |
220 | raw: torch.Tensor
221 | batch of input images
222 | models: List[torch.nn.Module]
223 | list of models to run inference with, e.g. multiple folds or a single model in a list
224 | nviews: int
225 | Number of test-time augmentation views to aggregate
226 | aug: SpatialAugmenter
227 | Augmentation module for geometric transformations
228 | color_aug_fn: torch.nn.Sequential
229 | Color Augmentation module
230 |
231 | Returns
232 | ----------
233 |
234 | ct: torch.Tensor
235 | Per pixel class predictions as a tensor of shape (batch_size, n_classes+1, tilesize, tilesize)
236 | inst: torch.Tensor
237 | Per pixel 3 class prediction map with boundary, background and foreground classes, shape (batch_size, 3, tilesize, tilesize)
238 | """
239 | tmp_3c_view = []
240 | tmp_ct_view = []
241 | # ensure that at least one view is run, even when specifying 1 view with many models
242 | if nviews <= 0:
243 | out_fast = []
244 | with torch.inference_mode():
245 | for mod in models:
246 | with torch.autocast(device_type="cuda", dtype=torch.float16):
247 | out_fast.append(mod(raw))
248 | out_fast = torch.stack(out_fast, axis=0).nanmean(0)
249 | ct = out_fast[:, 5:].softmax(1)
250 | inst = out_fast[:, 2:5].softmax(1)
251 | else:
252 | for _ in range(nviews):
253 | aug.interpolation = "bilinear"
254 | view_aug = aug.forward_transform(raw)
255 | aug.interpolation = "nearest"
256 | view_aug = torch.clamp(color_aug_fn(view_aug), 0, 1)
257 | out_fast = []
258 | with torch.inference_mode():
259 | for mod in models:
260 | with torch.autocast(device_type="cuda", dtype=torch.float16):
261 | out_fast.append(aug.inverse_transform(mod(view_aug)))
262 | out_fast = torch.stack(out_fast, axis=0).nanmean(0)
263 | tmp_3c_view.append(out_fast[:, 2:5].softmax(1))
264 | tmp_ct_view.append(out_fast[:, 5:].softmax(1))
265 | ct = torch.stack(tmp_ct_view).nanmean(0)
266 | inst = torch.stack(tmp_3c_view).nanmean(0)
267 | return ct, inst
268 |
269 |
270 | def get_inference_setup(params):
271 | """
272 | get model/ models and load checkpoint, create augmentation functions and set up parameters for inference
273 | """
274 | models = []
275 | for pth in params["data_dirs"]:
276 | if not os.path.exists(pth):
277 | pth = download_weights(os.path.split(pth)[-1])
278 |
279 | checkpoint_path = f"{pth}/train/best_model"
280 | mod_params = toml.load(f"{pth}/params.toml")
281 | params["out_channels_cls"] = mod_params["out_channels_cls"]
282 | params["inst_channels"] = mod_params["inst_channels"]
283 | model = get_model(
284 | enc=mod_params["encoder"],
285 | out_channels_cls=params["out_channels_cls"],
286 | out_channels_inst=params["inst_channels"],
287 | ).to(device)
288 | model = load_checkpoint(model, checkpoint_path, device)
289 | model.eval()
290 | models.append(copy.deepcopy(model))
291 | # create augmentation functions on device
292 | augmenter = SpatialAugmenter(TTA_AUG_PARAMS).to(device)
293 | color_aug_fn = color_augmentations(False, rank=device)
294 |
295 | if mod_params["dataset"] == "pannuke":
296 | params["pannuke"] = True
297 | else:
298 | params["pannuke"] = False
299 | print(
300 | "processing input using",
301 | "pannuke" if params["pannuke"] else "lizard",
302 | "trained model",
303 | )
304 |
305 | return params, models, augmenter, color_aug_fn
306 |
307 | def download_weights(model_code):
308 | if model_code in VALID_WEIGHTS:
309 | url = f"https://zenodo.org/records/10635618/files/{model_code}.zip"
310 | print("downloading",model_code,"weights to",os.getcwd())
311 | try:
312 | response = requests.get(url, stream=True, timeout=15.0)
313 | except requests.exceptions.Timeout:
314 | print("Timeout")
315 | total_size = int(response.headers.get("content-length", 0))
316 | block_size = 1024 # 1 Kibibyte
317 | with tqdm(total=total_size, unit="iB", unit_scale=True) as t:
318 | with open("cache.zip", "wb") as f:
319 | for data in response.iter_content(block_size):
320 | t.update(len(data))
321 | f.write(data)
322 | with zipfile.ZipFile("cache.zip", "r") as zip:
323 | zip.extractall("")
324 | os.remove("cache.zip")
325 | return model_code
326 | else:
327 | raise ValueError("Model id not found in valid identifiers, please make select one of", VALID_WEIGHTS)
328 |
--------------------------------------------------------------------------------
/src/multi_head_unet.py:
--------------------------------------------------------------------------------
1 | import segmentation_models_pytorch as smp
2 |
3 | # from segmentation_models_pytorch.encoders import get_encoder
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from collections import OrderedDict
8 | from segmentation_models_pytorch.base import modules as md
9 | import segmentation_models_pytorch.base.initialization as init
10 | import timm
11 |
12 |
13 | def load_checkpoint(model, cp_path, device):
14 | """
15 | load checkpoint and fix DataParallel/DistributedDataParallel
16 | """
17 |
18 | cp = torch.load(cp_path, map_location=device)
19 | try:
20 | model.load_state_dict(cp["model_state_dict"])
21 |
22 | print("succesfully loaded model weights")
23 | except:
24 | print("trying secondary checkpoint loading")
25 | state_dict = cp["model_state_dict"]
26 | new_state_dict = OrderedDict()
27 | for k, v in state_dict.items():
28 | name = k[7:] # remove 'module.' of DataParallel/DistributedDataParallel
29 | new_state_dict[name] = v
30 |
31 | model.load_state_dict(new_state_dict)
32 | print("succesfully loaded model weights")
33 | return model
34 |
35 |
36 | class TimmEncoderFixed(nn.Module):
37 | """
38 | Modified version of timm encoder.
39 | Original from: https://github.com/huggingface/pytorch-image-models
40 |
41 | """
42 |
43 | def __init__(
44 | self,
45 | name,
46 | pretrained=True,
47 | in_channels=3,
48 | depth=5,
49 | output_stride=32,
50 | drop_rate=0.5,
51 | drop_path_rate=0.25,
52 | ):
53 | super().__init__()
54 | kwargs = dict(
55 | in_chans=in_channels,
56 | features_only=True,
57 | pretrained=pretrained,
58 | out_indices=tuple(range(depth)),
59 | drop_rate=drop_rate,
60 | drop_path_rate=drop_path_rate,
61 | )
62 |
63 | self.model = timm.create_model(name, **kwargs)
64 |
65 | self._in_channels = in_channels
66 | self._out_channels = [
67 | in_channels,
68 | ] + self.model.feature_info.channels()
69 | self._depth = depth
70 | self._output_stride = output_stride
71 |
72 | def forward(self, x):
73 | features = self.model(x)
74 | features = [
75 | x,
76 | ] + features
77 | return features
78 |
79 | @property
80 | def out_channels(self):
81 | return self._out_channels
82 |
83 | @property
84 | def output_stride(self):
85 | return min(self._output_stride, 2**self._depth)
86 |
87 |
88 | def get_model(
89 | enc="convnextv2_tiny.fcmae_ft_in22k_in1k",
90 | out_channels_cls=8,
91 | out_channels_inst=5,
92 | pretrained=True,
93 | ):
94 | depth = 4 if "next" in enc else 5
95 | encoder = TimmEncoderFixed(
96 | name=enc,
97 | pretrained=pretrained,
98 | in_channels=3,
99 | depth=depth,
100 | output_stride=32,
101 | drop_rate=0.5,
102 | drop_path_rate=0.0,
103 | )
104 |
105 | decoder_channels = (256, 128, 64, 32, 16)[:depth]
106 | decoder_inst = UnetDecoder(
107 | encoder_channels=encoder.out_channels,
108 | decoder_channels=decoder_channels,
109 | n_blocks=len(decoder_channels),
110 | use_batchnorm=False,
111 | center=False,
112 | attention_type=None,
113 | next="next" in enc,
114 | )
115 | decoder_ct = UnetDecoder(
116 | encoder_channels=encoder.out_channels,
117 | decoder_channels=decoder_channels,
118 | n_blocks=len(decoder_channels),
119 | use_batchnorm=False,
120 | center=False,
121 | attention_type=None,
122 | next="next" in enc,
123 | )
124 | head_inst = smp.base.SegmentationHead(
125 | in_channels=decoder_inst.blocks[-1].conv2[0].out_channels,
126 | out_channels=out_channels_inst, # instance channels
127 | activation=None,
128 | kernel_size=1,
129 | )
130 | head_ct = smp.base.SegmentationHead(
131 | in_channels=decoder_ct.blocks[-1].conv2[0].out_channels,
132 | out_channels=out_channels_cls,
133 | activation=None,
134 | kernel_size=1,
135 | )
136 |
137 | decoders = [decoder_inst, decoder_ct]
138 | heads = [head_inst, head_ct]
139 | model = MultiHeadModel(encoder, decoders, heads)
140 | return model
141 |
142 |
143 | class Conv2dReLU(nn.Sequential):
144 | def __init__(
145 | self,
146 | in_channels,
147 | out_channels,
148 | kernel_size,
149 | padding=0,
150 | stride=1,
151 | use_batchnorm=True,
152 | ):
153 | conv = nn.Conv2d(
154 | in_channels,
155 | out_channels,
156 | kernel_size,
157 | stride=stride,
158 | padding=padding,
159 | bias=not (use_batchnorm),
160 | )
161 | relu = nn.ReLU()
162 |
163 | if use_batchnorm:
164 | bn = nn.BatchNorm2d(out_channels)
165 |
166 | else:
167 | bn = nn.Identity()
168 |
169 | super(Conv2dReLU, self).__init__(conv, bn, relu)
170 |
171 |
172 | class DecoderBlock(nn.Module):
173 | def __init__(
174 | self,
175 | in_channels,
176 | skip_channels,
177 | out_channels,
178 | use_batchnorm=True,
179 | attention_type=None,
180 | ):
181 | super().__init__()
182 | self.conv1 = md.Conv2dReLU(
183 | in_channels + skip_channels,
184 | out_channels,
185 | kernel_size=3,
186 | padding=1,
187 | use_batchnorm=use_batchnorm,
188 | )
189 | self.attention1 = md.Attention(
190 | attention_type, in_channels=in_channels + skip_channels
191 | )
192 | self.conv2 = md.Conv2dReLU(
193 | out_channels,
194 | out_channels,
195 | kernel_size=3,
196 | padding=1,
197 | use_batchnorm=use_batchnorm,
198 | )
199 | self.attention2 = md.Attention(attention_type, in_channels=out_channels)
200 |
201 | def forward(self, x, skip=None):
202 | x = F.interpolate(x, scale_factor=2, mode="nearest")
203 | if skip is not None:
204 | x = torch.cat([x, skip], dim=1)
205 | x = self.attention1(x)
206 | x = self.conv1(x)
207 | x = self.conv2(x)
208 | x = self.attention2(x)
209 | return x
210 |
211 |
212 | class CenterBlock(nn.Sequential):
213 | def __init__(self, in_channels, out_channels, use_batchnorm=True):
214 | conv1 = md.Conv2dReLU(
215 | in_channels,
216 | out_channels,
217 | kernel_size=3,
218 | padding=1,
219 | use_batchnorm=use_batchnorm,
220 | )
221 | conv2 = md.Conv2dReLU(
222 | out_channels,
223 | out_channels,
224 | kernel_size=3,
225 | padding=1,
226 | use_batchnorm=use_batchnorm,
227 | )
228 | super().__init__(conv1, conv2)
229 |
230 |
231 | class UnetDecoder(nn.Module):
232 | def __init__(
233 | self,
234 | encoder_channels,
235 | decoder_channels,
236 | n_blocks=5,
237 | use_batchnorm=False,
238 | attention_type=None,
239 | center=False,
240 | next=False,
241 | ):
242 | super().__init__()
243 |
244 | if n_blocks != len(decoder_channels):
245 | raise ValueError(
246 | "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
247 | n_blocks, len(decoder_channels)
248 | )
249 | )
250 |
251 | # remove first skip with same spatial resolution
252 | encoder_channels = encoder_channels[1:]
253 | # reverse channels to start from head of encoder
254 | encoder_channels = encoder_channels[::-1]
255 |
256 | # computing blocks input and output channels
257 | head_channels = encoder_channels[0]
258 | in_channels = [head_channels] + list(decoder_channels[:-1])
259 | skip_channels = list(encoder_channels[1:]) + [0]
260 | out_channels = decoder_channels
261 |
262 | if center:
263 | self.center = CenterBlock(
264 | head_channels, head_channels, use_batchnorm=use_batchnorm
265 | )
266 | else:
267 | self.center = nn.Identity()
268 |
269 | # combine decoder keyword arguments
270 | kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type)
271 | blocks = [
272 | DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
273 | for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels)
274 | ]
275 | if next:
276 | blocks.append(
277 | DecoderBlock(out_channels[-1], 0, out_channels[-1] // 2, **kwargs)
278 | )
279 | self.blocks = nn.ModuleList(blocks)
280 |
281 | def forward(self, *features):
282 | features = features[1:] # remove first skip with same spatial resolution
283 | features = features[::-1] # reverse channels to start from head of encoder
284 |
285 | head = features[0]
286 | skips = features[1:]
287 |
288 | x = self.center(head)
289 | for i, decoder_block in enumerate(self.blocks):
290 | skip = skips[i] if i < len(skips) else None
291 | x = decoder_block(x, skip)
292 |
293 | return x
294 |
295 |
296 | class MultiHeadModel(torch.nn.Module):
297 | def __init__(self, encoder, decoder_list, head_list):
298 | super(MultiHeadModel, self).__init__()
299 | self.encoder = nn.ModuleList([encoder])[0]
300 | self.decoders = nn.ModuleList(decoder_list)
301 | self.heads = nn.ModuleList(head_list)
302 | self.initialize()
303 |
304 | def initialize(self):
305 | for decoder in self.decoders:
306 | init.initialize_decoder(decoder)
307 | for head in self.heads:
308 | init.initialize_head(head)
309 |
310 | def check_input_shape(self, x):
311 | h, w = x.shape[-2:]
312 | output_stride = self.encoder.output_stride
313 | if h % output_stride != 0 or w % output_stride != 0:
314 | new_h = (
315 | (h // output_stride + 1) * output_stride
316 | if h % output_stride != 0
317 | else h
318 | )
319 | new_w = (
320 | (w // output_stride + 1) * output_stride
321 | if w % output_stride != 0
322 | else w
323 | )
324 | raise RuntimeError(
325 | f"Wrong input shape height={h}, width={w}. Expected image height and width "
326 | f"divisible by {output_stride}. Consider pad your images to shape ({new_h}, {new_w})."
327 | )
328 |
329 | def forward(self, x):
330 | """Sequentially pass `x` trough model`s encoder, decoder and heads"""
331 |
332 | # self.check_input_shape(x)
333 |
334 | features = self.encoder(x)
335 | decoder_outputs = []
336 | for decoder in self.decoders:
337 | decoder_outputs.append(decoder(*features))
338 |
339 | masks = []
340 | for head, decoder_output in zip(self.heads, decoder_outputs):
341 | masks.append(head(decoder_output))
342 |
343 | return torch.cat(masks, 1)
344 |
345 | @torch.no_grad()
346 | def predict(self, x):
347 | """Inference method. Switch model to `eval` mode, call `.forward(x)` with `torch.no_grad()`
348 | Args:
349 | x: 4D torch tensor with shape (batch_size, channels, height, width)
350 | Return:
351 | prediction: 4D torch tensor with shape (batch_size, classes, height, width)
352 | """
353 | if self.training:
354 | self.eval()
355 |
356 | x = self.forward(x)
357 |
358 | return x
359 |
--------------------------------------------------------------------------------
/src/post_process.py:
--------------------------------------------------------------------------------
1 | from src.post_process_utils import (
2 | work,
3 | write,
4 | get_pp_params,
5 | get_shapes,
6 | get_tile_coords,
7 | )
8 | from src.viz_utils import create_tsvs, create_polygon_output
9 | from src.data_utils import NpyDataset, ImageDataset
10 | from typing import List, Tuple
11 | import zarr
12 | from numcodecs import Blosc
13 | from concurrent.futures import ProcessPoolExecutor
14 | import concurrent.futures
15 | import json
16 | import os
17 | from typing import Union
18 | from tqdm.auto import tqdm
19 | from src.viz_utils import create_geojson
20 | from src.constants import (
21 | CLASS_LABELS_LIZARD,
22 | CLASS_LABELS_PANNUKE,
23 | )
24 |
25 |
26 | def post_process_main(
27 | params: dict,
28 | z: Union[Tuple[zarr.ZipStore, zarr.ZipStore], None] = None,
29 | ):
30 | """
31 | Post processing function for inference results. Computes stitched output maps and refines prediction results and produces instance and class maps
32 |
33 | Parameters
34 | ----------
35 |
36 | params: dict
37 | Parameter store, defined in initial main
38 |
39 | Returns
40 | ----------
41 | z_pp: zarr.ZipStore
42 | instance segmentation results as zarr store, kept open for further processing
43 |
44 | """
45 | # get best parameters for respective evaluation metric
46 |
47 | params = get_pp_params(params, True)
48 | params, ds_coord = get_shapes(params, len(params["best_fg_thresh_cl"]))
49 |
50 | tile_crds = get_tile_coords(
51 | params["out_img_shape"],
52 | params["pp_tiling"],
53 | pad_size=params["pp_overlap"],
54 | npy=params["input_type"] != "wsi",
55 | )
56 | if params["input_type"] == "wsi":
57 | pinst_out = zarr.zeros(
58 | shape=(
59 | params["out_img_shape"][-1],
60 | params["out_img_shape"][-2],
61 | ),
62 | dtype="i4",
63 | compressor=Blosc(cname="lz4", clevel=3, shuffle=Blosc.SHUFFLE),
64 | )
65 |
66 | else:
67 | pinst_out = zarr.zeros(
68 | shape=(params["orig_shape"][0], *params["orig_shape"][-2:]),
69 | dtype="i4",
70 | compressor=Blosc(cname="lz4", clevel=3, shuffle=Blosc.SHUFFLE),
71 | )
72 |
73 | executor = ProcessPoolExecutor(max_workers=params["pp_workers"])
74 | tile_processors = [
75 | executor.submit(work, tcrd, ds_coord, z, params) for tcrd in tile_crds
76 | ]
77 | pcls_out = {}
78 | running_max = 0
79 | class_labels = []
80 | res_poly = []
81 | for future in tqdm(
82 | concurrent.futures.as_completed(tile_processors), total=len(tile_processors)
83 | ):
84 | pinst_out, pcls_out, running_max, class_labels, res_poly = write(
85 | pinst_out, pcls_out, running_max, future.result(), params, class_labels, res_poly
86 | )
87 | executor.shutdown(wait=False)
88 |
89 | if params["output_dir"] is not None:
90 | print("saving final output")
91 | zarr.save(os.path.join(params["output_dir"], "pinst_pp.zip"), pinst_out)
92 | print("storing class dictionary...")
93 | with open(os.path.join(params["output_dir"], "class_inst.json"), "w") as fp:
94 | json.dump(pcls_out, fp)
95 |
96 | if params["input_type"] == "wsi":
97 | print("saving geojson coordinates for qupath...")
98 | create_tsvs(pcls_out, params)
99 | # TODO this is way to slow for large images
100 | if params["save_polygon"]:
101 | pred_keys = CLASS_LABELS_PANNUKE if params["pannuke"] else CLASS_LABELS_LIZARD
102 | create_geojson(
103 | res_poly,
104 | class_labels,
105 | dict((v, k) for k, v in pred_keys.items()),
106 | params,
107 | )
108 |
109 | return pinst_out
110 |
--------------------------------------------------------------------------------
/src/post_process_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import zarr
4 | import gc
5 | import json
6 | import os
7 | import time
8 | import openslide
9 | from skimage.segmentation import watershed
10 | from scipy.ndimage import find_objects
11 | from numcodecs import Blosc
12 | from src.viz_utils import cont
13 | from skimage.measure import regionprops
14 | from src.constants import (
15 | MIN_THRESHS_LIZARD,
16 | MIN_THRESHS_PANNUKE,
17 | MAX_THRESHS_LIZARD,
18 | MAX_THRESHS_PANNUKE,
19 | MAX_HOLE_SIZE,
20 | LUT_MAGNIFICATION_MPP,
21 | LUT_MAGNIFICATION_X,
22 | )
23 | from src.data_utils import center_crop, WholeSlideDataset, NpyDataset, ImageDataset
24 |
25 |
26 | def update_dicts(pinst_, pcls_, pcls_out, t_, old_ids, initial_ids):
27 | props = [(p.label, p.centroid) for p in regionprops(pinst_)]
28 | pcls_new = {}
29 | for id_, cen in props:
30 | try:
31 | pcls_new[str(id_)] = (pcls_[str(id_)], (cen[0] + t_[2], cen[1] + t_[0]))
32 | except KeyError:
33 | pcls_new[str(id_)] = (pcls_out[str(id_)], (cen[0] + t_[2], cen[1] + t_[0]))
34 |
35 | new_ids = [p[0] for p in props]
36 |
37 | for i in np.setdiff1d(old_ids, new_ids):
38 | try:
39 | del pcls_out[str(i)]
40 | except KeyError:
41 | pass
42 | for i in np.setdiff1d(new_ids, initial_ids):
43 | try:
44 | del pcls_new[str(i)]
45 | except KeyError:
46 | pass
47 | return pcls_out | pcls_new
48 |
49 |
50 | def write(pinst_out, pcls_out, running_max, res, params, class_labels, res_poly):
51 | pinst_, pcls_, max_, t_, skip = res
52 | if not skip:
53 | if params["input_type"] != "wsi":
54 | pinst_.vindex[pinst_[:] != 0] += running_max
55 | pcls_ = {str(int(k) + running_max): v for k, v in pcls_.items()}
56 | props = [(p.label, p.centroid) for p in regionprops(pinst_)]
57 | pcls_new = {}
58 | for id_, cen in props:
59 | pcls_new[str(id_)] = (pcls_[str(id_)], (t_[-1], cen[0], cen[1]))
60 |
61 | running_max += max_
62 | pcls_out |= pcls_new
63 | pinst_out[t_[-1]] = np.asarray(pinst_, dtype=np.int32)
64 |
65 | else:
66 | pinst_ = np.asarray(pinst_, dtype=np.int32)
67 | ov_regions, local_regions, which = get_overlap_regions(
68 | t_, params["pp_overlap"], pinst_out.shape
69 | )
70 | msk = pinst_ != 0
71 | pinst_[msk] += running_max
72 | pcls_ = {str(int(k) + running_max): v for k, v in pcls_.items()}
73 | running_max += max_
74 | initial_ids = np.unique(pinst_[msk])
75 | old_ids = []
76 |
77 | for reg, loc, whi in zip(ov_regions, local_regions, which):
78 | if reg is None:
79 | continue
80 |
81 | written = np.array(
82 | pinst_out[reg[2] : reg[3], reg[0] : reg[1]], dtype=np.int32
83 | )
84 | old_ids.append(np.unique(written[written != 0]))
85 |
86 | small, large = get_subregions(whi, written.shape)
87 | subregion = written[
88 | small[0] : small[1], small[2] : small[3]
89 | ] # 1/4 of the region
90 | larger_subregion = written[
91 | large[0] : large[1], large[2] : large[3]
92 | ] # 1/2 of the region
93 | keep = np.unique(subregion[subregion != 0])
94 | if len(keep) == 0:
95 | continue
96 |
97 | keep_objects = find_objects(
98 | larger_subregion, max_label=max(keep)
99 | ) # [keep-1]
100 | pinst_reg = pinst_[loc[2] : loc[3], loc[0] : loc[1]][
101 | large[0] : large[1], large[2] : large[3]
102 | ]
103 |
104 | for id_ in keep:
105 | obj = keep_objects[id_ - 1]
106 | if obj is None:
107 | continue
108 | written_mask = larger_subregion[obj] == id_
109 | pinst_reg[obj][written_mask] = id_
110 |
111 | old_ids = np.concatenate(old_ids)
112 | pcls_out = update_dicts(pinst_, pcls_, pcls_out, t_, old_ids, initial_ids)
113 | pinst_out[t_[2] : t_[3], t_[0] : t_[1]] = pinst_
114 | if params["save_polygon"]:
115 | props = [(p.label, p.image, p.bbox) for p in regionprops(np.asarray(pinst_))]
116 | class_labels_partial = [pcls_out[str(p[0])] for p in props]
117 | res_poly_partial = [cont(i, [t_[2], t_[0]]) for i in props]
118 | class_labels.extend(class_labels_partial)
119 | res_poly.extend(res_poly_partial)
120 | # res.task_done()
121 |
122 | return pinst_out, pcls_out, running_max, class_labels, res_poly
123 |
124 |
125 | def work(tcrd, ds_coord, z, params):
126 | out_img = gen_tile_map(
127 | tcrd,
128 | ds_coord,
129 | params["ccrop"],
130 | model_out_p=params["model_out_p"],
131 | which="_inst",
132 | dim=params["out_img_shape"][-3],
133 | z=z,
134 | npy=params["input_type"] != "wsi",
135 | )
136 | out_cls = gen_tile_map(
137 | tcrd,
138 | ds_coord,
139 | params["ccrop"],
140 | model_out_p=params["model_out_p"],
141 | which="_cls",
142 | dim=params["out_cls_shape"][-3],
143 | z=z,
144 | npy=params["input_type"] != "wsi",
145 | )
146 | if params["input_type"] != "wsi":
147 | out_img = out_img[
148 | :,
149 | params["tile_size"] : -params["tile_size"],
150 | params["tile_size"] : -params["tile_size"],
151 | ]
152 | out_cls = out_cls[
153 | :,
154 | params["tile_size"] : -params["tile_size"],
155 | params["tile_size"] : -params["tile_size"],
156 | ]
157 | best_min_threshs = MIN_THRESHS_PANNUKE if params["pannuke"] else MIN_THRESHS_LIZARD
158 | best_max_threshs = MAX_THRESHS_PANNUKE if params["pannuke"] else MAX_THRESHS_LIZARD
159 |
160 | # using apply_func to apply along axis for npy stacks
161 | pred_inst, skip = faster_instance_seg(
162 | out_img, out_cls, params["best_fg_thresh_cl"], params["best_seed_thresh_cl"]
163 | )
164 | del out_img
165 | gc.collect()
166 | max_hole_size = MAX_HOLE_SIZE if params["pannuke"] else (MAX_HOLE_SIZE // 4)
167 | if skip:
168 | pred_inst = zarr.array(
169 | pred_inst, compressor=Blosc(cname="zstd", clevel=3, shuffle=Blosc.SHUFFLE)
170 | )
171 |
172 | return (pred_inst, {}, 0, tcrd, skip)
173 | pred_inst = post_proc_inst(
174 | pred_inst,
175 | max_hole_size,
176 | )
177 | pred_ct = make_ct(out_cls, pred_inst)
178 | del out_cls
179 | gc.collect()
180 |
181 | processed = remove_obj_cls(pred_inst, pred_ct, best_min_threshs, best_max_threshs)
182 | # TODO why is this here?
183 | pred_inst, pred_ct = processed
184 | max_inst = np.max(pred_inst)
185 | pred_inst = zarr.array(
186 | pred_inst.astype(np.int32),
187 | compressor=Blosc(cname="zstd", clevel=3, shuffle=Blosc.SHUFFLE),
188 | )
189 | return (pred_inst, pred_ct, max_inst, tcrd, skip)
190 |
191 |
192 | def get_overlap_regions(tcrd, pad_size, out_img_shape):
193 | top = [tcrd[0], tcrd[0] + 2 * pad_size, tcrd[2], tcrd[3]] if tcrd[0] != 0 else None
194 | bottom = (
195 | [tcrd[1] - 2 * pad_size, tcrd[1], tcrd[2], tcrd[3]]
196 | if tcrd[1] != out_img_shape[-2]
197 | else None
198 | )
199 | left = [tcrd[0], tcrd[1], tcrd[2], tcrd[2] + 2 * pad_size] if tcrd[2] != 0 else None
200 | right = (
201 | [tcrd[0], tcrd[1], tcrd[3] - 2 * pad_size, tcrd[3]]
202 | if tcrd[3] != out_img_shape[-1]
203 | else None
204 | )
205 | d_top = [0, 2 * pad_size, 0, tcrd[3] - tcrd[2]]
206 | d_bottom = [
207 | tcrd[1] - tcrd[0] - 2 * pad_size,
208 | tcrd[1] - tcrd[0],
209 | 0,
210 | tcrd[3] - tcrd[2],
211 | ]
212 | d_left = [0, tcrd[1] - tcrd[0], 0, 2 * pad_size]
213 | d_right = [
214 | 0,
215 | tcrd[1] - tcrd[0],
216 | tcrd[3] - tcrd[2] - 2 * pad_size,
217 | tcrd[3] - tcrd[2],
218 | ]
219 | return (
220 | [top, bottom, left, right],
221 | [d_top, d_bottom, d_left, d_right],
222 | ["top", "bottom", "left", "right"],
223 | ) #
224 |
225 |
226 | def get_subregions(which, shape):
227 | """
228 | Note that the names are incorrect :), inconsistency to be fixed with coordinates and xy swap
229 | """
230 | if which == "top":
231 | return [0, shape[0], 0, shape[1] // 4], [0, shape[0], 0, shape[1] // 2]
232 | elif which == "bottom":
233 | return [0, shape[0], (shape[1] * 3) // 4, shape[1]], [
234 | 0,
235 | shape[0],
236 | shape[1] // 2,
237 | shape[1],
238 | ]
239 | elif which == "left":
240 | return [0, shape[0] // 4, 0, shape[1]], [0, shape[0] // 2, 0, shape[1]]
241 | elif which == "right":
242 | return [(shape[0] * 3) // 4, shape[0], 0, shape[1]], [
243 | shape[0] // 2,
244 | shape[0],
245 | 0,
246 | shape[1],
247 | ]
248 |
249 | else:
250 | raise ValueError("Invalid which")
251 |
252 |
253 | def expand_bbox(bbox, pad_size, img_size):
254 | return [
255 | max(0, bbox[0] - pad_size),
256 | max(0, bbox[1] - pad_size),
257 | min(img_size[0], bbox[2] + pad_size),
258 | min(img_size[1], bbox[3] + pad_size),
259 | ]
260 |
261 |
262 | def get_tile_coords(shape, splits, pad_size, npy):
263 | if npy:
264 | tile_crds = [[0, shape[-2], 0, shape[-1], i] for i in range(shape[0])]
265 | return tile_crds
266 |
267 | else:
268 | shape = shape[-2:]
269 | tile_crds = []
270 | ts_1 = np.array_split(np.arange(0, shape[0]), splits)
271 | ts_2 = np.array_split(np.arange(0, shape[1]), splits)
272 | for i in ts_1:
273 | for j in ts_2:
274 | x_start = 0 if i[0] < pad_size else i[0] - pad_size
275 | x_end = shape[0] if i[-1] + pad_size > shape[0] else i[-1] + pad_size
276 | y_start = 0 if j[0] < pad_size else j[0] - pad_size
277 | y_end = shape[1] if j[-1] + pad_size > shape[1] else j[-1] + pad_size
278 | tile_crds.append([x_start, x_end, y_start, y_end])
279 | return tile_crds
280 |
281 |
282 | def proc_tile(t, ccrop, which="_cls"):
283 | t = center_crop(t, ccrop, ccrop)
284 | if which == "_cls":
285 | t = t[1:]
286 | t = t.reshape(t.shape[0], -1)
287 | out = np.zeros(t.shape, dtype=bool)
288 | out[t.argmax(axis=0), np.arange(t.shape[1])] = 1
289 | t = out.reshape(-1, ccrop, ccrop)
290 |
291 | else:
292 | t = t[:2].astype(np.float16)
293 | return t
294 |
295 |
296 | def gen_tile_map(
297 | tile_crd,
298 | ds_coord,
299 | ccrop,
300 | model_out_p="",
301 | which="_cls",
302 | dim=5,
303 | z=None,
304 | npy=False,
305 | ):
306 | if z is None:
307 | z = zarr.open(model_out_p + f"{which}.zip", mode="r")
308 | else:
309 | if which == "_cls":
310 | z = z[1]
311 | else:
312 | z = z[0]
313 | cadj = (z.shape[-1] - ccrop) // 2
314 | tx, ty, tz = None, None, None
315 | dtype = bool if which == "_cls" else np.float16
316 |
317 | if npy:
318 | # TODO fix npy
319 | coord_filter = ds_coord[:, 0] == tile_crd[-1]
320 | ds_coord_subset = ds_coord[coord_filter]
321 | zero_map = np.zeros(
322 | (dim, tile_crd[1] - tile_crd[0], tile_crd[3] - tile_crd[2]), dtype=dtype
323 | )
324 | else:
325 | zero_map = np.zeros(
326 | (dim, tile_crd[3] - tile_crd[2], tile_crd[1] - tile_crd[0]), dtype=dtype
327 | )
328 | coord_filter = (
329 | ((ds_coord[:, 0]) < tile_crd[1])
330 | & ((ds_coord[:, 0] + ccrop) > tile_crd[0])
331 | & ((ds_coord[:, 1]) < tile_crd[3])
332 | & ((ds_coord[:, 1] + ccrop) > tile_crd[2])
333 | )
334 | ds_coord_subset = ds_coord[coord_filter] - np.array([tile_crd[0], tile_crd[2]])
335 |
336 | z_address = np.arange(ds_coord.shape[0])[coord_filter]
337 | for _, (crd, tile) in enumerate(zip(ds_coord_subset, z[z_address])):
338 | if npy:
339 | tz, ty, tx = crd
340 | else:
341 | tx, ty = crd
342 | tx = tx
343 | ty = ty
344 | p_shift = [abs(i) if i < 0 else 0 for i in [ty, tx]]
345 | n_shift = [
346 | crd - (i + ccrop) if (i + ccrop) > crd else 0
347 | for i, crd in zip([ty, tx], zero_map.shape[1:3])
348 | ]
349 | try:
350 | zero_map[
351 | :,
352 | ty + p_shift[0] : ty + ccrop + n_shift[0],
353 | tx + p_shift[1] : tx + ccrop + n_shift[1],
354 | ] = proc_tile(tile, ccrop, which)[
355 | ...,
356 | p_shift[0] : ccrop + n_shift[0],
357 | p_shift[1] : ccrop + n_shift[1],
358 | ]
359 |
360 | except:
361 | print(zero_map.shape)
362 | print(tx)
363 | print(ty)
364 | print(ccrop)
365 | print(tile.shape)
366 | raise ValueError
367 | return zero_map
368 |
369 |
370 | def faster_instance_seg(out_img, out_cls, best_fg_thresh_cl, best_seed_thresh_cl):
371 | _, rois = cv2.connectedComponents((out_img[0] > 0).astype(np.uint8), connectivity=8)
372 | bboxes = find_objects(rois)
373 | del rois
374 | gc.collect()
375 | skip = False
376 | labelling = zarr.zeros(
377 | out_cls.shape[1:],
378 | dtype=np.int32,
379 | compressor=Blosc(cname="lz4", clevel=3, shuffle=Blosc.BITSHUFFLE),
380 | )
381 | if len(bboxes) == 0:
382 | skip = True
383 | return labelling, skip
384 | max_inst = 0
385 | for bb in bboxes:
386 | bg_pred = out_img[(slice(0, 1, None), *bb)].squeeze()
387 | if (
388 | (np.array(bg_pred.shape[-2:]) <= 2).any()
389 | | (np.array(bg_pred.shape).sum() <= 64)
390 | | (len(bg_pred.shape) < 2)
391 | ):
392 | continue
393 | fg_pred = out_img[(slice(1, 2, None), *bb)].squeeze()
394 | sem = out_cls[(slice(0, len(best_fg_thresh_cl), None), *bb)]
395 | ws_surface = 1.0 - fg_pred # .astype(np.float32)
396 | fg = np.zeros_like(ws_surface, dtype="bool")
397 | seeds = np.zeros_like(ws_surface, dtype="bool")
398 |
399 | for cl, fg_t in enumerate(best_fg_thresh_cl):
400 | mask = sem[cl]
401 | fg[mask] |= (1.0 - bg_pred[mask]) > fg_t
402 | seeds[mask] |= fg_pred[mask] > best_seed_thresh_cl[cl]
403 |
404 | del fg_pred, bg_pred, sem, mask
405 | gc.collect()
406 | _, markers = cv2.connectedComponents((seeds).astype(np.uint8), connectivity=8)
407 | del seeds
408 | gc.collect()
409 | bb_ws = watershed(ws_surface, markers, mask=fg, connectivity=2)
410 | del ws_surface, markers, fg
411 | gc.collect()
412 | bb_ws[bb_ws != 0] += max_inst
413 | labelling[bb] = bb_ws
414 | max_inst = np.max(bb_ws)
415 | del bb_ws
416 | gc.collect()
417 | return labelling, skip
418 |
419 |
420 | def get_wsi(wsi_path, read_ds=32, pannuke=False, tile_size=256, padding_factor=0.96875):
421 | # TODO change this so it works with non-rescaled version as well
422 | ccrop = int(tile_size * padding_factor)
423 | level = 40 if pannuke else 20
424 | crop_adj = int((tile_size - ccrop) // 2)
425 |
426 | ws_ds = WholeSlideDataset(
427 | wsi_path,
428 | crop_sizes_px=[tile_size],
429 | crop_magnifications=[level],
430 | padding_factor=padding_factor,
431 | ratio_object_thresh=0.0001,
432 | )
433 | sl = ws_ds.s # openslide.open_slide(wsi_path)
434 | sl_info = get_openslide_info(sl)
435 | target_level = np.argwhere(np.isclose(sl_info["level_downsamples"], read_ds)).item()
436 | ds_coord = ws_ds.crop_metadatas[0]
437 | ds_coord[:, 2:4] -= np.array([sl_info["bounds_x"], sl_info["bounds_y"]])
438 |
439 | ds_coord[:, 2:4] += tile_size - ccrop
440 | w, h = np.max(ds_coord[:, 2:4], axis=0)
441 |
442 | raw = np.asarray(
443 | sl.read_region(
444 | (
445 | sl_info["bounds_x"] + crop_adj,
446 | sl_info["bounds_y"] + crop_adj,
447 | ),
448 | target_level,
449 | (
450 | int((w + ccrop) // (sl_info["level_downsamples"][target_level])),
451 | int((h + ccrop) // (sl_info["level_downsamples"][target_level])),
452 | ),
453 | )
454 | )
455 | raw = raw[..., :3]
456 | sl.close()
457 | return raw
458 |
459 |
460 | def post_proc_inst(
461 | pred_inst,
462 | hole_size=50,
463 | ):
464 | pshp = pred_inst.shape
465 | pred_inst = np.asarray(pred_inst)
466 | init = find_objects(pred_inst)
467 | init_large = []
468 | adj = 8
469 | for i, sl in enumerate(init):
470 | if sl:
471 | slx1 = sl[0].start - adj if (sl[0].start - adj) > 0 else 0
472 | slx2 = sl[0].stop + adj if (sl[0].stop + adj) < pshp[0] else pshp[0]
473 | sly1 = sl[1].start - adj if (sl[1].start - adj) > 0 else 0
474 | sly2 = sl[1].stop + adj if (sl[1].stop + adj) < pshp[1] else pshp[1]
475 | init_large.append(
476 | (i + 1, (slice(slx1, slx2, None), slice(sly1, sly2, None)))
477 | )
478 | out = np.zeros(pshp, dtype=np.int32)
479 | i = 1
480 | for sl in init_large:
481 | rm_small_hole = remove_small_holescv2(pred_inst[sl[1]] == (sl[0]), hole_size)
482 | out[sl[1]][rm_small_hole > 0] = i
483 | i += 1
484 |
485 | del pred_inst
486 | gc.collect()
487 |
488 | after_sh = find_objects(out)
489 | out_ = np.zeros(out.shape, dtype=np.int32)
490 | i_ = 1
491 | for i, sl in enumerate(after_sh):
492 | i += 1
493 | if sl:
494 | nr_objects, relabeled = cv2.connectedComponents(
495 | (out[sl] == i).astype(np.uint8), connectivity=8
496 | )
497 | for new_lab in range(1, nr_objects):
498 | out_[sl] += (relabeled == new_lab) * i_
499 | i_ += 1
500 | return out_
501 |
502 |
503 | def make_ct(pred_class, instance_map):
504 | if type(pred_class) != np.ndarray:
505 | pred_class = pred_class[:]
506 | slices = find_objects(instance_map)
507 | pred_class = np.rollaxis(pred_class, 0, 3)
508 | # pred_class = softmax(pred_class,0)
509 | out = []
510 | out.append((0, 0))
511 | for i, sl in enumerate(slices):
512 | i += 1
513 | if sl:
514 | inst = instance_map[sl] == i
515 | i_cls = pred_class[sl][inst]
516 | i_cls = np.sum(i_cls, axis=0).argmax() + 1
517 | out.append((i, i_cls))
518 | out_ = np.array(out)
519 | pred_ct = {str(k): int(v) for k, v in out_ if v != 0}
520 | return pred_ct
521 |
522 |
523 | def remove_obj_cls(pred_inst, pred_cls_dict, best_min_threshs, best_max_threshs):
524 | out_oi = np.zeros_like(pred_inst, dtype=np.int64)
525 | i_ = 1
526 | out_oc = []
527 | out_oc.append((0, 0))
528 | slices = find_objects(pred_inst)
529 |
530 | for i, sl in enumerate(slices):
531 | i += 1
532 | px = np.sum([pred_inst[sl] == i])
533 | cls_ = pred_cls_dict[str(i)]
534 | if (px > best_min_threshs[cls_ - 1]) & (px < best_max_threshs[cls_ - 1]):
535 | out_oc.append((i_, cls_))
536 | out_oi[sl][pred_inst[sl] == i] = i_
537 | i_ += 1
538 | out_oc = np.array(out_oc)
539 | out_dict = {str(k): int(v) for k, v in out_oc if v != 0}
540 | return out_oi, out_dict
541 |
542 |
543 | def remove_small_holescv2(img, sz):
544 | # this is still pretty slow but at least its a bit faster than other approaches?
545 | img = np.logical_not(img).astype(np.uint8)
546 |
547 | nb_blobs, im_with_separated_blobs, stats, _ = cv2.connectedComponentsWithStats(img)
548 | # stats (and the silenced output centroids) gives some information about the blobs. See the docs for more information.
549 | # here, we're interested only in the size of the blobs, contained in the last column of stats.
550 | sizes = stats[1:, -1]
551 | nb_blobs -= 1
552 | im_result = np.zeros((img.shape), dtype=np.uint16)
553 | for blob in range(nb_blobs):
554 | if sizes[blob] >= sz:
555 | im_result[im_with_separated_blobs == blob + 1] = 1
556 |
557 | im_result = np.logical_not(im_result)
558 | return im_result
559 |
560 |
561 | def get_pp_params(params, mit_eval=False):
562 | eval_metric = params["metric"]
563 | fg_threshs = []
564 | seed_threshs = []
565 | for exp in params["data_dirs"]:
566 | mod_path = os.path.join(params["root"], exp)
567 | if "pannuke" in exp:
568 | with open(
569 | os.path.join(mod_path, "pannuke_test_param_dict.json"), "r"
570 | ) as js:
571 | dt = json.load(js)
572 | fg_threshs.append(dt[f"best_fg_{eval_metric}"])
573 | seed_threshs.append(dt[f"best_seed_{eval_metric}"])
574 | elif mit_eval:
575 | with open(os.path.join(mod_path, "liz_test_param_dict.json"), "r") as js:
576 | dt = json.load(js)
577 | fg_tmp = dt[f"best_fg_{eval_metric}"]
578 | seed_tmp = dt[f"best_seed_{eval_metric}"]
579 | with open(os.path.join(mod_path, "mit_test_param_dict.json"), "r") as js:
580 | dt = json.load(js)
581 | fg_tmp[-1] = dt[f"best_fg_{eval_metric}"][-1]
582 | seed_tmp[-1] = dt[f"best_seed_{eval_metric}"][-1]
583 | fg_threshs.append(fg_tmp)
584 | seed_threshs.append(seed_tmp)
585 | else:
586 | with open(os.path.join(mod_path, "param_dict.json"), "r") as js:
587 | dt = json.load(js)
588 | fg_threshs.append(dt[f"best_fg_{eval_metric}"])
589 | seed_threshs.append(dt[f"best_seed_{eval_metric}"])
590 | params["best_fg_thresh_cl"] = np.mean(fg_threshs, axis=0)
591 | params["best_seed_thresh_cl"] = np.mean(seed_threshs, axis=0)
592 | print(params["best_fg_thresh_cl"], params["best_seed_thresh_cl"])
593 |
594 | return params
595 |
596 |
597 | def get_shapes(params, nclasses):
598 | padding_factor = params["overlap"]
599 | tile_size = params["tile_size"]
600 | ds_factor = 1
601 | if params["input_type"] in ["img", "npy"]:
602 | if params["input_type"] == "npy":
603 | dataset = NpyDataset(
604 | params["p"],
605 | tile_size,
606 | padding_factor=padding_factor,
607 | ratio_object_thresh=0.3,
608 | min_tiss=0.1,
609 | )
610 | else:
611 | dataset = ImageDataset(
612 | params["p"],
613 | params["tile_size"],
614 | padding_factor=params["overlap"],
615 | ratio_object_thresh=0.3,
616 | min_tiss=0.1,
617 | )
618 | params["orig_shape"] = dataset.orig_shape[:-1]
619 | ds_coord = np.array(dataset.idx).astype(int)
620 | shp = dataset.store.shape
621 |
622 | ccrop = int(dataset.padding_factor * dataset.crop_size_px)
623 | coord_adj = (dataset.crop_size_px - ccrop) // 2
624 | ds_coord[:, 1:] += coord_adj
625 | out_img_shape = (shp[0], 2, shp[1], shp[2])
626 | out_cls_shape = (shp[0], nclasses, shp[1], shp[2])
627 | else:
628 | level = 40 if params["pannuke"] else 20
629 | dataset = WholeSlideDataset(
630 | params["p"],
631 | crop_sizes_px=[tile_size],
632 | crop_magnifications=[level],
633 | padding_factor=padding_factor,
634 | ratio_object_thresh=0.0001,
635 | )
636 |
637 | print("getting coords:")
638 | ds_coord = dataset.crop_metadatas[0][:, 2:4].copy()
639 | try:
640 | sl = dataset.s
641 | bounds_x = int(sl.properties["openslide.bounds-x"]) # 158208
642 | bounds_y = int(sl.properties["openslide.bounds-y"]) # 28672
643 | except KeyError:
644 | bounds_x = 0
645 | bounds_y = 0
646 |
647 | ds_coord -= np.array([bounds_x, bounds_y])
648 |
649 | ccrop = int(tile_size * padding_factor)
650 | rel_res = np.isclose(dataset.mpp, LUT_MAGNIFICATION_MPP, rtol=0.2)
651 | if sum(rel_res) != 1:
652 | raise NotImplementedError(
653 | "Currently no support for images with this resolution. Check src.constants in LUT_MAGNIFICATION_MPP and LUT_MAGNIFICATION_X to add the resultion - downsampling pair"
654 | )
655 | else:
656 | ds_factor = LUT_MAGNIFICATION_X[rel_res.argmax()] / level
657 | # if ds_factor < 1:
658 | # raise NotImplementedError(
659 | # "The specified model does not support images at this resolution. Consider supplying a higher resolution image"
660 | # )
661 | ds_coord /= ds_factor
662 |
663 | ds_coord += (tile_size - ccrop) // 2
664 | ds_coord = ds_coord.astype(int)
665 | h, w = np.max(ds_coord, axis=0)
666 | out_img_shape = (2, int(h + ccrop), int(w + ccrop))
667 | out_cls_shape = (nclasses, int(h + ccrop), int(w + ccrop))
668 | params["ds_factor"] = ds_factor
669 | params["out_img_shape"] = out_img_shape
670 | params["out_cls_shape"] = out_cls_shape
671 | params["ccrop"] = ccrop
672 |
673 | return params, ds_coord
674 |
675 |
676 | def get_openslide_info(sl: openslide.OpenSlide):
677 | level_count = len(sl.level_downsamples)
678 | try:
679 | mpp_x = float(sl.properties[openslide.PROPERTY_NAME_MPP_X])
680 | mpp_y = float(sl.properties[openslide.PROPERTY_NAME_MPP_Y])
681 | except KeyError:
682 | print("'No resolution found in WSI metadata, using default .2425")
683 | mpp_x = 0.2425
684 | mpp_y = 0.2425
685 | try:
686 | bounds_x, bounds_y = (
687 | int(sl.properties["openslide.bounds-x"]),
688 | int(sl.properties["openslide.bounds-y"]),
689 | )
690 | except KeyError:
691 | bounds_x = 0
692 | bounds_y = 0
693 | level_downsamples = sl.level_downsamples
694 |
695 | level_mpp_x = [mpp_x * i for i in level_downsamples]
696 | level_mpp_y = [mpp_y * i for i in level_downsamples]
697 | return {
698 | "level_count": level_count,
699 | "mpp_x": mpp_x,
700 | "mpp_y": mpp_y,
701 | "bounds_x": bounds_x,
702 | "bounds_y": bounds_y,
703 | "level_downsamples": level_downsamples,
704 | "level_mpp_x": level_mpp_x,
705 | "level_mpp_y": level_mpp_y,
706 | }
707 |
--------------------------------------------------------------------------------
/src/spatial_augmenter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | from torchvision.transforms.transforms import GaussianBlur
5 | import math
6 |
7 |
8 | class SpatialAugmenter(
9 | torch.nn.Module,
10 | ):
11 |
12 | def __init__(self, params, interpolation="bilinear", padding_mode="zeros"):
13 | """
14 | params= {
15 | 'mirror': {'prob': float [0,1], 'prob_x': float [0,1],'prob_y': float [0,1]},
16 | 'translate': {'max_percent':float [0,1], 'prob': float [0,1]},
17 | 'scale': {'min': float, 'max':float, 'prob': float [0,1]},
18 | 'zoom': {'min': float, 'max':float, 'prob': float [0,1]},
19 | 'rotate': {'rot90': bool, 'max_degree': int [0,360], 'prob': float [0,1]},
20 | 'shear': {'max_percent': float [0,1], 'prob': float [0,1]},
21 | 'elastic': {'alpha': list[float|int], 'sigma': float|int, 'prob': float [0,1]}}
22 | """
23 | super(SpatialAugmenter, self).__init__()
24 | self.params = params
25 | self.mode = "forward"
26 | self.random_state = {}
27 | # fill dict so that augmentation functions can be tested
28 | for key in self.params.keys():
29 | self.random_state[key] = {}
30 | self.interpolation = interpolation
31 | self.padding_mode = padding_mode
32 |
33 | def forward_transform(self, img, label=None, random_state=None):
34 | self.mode = "forward"
35 | self.device = img.device
36 | if random_state:
37 | self.random_state = random_state
38 | else:
39 | for key in self.params.keys():
40 | self.random_state[key] = {
41 | "prob": bool(np.random.binomial(1, self.params[key]["prob"]))
42 | }
43 | for key in list(self.params.keys()):
44 | if self.random_state[key]["prob"]:
45 | # print('Do transform: ', key)
46 | func = getattr(self, key)
47 | img, label = func(img, label=label, random_state=random_state)
48 | if label is not None:
49 | return img, label
50 | else:
51 | return img
52 |
53 | def inverse_transform(self, img, label=None, random_state=None):
54 | self.mode = "inverse"
55 | self.device = img.device
56 | keylist = list(self.params.keys())
57 | keylist.reverse()
58 | if random_state:
59 | self.random_state = random_state
60 | for key in keylist:
61 | if self.random_state[key]["prob"]:
62 | # print('Do inverse transform: ', key)
63 | func = getattr(self, key)
64 | img, label = func(img, label=label)
65 | if label is not None:
66 | return img, label
67 | else:
68 | return img
69 |
70 | def mirror(self, img, label, random_state=None):
71 | if self.mode == "forward" and not random_state:
72 | self.random_state["mirror"]["x"] = bool(
73 | np.random.binomial(1, self.params["mirror"]["prob_x"])
74 | )
75 | self.random_state["mirror"]["y"] = bool(
76 | np.random.binomial(1, self.params["mirror"]["prob_y"])
77 | )
78 | #
79 | x = self.random_state["mirror"]["x"]
80 | y = self.random_state["mirror"]["y"]
81 | if x:
82 | x = -1
83 | else:
84 | x = 1
85 | if y:
86 | y = -1
87 | else:
88 | y = 1
89 | theta = torch.tensor(
90 | [[[x, 0.0, 0.0], [0.0, y, 0.0]]], device=self.device, dtype=img.dtype
91 | )
92 | grid = F.affine_grid(
93 | theta.repeat(img.size()[0], 1, 1), img.size(), align_corners=False
94 | )
95 | if label is not None:
96 | return F.grid_sample(
97 | img,
98 | grid,
99 | mode=self.interpolation,
100 | padding_mode=self.padding_mode,
101 | align_corners=False,
102 | ), F.grid_sample(
103 | label,
104 | grid,
105 | mode="nearest",
106 | padding_mode=self.padding_mode,
107 | align_corners=False,
108 | )
109 | else:
110 | return (
111 | F.grid_sample(
112 | img,
113 | grid,
114 | mode=self.interpolation,
115 | padding_mode=self.padding_mode,
116 | align_corners=False,
117 | ),
118 | None,
119 | )
120 |
121 | def translate(self, img, label, random_state=None):
122 | if self.mode == "forward" and not random_state:
123 | x = np.random.uniform(
124 | -self.params["translate"]["max_percent"],
125 | self.params["translate"]["max_percent"],
126 | )
127 | y = np.random.uniform(
128 | -self.params["translate"]["max_percent"],
129 | self.params["translate"]["max_percent"],
130 | )
131 | self.random_state["translate"]["x"] = x
132 | self.random_state["translate"]["y"] = y
133 | elif self.mode == "inverse":
134 | x = -1 * self.random_state["translate"]["x"]
135 | y = -1 * self.random_state["translate"]["y"]
136 | else:
137 | x = self.random_state["translate"]["x"]
138 | y = self.random_state["translate"]["y"]
139 | theta = torch.tensor(
140 | [[[1.0, 0.0, x], [0.0, 1.0, y]]], device=self.device, dtype=img.dtype
141 | )
142 | grid = F.affine_grid(
143 | theta.repeat(img.size()[0], 1, 1), img.size(), align_corners=False
144 | )
145 | if label is not None:
146 | return F.grid_sample(
147 | img,
148 | grid,
149 | mode=self.interpolation,
150 | padding_mode=self.padding_mode,
151 | align_corners=False,
152 | ), F.grid_sample(
153 | label,
154 | grid,
155 | mode="nearest",
156 | padding_mode=self.padding_mode,
157 | align_corners=False,
158 | )
159 | else:
160 | return (
161 | F.grid_sample(
162 | img,
163 | grid,
164 | mode=self.interpolation,
165 | padding_mode=self.padding_mode,
166 | align_corners=False,
167 | ),
168 | None,
169 | )
170 |
171 | def zoom(self, img, label, random_state=None):
172 | if self.mode == "forward" and not random_state:
173 | zoom_factor = np.random.uniform(
174 | self.params["scale"]["min"], self.params["scale"]["max"]
175 | )
176 | self.random_state["zoom"]["factor"] = zoom_factor
177 | elif self.mode == "inverse":
178 | zoom_factor = 1 / self.random_state["zoom"]["factor"]
179 | else:
180 | zoom_factor = self.random_state["zoom"]["factor"]
181 | theta = torch.tensor(
182 | [[[zoom_factor, 0.0, 0.0], [0.0, zoom_factor, 0.0]]],
183 | device=self.device,
184 | dtype=img.dtype,
185 | )
186 | grid = F.affine_grid(
187 | theta.repeat(img.size()[0], 1, 1), img.size(), align_corners=False
188 | )
189 | if label is not None:
190 | return F.grid_sample(
191 | img,
192 | grid,
193 | mode=self.interpolation,
194 | padding_mode=self.padding_mode,
195 | align_corners=False,
196 | ), F.grid_sample(
197 | label,
198 | grid,
199 | mode="nearest",
200 | padding_mode=self.padding_mode,
201 | align_corners=False,
202 | )
203 | else:
204 | return (
205 | F.grid_sample(
206 | img,
207 | grid,
208 | mode=self.interpolation,
209 | padding_mode=self.padding_mode,
210 | align_corners=False,
211 | ),
212 | None,
213 | )
214 |
215 | def scale(self, img, label, random_state=None):
216 | if self.mode == "forward" and not random_state:
217 | x = np.random.uniform(
218 | self.params["scale"]["min"], self.params["scale"]["max"]
219 | )
220 | y = np.random.uniform(
221 | self.params["scale"]["min"], self.params["scale"]["max"]
222 | )
223 | self.random_state["scale"]["x"] = x
224 | self.random_state["scale"]["y"] = y
225 | elif self.mode == "inverse":
226 | x = 1 / self.random_state["scale"]["x"]
227 | y = 1 / self.random_state["scale"]["y"]
228 | else:
229 | x = self.random_state["scale"]["x"]
230 | y = self.random_state["scale"]["y"]
231 | theta = torch.tensor(
232 | [[[x, 0.0, 0.0], [0.0, y, 0.0]]], device=self.device, dtype=img.dtype
233 | )
234 | grid = F.affine_grid(
235 | theta.repeat(img.size()[0], 1, 1), img.size(), align_corners=False
236 | )
237 | if label is not None:
238 | return F.grid_sample(
239 | img,
240 | grid,
241 | mode=self.interpolation,
242 | padding_mode=self.padding_mode,
243 | align_corners=False,
244 | ), F.grid_sample(
245 | label,
246 | grid,
247 | mode="nearest",
248 | padding_mode=self.padding_mode,
249 | align_corners=False,
250 | )
251 | else:
252 | return (
253 | F.grid_sample(
254 | img,
255 | grid,
256 | mode=self.interpolation,
257 | padding_mode=self.padding_mode,
258 | align_corners=False,
259 | ),
260 | None,
261 | )
262 |
263 | def rotate(self, img, label, random_state=None):
264 | if self.mode == "forward" and not random_state:
265 | if (
266 | "rot90" in self.params["rotate"].keys()
267 | and self.params["rotate"]["rot90"]
268 | ):
269 | degree = np.random.choice([-270, -180, -90, 90, 180, 270])
270 | else:
271 | degree = np.random.uniform(
272 | -self.params["rotate"]["max_degree"],
273 | self.params["rotate"]["max_degree"],
274 | )
275 | self.random_state["rotate"]["degree"] = degree
276 | elif self.mode == "inverse":
277 | degree = -1 * self.random_state["rotate"]["degree"]
278 | else:
279 | degree = self.random_state["rotate"]["degree"]
280 | rad = math.radians(degree)
281 | theta = torch.tensor(
282 | [
283 | [
284 | [math.cos(rad), -math.sin(rad), 0.0],
285 | [math.sin(rad), math.cos(rad), 0.0],
286 | ]
287 | ],
288 | device=self.device,
289 | dtype=img.dtype,
290 | )
291 | grid = F.affine_grid(
292 | theta.repeat(img.size()[0], 1, 1), img.size(), align_corners=False
293 | )
294 | if label is not None:
295 | return F.grid_sample(
296 | img,
297 | grid,
298 | mode=self.interpolation,
299 | padding_mode=self.padding_mode,
300 | align_corners=False,
301 | ), F.grid_sample(
302 | label,
303 | grid,
304 | mode="nearest",
305 | padding_mode=self.padding_mode,
306 | align_corners=False,
307 | )
308 | else:
309 | return (
310 | F.grid_sample(
311 | img,
312 | grid,
313 | mode=self.interpolation,
314 | padding_mode=self.padding_mode,
315 | align_corners=False,
316 | ),
317 | None,
318 | )
319 |
320 | def shear(self, img, label, random_state=None):
321 | if self.mode == "forward" and not random_state:
322 | x = np.random.uniform(
323 | -self.params["shear"]["max_percent"],
324 | self.params["shear"]["max_percent"],
325 | )
326 | y = np.random.uniform(
327 | -self.params["shear"]["max_percent"],
328 | self.params["shear"]["max_percent"],
329 | )
330 | self.random_state["shear"]["x"] = x
331 | self.random_state["shear"]["y"] = y
332 | elif self.mode == "inverse":
333 | x = -self.random_state["shear"]["x"]
334 | y = -self.random_state["shear"]["y"]
335 | else:
336 | x = self.random_state["shear"]["x"]
337 | y = self.random_state["shear"]["y"]
338 | theta = torch.tensor(
339 | [[[1.0, x, 0.0], [y, 1.0, 0.0]]], device=self.device, dtype=img.dtype
340 | )
341 | grid = F.affine_grid(
342 | theta.repeat(img.size()[0], 1, 1), img.size(), align_corners=False
343 | )
344 | if label is not None:
345 | return F.grid_sample(
346 | img,
347 | grid,
348 | mode=self.interpolation,
349 | padding_mode=self.padding_mode,
350 | align_corners=False,
351 | ), F.grid_sample(
352 | label,
353 | grid,
354 | mode="nearest",
355 | padding_mode=self.padding_mode,
356 | align_corners=False,
357 | )
358 | else:
359 | return (
360 | F.grid_sample(
361 | img,
362 | grid,
363 | mode=self.interpolation,
364 | padding_mode=self.padding_mode,
365 | align_corners=False,
366 | ),
367 | None,
368 | )
369 |
370 | def identity_grid(self, img):
371 | theta = torch.tensor(
372 | [[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]], device=self.device, dtype=img.dtype
373 | )
374 | return F.affine_grid(
375 | theta.repeat(img.size()[0], 1, 1), img.size(), align_corners=False
376 | )
377 |
378 | def elastic(self, img, label, random_state=None):
379 | if self.mode == "forward" and not random_state:
380 | displacement = self.create_elastic_transformation(
381 | shape=list(img.shape[-2:]),
382 | alpha=self.params["elastic"]["alpha"],
383 | sigma=self.params["elastic"]["sigma"],
384 | )
385 | self.random_state["elastic"]["displacement"] = displacement
386 | elif self.mode == "inverse":
387 | displacement = -1 * self.random_state["elastic"]["displacement"]
388 | else:
389 | displacement = self.random_state["elastic"]["displacement"]
390 | identity_grid = self.identity_grid(img)
391 | grid = identity_grid + displacement
392 | if label is not None:
393 | return F.grid_sample(
394 | img,
395 | grid,
396 | mode=self.interpolation,
397 | padding_mode=self.padding_mode,
398 | align_corners=False,
399 | ), F.grid_sample(
400 | label,
401 | grid,
402 | mode="nearest",
403 | padding_mode=self.padding_mode,
404 | align_corners=False,
405 | )
406 | else:
407 | return (
408 | F.grid_sample(
409 | img,
410 | grid,
411 | mode=self.interpolation,
412 | padding_mode=self.padding_mode,
413 | align_corners=False,
414 | ),
415 | None,
416 | )
417 |
418 | def create_elastic_transformation(self, shape, alpha=[80, 80], sigma=8):
419 |
420 | blur = GaussianBlur(kernel_size=int(8 * sigma + 1), sigma=sigma)
421 | dx = (
422 | blur(
423 | torch.rand(*shape, device=self.device).unsqueeze(0).unsqueeze(0) * 2 - 1
424 | )
425 | * alpha[0]
426 | / shape[0]
427 | )
428 | dy = (
429 | blur(
430 | torch.rand(*shape, device=self.device).unsqueeze(0).unsqueeze(0) * 2 - 1
431 | )
432 | * alpha[1]
433 | / shape[1]
434 | )
435 |
436 | displacement = torch.concat([dx, dy], 1).permute([0, 2, 3, 1])
437 | return displacement
438 |
--------------------------------------------------------------------------------
/src/viz_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import geojson
4 | import openslide
5 | import cv2
6 | from skimage.measure import regionprops
7 | from src.constants import (
8 | CLASS_LABELS_LIZARD,
9 | CLASS_LABELS_PANNUKE,
10 | COLORS_LIZARD,
11 | COLORS_PANNUKE,
12 | CONIC_MPP,
13 | PANNUKE_MPP,
14 | )
15 |
16 |
17 | def create_geojson(polygons, classids, lookup, params):
18 | features = []
19 | colors = COLORS_PANNUKE if params["pannuke"] else COLORS_LIZARD
20 | if isinstance(classids[0], (list, tuple)):
21 | classids = [cid[0] for cid in classids]
22 | for i, (poly, cid) in enumerate(zip(polygons, classids)):
23 | poly = np.array(poly)
24 | poly = poly[:, [1, 0]] * params["ds_factor"]
25 | poly = poly.tolist()
26 |
27 | geom = geojson.Polygon([poly], precision=2)
28 | if not geom.is_valid:
29 | print(f"Polygon {i}:{[poly]} is not valid, skipping...")
30 | continue
31 | # poly.append(poly[0])
32 | measurements = {classifications: 0 for classifications in lookup.values()}
33 | measurements[lookup[cid]] = 1
34 | feature = geojson.Feature(
35 | geometry=geojson.Polygon([poly], precision=2),
36 | properties={
37 | "Name": f"Nuc {i}",
38 | "Type": "Polygon",
39 | "color": colors[cid - 1],
40 | "classification": lookup[cid],
41 | "measurements": measurements,
42 | "objectType": "tile"
43 | },
44 | )
45 | features.append(feature)
46 | feature_collection = geojson.FeatureCollection(features)
47 | with open(params["output_dir"] + "/poly.geojson", "w") as outfile:
48 | geojson.dump(feature_collection, outfile)
49 |
50 |
51 | def create_tsvs(pcls_out, params):
52 | pred_keys = CLASS_LABELS_PANNUKE if params["pannuke"] else CLASS_LABELS_LIZARD
53 |
54 | coord_array = np.array([[i[0], *i[1]] for i in pcls_out.values()])
55 | classes = list(pred_keys.keys())
56 | colors = ["-256", "-65536"]
57 | i = 0
58 | for pt in classes:
59 | file = os.path.join(params["output_dir"], "pred_" + pt + ".tsv")
60 | textfile = open(file, "w")
61 |
62 | textfile.write("x" + "\t" + "y" + "\t" + "name" + "\t" + "color" + "\n")
63 | textfile.writelines(
64 | [
65 | str(element[2] * params["ds_factor"])
66 | + "\t"
67 | + str(element[1] * params["ds_factor"])
68 | + "\t"
69 | + pt
70 | + "\t"
71 | + colors[0]
72 | + "\n"
73 | for element in coord_array[coord_array[:, 0] == pred_keys[pt]]
74 | ]
75 | )
76 |
77 | textfile.close()
78 | i += 1
79 |
80 |
81 | def cont(x, offset=None):
82 | _, im, bb = x
83 | im = np.pad(im.astype(np.uint8), 1, mode="constant", constant_values=0)
84 |
85 | # initial contour finding
86 | cont = cv2.findContours(
87 | im,
88 | mode=cv2.RETR_EXTERNAL,
89 | method=cv2.CHAIN_APPROX_TC89_KCOS,
90 | )[0][0].reshape(-1, 2)[:, [1, 0]]
91 | # since opencv does not do "pixel" contours, we artificially do this for single pixel detections (if they exist)
92 | if cont.shape[0] <= 1:
93 | im = cv2.resize(im, None, fx=2.0, fy=2.0)
94 | cont = (
95 | cv2.findContours(
96 | im,
97 | mode=cv2.RETR_EXTERNAL,
98 | method=cv2.CHAIN_APPROX_TC89_KCOS,
99 | )[0][0].reshape(-1, 2)[:, [1, 0]]
100 | / 2.0
101 | )
102 | if offset is not None:
103 | cont = (cont + offset + bb[0:2] - 1).tolist()
104 | else:
105 | cont = (cont + bb[0:2] - 1).tolist()
106 | # close polygon:
107 | if cont[0] != cont[-1]:
108 | cont.append(cont[0])
109 | return cont
110 |
111 |
112 | def create_polygon_output(pinst, pcls_out, params):
113 | # polygon output is slow and unwieldy, TODO
114 | pred_keys = CLASS_LABELS_PANNUKE if params["pannuke"] else CLASS_LABELS_LIZARD
115 | # whole slide regionprops could be avoided to speed up this process...
116 | print("getting all detections...")
117 | props = [(p.label, p.image, p.bbox) for p in regionprops(np.asarray(pinst))]
118 | class_labels = [pcls_out[str(p[0])] for p in props]
119 | print("generating contours...")
120 | res_poly = [cont(i) for i in props]
121 | print("creating output...")
122 | create_geojson(
123 | res_poly,
124 | class_labels,
125 | dict((v, k) for k, v in pred_keys.items()),
126 | params,
127 | )
128 |
--------------------------------------------------------------------------------