├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── PULL_REQUEST_TEMPLATE.md └── workflows │ ├── build-docs.yml │ ├── build-master.yml │ ├── publish.yml │ └── test-and-lint.yml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── data ├── 3500000427_100X_20170120_F05_P27.czi ├── EM_low.tif └── MBP_low.tif ├── docs ├── Makefile ├── conf.py ├── index.rst └── make.bat ├── examples ├── download_and_train.py └── predict.py ├── fnet ├── __init__.py ├── cli │ ├── __init__.py │ ├── init.py │ ├── main.py │ ├── predict.py │ └── train_model.py ├── data │ ├── __init__.py │ ├── bufferedpatchdataset.py │ ├── czidataset.py │ ├── czireader.py │ ├── dummydataset.py │ ├── fnetdataset.py │ ├── multichtiffdataset.py │ └── tiffdataset.py ├── fnet_ensemble.py ├── fnet_model.py ├── fnetlogger.py ├── losses.py ├── metrics.py ├── models.py ├── nn_modules │ ├── __init__.py │ ├── dummy.py │ ├── fnet_nn_2d.py │ ├── fnet_nn_3d.py │ └── fnet_nn_3d_params.py ├── predict_piecewise.py ├── tests │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── dummymodule.py │ │ ├── nn_test.py │ │ ├── testlib.py │ │ ├── train_options_custom.json │ │ └── train_options_test.json │ ├── test_bufferedpatchdataset.py │ ├── test_cli.py │ ├── test_fnet_model.py │ ├── test_multichtiffdataset.py │ ├── test_predict_piecewise.py │ ├── test_tiffdataset.py │ └── test_utils.py ├── transforms.py └── utils │ ├── __init__.py │ ├── general_utils.py │ ├── model_utils.py │ ├── split_dataset.py │ └── viz_utils.py ├── resources ├── PredictingStructures-1.jpg └── multi_pred_b.png ├── setup.cfg ├── setup.py └── tox.ini /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug Report 3 | about: '"Something''s wrong..."' 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## Description 11 | *A clear description of the bug* 12 | 13 | 14 | 15 | 16 | ## Expected Behavior 17 | *What did you expect to happen instead?* 18 | 19 | 20 | 21 | 22 | ## Reproduction 23 | *A minimal example that exhibits the behavior.* 24 | 25 | 26 | 27 | 28 | ## Environment 29 | *Any additional information about your environment* 30 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature Request 3 | about: '"It would be really cool if x did y..."' 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## Use Case 11 | *Please provide a use case to help us understand your request in context* 12 | 13 | 14 | 15 | 16 | ## Solution 17 | *Please describe your ideal solution* 18 | 19 | 20 | 21 | 22 | ## Alternatives 23 | *Please describe any alternatives you've considered, even if you've dismissed them* 24 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | **Pull request recommendations:** 2 | - [ ] Name your pull request _your-development-type/short-description_. Ex: _feature/read-tiff-files_ 3 | - [ ] Link to any relevant issue in the PR description. Ex: _Resolves [gh-12], adds tiff file format support_ 4 | - [ ] Provide context of changes. 5 | - [ ] Provide relevant tests for your feature or bug fix. 6 | - [ ] Provide or update documentation for any feature added by your pull request. 7 | 8 | Thanks for contributing! 9 | -------------------------------------------------------------------------------- /.github/workflows/build-docs.yml: -------------------------------------------------------------------------------- 1 | name: Documentation 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | 8 | jobs: 9 | docs: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v1 13 | - name: Set up Python 14 | uses: actions/setup-python@v1 15 | with: 16 | python-version: 3.7 17 | - name: Install Dependencies 18 | run: | 19 | pip install --upgrade pip 20 | pip install .[dev] 21 | - name: Generate Docs 22 | run: | 23 | make gen-docs 24 | touch docs/_build/html/.nojekyll 25 | - name: Publish Docs 26 | uses: JamesIves/github-pages-deploy-action@releases/v3 27 | with: 28 | ACCESS_TOKEN: ${{ secrets.ACCESS_TOKEN }} 29 | BASE_BRANCH: master # The branch the action should deploy from. 30 | BRANCH: gh-pages # The branch the action should deploy to. 31 | FOLDER: docs/_build/html/ # The folder the action should deploy. 32 | 33 | -------------------------------------------------------------------------------- /.github/workflows/build-master.yml: -------------------------------------------------------------------------------- 1 | name: Build Master 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | schedule: 8 | # 9 | # https://pubs.opengroup.org/onlinepubs/9699919799/utilities/crontab.html#tag_20_25_07 10 | # Run every Monday at 18:00:00 UTC (Monday at 10:00:00 PST) 11 | - cron: '0 18 * * 1' 12 | 13 | jobs: 14 | test: 15 | runs-on: ${{ matrix.os }} 16 | strategy: 17 | max-parallel: 6 18 | matrix: 19 | python-version: [3.6, 3.7] 20 | os: [ubuntu-latest] 21 | 22 | steps: 23 | - uses: actions/checkout@v1 24 | - name: Set up Python ${{ matrix.python-version }} 25 | uses: actions/setup-python@v1 26 | with: 27 | python-version: ${{ matrix.python-version }} 28 | - name: Install Dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install .[test] 32 | - name: Test with pytest 33 | run: | 34 | pytest fnet/tests/ 35 | 36 | lint: 37 | runs-on: ubuntu-latest 38 | 39 | steps: 40 | - uses: actions/checkout@v1 41 | - name: Set up Python 3.7 42 | uses: actions/setup-python@v1 43 | with: 44 | python-version: 3.7 45 | - name: Install Dependencies 46 | run: | 47 | python -m pip install --upgrade pip 48 | pip install .[test] 49 | - name: Lint with flake8 50 | run: | 51 | flake8 fnet --count --verbose --max-line-length=127 --show-source --statistics 52 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish 2 | 3 | on: 4 | push: 5 | branches: 6 | - stable 7 | 8 | jobs: 9 | publish: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v1 13 | - name: Set up Python 14 | uses: actions/setup-python@v1 15 | with: 16 | python-version: 3.7 17 | - name: Install Dependencies 18 | run: | 19 | python -m pip install --upgrade pip 20 | pip install setuptools wheel 21 | - name: Build Package 22 | run: | 23 | python setup.py sdist bdist_wheel 24 | - name: Publish to PyPI 25 | uses: pypa/gh-action-pypi-publish@v1 26 | with: 27 | user: gregj 28 | password: ${{ secrets.PYPI_TOKEN }} 29 | -------------------------------------------------------------------------------- /.github/workflows/test-and-lint.yml: -------------------------------------------------------------------------------- 1 | name: Test and Lint 2 | 3 | on: pull_request 4 | 5 | jobs: 6 | test: 7 | runs-on: ${{ matrix.os }} 8 | strategy: 9 | max-parallel: 6 10 | matrix: 11 | python-version: [3.6, 3.7] 12 | os: [ubuntu-latest] 13 | 14 | steps: 15 | - uses: actions/checkout@master 16 | - name: Set up Python ${{ matrix.python-version }} 17 | uses: actions/setup-python@master 18 | with: 19 | python-version: ${{ matrix.python-version }} 20 | - name: Install dependencies 21 | run: | 22 | python -m pip install --upgrade pip 23 | pip install .[dev] 24 | - name: Test with pytest 25 | run: | 26 | pytest fnet/tests/ 27 | lint: 28 | runs-on: ubuntu-latest 29 | 30 | steps: 31 | - uses: actions/checkout@master 32 | - name: Set up Python 3.7 33 | uses: actions/setup-python@master 34 | with: 35 | python-version: 3.7 36 | - name: Install dependencies 37 | run: | 38 | python -m pip install --upgrade pip 39 | pip install .[dev] 40 | - name: Lint with flake8 41 | run: | 42 | flake8 fnet --count --verbose --max-line-length=127 --show-source --statistics -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # OS generated files 29 | .DS_Store 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | docs/fnet.*rst 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # dotenv 88 | .env 89 | 90 | # virtualenv 91 | .venv 92 | venv/ 93 | ENV/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | 108 | # Project specific standalone files 109 | workbench.ipynb 110 | 111 | examples/*/ 112 | examples/*.csv 113 | test_model/ 114 | .editorconfig 115 | data/*/ 116 | !data/csvs/ 117 | data/csvs/*/ 118 | .vscode 119 | 120 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project e-mail 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting any of the maintainers of this project and 59 | we will attempt to resolve the issues with respect and dignity. 60 | 61 | Project maintainers who do not follow or enforce the Code of Conduct in good 62 | faith may face temporary or permanent repercussions as determined by other 63 | members of the project's leadership. 64 | 65 | ## Attribution 66 | 67 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 68 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 69 | 70 | [homepage]: https://www.contributor-covenant.org 71 | 72 | For answers to common questions about this code of conduct, see 73 | https://www.contributor-covenant.org/faq 74 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Contributions are welcome, and they are greatly appreciated! Every little bit 4 | helps, and credit will always be given. 5 | 6 | ## Get Started! 7 | Ready to contribute? Here's how to set up `fnet` for local development. 8 | 9 | * Fork the `fnet` repo on GitHub. 10 | * Clone your fork locally: 11 | 12 | ``` 13 | $ git clone --recurse-submodules git@github.com:{your_name_here}/fnet.git 14 | ``` 15 | 16 | * Install the project in editable mode. (It is also recommended to work in a virtualenv or anaconda environment): 17 | 18 | ``` 19 | $ cd fnet/ 20 | $ pip install -e .[dev] 21 | ``` 22 | 23 | * Create a branch for local development: 24 | 25 | ``` 26 | $ git checkout -b {your_development_type}/short-description 27 | ``` 28 | Ex: feature/read-tiff-files or bugfix/handle-file-not-found
29 | Now you can make your changes locally.
30 | 31 | * When you're done making changes, check that your changes pass linting and tests, including testing other Python 32 | versions with make: 33 | 34 | ``` 35 | $ make build 36 | ``` 37 | 38 | * Commit your changes and push your branch to GitHub: 39 | 40 | ``` 41 | $ git add . 42 | $ git commit -m "Resolves gh-###. Your detailed description of your changes." 43 | $ git push origin {your_development_type}/short-description 44 | ``` 45 | 46 | * Submit a pull request through the GitHub website. 47 | 48 | ## Deploying 49 | 50 | A reminder for the maintainers on how to deploy. 51 | Make sure all your changes are committed. 52 | Then run: 53 | 54 | ``` 55 | $ bumpversion patch # possible: major / minor / patch 56 | $ git push 57 | $ git push --tags 58 | ``` 59 | 60 | Make and merge a PR to branch `stable` and GitHub will then deploy to PyPI once merged. 61 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Allen Institute Software License – This software license is the 2-clause BSD 2 | license plus a third clause that prohibits redistribution and use for 3 | commercial purposes without further permission. 4 | 5 | Copyright © 2020 6 | Gregory R Johnson, Allen Institute. All rights reserved. 7 | 8 | Redistribution and use in source and binary forms, with or without 9 | modification, are permitted provided that the following conditions are met: 10 | 11 | 1. Redistributions of source code must retain the above copyright notice, this 12 | list of conditions and the following disclaimer. 13 | 14 | 2. Redistributions in binary form must reproduce the above copyright notice, 15 | this list of conditions and the following disclaimer in the documentation 16 | and/or other materials provided with the distribution. 17 | 18 | 3. Redistributions and use for commercial purposes are not permitted without 19 | the Allen Institute’s written permission. For purposes of this license, 20 | commercial purposes are the incorporation of the Allen Institute's software 21 | into anything for which you will charge fees or other compensation or use of 22 | the software to perform a commercial service for a third party. Contact 23 | terms@alleninstitute.org for commercial licensing opportunities. 24 | 25 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 26 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 27 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 28 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 29 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 30 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 31 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 32 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 33 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 34 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 35 | 36 | 37 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include CONTRIBUTING.md 2 | include LICENSE 3 | include README.md 4 | 5 | recursive-include tests * 6 | recursive-exclude * __pycache__ 7 | recursive-exclude * *.py[co] 8 | 9 | recursive-include docs *.rst conf.py Makefile make.bat *.jpg *.png *.gif 10 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: clean clean-test clean-pyc clean-build docs help 2 | .DEFAULT_GOAL := help 3 | 4 | define BROWSER_PYSCRIPT 5 | import os, webbrowser, sys 6 | 7 | try: 8 | from urllib import pathname2url 9 | except: 10 | from urllib.request import pathname2url 11 | 12 | webbrowser.open("file://" + pathname2url(os.path.abspath(sys.argv[1]))) 13 | endef 14 | export BROWSER_PYSCRIPT 15 | 16 | define PRINT_HELP_PYSCRIPT 17 | import re, sys 18 | 19 | for line in sys.stdin: 20 | match = re.match(r'^([a-zA-Z_-]+):.*?## (.*)$$', line) 21 | if match: 22 | target, help = match.groups() 23 | print("%-20s %s" % (target, help)) 24 | endef 25 | export PRINT_HELP_PYSCRIPT 26 | 27 | BROWSER := python -c "$$BROWSER_PYSCRIPT" 28 | 29 | help: 30 | @python -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST) 31 | 32 | clean: ## clean all build, python, and testing files 33 | rm -fr build/ 34 | rm -fr dist/ 35 | rm -fr .eggs/ 36 | find . -name '*.egg-info' -exec rm -fr {} + 37 | find . -name '*.egg' -exec rm -f {} + 38 | find . -name '*.pyc' -exec rm -f {} + 39 | find . -name '*.pyo' -exec rm -f {} + 40 | find . -name '*~' -exec rm -f {} + 41 | find . -name '__pycache__' -exec rm -fr {} + 42 | rm -fr .tox/ 43 | rm -fr .coverage 44 | rm -fr coverage.xml 45 | rm -fr htmlcov/ 46 | rm -fr .pytest_cache 47 | 48 | build: ## run tox / run tests and lint 49 | tox 50 | 51 | gen-docs: ## generate Sphinx HTML documentation, including API docs 52 | rm -f docs/fnet*.rst 53 | rm -f docs/modules.rst 54 | sphinx-apidoc -o docs/ fnet **/tests/ 55 | $(MAKE) -C docs html 56 | cp -r ./resources ./docs/_build/html/resources 57 | 58 | docs: ## generate Sphinx HTML documentation, including API docs, and serve to browser 59 | make gen-docs 60 | $(BROWSER) docs/_build/html/index.html 61 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Label-free prediction of three-dimensional fluorescence images from transmitted-light microscopy 2 | 3 | [![Build Status](https://github.com/AllenCellModeling/pytorch_fnet/workflows/Build%20Master/badge.svg)](https://github.com/AllenCellModeling/pytorch_fnet/actions) 4 | [![Documentation](https://github.com/AllenCellModeling/pytorch_fnet/workflows/Documentation/badge.svg)](https://allencellmodeling.github.io/pytorch_fnet/) 5 | ![Combined outputs](./resources/PredictingStructures-1.jpg?raw=true "Combined outputs") 6 | 7 | ## Support 8 | 9 | This code is in active development and is used within our organization. We are currently not supporting this code for external use and are simply releasing the code to the community AS IS. The community is welcome to submit issues, but you should not expect an active response. 10 | 11 | For the code corresponding to our Nature Methods paper, please use the `release_1` branch [here](https://github.com/AllenCellModeling/pytorch_fnet/tree/release_1). 12 | 13 | ## System requirements 14 | 15 | We recommend installation on Linux and an NVIDIA graphics card with 12+ GB of RAM (e.g., NVIDIA Titan X Pascal) with the latest drivers installed. 16 | 17 | ## Installation 18 | 19 | - We recommend an environment manager such as [Conda](https://docs.conda.io/en/latest/miniconda.html). 20 | - Install Python 3.6+ if necessary. 21 | - All commands listed below assume the bash shell. 22 | - Clone and install the repo: 23 | 24 | ```shell 25 | git clone https://github.com/AllenCellModeling/pytorch_fnet.git 26 | cd pytorch_fnet 27 | pip install . 28 | ``` 29 | 30 | - If you would like to instead install for development: 31 | 32 | ```shell 33 | pip install -e .[dev] 34 | ``` 35 | 36 | - If you want to run the demos in the examples directory: 37 | 38 | ```shell 39 | pip install .[examples] 40 | ``` 41 | 42 | ## Demo on Canned AICS Data 43 | This will download some images from our [Integrated Cell Quilt repository](https://open.quiltdata.com/b/allencell/tree/aics/pipeline_integrated_cell/) and start training a model 44 | ```shell 45 | cd examples 46 | python download_and_train.py 47 | ``` 48 | When training is complete, you can predict on the held-out data with 49 | ```shell 50 | python predict.py 51 | ``` 52 | 53 | ## Command-line tool 54 | 55 | Once the package is installed, users can train and use models through the `fnet` command-line tool. To see what commands are available, use the `-h` flag. 56 | 57 | ```shell 58 | fnet -h 59 | ``` 60 | 61 | The `-h` flag is also available for all `fnet` commands. For example, 62 | 63 | ```shell 64 | fnet train -h 65 | ``` 66 | 67 | ## Train a model 68 | 69 | Model training is done through the the `fnet train` command, which requires a json indicating various training parameters. e.g., what dataset to use, where to save the model, how the hyperparameters should be set, etc. To create a template json: 70 | 71 | ```shell 72 | fnet train --json /path/to/train_options.json 73 | ``` 74 | 75 | Users are expected to modify this json to suit their needs. At a minimum, users should verify the following json fields and change them if necessary: 76 | 77 | - `"dataset_train"`: The name of the training dataset. 78 | - `"path_save_dir"`: The directory where the model will be saved. We recommend that the model be saved in the same directory as the training options json. 79 | 80 | Once any modifications are complete, initiate training by repeating the above command: 81 | 82 | ```shell 83 | fnet train --json /path/to/train_options.json 84 | ``` 85 | 86 | Since this time the json already exists, training should commence. 87 | 88 | ## Perform predictions with a trained model 89 | 90 | User can perform predictions using a trained model with the `fnet predict` command. A path to a saved model and a data source must be specified. For example: 91 | 92 | ```shell 93 | fnet predict --json path/to/predict_options.json 94 | ``` 95 | 96 | As above, users are expected to modify this json to suit their needs. At a minimum, populate the following fields and/or copy and paste corresponding dataset values from `/path/to/train_options.json` 97 | 98 | e.g.: 99 | ```shell 100 | "dataset_kwargs": { 101 | "col_index": "Index", 102 | "col_signal": "signal", 103 | "col_target": "target", 104 | "path_csv": "path/to/my/train.csv", 105 | ... 106 | "path_model_dir": [ 107 | "models/model_0" 108 | ], 109 | "path_save_dir": "path/to/predictions/dir", 110 | ``` 111 | 112 | This will use the model save `models/dna` to perform predictions on the `some.dataset` dataset. To see additional command options, use `fnet predict -h`. 113 | 114 | Once any modifications are complete, initiate training by repeating the above command: 115 | 116 | ```shell 117 | fnet predict --json path/to/predict_options.json 118 | ``` 119 | 120 | ## Citation 121 | 122 | If you find this code useful in your research, please consider citing our manuscript in Nature Methods: 123 | 124 | ``` 125 | @article{Ounkomol2018, 126 | doi = {10.1038/s41592-018-0111-2}, 127 | url = {https://doi.org/10.1038/s41592-018-0111-2}, 128 | year = {2018}, 129 | month = {sep}, 130 | publisher = {Springer Nature America, Inc}, 131 | volume = {15}, 132 | number = {11}, 133 | pages = {917--920}, 134 | author = {Chawin Ounkomol and Sharmishtaa Seshamani and Mary M. Maleckar and Forrest Collman and Gregory R. Johnson}, 135 | title = {Label-free prediction of three-dimensional fluorescence images from transmitted-light microscopy}, 136 | journal = {Nature Methods} 137 | } 138 | ``` 139 | 140 | ## Contact 141 | 142 | Gregory Johnson 143 | E-mail: 144 | 145 | ## Allen Institute Software License 146 | 147 | Allen Institute Software License – This software license is the 2-clause BSD license plus clause a third clause that prohibits redistribution and use for commercial purposes without further permission. 148 | Copyright © 2018. Allen Institute. All rights reserved. 149 | 150 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 151 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 152 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 153 | 3. Redistributions and use for commercial purposes are not permitted without the Allen Institute’s written permission. For purposes of this license, commercial purposes are the incorporation of the Allen Institute's software into anything for which you will charge fees or other compensation or use of the software to perform a commercial service for a third party. Contact terms@alleninstitute.org for commercial licensing opportunities. 154 | 155 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 156 | -------------------------------------------------------------------------------- /data/3500000427_100X_20170120_F05_P27.czi: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenCellModeling/pytorch_fnet/64c53d123df644cebe5e4f7f2ab6efc5c0732f4e/data/3500000427_100X_20170120_F05_P27.czi -------------------------------------------------------------------------------- /data/EM_low.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenCellModeling/pytorch_fnet/64c53d123df644cebe5e4f7f2ab6efc5c0732f4e/data/EM_low.tif -------------------------------------------------------------------------------- /data/MBP_low.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenCellModeling/pytorch_fnet/64c53d123df644cebe5e4f7f2ab6efc5c0732f4e/data/MBP_low.tif -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = python -msphinx 7 | SPHINXPROJ = python_boilerplate 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Configuration file for the Sphinx documentation builder. 4 | # 5 | # This file does only contain a selection of the most common options. For a 6 | # full list see the documentation: 7 | # http://www.sphinx-doc.org/en/master/config 8 | 9 | # -- Path setup -------------------------------------------------------------- 10 | 11 | # If extensions (or modules to document with autodoc) are in another directory, 12 | # add these directories to sys.path here. If the directory is relative to the 13 | # documentation root, use os.path.abspath to make it absolute, like shown here. 14 | # 15 | import os 16 | import sys 17 | 18 | sys.path.insert(0, os.path.abspath("../../")) 19 | 20 | 21 | # -- Project information ----------------------------------------------------- 22 | 23 | project = "fnet" 24 | copyright = "2018, Chek" 25 | author = "Chek" 26 | 27 | # The short X.Y version 28 | version = "" 29 | # The full version, including alpha/beta/rc tags 30 | release = "0.1" 31 | 32 | 33 | # -- General configuration --------------------------------------------------- 34 | 35 | # If your documentation needs a minimal Sphinx version, state it here. 36 | # 37 | # needs_sphinx = '1.0' 38 | 39 | # Add any Sphinx extension module names here, as strings. They can be 40 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 41 | # ones. 42 | extensions = [ 43 | "sphinx.ext.autodoc", 44 | "sphinx.ext.todo", 45 | "sphinx.ext.coverage", 46 | "sphinx.ext.mathjax", 47 | "sphinx.ext.ifconfig", 48 | "sphinx.ext.viewcode", 49 | "sphinx.ext.napoleon", 50 | "m2r", 51 | ] 52 | 53 | # Add any paths that contain templates here, relative to this directory. 54 | templates_path = ["_templates"] 55 | 56 | # The suffix(es) of source filenames. 57 | # You can specify multiple suffix as a list of string: 58 | # 59 | # source_suffix = ['.rst', '.md'] 60 | source_suffix = ".rst" 61 | 62 | # The master toctree document. 63 | master_doc = "index" 64 | 65 | # The language for content autogenerated by Sphinx. Refer to documentation 66 | # for a list of supported languages. 67 | # 68 | # This is also used if you do content translation via gettext catalogs. 69 | # Usually you set "language" from the command line for these cases. 70 | language = None 71 | 72 | # List of patterns, relative to source directory, that match files and 73 | # directories to ignore when looking for source files. 74 | # This pattern also affects html_static_path and html_extra_path . 75 | exclude_patterns = ["predict.rst", "train_model.rst", "evaluate_model.rst"] 76 | 77 | # The name of the Pygments (syntax highlighting) style to use. 78 | pygments_style = "sphinx" 79 | 80 | 81 | # -- Options for HTML output ------------------------------------------------- 82 | 83 | # The theme to use for HTML and HTML Help pages. See the documentation for 84 | # a list of builtin themes. 85 | # 86 | html_theme = "sphinx_rtd_theme" 87 | 88 | # Theme options are theme-specific and customize the look and feel of a theme 89 | # further. For a list of options available for each theme, see the 90 | # documentation. 91 | # 92 | # html_theme_options = {} 93 | 94 | # Add any paths that contain custom static files (such as style sheets) here, 95 | # relative to this directory. They are copied after the builtin static files, 96 | # so a file named "default.css" will overwrite the builtin "default.css". 97 | # html_static_path = ["_static"] 98 | 99 | # Custom sidebar templates, must be a dictionary that maps document names 100 | # to template names. 101 | # 102 | # The default sidebars (for documents that don't match any pattern) are 103 | # defined by theme itself. Builtin themes are using these templates by 104 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 105 | # 'searchbox.html']``. 106 | # 107 | # html_sidebars = {} 108 | 109 | 110 | # -- Options for HTMLHelp output --------------------------------------------- 111 | 112 | # Output file base name for HTML help builder. 113 | htmlhelp_basename = "fnetdoc" 114 | 115 | 116 | # -- Options for LaTeX output ------------------------------------------------ 117 | 118 | latex_elements = { 119 | # The paper size ('letterpaper' or 'a4paper'). 120 | # 121 | # 'papersize': 'letterpaper', 122 | # The font size ('10pt', '11pt' or '12pt'). 123 | # 124 | # 'pointsize': '10pt', 125 | # Additional stuff for the LaTeX preamble. 126 | # 127 | # 'preamble': '', 128 | # Latex figure (float) alignment 129 | # 130 | # 'figure_align': 'htbp', 131 | } 132 | 133 | # Grouping the document tree into LaTeX files. List of tuples 134 | # (source start file, target name, title, 135 | # author, documentclass [howto, manual, or own class]). 136 | latex_documents = [(master_doc, "fnet.tex", "fnet Documentation", "Chek", "manual")] 137 | 138 | 139 | # -- Options for manual page output ------------------------------------------ 140 | 141 | # One entry per manual page. List of tuples 142 | # (source start file, name, description, authors, manual section). 143 | man_pages = [(master_doc, "fnet", "fnet Documentation", [author], 1)] 144 | 145 | 146 | # -- Options for Texinfo output ---------------------------------------------- 147 | 148 | # Grouping the document tree into Texinfo files. List of tuples 149 | # (source start file, target name, title, author, 150 | # dir menu entry, description, category) 151 | texinfo_documents = [ 152 | ( 153 | master_doc, 154 | "fnet", 155 | "fnet Documentation", 156 | author, 157 | "fnet", 158 | "One line description of project.", 159 | "Miscellaneous", 160 | ) 161 | ] 162 | 163 | 164 | # -- Extension configuration ------------------------------------------------- 165 | 166 | # -- Options for todo extension ---------------------------------------------- 167 | 168 | # If true, `todo` and `todoList` produce output, else they produce nothing. 169 | todo_include_todos = True 170 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. fnet documentation master file, created by 2 | sphinx-quickstart on Tue Sep 18 16:32:49 2018. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Pytorch_fnet 7 | ================================ 8 | 9 | .. toctree:: 10 | :hidden: 11 | :maxdepth: 3 12 | :caption: Contents: 13 | 14 | Overview 15 | fnet 16 | 17 | 18 | 19 | .. mdinclude:: ../README.md 20 | 21 | Indices and tables 22 | ================== 23 | 24 | * :ref:`genindex` 25 | * :ref:`modindex` 26 | * :ref:`search` 27 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=python -msphinx 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=fnet 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The Sphinx module was not found. Make sure you have Sphinx installed, 20 | echo.then set the SPHINXBUILD environment variable to point to the full 21 | echo.path of the 'sphinx-build' executable. Alternatively you may add the 22 | echo.Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | 37 | popd -------------------------------------------------------------------------------- /examples/download_and_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | from pathlib import Path 5 | 6 | import quilt3 7 | import pandas as pd 8 | import numpy as np 9 | 10 | from fnet.cli.init import save_default_train_options 11 | 12 | 13 | parser = argparse.ArgumentParser() 14 | 15 | parser.add_argument("--gpu_id", default=0, type=int, help="GPU to use.") 16 | parser.add_argument("--n_imgs", default=40, type=int, help="Number of images to use.") 17 | parser.add_argument( 18 | "--n_iterations", default=50000, type=int, help="Number of training iterations." 19 | ) 20 | parser.add_argument( 21 | "--interval_checkpoint", 22 | default=10000, 23 | type=int, 24 | help="Number of training iterations between checkpoints.", 25 | ) 26 | 27 | args = parser.parse_args() 28 | 29 | ################################################### 30 | # Download the 3D multi-channel tiffs via Quilt/T4 31 | ################################################### 32 | 33 | gpu_id = args.gpu_id 34 | n_images_to_download = args.n_imgs # more images the better 35 | train_fraction = 0.75 36 | 37 | image_save_dir = "{}/".format(os.getcwd()) 38 | model_save_dir = "{}/model/".format(os.getcwd()) 39 | prefs_save_path = "{}/prefs.json".format(model_save_dir) 40 | 41 | data_save_path_train = "{}/image_list_train.csv".format(image_save_dir) 42 | data_save_path_test = "{}/image_list_test.csv".format(image_save_dir) 43 | 44 | if not os.path.exists(image_save_dir): 45 | os.makedirs(image_save_dir) 46 | 47 | 48 | aics_pipeline = quilt3.Package.browse( 49 | "aics/pipeline_integrated_cell", registry="s3://allencell" 50 | ) 51 | 52 | data_manifest = aics_pipeline["metadata.csv"]() 53 | 54 | # THE ROWS OF THE MANIFEST CORRESPOND TO CELLS, WE TRIM DOWN TO UNIQUIE FOVS 55 | unique_fov_indices = np.unique(data_manifest['FOVId'], return_index=True)[1] 56 | data_manifest = data_manifest.iloc[unique_fov_indices] 57 | 58 | # SELECT THE FIRST N_IMAGES_TO_DOWNLOAD 59 | data_manifest = data_manifest.iloc[0:n_images_to_download] 60 | 61 | image_source_paths = data_manifest["SourceReadPath"] 62 | 63 | image_target_paths = [ 64 | "{}/{}".format(image_save_dir, image_source_path) 65 | for image_source_path in image_source_paths 66 | ] 67 | 68 | for image_source_path, image_target_path in zip(image_source_paths, image_target_paths): 69 | if os.path.exists(image_target_path): 70 | continue 71 | 72 | # We only do this because T4 hates our filesystem. It probably wont affect you. 73 | try: 74 | aics_pipeline[image_source_path].fetch(image_target_path) 75 | except OSError: 76 | pass 77 | 78 | ################################################### 79 | # Make a manifest of all of the files in csv form 80 | ################################################### 81 | 82 | df = pd.DataFrame(columns=["path_tiff", "channel_signal", "channel_target"]) 83 | 84 | df["path_tiff"] = image_target_paths 85 | df["channel_signal"] = data_manifest["ChannelNumberBrightfield"] 86 | df["channel_target"] = data_manifest[ 87 | "ChannelNumber405" 88 | ] # this is the DNA channel for all FOVs 89 | 90 | n_train_images = int(n_images_to_download * train_fraction) 91 | df_train = df[:n_train_images] 92 | df_test = df[n_train_images:] 93 | 94 | df_test.to_csv(data_save_path_test, index=False) 95 | df_train.to_csv(data_save_path_train, index=False) 96 | 97 | ################################################ 98 | # Run the label-free stuff (dont change this) 99 | ################################################ 100 | 101 | prefs_save_path = Path(prefs_save_path) 102 | 103 | save_default_train_options(prefs_save_path) 104 | 105 | with open(prefs_save_path, "r") as fp: 106 | prefs = json.load(fp) 107 | 108 | # takes about 16 hours, go up to 250,000 for full training 109 | prefs["n_iter"] = args.n_iterations 110 | prefs["interval_checkpoint"] = args.interval_checkpoint 111 | 112 | prefs["dataset_train"] = "fnet.data.MultiChTiffDataset" 113 | prefs["dataset_train_kwargs"] = {"path_csv": data_save_path_train} 114 | prefs["dataset_val"] = "fnet.data.MultiChTiffDataset" 115 | prefs["dataset_val_kwargs"] = {"path_csv": data_save_path_test} 116 | 117 | # This Fnet call will be updated as a python API becomes available 118 | 119 | with open(prefs_save_path, "w") as fp: 120 | json.dump(prefs, fp) 121 | 122 | command_str = f"fnet train --json {prefs_save_path} --gpu_ids {gpu_id}" 123 | 124 | print(command_str) 125 | os.system(command_str) 126 | -------------------------------------------------------------------------------- /examples/predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | ################################################### 5 | # Assume the user already ran download_and_train.py 6 | ################################################### 7 | 8 | parser = argparse.ArgumentParser() 9 | 10 | parser.add_argument("--gpu_id", default=0, type=int, help="GPU to use.") 11 | 12 | args = parser.parse_args() 13 | 14 | # Normally this would be run via command-line but this Fnet call will be updated as a python API becomes available 15 | gpu_id = args.gpu_id 16 | 17 | image_save_dir = "{}/images/".format(os.getcwd()) 18 | model_save_dir = "{}/model/".format(os.getcwd()) 19 | 20 | data_save_path_test = "{}/image_list_test.csv".format(os.getcwd()) 21 | 22 | command_str = ( 23 | "fnet predict " 24 | "--path_model_dir {} " 25 | "--dataset fnet.data.MultiChTiffDataset " 26 | '--dataset_kwargs \'{{"path_csv": "{}"}}\' ' 27 | "--gpu_ids {}".format(model_save_dir, data_save_path_test, gpu_id) 28 | ) 29 | 30 | print(command_str) 31 | os.system(command_str) 32 | -------------------------------------------------------------------------------- /fnet/__init__.py: -------------------------------------------------------------------------------- 1 | from fnet import models 2 | from fnet.fnetlogger import FnetLogger 3 | 4 | # Clean these up later - GRJ 2020-02-04 5 | from fnet.cli.train_model import train_model as train 6 | from fnet.cli.predict import main as predict 7 | 8 | __author__ = "Gregory R. Johnson" 9 | __email__ = "gregj@alleninstitute.org" 10 | __version__ = "0.2.0" 11 | 12 | 13 | def get_module_version(): 14 | return __version__ 15 | 16 | 17 | __all__ = ["models", "FnetLogger"] 18 | -------------------------------------------------------------------------------- /fnet/cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenCellModeling/pytorch_fnet/64c53d123df644cebe5e4f7f2ab6efc5c0732f4e/fnet/cli/__init__.py -------------------------------------------------------------------------------- /fnet/cli/init.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | from pathlib import Path 3 | import argparse 4 | import json 5 | import logging 6 | import os 7 | import shutil 8 | import sys 9 | 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def save_example_scripts(path_save_dir: str) -> None: 15 | """Save example training and prediction scripts. 16 | 17 | Parameters 18 | ---------- 19 | path_save_dir 20 | Directory in which to save scripts. 21 | 22 | """ 23 | if not os.path.exists(path_save_dir): 24 | os.makedirs(path_save_dir) 25 | path_examples_dir = os.path.join( 26 | os.path.dirname(sys.modules["fnet"].__file__), "cli" 27 | ) 28 | for fname in ["train_model.py", "predict.py"]: 29 | path_src = os.path.join(path_examples_dir, fname) 30 | path_dst = os.path.join(path_save_dir, fname) 31 | if os.path.exists(path_dst): 32 | logger.info(f"Example script already exists: {path_dst}") 33 | continue 34 | shutil.copy(path_src, path_dst) 35 | logger.info(f"Saved: {path_dst}") 36 | 37 | 38 | def save_options_json(path_save: Path, options: Dict) -> None: 39 | """Saves options dictionary as a json. 40 | 41 | Parameters 42 | ---------- 43 | path_save 44 | JSON save path. 45 | options 46 | Options dictionary. 47 | 48 | Returns 49 | ------- 50 | None 51 | 52 | """ 53 | if path_save.exists(): 54 | logger.info(f"Options json already exists: {path_save}") 55 | return 56 | path_save.parent.mkdir(parents=True, exist_ok=True) 57 | with path_save.open("w") as fo: 58 | json.dump(options, fo, indent=4, sort_keys=True) 59 | logger.info(f"Saved: {path_save}") 60 | 61 | 62 | def save_default_train_options(path_save: Path) -> None: 63 | """Save default training options json. 64 | 65 | Parameters 66 | ---------- 67 | path_save 68 | Save path for default training options json. 69 | 70 | """ 71 | train_options = { 72 | "batch_size": 28, 73 | "bpds_kwargs": { 74 | "buffer_size": 16, 75 | "buffer_switch_interval": 2800, # every 100 updates 76 | "patch_shape": [32, 64, 64], 77 | }, 78 | "dataset_train": "fnet.data.TiffDataset", 79 | "dataset_train_kwargs": { 80 | "path_csv": "some_training_set.csv", 81 | "col_index": "some_id_col", 82 | "col_signal": "some_signal_col", 83 | "col_target": "some_target_col", 84 | "transform_signal": ["fnet.transforms.norm_around_center"], 85 | "transform_target": ["fnet.transforms.norm_around_center"], 86 | }, 87 | "dataset_val": None, 88 | "dataset_val_kwargs": {}, 89 | "fnet_model_class": "fnet.fnet_model.Model", 90 | "fnet_model_kwargs": { 91 | "betas": [0.9, 0.999], 92 | "criterion_class": "fnet.losses.WeightedMSE", 93 | "init_weights": False, 94 | "lr": 0.001, 95 | "nn_class": "fnet.nn_modules.fnet_nn_3d.Net", 96 | "scheduler": None, 97 | }, 98 | "interval_checkpoint": 50000, 99 | "interval_save": 1000, 100 | "iter_checkpoint": [], 101 | "n_iter": 50000, 102 | "path_save_dir": str(path_save.parent), 103 | "seed": None, 104 | } 105 | save_options_json(path_save, train_options) 106 | 107 | 108 | def save_default_predict_options(path_save: Path) -> None: 109 | """Save default prediction options json. 110 | 111 | Parameters 112 | ---------- 113 | path_save 114 | Save path for default prediction options json. 115 | 116 | """ 117 | predict_options = { 118 | "dataset": "fnet.data.TiffDataset", 119 | "dataset_kwargs": { 120 | "col_index": "some_id_col", 121 | "col_signal": "some_signal_col", 122 | "col_target": "some_target_col", 123 | "path_csv": "some_test_set.csv", 124 | "transform_signal": ["fnet.transforms.norm_around_center"], 125 | "transform_target": ["fnet.transforms.norm_around_center"], 126 | }, 127 | "gpu_ids": 0, 128 | "idx_sel": None, 129 | "metric": "fnet.metrics.corr_coef", 130 | "n_images": -1, 131 | "no_prediction": False, 132 | "no_signal": False, 133 | "no_target": False, 134 | "path_model_dir": ["some_model"], 135 | "path_save_dir": str(path_save.parent), 136 | "path_tif": None, 137 | } 138 | save_options_json(path_save, predict_options) 139 | 140 | 141 | def add_parser_arguments(parser: argparse.ArgumentParser) -> None: 142 | """Add init script arguments to parser.""" 143 | parser.add_argument( 144 | "--path_scripts_dir", 145 | default="scripts", 146 | help="Path to where example scripts should be saved.", 147 | ) 148 | parser.add_argument( 149 | "--path_train_template", 150 | default="train_options_templates/default.json", 151 | type=Path, 152 | help="Path to where training options template should be saved.", 153 | ) 154 | 155 | 156 | def main(args: Optional[argparse.Namespace] = None) -> None: 157 | """Install default training options and example model training/prediction 158 | scripts into current directory.""" 159 | if args is None: 160 | parser = argparse.ArgumentParser() 161 | add_parser_arguments(parser) 162 | args = parser.parse_args() 163 | save_example_scripts(args.path_scripts_dir) 164 | save_default_train_options(args.path_train_template) 165 | -------------------------------------------------------------------------------- /fnet/cli/main.py: -------------------------------------------------------------------------------- 1 | """Module for command-line 'fnet' command.""" 2 | 3 | 4 | import argparse 5 | import os 6 | import sys 7 | 8 | from fnet.cli import init 9 | from fnet.cli import predict 10 | from fnet.cli import train_model 11 | from fnet.utils.general_utils import init_fnet_logging 12 | 13 | 14 | def main() -> None: 15 | """Main function for command-line 'fnet' command.""" 16 | init_fnet_logging() 17 | parser = argparse.ArgumentParser(prog="fnet") 18 | subparser = parser.add_subparsers(title="command") 19 | parser_init = subparser.add_parser( 20 | "init", 21 | help=( 22 | "Initialize current directory with example fnet scripts and " 23 | "training options template." 24 | ), 25 | ) 26 | parser_train = subparser.add_parser("train", help="Train a model.") 27 | parser_predict = subparser.add_parser("predict", help="Predict using a model.") 28 | init.add_parser_arguments(parser_init) 29 | train_model.add_parser_arguments(parser_train) 30 | predict.add_parser_arguments(parser_predict) 31 | 32 | parser_init.set_defaults(func=init.main) 33 | parser_train.set_defaults(func=train_model.main) 34 | parser_predict.set_defaults(func=predict.main) 35 | args = parser.parse_args() 36 | 37 | # Remove 'func' from args so it is not passed to target script 38 | func = args.func 39 | delattr(args, "func") 40 | sys.path.append(os.getcwd()) 41 | func(args) 42 | 43 | 44 | if __name__ == "__main__": 45 | main() 46 | -------------------------------------------------------------------------------- /fnet/cli/predict.py: -------------------------------------------------------------------------------- 1 | """Generates predictions from a model.""" 2 | 3 | 4 | from pathlib import Path 5 | from typing import Any, Callable, Dict, List, Optional, Tuple 6 | import argparse 7 | import json 8 | import logging 9 | import os 10 | 11 | import numpy as np 12 | import pandas as pd 13 | import tifffile 14 | import torch 15 | 16 | from fnet.cli.init import save_default_predict_options 17 | from fnet.data import FnetDataset, TiffDataset 18 | from fnet.models import load_model 19 | from fnet.transforms import norm_around_center 20 | from fnet.utils.general_utils import files_from_dir 21 | from fnet.utils.general_utils import retry_if_oserror 22 | from fnet.utils.general_utils import str_to_object 23 | 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | def get_dataset(args: argparse.Namespace) -> torch.utils.data.Dataset: 29 | """Returns dataset. 30 | 31 | Returns 32 | ------- 33 | torch.utils.data.Dataset 34 | Dataset object. 35 | 36 | """ 37 | if sum([args.dataset is not None, args.path_tif is not None]) != 1: 38 | raise ValueError("Must specify one input source type") 39 | if args.dataset is not None: 40 | ds_fn = str_to_object(args.dataset) 41 | if not isinstance(ds_fn, Callable): 42 | raise ValueError(f"{args.dataset} must be callable") 43 | return ds_fn(**args.dataset_kwargs) 44 | if args.path_tif is not None: 45 | if not os.path.exists(args.path_tif): 46 | raise ValueError(f"Path does not exists: {args.path_tif}") 47 | paths_tif = [args.path_tif] 48 | if os.path.isdir(args.path_tif): 49 | paths_tif = files_from_dir(args.path_tif) 50 | ds = TiffDataset( 51 | dataframe=pd.DataFrame({"path_bf": paths_tif, "path_target": None}), 52 | transform_signal=[norm_around_center], 53 | transform_target=[norm_around_center], 54 | col_signal="path_bf", 55 | ) 56 | return ds 57 | raise NotImplementedError 58 | 59 | 60 | def get_indices(args: argparse.Namespace, dataset: Any) -> List[int]: 61 | """Returns indices of dataset items on which to perform predictions.""" 62 | indices = args.idx_sel 63 | if indices is None: 64 | if isinstance(dataset, FnetDataset): 65 | indices = dataset.df.index 66 | else: 67 | indices = list(range(len(dataset))) 68 | if args.n_images > 0: 69 | return indices[: args.n_images] 70 | return indices 71 | 72 | 73 | def item_from_dataset( 74 | dataset: Any, idx: int 75 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 76 | """Returns signal-target image pair from dataset. 77 | 78 | If the dataset is a FnetDataset, it will be indexed using 'loc'-style 79 | indexing. 80 | 81 | Parameters 82 | ---------- 83 | dataset 84 | Object with __getitem__ implemented. 85 | idx 86 | Index of data to be retrieved from dataset. 87 | 88 | Returns 89 | ------- 90 | Tuple[torch.Tensor, Optional[torch.Tensor]] 91 | Signal-target data pair. Target can be None if dataset does not return 92 | a target for the given index. 93 | 94 | """ 95 | if isinstance(dataset, FnetDataset): 96 | item = dataset.loc[idx] 97 | else: 98 | item = dataset[idx] 99 | target = None 100 | 101 | if isinstance(item, Tuple): 102 | signal = item[0] 103 | if len(item) > 1: 104 | target = item[1] 105 | else: 106 | signal = item 107 | return (signal, target) 108 | 109 | 110 | def save_tif(fname: str, ar: np.ndarray, path_root: str) -> str: 111 | """Saves a tif and returns tif save path relative to root save directory. 112 | 113 | Image will be stored at: 'path_root/tifs/fname' 114 | 115 | Parameters 116 | ---------- 117 | fname 118 | Basename of save path. 119 | ar 120 | Array to be saved as tif. 121 | path_root 122 | Root directory of save path. 123 | 124 | Returns 125 | ------- 126 | str 127 | Save path relative to root directory. 128 | 129 | """ 130 | path_tif_dir = os.path.join(path_root, "tifs") 131 | if not os.path.exists(path_tif_dir): 132 | os.makedirs(path_tif_dir) 133 | logger.info(f"Created: {path_tif_dir}") 134 | path_save = os.path.join(path_tif_dir, fname) 135 | tifffile.imsave(path_save, ar, compress=2) 136 | logger.info(f"Saved: {path_save}") 137 | return os.path.relpath(path_save, path_root) 138 | 139 | 140 | def parse_model(model_str: str) -> Dict: 141 | """Parse model definition string into dictionary.""" 142 | model_def = {} 143 | parts = model_str.split(":") 144 | if len(parts) > 2: 145 | raise ValueError('Multiple ":" in specified model') 146 | name = os.path.basename(parts[0]) 147 | options = [] 148 | if len(parts) == 2: 149 | options.extend(parts[1].split(",")) 150 | model_def["path"] = parts[0] 151 | model_def["options"] = options 152 | model_def["name"] = ".".join([name] + options) 153 | return model_def 154 | 155 | 156 | def save_predictions_csv( 157 | path_csv: Path, pred_records: List[Dict], dataset: Any 158 | ) -> None: 159 | """Saves csv with metadata of predictions. 160 | 161 | Parameters 162 | ---------- 163 | path_csv 164 | CSV save path. 165 | pred_records 166 | List of metadata for each prediction. 167 | dataset 168 | Dataset from where signal-target pairs were retrieved. 169 | 170 | """ 171 | df = pd.DataFrame(pred_records).set_index("index") 172 | if isinstance(dataset, FnetDataset): 173 | # For FnetDataset, add additional metadata 174 | df = df.rename_axis(dataset.df.index.name).join(dataset.df, lsuffix="_pre") 175 | if os.path.exists(path_csv): 176 | df_old = pd.read_csv(path_csv) 177 | col_index = df_old.columns[0] # Assumes first col is index col 178 | df_old = df_old.set_index(col_index) 179 | df = df.combine_first(df_old) 180 | df = df.sort_index(axis=1) 181 | dirname = os.path.dirname(path_csv) 182 | if not os.path.exists(dirname): 183 | os.makedirs(dirname) 184 | logger.info(f"Created: {dirname}") 185 | retry_if_oserror(df.to_csv)(path_csv) 186 | logger.info(f"Saved: {path_csv}") 187 | 188 | 189 | def save_args_as_json(path_save_dir: str, args: argparse.Namespace) -> None: 190 | """Saves script arguments as json in save directory. 191 | 192 | A json is saved only if the "--json" option was not specified. 193 | 194 | By default, this function tries to save arguments as predict_options.json 195 | within the save directory. If that file already exists, appends a digit to 196 | uniquify the save path. 197 | 198 | Parameters 199 | ---------- 200 | path_save_dir 201 | Save directory 202 | args 203 | Script arguments. 204 | 205 | """ 206 | if args.json is not None: 207 | return 208 | args.__dict__.pop("json") 209 | path_json = os.path.join(path_save_dir, "predict_options.json") 210 | while os.path.exists(path_json): 211 | number = path_json.split(".")[-2] 212 | if not number.isdigit(): 213 | number = "-1" 214 | number = str(int(number) + 1) 215 | path_json = os.path.join( 216 | path_save_dir, ".".join(["predict_options", number, "json"]) 217 | ) 218 | with open(path_json, "w") as fo: 219 | json.dump(vars(args), fo, indent=4, sort_keys=True) 220 | logger.info(f"Saved: {path_json}") 221 | 222 | 223 | def load_from_json(args: argparse.Namespace) -> None: 224 | """Loads arguments from if a json is specified.""" 225 | if args.json is None: 226 | return 227 | with args.json.open(mode="r") as fi: 228 | predict_options = json.load(fi) 229 | args.__dict__.update(predict_options) 230 | 231 | 232 | def add_parser_arguments(parser) -> None: 233 | """Add training script arguments to parser.""" 234 | parser.add_argument("--dataset", help="dataset name") 235 | parser.add_argument( 236 | "--dataset_kwargs", type=json.loads, default={}, help="dataset kwargs" 237 | ) 238 | parser.add_argument("--gpu_ids", type=int, default=0, help="GPU ID") 239 | parser.add_argument( 240 | "--idx_sel", nargs="+", type=int, help="specify dataset indices" 241 | ) 242 | parser.add_argument("--json", type=Path, help="path to prediction options json") 243 | parser.add_argument( 244 | "--metric", default="fnet.metrics.corr_coef", help="evaluation metric" 245 | ) 246 | parser.add_argument( 247 | "--n_images", type=int, default=-1, help="max number of images to test" 248 | ) 249 | parser.add_argument( 250 | "--no_prediction", action="store_true", help="set to not save predicted image" 251 | ) 252 | parser.add_argument( 253 | "--no_signal", action="store_true", help="set to not save signal image" 254 | ) 255 | parser.add_argument( 256 | "--no_target", action="store_true", help="set to not save target image" 257 | ) 258 | parser.add_argument( 259 | "--path_model_dir", nargs="+", help="path(s) to model directory" 260 | ) 261 | parser.add_argument( 262 | "--path_save_dir", default="predictions", help="path to output root directory" 263 | ) 264 | parser.add_argument("--path_tif", help="path(s) to input tif(s)") 265 | 266 | 267 | def main(args: Optional[argparse.Namespace] = None) -> None: 268 | """Predicts using model.""" 269 | if args is None: 270 | parser = argparse.ArgumentParser() 271 | add_parser_arguments(parser) 272 | args = parser.parse_args() 273 | if args.json and not args.json.exists(): 274 | save_default_predict_options(args.json) 275 | return 276 | load_from_json(args) 277 | metric = str_to_object(args.metric) 278 | dataset = get_dataset(args) 279 | entries = [] 280 | model = None 281 | indices = get_indices(args, dataset) 282 | for count, idx in enumerate(indices, 1): 283 | logger.info(f"Processing: {idx:3d} ({count}/{len(indices)})") 284 | entry = {} 285 | entry["index"] = idx 286 | signal, target = item_from_dataset(dataset, idx) 287 | if not args.no_signal: 288 | entry["path_signal"] = save_tif( 289 | f"{idx}_signal.tif", signal.numpy()[0,], args.path_save_dir 290 | ) 291 | if not args.no_target and target is not None: 292 | entry["path_target"] = save_tif( 293 | f"{idx}_target.tif", target.numpy()[0,], args.path_save_dir 294 | ) 295 | for path_model_dir in args.path_model_dir: 296 | if model is None or len(args.path_model_dir) > 1: 297 | model_def = parse_model(path_model_dir) 298 | model = load_model(model_def["path"], no_optim=True) 299 | model.to_gpu(args.gpu_ids) 300 | logger.info(f'Loaded model: {model_def["name"]}') 301 | prediction = model.predict_piecewise( 302 | signal, tta=("no_tta" not in model_def["options"]) 303 | ) 304 | evaluation = metric(target, prediction) 305 | entry[args.metric + f'.{model_def["name"]}'] = evaluation 306 | if not args.no_prediction: 307 | for idx_c in range(prediction.size()[0]): 308 | tag = f'prediction_c{idx_c}.{model_def["name"]}' 309 | pred_c = prediction.numpy()[idx_c,] 310 | entry[f"path_{tag}"] = save_tif( 311 | f"{idx}_{tag}.tif", pred_c, args.path_save_dir 312 | ) 313 | entries.append(entry) 314 | save_predictions_csv( 315 | path_csv=os.path.join(args.path_save_dir, "predictions.csv"), 316 | pred_records=entries, 317 | dataset=dataset, 318 | ) 319 | save_args_as_json(args.path_save_dir, args) 320 | 321 | 322 | if __name__ == "__main__": 323 | main() 324 | -------------------------------------------------------------------------------- /fnet/cli/train_model.py: -------------------------------------------------------------------------------- 1 | """Trains a model.""" 2 | 3 | 4 | from pathlib import Path 5 | from typing import Callable, Dict, List, Optional 6 | import argparse 7 | import copy 8 | import datetime 9 | import inspect 10 | import json 11 | import logging 12 | import os 13 | import pprint 14 | import time 15 | 16 | import numpy as np 17 | import torch 18 | 19 | from fnet.cli.init import save_default_train_options 20 | from fnet.data import BufferedPatchDataset 21 | from fnet.utils.general_utils import add_logging_file_handler 22 | from fnet.utils.general_utils import str_to_object 23 | import fnet 24 | import fnet.utils.viz_utils as vu 25 | 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | 30 | def log_training_options(options: Dict) -> None: 31 | """Logs training options.""" 32 | for line in ["*** Training options ***"] + pprint.pformat(options).split( 33 | os.linesep 34 | ): 35 | logger.info(line) 36 | 37 | 38 | def set_seeds(seed: Optional[int]) -> None: 39 | """Sets random seeds""" 40 | if seed is None: 41 | return 42 | np.random.seed(seed) 43 | torch.manual_seed(seed) 44 | torch.cuda.manual_seed_all(seed) 45 | 46 | 47 | def init_cuda(gpu: int) -> None: 48 | """Initialize Pytorch CUDA state.""" 49 | if gpu < 0: 50 | return 51 | try: 52 | torch.cuda.set_device(gpu) 53 | torch.cuda.init() 54 | except RuntimeError: 55 | logger.exception("Failed to init CUDA") 56 | 57 | 58 | def get_bpds_train(args: argparse.Namespace) -> BufferedPatchDataset: 59 | """Creates data provider for training.""" 60 | ds_fn = str_to_object(args.dataset_train) 61 | if not isinstance(ds_fn, Callable): 62 | raise ValueError("Dataset function should be Callable") 63 | ds = ds_fn(**args.dataset_train_kwargs) 64 | return BufferedPatchDataset(dataset=ds, **args.bpds_kwargs) 65 | 66 | 67 | def get_bpds_val(args: argparse.Namespace) -> Optional[BufferedPatchDataset]: 68 | """Creates data provider for validation.""" 69 | if args.dataset_val is None: 70 | return None 71 | bpds_kwargs = copy.deepcopy(args.bpds_kwargs) 72 | ds_fn = str_to_object(args.dataset_val) 73 | if not isinstance(ds_fn, Callable): 74 | raise ValueError("Dataset function should be Callable") 75 | ds = ds_fn(**args.dataset_val_kwargs) 76 | bpds_kwargs["buffer_size"] = min(4, len(ds)) 77 | bpds_kwargs["buffer_switch_interval"] = -1 78 | return BufferedPatchDataset(dataset=ds, **bpds_kwargs) 79 | 80 | 81 | def add_parser_arguments(parser) -> None: 82 | """Add training script arguments to parser.""" 83 | parser.add_argument( 84 | "--json", type=Path, required=True, help="json with training options" 85 | ) 86 | parser.add_argument("--gpu_ids", nargs="+", default=[0], type=int, help="gpu_id(s)") 87 | 88 | 89 | def main(args: Optional[argparse.Namespace] = None): 90 | """Trains a model.""" 91 | time_start = time.time() 92 | 93 | if args is None: 94 | parser = argparse.ArgumentParser() 95 | add_parser_arguments(parser) 96 | args = parser.parse_args() 97 | 98 | args.path_json = Path(args.json) 99 | 100 | if args.path_json and not args.path_json.exists(): 101 | save_default_train_options(args.path_json) 102 | return None 103 | 104 | with open(args.path_json, "r") as fi: 105 | train_options = json.load(fi) 106 | 107 | args.__dict__.update(train_options) 108 | add_logging_file_handler(Path(args.path_save_dir, "train_model.log")) 109 | logger.info(f"Started training at: {datetime.datetime.now()}") 110 | 111 | set_seeds(args.seed) 112 | log_training_options(vars(args)) 113 | path_model = os.path.join(args.path_save_dir, "model.p") 114 | model = fnet.models.load_or_init_model(path_model, args.path_json) 115 | init_cuda(args.gpu_ids[0]) 116 | model.to_gpu(args.gpu_ids) 117 | logger.info(model) 118 | 119 | path_losses_csv = os.path.join(args.path_save_dir, "losses.csv") 120 | if os.path.exists(path_losses_csv): 121 | fnetlogger = fnet.FnetLogger(path_losses_csv) 122 | logger.info(f"History loaded from: {path_losses_csv}") 123 | else: 124 | fnetlogger = fnet.FnetLogger(columns=["num_iter", "loss_train", "loss_val"]) 125 | 126 | if (args.n_iter - model.count_iter) <= 0: 127 | # Stop if no more iterations needed 128 | return model 129 | 130 | # Get patch pair providers 131 | bpds_train = get_bpds_train(args) 132 | bpds_val = get_bpds_val(args) 133 | 134 | # MAIN LOOP 135 | for idx_iter in range(model.count_iter, args.n_iter): 136 | do_save = ((idx_iter + 1) % args.interval_save == 0) or ( 137 | (idx_iter + 1) == args.n_iter 138 | ) 139 | 140 | loss_train = model.train_on_batch(*bpds_train.get_batch(args.batch_size)) 141 | loss_val = None 142 | if do_save and bpds_val is not None: 143 | loss_val = model.test_on_iterator( 144 | [bpds_val.get_batch(args.batch_size) for _ in range(4)] 145 | ) 146 | 147 | fnetlogger.add( 148 | {"num_iter": idx_iter + 1, "loss_train": loss_train, "loss_val": loss_val} 149 | ) 150 | print( 151 | f'iter: {fnetlogger.data["num_iter"][-1]:6d} | ' 152 | f'loss_train: {fnetlogger.data["loss_train"][-1]:.4f}' 153 | ) 154 | if do_save: 155 | model.save(path_model) 156 | fnetlogger.to_csv(path_losses_csv) 157 | logger.info( 158 | "BufferedPatchDataset buffer history: %s", 159 | bpds_train.get_buffer_history(), 160 | ) 161 | logger.info(f"Loss log saved to: {path_losses_csv}") 162 | logger.info(f"Model saved to: {path_model}") 163 | logger.info(f"Elapsed time: {time.time() - time_start:.1f} s") 164 | if ((idx_iter + 1) in args.iter_checkpoint) or ( 165 | (idx_iter + 1) % args.interval_checkpoint == 0 166 | ): 167 | path_checkpoint = os.path.join( 168 | args.path_save_dir, "checkpoints", "model_{:06d}.p".format(idx_iter + 1) 169 | ) 170 | model.save(path_checkpoint) 171 | logger.info(f"Saved model checkpoint: {path_checkpoint}") 172 | vu.plot_loss( 173 | args.path_save_dir, 174 | path_save=os.path.join(args.path_save_dir, "loss_curves.png"), 175 | ) 176 | 177 | return model 178 | 179 | 180 | def train_model( 181 | batch_size: int = 28, 182 | bpds_kwargs: Optional[Dict] = None, 183 | dataset_train: str = "fnet.data.TiffDataset", 184 | dataset_train_kwargs: Optional[Dict] = None, 185 | dataset_val: Optional[str] = None, 186 | dataset_val_kwargs: Optional[Dict] = None, 187 | fnet_model_class: str = "fnet.fnet_model.Model", 188 | fnet_model_kwargs: Optional[Dict] = None, 189 | interval_checkpoint: int = 50000, 190 | interval_save: int = 1000, 191 | iter_checkpoint: Optional[List] = None, 192 | n_iter: int = 250000, 193 | path_save_dir: str = "models/some_model", 194 | seed: Optional[int] = None, 195 | json: Optional[str] = None, 196 | gpu_ids: Optional[List[int]] = None, 197 | ): 198 | """Python API for training.""" 199 | 200 | bpds_kwargs = bpds_kwargs or { 201 | "buffer_size": 16, 202 | "buffer_switch_interval": 2800, # every 100 updates 203 | "patch_shape": [32, 64, 64], 204 | } 205 | dataset_train_kwargs = dataset_train_kwargs or {} 206 | dataset_val_kwargs = dataset_val_kwargs or {} 207 | fnet_model_kwargs = fnet_model_kwargs or { 208 | "betas": [0.9, 0.999], 209 | "criterion_class": "fnet.losses.WeightedMSE", 210 | "init_weights": False, 211 | "lr": 0.001, 212 | "nn_class": "fnet.nn_modules.fnet_nn_3d.Net", 213 | "scheduler": None, 214 | } 215 | iter_checkpoint = iter_checkpoint or [] 216 | gpu_ids = gpu_ids or [0] 217 | 218 | json = json or f"{path_save_dir}train_options.json" 219 | 220 | pnames, _, _, locs = inspect.getargvalues(inspect.currentframe()) 221 | train_options = {k: locs[k] for k in pnames} 222 | 223 | path_json = Path(json) 224 | if path_json.exists(): 225 | logger.warning(f"Overwriting existing json: {path_json}") 226 | if not path_json.parent.exists(): 227 | logger.info(f"Created: {path_json.parent}") 228 | path_json.parent.mkdir(parents=True) 229 | 230 | json = globals()["json"] # retrieve global module 231 | with path_json.open("w") as f: 232 | json.dump(train_options, f, indent=4, sort_keys=True) 233 | logger.info(f"Saved: {path_json}") 234 | 235 | args = argparse.Namespace() 236 | args.__dict__.update(train_options) 237 | 238 | return main(args) 239 | -------------------------------------------------------------------------------- /fnet/data/__init__.py: -------------------------------------------------------------------------------- 1 | from fnet.data.bufferedpatchdataset import BufferedPatchDataset 2 | from fnet.data.tiffdataset import TiffDataset 3 | from fnet.data.fnetdataset import FnetDataset 4 | from fnet.data.multichtiffdataset import MultiChTiffDataset 5 | from fnet.data.dummydataset import DummyFnetDataset, DummyCustomFnetDataset 6 | 7 | 8 | __all__ = [ 9 | "BufferedPatchDataset", 10 | "FnetDataset", 11 | "TiffDataset", 12 | "MultiChTiffDataset", 13 | "DummyFnetDataset", 14 | "DummyCustomFnetDataset", 15 | ] 16 | -------------------------------------------------------------------------------- /fnet/data/bufferedpatchdataset.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | from typing import List, Sequence, Union 3 | import collections.abc 4 | import logging 5 | 6 | from tqdm import tqdm 7 | import numpy as np 8 | import torch 9 | 10 | 11 | logger = logging.getLogger(__name__) 12 | ArrayLike = Union[np.ndarray, torch.Tensor] 13 | 14 | 15 | class BufferedPatchDataset: 16 | """Provides patches from items of a dataset. 17 | 18 | Parameters 19 | ---------- 20 | dataset 21 | Dataset object. 22 | patch_shape 23 | Shape of patch to be extracted from dataset items. 24 | buffer_size 25 | Size of buffer. 26 | buffer_switch_interval 27 | Number of patches provided between buffer item exchanges. Set to -1 to 28 | disable exchanges. 29 | shuffle_images 30 | Set to randomize order of dataset item insertion into buffer. 31 | 32 | """ 33 | 34 | def __init__( 35 | self, 36 | dataset: collections.abc.Sequence, 37 | patch_shape: Sequence[int] = (32, 64, 64), 38 | buffer_size: int = 1, 39 | buffer_switch_interval: int = -1, 40 | shuffle_images: bool = True, 41 | ): 42 | self.dataset = dataset 43 | self.patch_shape = patch_shape 44 | self.buffer_size = min(len(self.dataset), buffer_size) 45 | self.buffer_switch_interval = buffer_switch_interval 46 | self.shuffle_images = shuffle_images 47 | 48 | self.counter = 0 49 | self.epochs = -1 # incremented to 0 when buffer initially filled 50 | self.buffer = deque() 51 | self.remaining_to_be_in_buffer = deque() 52 | self.buffer_history = [] 53 | for _ in tqdm(range(self.buffer_size), desc="Buffering images"): 54 | self.insert_new_element_into_buffer() 55 | 56 | def __iter__(self): 57 | return self 58 | 59 | def __next__(self): 60 | patch = self.get_random_patch() 61 | self.counter += 1 62 | if (self.buffer_switch_interval > 0) and ( 63 | self.counter % self.buffer_switch_interval == 0 64 | ): 65 | self.insert_new_element_into_buffer() 66 | return patch 67 | 68 | def _check_last_datum(self) -> None: 69 | """Checks last dataset item added to buffer.""" 70 | nd = len(self.patch_shape) 71 | idx_buf = self.buffer_history[-1] 72 | shape_spatial = None 73 | for idx_c, component in enumerate(self.buffer[-1]): 74 | if shape_spatial is None: 75 | shape_spatial = component.shape[-nd:] 76 | elif component.shape[-nd:] != shape_spatial: 77 | raise ValueError( 78 | f"Dataset item {idx_buf}, component {idx_c} shape " 79 | f"{component.shape} incompatible with first component " 80 | f"shape {self.buffer[-1][0].shape}" 81 | ) 82 | if nd > len(component.shape) or any( 83 | self.patch_shape[d] > shape_spatial[d] for d in range(nd) 84 | ): 85 | raise ValueError( 86 | f"Dataset item {idx_buf}, component {idx_c} shape " 87 | f"{component.shape} incompatible with patch_shape " 88 | f"{self.patch_shape}" 89 | ) 90 | 91 | def insert_new_element_into_buffer(self) -> None: 92 | """Inserts new dataset item into buffer. 93 | 94 | Returns 95 | ------- 96 | None 97 | 98 | """ 99 | if len(self.remaining_to_be_in_buffer) == 0: 100 | self.epochs += 1 101 | self.remaining_to_be_in_buffer = deque(range(len(self.dataset))) 102 | if self.shuffle_images: 103 | np.random.shuffle(self.remaining_to_be_in_buffer) 104 | if len(self.buffer) >= self.buffer_size: 105 | self.buffer.popleft() 106 | new_datum_index = self.remaining_to_be_in_buffer.popleft() 107 | self.buffer_history.append(new_datum_index) 108 | self.buffer.append(self.dataset[new_datum_index]) 109 | logger.info(f"Added item {new_datum_index} into buffer") 110 | self._check_last_datum() 111 | 112 | def get_random_patch(self) -> List[ArrayLike]: 113 | """Samples random patch from an item in the buffer. 114 | 115 | Let nd be the number of dimensions of the patch. If the item has more 116 | dimensions than the patch, then sampling will be from the last nd 117 | dimensions of the item. 118 | 119 | Returns 120 | ------- 121 | List[ArrayLike] 122 | Random patch sampled from a dataset item. 123 | 124 | """ 125 | nd = len(self.patch_shape) 126 | buffer_index = np.random.randint(len(self.buffer)) 127 | datum = self.buffer[buffer_index] 128 | shape_spatial = datum[0].shape[-nd:] 129 | patch = [] 130 | slices = None 131 | for part in datum: 132 | if slices is None: 133 | starts = np.array( 134 | [ 135 | np.random.randint(0, d - p + 1) 136 | for d, p in zip(shape_spatial, self.patch_shape) 137 | ] 138 | ) 139 | ends = starts + np.array(self.patch_shape) 140 | slices = tuple(slice(s, e) for s, e in zip(starts, ends)) 141 | # Pad slices with "slice(None)" if there are non-spatial dimensions 142 | slices_pad = (slice(None),) * (len(part.shape) - len(shape_spatial)) 143 | patch.append(part[slices_pad + slices]) 144 | return patch 145 | 146 | def get_batch(self, batch_size: int) -> Sequence[torch.Tensor]: 147 | """Returns a batch of patches. 148 | 149 | Parameters 150 | ---------- 151 | batch_size 152 | Number of patches in batch. 153 | 154 | Returns 155 | ------- 156 | Sequence[torch.Tensor] 157 | Batch of patches. 158 | 159 | """ 160 | return tuple( 161 | torch.tensor(np.stack(part)) 162 | for part in zip(*[next(self) for _ in range(batch_size)]) 163 | ) 164 | 165 | def get_buffer_history(self) -> List[int]: 166 | """Returns a list of indices of dataset elements inserted into the 167 | buffer. 168 | 169 | Returns 170 | ------- 171 | List[int] 172 | Indices of dataset elements. 173 | 174 | """ 175 | return self.buffer_history 176 | -------------------------------------------------------------------------------- /fnet/data/czidataset.py: -------------------------------------------------------------------------------- 1 | from fnet.data.czireader import CziReader 2 | from fnet.data.fnetdataset import FnetDataset 3 | import numpy as np 4 | import pdb # noqa: F401 5 | import torch.utils.data 6 | 7 | 8 | class CziDataset(FnetDataset): 9 | """Dataset for CZI files. 10 | 11 | """ 12 | 13 | def __init__(self, **kwargs): 14 | super().__init__(**kwargs) 15 | 16 | def __getitem__(self, index): 17 | element = self.df.iloc[index, :] 18 | has_target = not np.isnan(element["channel_target"]) 19 | czi = CziReader(element["path_czi"]) 20 | 21 | im_out = list() 22 | im_out.append(czi.get_volume(element["channel_signal"])) 23 | if has_target: 24 | im_out.append(czi.get_volume(element["channel_target"])) 25 | if self.transform_signal is not None: 26 | for t in self.transform_signal: 27 | im_out[0] = t(im_out[0]) 28 | if has_target and self.transform_target is not None: 29 | for t in self.transform_target: 30 | im_out[1] = t(im_out[1]) 31 | im_out = [torch.from_numpy(im.astype(float)).float() for im in im_out] 32 | # unsqueeze to make the first dimension be the channel dimension 33 | im_out = [torch.unsqueeze(im, 0) for im in im_out] 34 | return im_out 35 | 36 | def __len__(self): 37 | return len(self.df) 38 | 39 | def get_information(self, index: int) -> dict: 40 | return self.df.iloc[index, :].to_dict() 41 | -------------------------------------------------------------------------------- /fnet/data/czireader.py: -------------------------------------------------------------------------------- 1 | import czifile 2 | 3 | 4 | def get_czi_metadata(element, tag_list): 5 | """ 6 | element - (xml.etree.ElementTree.Element) 7 | tag_list - list of strings 8 | """ 9 | if len(tag_list) == 0: 10 | return None 11 | if len(tag_list) == 1: 12 | if tag_list[0] == "attrib": 13 | return [element.attrib] 14 | if tag_list[0] == "text": 15 | return [element.text] 16 | values = [] 17 | for sub_ele in element: 18 | if sub_ele.tag == tag_list[0]: 19 | if len(tag_list) == 1: 20 | values.extend([sub_ele]) 21 | else: 22 | retval = get_czi_metadata(sub_ele, tag_list[1:]) 23 | if retval is not None: 24 | values.extend(retval) 25 | if len(values) == 0: 26 | return None 27 | return values 28 | 29 | 30 | def get_shape_from_metadata(metadata): 31 | """Return tuple of CZI's dimensions in order (Z, Y, X).""" 32 | tag_list = "Metadata.Information.Image".split(".") 33 | elements = get_czi_metadata(metadata, tag_list) 34 | if elements is None: 35 | return None 36 | ele_image = elements[0] 37 | dim_tags = ("SizeZ", "SizeY", "SizeX") 38 | shape = [] 39 | for dim_tag in dim_tags: 40 | ele_dim = get_czi_metadata(ele_image, [dim_tag, "text"]) 41 | shape_dim = int(ele_dim[0]) 42 | shape.append(shape_dim) 43 | return tuple(shape) 44 | 45 | 46 | class CziReader: 47 | """Wraps czifile.CziFile. 48 | 49 | """ 50 | 51 | def __init__(self, path_czi): 52 | with czifile.CziFile(path_czi) as czi: 53 | self.czi_np = czi.asarray() 54 | self.axes = czi.axes 55 | self.metadata = czi.metadata 56 | 57 | def get_size(self, dim_sel): 58 | dim = -1 59 | if isinstance(dim_sel, int): 60 | dim = dim_sel 61 | elif isinstance(dim_sel, str): 62 | dim = self.axes.find(dim_sel) 63 | assert dim >= 0 64 | return self.czi_np.shape[dim] 65 | 66 | def get_scales(self): 67 | tag_list = "Metadata.Scaling.Items.Distance".split(".") 68 | dict_scales = {} 69 | for entry in get_czi_metadata(self.metadata, tag_list): 70 | dim = entry.attrib.get("Id") 71 | if (dim is not None) and (dim.lower() in "zyx"): 72 | # convert from m/px to um/px 73 | scale = 10 ** 6 * float(get_czi_metadata(entry, ["Value"])[0].text) 74 | dict_scales[dim.lower()] = scale 75 | return dict_scales 76 | 77 | def get_volume(self, chan, time_slice=None): 78 | """Returns the image volume for the specified channel.""" 79 | slices = [] 80 | for i in range(len(self.axes)): 81 | dim_label = self.axes[i] 82 | if dim_label in "C": 83 | slices.append(chan) 84 | elif dim_label in "T": 85 | if time_slice is None: 86 | slices.append(0) 87 | else: 88 | slices.append(time_slice) 89 | elif dim_label in "ZYX": 90 | slices.append(slice(None)) 91 | else: 92 | slices.append(0) 93 | slices = tuple(slices) 94 | return self.czi_np[slices] 95 | -------------------------------------------------------------------------------- /fnet/data/dummydataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import tifffile 6 | import torch 7 | 8 | from fnet.data import TiffDataset 9 | from fnet.utils.general_utils import add_augmentations 10 | 11 | 12 | def DummyFnetDataset(train: bool = False) -> TiffDataset: 13 | """Returns a dummy Fnetdataset.""" 14 | df = pd.DataFrame( 15 | { 16 | "path_signal": [os.path.join("data", "EM_low.tif")], 17 | "path_target": [os.path.join("data", "MBP_low.tif")], 18 | } 19 | ).rename_axis("arbitrary") 20 | if not train: 21 | df = add_augmentations(df) 22 | return TiffDataset(dataframe=df) 23 | 24 | 25 | class _CustomDataset: 26 | """Custom, non-FnetDataset.""" 27 | 28 | def __init__(self, df: pd.DataFrame): 29 | self._df = df 30 | 31 | def __len__(self): 32 | return len(self._df) 33 | 34 | def __getitem__(self, idx): 35 | loc = self._df.index[idx] 36 | sig = torch.from_numpy( 37 | tifffile.imread(self._df.loc[loc, "path_signal"])[np.newaxis,] 38 | ) 39 | tar = torch.from_numpy( 40 | tifffile.imread(self._df.loc[loc, "path_target"])[np.newaxis,] 41 | ) 42 | return (sig, tar) 43 | 44 | 45 | def DummyCustomFnetDataset(train: bool = False) -> TiffDataset: 46 | """Returns a dummy custom dataset.""" 47 | df = pd.DataFrame( 48 | { 49 | "path_signal": [os.path.join("data", "EM_low.tif")], 50 | "path_target": [os.path.join("data", "MBP_low.tif")], 51 | } 52 | ) 53 | if not train: 54 | df = add_augmentations(df) 55 | return _CustomDataset(df) 56 | -------------------------------------------------------------------------------- /fnet/data/fnetdataset.py: -------------------------------------------------------------------------------- 1 | from fnet.utils.general_utils import to_objects, whats_my_name 2 | from typing import List, Optional, Union 3 | import pandas as pd 4 | import torch.utils.data 5 | 6 | 7 | def _to_str_list(olist: List) -> Optional[List[str]]: 8 | """Turns a list of objects into a list of the objects' string 9 | representations. 10 | 11 | """ 12 | if olist is None: 13 | return None 14 | return [whats_my_name(o) for o in olist] 15 | 16 | 17 | class _LocIndexer: 18 | """'Loc' indexer of objects with a 'df' (DataFrame) attribute.""" 19 | 20 | def __init__(self, super_obj): 21 | assert isinstance(super_obj.df, pd.DataFrame) 22 | self.super_obj = super_obj 23 | 24 | def __getitem__(self, idx): 25 | idx_trans = self.super_obj.df.index.get_loc(idx) 26 | return self.super_obj[idx_trans] 27 | 28 | 29 | class _iLocIndexer: 30 | """'iLoc' indexer of objects with a 'df' (DataFrame) attribute.""" 31 | 32 | def __init__(self, super_obj): 33 | assert isinstance(super_obj.df, pd.DataFrame) 34 | self.super_obj = super_obj 35 | 36 | def __getitem__(self, idx): 37 | return self.super_obj[idx] 38 | 39 | 40 | class FnetDataset(torch.utils.data.Dataset): 41 | """Abstract class for fnet datasets. 42 | 43 | Parameters 44 | ---------- 45 | dataframe 46 | DataFrame where rows are dataset elements. Overrides path_csv. 47 | path_csv 48 | Path to csv from which to create DataFrame. 49 | transform_signal 50 | List of transforms to apply to signal image. 51 | transform_target 52 | List of transforms to apply to target image. 53 | 54 | """ 55 | 56 | def __init__( 57 | self, 58 | dataframe: Optional[pd.DataFrame] = None, 59 | path_csv: Optional[str] = None, 60 | transform_signal: Optional[list] = None, 61 | transform_target: Optional[list] = None, 62 | ): 63 | self.path_csv = None 64 | if dataframe is not None: 65 | self.df = dataframe 66 | else: 67 | self.path_csv = path_csv 68 | self.df = pd.read_csv(self.path_csv) 69 | self.transform_signal = to_objects(transform_signal) 70 | self.transform_target = to_objects(transform_target) 71 | self._metadata = None 72 | self.loc = _LocIndexer(self) 73 | self.iloc = _iLocIndexer(self) 74 | 75 | @property 76 | def metadata(self) -> dict: 77 | """Returns metadata about the dataset.""" 78 | if self._metadata is not None: 79 | return self._metadata 80 | self._metadata = {} 81 | if self.path_csv is not None: 82 | self._metadata["path_csv"] = self.path_csv 83 | self._metadata["transform_signal"] = _to_str_list(self.transform_signal) 84 | self._metadata["transform_target"] = _to_str_list(self.transform_target) 85 | return self._metadata 86 | 87 | def get_information(self, index) -> Union[dict, str]: 88 | """Returns information to identify dataset element specified by index. 89 | 90 | """ 91 | raise NotImplementedError 92 | -------------------------------------------------------------------------------- /fnet/data/multichtiffdataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import torch 4 | 5 | from aicsimageio import AICSImage 6 | from fnet.data.fnetdataset import FnetDataset 7 | 8 | 9 | class MultiChTiffDataset(FnetDataset): 10 | """ 11 | Dataset for multi-channel tiff files. 12 | """ 13 | 14 | def __init__( 15 | self, 16 | dataframe: pd.DataFrame = None, 17 | path_csv: str = None, 18 | transform_signal=None, 19 | transform_target=None, 20 | ): 21 | 22 | super().__init__(dataframe, path_csv, transform_signal, transform_target) 23 | 24 | # if this column is a string assume it is in "[ind_1, ind_2, ..., ind_n]" format 25 | if isinstance(self.df["channel_signal"][0], str): 26 | self.df["channel_signal"] = [ 27 | np.fromstring(ch[1:-1], sep=", ").astype(int) 28 | for ch in self.df["channel_signal"] 29 | ] 30 | else: 31 | self.df["channel_signal"] = [[int(ch)] for ch in self.df["channel_signal"]] 32 | 33 | if isinstance(self.df["channel_target"][0], str): 34 | self.df["channel_target"] = [ 35 | np.fromstring(ch[1:-1], sep=", ").astype(int) 36 | for ch in self.df["channel_target"] 37 | ] 38 | else: 39 | self.df["channel_target"] = [[int(ch)] for ch in self.df["channel_target"]] 40 | 41 | assert all( 42 | i in self.df.columns 43 | for i in ["path_tiff", "channel_signal", "channel_target"] 44 | ) 45 | 46 | def __getitem__(self, index): 47 | """ 48 | Parameters 49 | ---------- 50 | index: integer 51 | 52 | Returns 53 | ------- 54 | C by torch.Tensor 55 | """ 56 | 57 | element = self.df.iloc[index, :] 58 | has_target = not np.any(np.isnan(element["channel_target"])) 59 | 60 | # aicsimageio.imread loads as STCZYX, so we load only CZYX 61 | with AICSImage(element["path_tiff"]) as img: 62 | im_tmp = img.get_image_data("CZYX", S=0, T=0) 63 | 64 | im_out = list() 65 | im_out.append(im_tmp[element["channel_signal"]]) 66 | 67 | if has_target: 68 | im_out.append(im_tmp[element["channel_target"]]) 69 | 70 | if self.transform_signal is not None: 71 | for t in self.transform_signal: 72 | im_out[0] = t(im_out[0]) 73 | 74 | if has_target and self.transform_target is not None: 75 | for t in self.transform_target: 76 | im_out[1] = t(im_out[1]) 77 | 78 | im_out = [torch.from_numpy(im.astype(float)).float() for im in im_out] 79 | 80 | # unsqueeze to make the first dimension be the channel dimension 81 | # im_out = [torch.unsqueeze(im, 0) for im in im_out] 82 | 83 | return tuple(im_out) 84 | 85 | def __len__(self): 86 | return len(self.df) 87 | 88 | def get_information(self, index: int) -> dict: 89 | return self.df.iloc[index, :].to_dict() 90 | -------------------------------------------------------------------------------- /fnet/data/tiffdataset.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import numpy as np 4 | import tifffile 5 | import torch 6 | 7 | from fnet.data.fnetdataset import FnetDataset 8 | from fnet.utils.general_utils import add_augmentations 9 | 10 | 11 | def _flip_y(ar): 12 | """Flip array along y axis. 13 | 14 | Array should have dimensions ZYX or YX. 15 | 16 | """ 17 | return np.flip(ar, axis=-2) 18 | 19 | 20 | def _flip_x(ar): 21 | """Flip array along x axis. 22 | 23 | Array should have dimensions ZYX or YX. 24 | 25 | """ 26 | return np.flip(ar, axis=-1) 27 | 28 | 29 | class TiffDataset(FnetDataset): 30 | """Dataset where each row is a signal-target pairing from TIFF files. 31 | 32 | Dataset items will be 2-item or 3-item tuples: 33 | (signal image, target image) or 34 | (signal image, target image, cost map) 35 | 36 | Parameters 37 | ---------- 38 | augment 39 | Set to augment dataset with flips about the x and/or y axis. 40 | 41 | """ 42 | 43 | def __init__( 44 | self, 45 | col_index: Optional[str] = None, 46 | col_signal: str = "path_signal", 47 | col_target: str = "path_target", 48 | col_weight_map: str = "path_weight_map", 49 | augment: bool = False, 50 | **kwargs, 51 | ): 52 | super().__init__(**kwargs) 53 | self.col_index = col_index 54 | self.col_signal = col_signal 55 | self.col_target = col_target 56 | self.col_weight_map = col_weight_map 57 | self.augment = augment 58 | if self.col_index is not None: 59 | self.df = self.df.set_index(self.col_index) 60 | if self.augment: 61 | self.df = add_augmentations(self.df) 62 | if self.col_weight_map not in self.df.columns: 63 | self.col_weight_map = None 64 | 65 | for col in [self.col_signal, self.col_target, self.col_weight_map]: 66 | if col is not None and col not in self.df.columns: 67 | raise ValueError(f"{col} not a dataset DataFrame column") 68 | 69 | def __len__(self): 70 | return self.df.shape[0] 71 | 72 | def __getitem__(self, idx): 73 | flip_y = self.df.iloc[idx, :].get("flip_y", -1) > 0 74 | flip_x = self.df.iloc[idx, :].get("flip_x", -1) > 0 75 | datum = [] 76 | for col, transforms in [ 77 | [self.col_signal, self.transform_signal], 78 | [self.col_target, self.transform_target], 79 | [self.col_weight_map, None], # optional weight maps 80 | ]: 81 | if col is None: 82 | continue 83 | path_read = self.df.loc[self.df.index[idx], col] 84 | if not isinstance(path_read, str): 85 | datum.append(None) 86 | continue 87 | ar = tifffile.imread(path_read) 88 | if transforms is None: 89 | transforms = [] 90 | if flip_y: 91 | transforms.append(_flip_y) 92 | if flip_x: 93 | transforms.append(_flip_x) 94 | for transform in transforms: 95 | ar = transform(ar) 96 | datum.append( 97 | torch.tensor(ar[np.newaxis,].astype(np.float32), dtype=torch.float32) 98 | ) 99 | return tuple(datum) 100 | 101 | def get_information(self, idx: int) -> dict: 102 | """Returns information about the dataset item. 103 | 104 | Parameters 105 | ---------- 106 | idx 107 | Index of dataset item for which to retrieve information. 108 | 109 | Returns 110 | ------- 111 | dict 112 | Information about dataset item. 113 | 114 | """ 115 | return self.df.loc[idx, :].to_dict() 116 | -------------------------------------------------------------------------------- /fnet/fnet_ensemble.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List 2 | import logging 3 | import os 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from fnet.fnet_model import Model 9 | from fnet.utils.general_utils import str_to_class 10 | 11 | 12 | logger = logging.info(__name__) 13 | 14 | 15 | def _load_model(path_model: str) -> Model: 16 | """Load saved model from path.""" 17 | state = torch.load(path_model) 18 | fnet_model_class = state["fnet_model_class"] 19 | fnet_model_kwargs = state["fnet_model_kwargs"] 20 | model = str_to_class(fnet_model_class)(**fnet_model_kwargs) 21 | model.load_state(state, no_optim=True) 22 | return model 23 | 24 | 25 | class FnetEnsemble(Model): 26 | """Ensemble of FnetModels. 27 | 28 | Parameters 29 | ---------- 30 | paths_model 31 | Path to a directory of saved models or a list of paths to saved models. 32 | 33 | Attributes 34 | ---------- 35 | paths_model : Union[str, List[str]] 36 | Paths to saved models in the ensemble. 37 | gpu_ids : List[int] 38 | GPU(s) used for prediction tasks. 39 | 40 | """ 41 | 42 | def __init__(self, paths_model: Union[str, List[str]]) -> None: 43 | if isinstance(paths_model, str): 44 | assert os.path.isdir(paths_model) 45 | paths_model = sorted( 46 | [ 47 | p.path 48 | for p in os.scandir(os.path.abspath(paths_model)) 49 | if p.path.lower().endswith(".p") 50 | ] 51 | ) 52 | assert len(paths_model) > 0 53 | self.paths_model = paths_model 54 | self.gpu_ids = -1 55 | 56 | def __str__(self): 57 | str_out = [] 58 | str_out.append(f"{len(self.paths_model)}-model ensemble:") 59 | str_out.extend([p for p in self.paths_model]) 60 | return os.linesep.join(str_out) 61 | 62 | def to_gpu(self, gpu_ids: Union[int, list]) -> None: 63 | """Move network to specified GPU(s). 64 | 65 | Parameters 66 | ---------- 67 | gpu_ids 68 | GPU(s) on which to perform training or prediction. 69 | 70 | """ 71 | if isinstance(gpu_ids, int): 72 | gpu_ids = [gpu_ids] 73 | self.gpu_ids = gpu_ids 74 | 75 | def predict( 76 | self, x: Union[torch.Tensor, np.ndarray], tta: bool = False 77 | ) -> torch.Tensor: 78 | """Performs model prediction. 79 | 80 | Parameters 81 | ---------- 82 | x 83 | Batched input. 84 | tta 85 | Set to to use test-time augmentation. 86 | 87 | Returns 88 | ------- 89 | torch.Tensor 90 | Model prediction. 91 | 92 | """ 93 | y_hat_mean = None 94 | for path_model in self.paths_model: 95 | model = _load_model(path_model) 96 | model.to_gpu(self.gpu_ids) 97 | y_hat = model.predict(x=x, tta=tta) 98 | if y_hat_mean is None: 99 | y_hat_mean = torch.zeros(*y_hat.size()) 100 | y_hat_mean += y_hat 101 | return y_hat_mean / len(self.paths_model) 102 | 103 | # Override 104 | def save(self, path_save: str): 105 | """Saves model to disk. 106 | 107 | Parameters 108 | ---------- 109 | path_save 110 | Filename to which model is saved. 111 | 112 | """ 113 | state = { 114 | "fnet_model_class": (self.__module__ + "." + self.__class__.__qualname__), 115 | "fnet_model_kwargs": {"paths_model": self.paths_model}, 116 | } 117 | dirname = os.path.dirname(path_save) 118 | if not os.path.exists(dirname): 119 | os.makedirs(dirname) 120 | torch.save(state, path_save) 121 | logger.info(f"Ensemble model saved to: {path_save}") 122 | 123 | # Override 124 | def load_state(self, state: dict, no_optim: bool = False): 125 | return 126 | -------------------------------------------------------------------------------- /fnet/fnet_model.py: -------------------------------------------------------------------------------- 1 | """Module to define main fnet model wrapper class.""" 2 | 3 | 4 | from pathlib import Path 5 | from typing import Callable, Iterator, List, Optional, Sequence, Tuple, Union 6 | import logging 7 | import math 8 | import os 9 | 10 | from scipy.ndimage import zoom 11 | import numpy as np 12 | import tifffile 13 | import torch 14 | 15 | from fnet.metrics import corr_coef 16 | from fnet.predict_piecewise import predict_piecewise as _predict_piecewise_fn 17 | from fnet.transforms import flip_y, flip_x, norm_around_center 18 | from fnet.utils.general_utils import get_args, retry_if_oserror, str_to_object 19 | from fnet.utils.model_utils import move_optim 20 | 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | def _weights_init(m): 26 | classname = m.__class__.__name__ 27 | if classname.startswith("Conv"): 28 | m.weight.data.normal_(0.0, 0.02) 29 | elif classname.find("BatchNorm") != -1: 30 | m.weight.data.normal_(1.0, 0.02) 31 | m.bias.data.fill_(0) 32 | 33 | 34 | def get_per_param_options(module, wd): 35 | """Returns list of per parameter group options. 36 | 37 | Applies the specified weight decay (wd) to parameters except parameters 38 | within batch norm layers and bias parameters. 39 | """ 40 | if wd == 0: 41 | return module.parameters() 42 | with_decay = list() 43 | without_decay = list() 44 | for idx_m, (name_m, module_sub) in enumerate(module.named_modules()): 45 | if list(module_sub.named_children()): 46 | continue # Skip "container" modules 47 | if isinstance(module_sub, torch.nn.modules.batchnorm._BatchNorm): 48 | for param in module_sub.parameters(): 49 | without_decay.append(param) 50 | continue 51 | for name_param, param in module_sub.named_parameters(): 52 | if "weight" in name_param: 53 | with_decay.append(param) 54 | elif "bias" in name_param: 55 | without_decay.append(param) 56 | # Check that no parameters were missed or duplicated 57 | n_param_module = len(list(module.parameters())) 58 | n_param_lists = len(with_decay) + len(without_decay) 59 | n_elem_module = sum([p.numel() for p in module.parameters()]) 60 | n_elem_lists = sum([p.numel() for p in with_decay + without_decay]) 61 | assert n_param_module == n_param_lists 62 | assert n_elem_module == n_elem_lists 63 | per_param_options = [ 64 | {"params": with_decay, "weight_decay": wd}, 65 | {"params": without_decay, "weight_decay": 0.0}, 66 | ] 67 | return per_param_options 68 | 69 | 70 | class Model: 71 | """Class that encompasses a pytorch network and its optimizer. 72 | 73 | """ 74 | 75 | def __init__( 76 | self, 77 | betas=(0.5, 0.999), 78 | criterion_class="fnet.losses.WeightedMSE", 79 | init_weights=True, 80 | lr=0.001, 81 | nn_class="fnet.nn_modules.fnet_nn_3d.Net", 82 | nn_kwargs={}, 83 | scheduler=None, 84 | weight_decay=0, 85 | gpu_ids=-1, 86 | ): 87 | self.betas = betas 88 | self.criterion = str_to_object(criterion_class)() 89 | self.gpu_ids = [gpu_ids] if isinstance(gpu_ids, int) else gpu_ids 90 | self.init_weights = init_weights 91 | self.lr = lr 92 | self.nn_class = nn_class 93 | self.nn_kwargs = nn_kwargs 94 | self.scheduler = scheduler 95 | self.weight_decay = weight_decay 96 | 97 | self.count_iter = 0 98 | self.device = ( 99 | torch.device("cuda", self.gpu_ids[0]) 100 | if self.gpu_ids[0] >= 0 101 | else torch.device("cpu") 102 | ) 103 | self.optimizer = None 104 | self._init_model() 105 | self.fnet_model_kwargs, self.fnet_model_posargs = get_args() 106 | self.fnet_model_kwargs.pop("self") 107 | 108 | def _init_model(self): 109 | self.net = str_to_object(self.nn_class)(**self.nn_kwargs) 110 | if self.init_weights: 111 | self.net.apply(_weights_init) 112 | self.net.to(self.device) 113 | self.optimizer = torch.optim.Adam( 114 | get_per_param_options(self.net, wd=self.weight_decay), 115 | lr=self.lr, 116 | betas=self.betas, 117 | ) 118 | if self.scheduler is not None: 119 | if self.scheduler[0] == "snapshot": 120 | period = self.scheduler[1] 121 | self.scheduler = torch.optim.lr_scheduler.LambdaLR( 122 | self.optimizer, 123 | lambda x: ( 124 | 0.01 125 | + (1 - 0.01) 126 | * (0.5 + 0.5 * math.cos(math.pi * (x % period) / period)) 127 | ), 128 | ) 129 | elif self.scheduler[0] == "step": 130 | step_size = self.scheduler[1] 131 | self.scheduler = torch.optim.lr_scheduler.StepLR( 132 | self.optimizer, step_size 133 | ) 134 | else: 135 | raise NotImplementedError 136 | 137 | def __str__(self): 138 | out_str = [ 139 | f"*** {self.__class__.__name__} ***", 140 | f"{self.nn_class}(**{self.nn_kwargs})", 141 | f"iter: {self.count_iter}", 142 | f"gpu: {self.gpu_ids}", 143 | ] 144 | return os.linesep.join(out_str) 145 | 146 | def get_state(self): 147 | return { 148 | "fnet_model_class": (self.__module__ + "." + self.__class__.__qualname__), 149 | "fnet_model_kwargs": self.fnet_model_kwargs, 150 | "fnet_model_posargs": self.fnet_model_posargs, 151 | "nn_state": self.net.state_dict(), 152 | "optimizer_state": self.optimizer.state_dict(), 153 | "count_iter": self.count_iter, 154 | } 155 | 156 | def to_gpu(self, gpu_ids: Union[int, List[int]]) -> None: 157 | """Move network to specified GPU(s). 158 | 159 | Parameters 160 | ---------- 161 | gpu_ids 162 | GPU(s) on which to perform training or prediction. 163 | 164 | """ 165 | if isinstance(gpu_ids, int): 166 | gpu_ids = [gpu_ids] 167 | self.gpu_ids = gpu_ids 168 | self.device = ( 169 | torch.device("cuda", self.gpu_ids[0]) 170 | if self.gpu_ids[0] >= 0 171 | else torch.device("cpu") 172 | ) 173 | self.net.to(self.device) 174 | if self.optimizer is not None: 175 | move_optim(self.optimizer, self.device) 176 | 177 | def save(self, path_save: str): 178 | """Saves model to disk. 179 | 180 | Parameters 181 | ---------- 182 | path_save 183 | Filename to which model is saved. 184 | 185 | """ 186 | dirname = os.path.dirname(path_save) 187 | if not os.path.exists(dirname): 188 | os.makedirs(dirname) 189 | logger.info(f"Created: {dirname}") 190 | curr_gpu_ids = self.gpu_ids 191 | self.to_gpu(-1) 192 | retry_if_oserror(torch.save)(self.get_state(), path_save) 193 | self.to_gpu(curr_gpu_ids) 194 | 195 | def load_state(self, state: dict, no_optim: bool = False): 196 | self.count_iter = state["count_iter"] 197 | self.net.load_state_dict(state["nn_state"]) 198 | if no_optim: 199 | self.optimizer = None 200 | return 201 | self.optimizer.load_state_dict(state["optimizer_state"]) 202 | 203 | def train_on_batch( 204 | self, 205 | x_batch: torch.Tensor, 206 | y_batch: torch.Tensor, 207 | weight_map_batch: Optional[torch.Tensor] = None, 208 | ) -> float: 209 | """Update model using a batch of inputs and targets. 210 | 211 | Parameters 212 | ---------- 213 | x_batch 214 | Batched input. 215 | y_batch 216 | Batched target. 217 | weight_map_batch 218 | Optional batched weight map. 219 | 220 | Returns 221 | ------- 222 | float 223 | Loss as determined by self.criterion. 224 | 225 | """ 226 | if self.scheduler is not None: 227 | self.scheduler.step() 228 | self.net.train() 229 | x_batch = x_batch.to(dtype=torch.float32, device=self.device) 230 | y_batch = y_batch.to(dtype=torch.float32, device=self.device) 231 | if len(self.gpu_ids) > 1: 232 | module = torch.nn.DataParallel(self.net, device_ids=self.gpu_ids) 233 | else: 234 | module = self.net 235 | 236 | self.optimizer.zero_grad() 237 | y_hat_batch = module(x_batch) 238 | args = [y_hat_batch, y_batch] 239 | if weight_map_batch is not None: 240 | args.append(weight_map_batch) 241 | loss = self.criterion(*args) 242 | loss.backward() 243 | self.optimizer.step() 244 | self.count_iter += 1 245 | return loss.item() 246 | 247 | def _predict_on_batch_tta(self, x_batch: torch.Tensor) -> torch.Tensor: 248 | """Performs model prediction using test-time augmentation.""" 249 | augs = [None, [flip_y], [flip_x], [flip_y, flip_x]] 250 | x_batch = x_batch.numpy() 251 | y_hat_batch_mean = None 252 | for aug in augs: 253 | x_batch_aug = x_batch.copy() 254 | if aug is not None: 255 | for trans in aug: 256 | x_batch_aug = trans(x_batch_aug) 257 | y_hat_batch = self.predict_on_batch(x_batch_aug.copy()).numpy() 258 | if aug is not None: 259 | for trans in aug: 260 | y_hat_batch = trans(y_hat_batch) 261 | if y_hat_batch_mean is None: 262 | y_hat_batch_mean = np.zeros(y_hat_batch.shape, dtype=np.float32) 263 | y_hat_batch_mean += y_hat_batch 264 | y_hat_batch_mean /= len(augs) 265 | return torch.tensor( 266 | y_hat_batch_mean, dtype=torch.float32, device=torch.device("cpu") 267 | ) 268 | 269 | def predict_on_batch(self, x_batch: torch.Tensor) -> torch.Tensor: 270 | """Performs model prediction on a batch of data. 271 | 272 | Parameters 273 | ---------- 274 | x_batch 275 | Batch of input data. 276 | 277 | Returns 278 | ------- 279 | torch.Tensor 280 | Batch of model predictions. 281 | 282 | """ 283 | x_batch = torch.tensor(x_batch, dtype=torch.float32, device=self.device) 284 | 285 | if len(self.gpu_ids) > 1: 286 | network = torch.nn.DataParallel(self.net, device_ids=self.gpu_ids) 287 | else: 288 | network = self.net 289 | 290 | network.eval() 291 | with torch.no_grad(): 292 | y_hat_batch = network(x_batch).cpu() 293 | 294 | network.train() 295 | 296 | return y_hat_batch 297 | 298 | def predict( 299 | self, x: Union[torch.Tensor, np.ndarray], tta: bool = False 300 | ) -> torch.Tensor: 301 | """Performs model prediction on a single example. 302 | 303 | Parameters 304 | ---------- 305 | x 306 | Input data. 307 | piecewise 308 | Set to perform piecewise predictions. i.e., predict on patches of 309 | the input and stitch together the predictions. 310 | tta 311 | Set to use test-time augmentation. 312 | 313 | Returns 314 | ------- 315 | torch.Tensor 316 | Model prediction. 317 | 318 | """ 319 | x_batch = torch.unsqueeze(torch.tensor(x), 0) 320 | if tta: 321 | return self._predict_on_batch_tta(x_batch).squeeze(0) 322 | return self.predict_on_batch(x_batch).squeeze(0) 323 | 324 | def predict_piecewise( 325 | self, x: Union[torch.Tensor, np.ndarray], **predict_kwargs 326 | ) -> torch.Tensor: 327 | """Performs model prediction piecewise on a single example. 328 | 329 | Predicts on patches of the input and stitchs together the predictions. 330 | 331 | Parameters 332 | ---------- 333 | x 334 | Input data. 335 | **predict_kwargs 336 | Kwargs to pass to predict method. 337 | 338 | Returns 339 | ------- 340 | torch.Tensor 341 | Model prediction. 342 | 343 | """ 344 | if isinstance(x, np.ndarray): 345 | x = torch.from_numpy(x) 346 | if len(x.size()) == 4: 347 | dims_max = [None, 32, 512, 512] 348 | elif len(x.size()) == 3: 349 | dims_max = [None, 1024, 1024] 350 | y_hat = _predict_piecewise_fn( 351 | self, x, dims_max=dims_max, overlaps=16, **predict_kwargs 352 | ) 353 | return y_hat 354 | 355 | def test_on_batch( 356 | self, 357 | x_batch: torch.Tensor, 358 | y_batch: torch.Tensor, 359 | weight_map_batch: Optional[torch.Tensor] = None, 360 | ) -> float: 361 | """Test model on a batch of inputs and targets. 362 | 363 | Parameters 364 | ---------- 365 | x_batch 366 | Batched input. 367 | y_batch 368 | Batched target. 369 | weight_map_batch 370 | Optional batched weight map. 371 | 372 | Returns 373 | ------- 374 | float 375 | Loss as evaluated by self.criterion. 376 | 377 | """ 378 | 379 | y_hat_batch = self.predict_on_batch(x_batch) 380 | 381 | args = [y_hat_batch, y_batch] 382 | 383 | if weight_map_batch is not None: 384 | args.append(weight_map_batch) 385 | 386 | loss = self.criterion(*args) 387 | 388 | return loss.item() 389 | 390 | def test_on_iterator(self, iterator: Iterator, **kwargs: dict) -> float: 391 | """Test model on iterator which has items to be passed to 392 | test_on_batch. 393 | 394 | Parameters 395 | ---------- 396 | iterator 397 | Iterator that generates items to be passed to test_on_batch. 398 | kwargs 399 | Additional keyword arguments to be passed to test_on_batch. 400 | 401 | Returns 402 | ------- 403 | float 404 | Mean loss for items in iterable. 405 | 406 | """ 407 | loss_sum = 0 408 | for item in iterator: 409 | loss_sum += self.test_on_batch(*item, **kwargs) 410 | return loss_sum / len(iterator) 411 | 412 | def evaluate( 413 | self, 414 | x: torch.Tensor, 415 | y: torch.Tensor, 416 | metric: Optional = None, 417 | piecewise: bool = False, 418 | **kwargs, 419 | ) -> Tuple[float, torch.Tensor]: 420 | """Evaluates model output using a metric function. 421 | 422 | Parameters 423 | ---------- 424 | x 425 | Input data. 426 | y 427 | Target data. 428 | metric 429 | Metric function. If None, uses fnet.metrics.corr_coef. 430 | piecewise 431 | Set to perform predictions piecewise. 432 | **kwargs 433 | Additional kwargs to be passed to predict() method. 434 | 435 | Returns 436 | ------- 437 | float 438 | Evaluation as determined by metric function. 439 | torch.Tensor 440 | Model prediction. 441 | 442 | """ 443 | if metric is None: 444 | metric = corr_coef 445 | if piecewise: 446 | y_hat = self.predict_piecewise(x, **kwargs) 447 | else: 448 | y_hat = self.predict(x, **kwargs) 449 | if y is None: 450 | return None, y_hat 451 | evaluation = metric(y, y_hat) 452 | return evaluation, y_hat 453 | 454 | def apply_on_single_zstack( 455 | self, 456 | input_img: Optional[np.ndarray] = None, 457 | filename: Optional[Union[Path, str]] = None, 458 | inputCh: Optional[int] = None, 459 | normalization: Optional[Callable] = None, 460 | already_normalized: bool = False, 461 | ResizeRatio: Optional[Sequence[float]] = None, 462 | cutoff: Optional[float] = None, 463 | ) -> np.ndarray: 464 | """Applies model to a single z-stack input. 465 | 466 | This assumes the loaded network architecture can receive 3d grayscale 467 | images as input. 468 | 469 | Parameters 470 | ---------- 471 | input_img 472 | 3d or 4d image with shape (Z, Y, X) or (C, Z, Y, X) respectively. 473 | filename 474 | Path to input image. Ignored if input_img is supplied. 475 | inputCh 476 | Selected channel if filename is a path to a 4d image. 477 | normalization 478 | Input image normalization function. 479 | already_normalized 480 | Set to skip input normalization. 481 | ResizeRatio 482 | Resizes each dimension of the the input image by the specified 483 | factor if specified. 484 | cutoff 485 | If specified, converts the output to a binary image with cutoff as 486 | threshold value. 487 | 488 | Returns 489 | ------- 490 | np.ndarray 491 | Predicted image with shape (Z, Y, X). If cutoff is set, dtype will 492 | be numpy.uint8. Otherwise, dtype will be numpy.float. 493 | 494 | Raises 495 | ------ 496 | ValueError 497 | If parameters are invalid. 498 | FileNotFoundError 499 | If specified file does not exist. 500 | IndexError 501 | If inputCh is invalid. 502 | 503 | """ 504 | if input_img is None: 505 | if filename is None: 506 | raise ValueError("input_img or filename must be specified") 507 | input_img = tifffile.imread(str(filename)) 508 | if inputCh is not None: 509 | if input_img.ndim != 4: 510 | raise ValueError("input_img must be 4d if inputCh specified") 511 | input_img = input_img[inputCh,] 512 | if input_img.ndim != 3: 513 | raise ValueError("input_img must be 3d") 514 | normalization = normalization or norm_around_center 515 | if not already_normalized: 516 | input_img = normalization(input_img) 517 | if ResizeRatio is not None: 518 | if len(ResizeRatio) != 3: 519 | raise ValueError("ResizeRatio must be length 3") 520 | input_img = zoom(input_img, zoom=ResizeRatio, mode="nearest") 521 | yhat = ( 522 | self.predict_piecewise(input_img[np.newaxis,], tta=True) 523 | .squeeze(dim=0) 524 | .numpy() 525 | ) 526 | if cutoff is not None: 527 | yhat = (yhat >= cutoff).astype(np.uint8) * 255 528 | return yhat 529 | -------------------------------------------------------------------------------- /fnet/fnetlogger.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | 4 | 5 | class FnetLogger(object): 6 | """Log values in a dict of lists.""" 7 | 8 | def __init__(self, path_csv=None, columns=None): 9 | if path_csv is not None: 10 | df = pd.read_csv(path_csv) 11 | self.columns = list(df.columns) 12 | self.data = df.to_dict(orient="list") 13 | else: 14 | self.columns = columns 15 | self.data = {} 16 | for c in columns: 17 | self.data[c] = [] 18 | 19 | def __repr__(self): 20 | return "FnetLogger({})".format(self.columns) 21 | 22 | def add(self, entry): 23 | if isinstance(entry, dict): 24 | for key, value in entry.items(): 25 | self.data[key].append(value) 26 | else: 27 | assert len(entry) == len(self.columns) 28 | for i, value in enumerate(entry): 29 | self.data[self.columns[i]].append(value) 30 | 31 | def to_csv(self, path_csv): 32 | dirname = os.path.dirname(path_csv) 33 | if not os.path.exists(dirname): 34 | os.makedirs(dirname) 35 | pd.DataFrame(self.data)[self.columns].to_csv(path_csv, index=False) 36 | -------------------------------------------------------------------------------- /fnet/losses.py: -------------------------------------------------------------------------------- 1 | """Loss functions for fnet models.""" 2 | 3 | 4 | from typing import Optional 5 | 6 | import torch 7 | 8 | 9 | class HeteroscedasticLoss(torch.nn.Module): 10 | """Loss function to capture heteroscedastic aleatoric uncertainty.""" 11 | 12 | def forward(self, y_hat_batch: torch.Tensor, y_batch: torch.Tensor): 13 | """Calculates loss. 14 | 15 | Parameters 16 | ---------- 17 | y_hat_batch 18 | Batched, 2-channel model output. 19 | y_batch 20 | Batched, 1-channel target output. 21 | 22 | """ 23 | mean_batch = y_hat_batch[:, 0:1, :, :, :] 24 | log_var_batch = y_hat_batch[:, 1:2, :, :, :] 25 | loss_batch = ( 26 | 0.5 * torch.exp(-log_var_batch) * (mean_batch - y_batch).pow(2) 27 | + 0.5 * log_var_batch 28 | ) 29 | return loss_batch.mean() 30 | 31 | 32 | class WeightedMSE(torch.nn.Module): 33 | """Criterion for weighted mean-squared error.""" 34 | 35 | def forward( 36 | self, 37 | y_hat_batch: torch.Tensor, 38 | y_batch: torch.Tensor, 39 | weight_map_batch: Optional[torch.Tensor] = None, 40 | ): 41 | """Calculates weighted MSE. 42 | 43 | Parameters 44 | ---------- 45 | y_hat_batch 46 | Batched prediction. 47 | y_batch 48 | Batched target. 49 | weight_map_batch 50 | Optional weight map. 51 | 52 | """ 53 | if weight_map_batch is None: 54 | return torch.nn.functional.mse_loss(y_hat_batch, y_batch) 55 | dim = tuple(range(1, len(weight_map_batch.size()))) 56 | return (weight_map_batch * (y_hat_batch - y_batch) ** 2).sum(dim=dim).mean() 57 | -------------------------------------------------------------------------------- /fnet/metrics.py: -------------------------------------------------------------------------------- 1 | """Model evaluation metrics.""" 2 | 3 | 4 | from typing import Union 5 | 6 | import numpy as np 7 | import torch 8 | 9 | 10 | def corr_coef( 11 | a: Union[np.ndarray, torch.Tensor], b: Union[np.ndarray, torch.Tensor] 12 | ) -> float: 13 | """Calculates the Pearson correlation coefficient between the inputs. 14 | 15 | Parameters 16 | ---------- 17 | a 18 | First input. 19 | b 20 | Second input. 21 | 22 | Returns 23 | ------- 24 | float 25 | Pearson correlation coefficient between the inputs. 26 | 27 | """ 28 | if a is None or b is None: 29 | return None 30 | if isinstance(a, torch.Tensor): 31 | a = a.numpy() 32 | if isinstance(b, torch.Tensor): 33 | b = b.numpy() 34 | assert a.shape == b.shape, "Inputs must be same shape" 35 | mean_a = np.mean(a) 36 | mean_b = np.mean(b) 37 | std_a = np.std(a) 38 | std_b = np.std(b) 39 | cc = np.mean((a - mean_a) * (b - mean_b)) / (std_a * std_b) 40 | return cc 41 | 42 | 43 | def corr_coef_chan0( 44 | a: Union[np.ndarray, torch.Tensor], b: Union[np.ndarray, torch.Tensor] 45 | ) -> float: 46 | """Calculates the Pearson correlation coefficient between channel 0 of the 47 | inputs. 48 | 49 | Assumes the first dimension of the inputs is the channel dimension. 50 | 51 | Parameters 52 | ---------- 53 | a 54 | First input. 55 | b 56 | Second input. 57 | 58 | Returns 59 | ------- 60 | float 61 | Pearson correlation coefficient between channel 0 of the inputs. 62 | 63 | """ 64 | if a is None or b is None: 65 | return None 66 | a = a[0:1,] 67 | b = b[0:1,] 68 | return corr_coef(a, b) 69 | -------------------------------------------------------------------------------- /fnet/models.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union 2 | import json 3 | import logging 4 | import os 5 | 6 | import torch 7 | 8 | from fnet.fnet_ensemble import FnetEnsemble 9 | from fnet.fnet_model import Model 10 | from fnet.utils.general_utils import str_to_class 11 | 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def _find_model_checkpoint(path_model_dir: str, checkpoint: str): 17 | """Finds path to a specific model checkpoint. 18 | 19 | Parameters 20 | ---------- 21 | path_model_dir 22 | Path to model as a directory. 23 | checkpoint 24 | String that identifies a model checkpoint 25 | 26 | Returns 27 | ------- 28 | str 29 | Path to saved model file. 30 | 31 | """ 32 | path_cp_dir = os.path.join(path_model_dir, "checkpoints") 33 | if not os.path.exists(path_cp_dir): 34 | raise ValueError(f"Model ({path_cp_dir} has no checkpoints)") 35 | paths_cp = sorted( 36 | [p.path for p in os.scandir(path_cp_dir) if p.path.endswith(".p")] 37 | ) 38 | for path_cp in paths_cp: 39 | if checkpoint in os.path.basename(path_cp): 40 | return path_cp 41 | raise ValueError(f"Model checkpoint not found: {checkpoint}") 42 | 43 | 44 | def load_model( 45 | path_model: str, 46 | no_optim: bool = False, 47 | checkpoint: Optional[str] = None, 48 | path_options: Optional[str] = None, 49 | ) -> Model: 50 | """Loaded saved FnetModel. 51 | 52 | Parameters 53 | ---------- 54 | path_model 55 | Path to model as a directory or .p file. 56 | no_optim 57 | Set to not the model optimizer. 58 | checkpoint 59 | Optional string that identifies a model checkpoint 60 | path_options 61 | Path to training options json. For legacy saved models where the 62 | FnetModel class/kwargs are not not included in the model save file. 63 | 64 | Returns 65 | ------- 66 | Model 67 | Loaded model. 68 | 69 | """ 70 | if not os.path.exists(path_model): 71 | raise ValueError(f"Model path does not exist: {path_model}") 72 | if os.path.isdir(path_model): 73 | if checkpoint is None: 74 | path_model = os.path.join(path_model, "model.p") 75 | if not os.path.exists(path_model): 76 | raise ValueError(f"Default model not found: {path_model}") 77 | if checkpoint is not None: 78 | path_model = _find_model_checkpoint(path_model, checkpoint) 79 | state = torch.load(path_model) 80 | if "fnet_model_class" not in state: 81 | if path_options is not None: 82 | with open(path_options, "r") as fi: 83 | train_options = json.load(fi) 84 | if "fnet_model_class" in train_options: 85 | state["fnet_model_class"] = train_options["fnet_model_class"] 86 | state["fnet_model_kwargs"] = train_options["fnet_model_kwargs"] 87 | fnet_model_class = state.get("fnet_model_class", "fnet.models.Model") 88 | fnet_model_kwargs = state.get("fnet_model_kwargs", {}) 89 | model = str_to_class(fnet_model_class)(**fnet_model_kwargs) 90 | model.load_state(state, no_optim) 91 | return model 92 | 93 | 94 | def load_or_init_model(path_model: str, path_options: str): 95 | """Loaded saved model if it exists otherwise inititialize new model. 96 | 97 | Parameters 98 | ---------- 99 | path_model 100 | Path to saved model. 101 | path_options 102 | Path to json where model training options are saved. 103 | 104 | Returns 105 | ------- 106 | FnetModel 107 | Loaded or new FnetModel instance. 108 | 109 | """ 110 | if not os.path.exists(path_model): 111 | with open(path_options, "r") as fi: 112 | train_options = json.load(fi) 113 | logger.info("Initializing new model!") 114 | fnet_model_class = train_options["fnet_model_class"] 115 | fnet_model_kwargs = train_options["fnet_model_kwargs"] 116 | return str_to_class(fnet_model_class)(**fnet_model_kwargs) 117 | return load_model(path_model, path_options=path_options) 118 | 119 | 120 | def create_ensemble(paths_model: Union[str, List[str]], path_save_dir: str) -> None: 121 | """Create and save an ensemble model. 122 | 123 | Parameters 124 | ---------- 125 | paths_model 126 | Paths to models or model directories. Paths can be specified as items 127 | in list or as a string with paths separated by spaces. Any model 128 | specified as a directory assumed to be at 'directory/model.p'. 129 | path_save_dir 130 | Model save path directory. Model will be saved at in path_save_dir as 131 | 'model.p'. 132 | 133 | """ 134 | if isinstance(paths_model, str): 135 | paths_model = paths_model.split(" ") 136 | paths_member = [] 137 | for path_model in paths_model: 138 | path_model = os.path.abspath(path_model) 139 | if os.path.isdir(path_model): 140 | path_member = os.path.join(path_model, "model.p") 141 | if os.path.exists(path_member): 142 | paths_member.append(path_member) 143 | continue 144 | paths_member.extend( 145 | sorted( 146 | [p.path for p in os.scandir(path_model) if p.path.endswith(".p")] 147 | ) 148 | ) 149 | else: 150 | paths_member.append(path_model) 151 | path_save = os.path.join(path_save_dir, "model.p") 152 | ensemble = FnetEnsemble(paths_model=paths_member) 153 | ensemble.save(path_save) 154 | -------------------------------------------------------------------------------- /fnet/nn_modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenCellModeling/pytorch_fnet/64c53d123df644cebe5e4f7f2ab6efc5c0732f4e/fnet/nn_modules/__init__.py -------------------------------------------------------------------------------- /fnet/nn_modules/dummy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class DummyModel(torch.nn.Module): 5 | def __init__(self, some_param=42): 6 | super().__init__() 7 | self.some_param = some_param 8 | self.network = torch.nn.Conv3d(1, 1, kernel_size=3, padding=1) 9 | 10 | def __call__(self, x): 11 | return self.network(x) 12 | -------------------------------------------------------------------------------- /fnet/nn_modules/fnet_nn_2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Net(torch.nn.Module): 5 | def __init__(self): 6 | super().__init__() 7 | mult_chan = 32 8 | depth = 4 9 | self.net_recurse = _Net_recurse( 10 | n_in_channels=1, mult_chan=mult_chan, depth=depth 11 | ) 12 | self.conv_out = torch.nn.Conv2d(mult_chan, 1, kernel_size=3, padding=1) 13 | 14 | def forward(self, x): 15 | x_rec = self.net_recurse(x) 16 | return self.conv_out(x_rec) 17 | 18 | 19 | class _Net_recurse(torch.nn.Module): 20 | def __init__(self, n_in_channels, mult_chan=2, depth=0): 21 | """Class for recursive definition of U-network.p 22 | 23 | Parameters 24 | ---------- 25 | in_channels 26 | Number of channels for input. 27 | mult_chan 28 | Factor to determine number of output channels 29 | depth 30 | If 0, this subnet will only be convolutions that double the channel 31 | count. 32 | 33 | """ 34 | super().__init__() 35 | self.depth = depth 36 | n_out_channels = n_in_channels * mult_chan 37 | self.sub_2conv_more = SubNet2Conv(n_in_channels, n_out_channels) 38 | 39 | if depth > 0: 40 | self.sub_2conv_less = SubNet2Conv(2 * n_out_channels, n_out_channels) 41 | self.conv_down = torch.nn.Conv2d( 42 | n_out_channels, n_out_channels, 2, stride=2 43 | ) 44 | self.bn0 = torch.nn.BatchNorm2d(n_out_channels) 45 | self.relu0 = torch.nn.ReLU() 46 | self.convt = torch.nn.ConvTranspose2d( 47 | 2 * n_out_channels, n_out_channels, kernel_size=2, stride=2 48 | ) 49 | self.bn1 = torch.nn.BatchNorm2d(n_out_channels) 50 | self.relu1 = torch.nn.ReLU() 51 | self.sub_u = _Net_recurse(n_out_channels, mult_chan=2, depth=(depth - 1)) 52 | 53 | def forward(self, x): 54 | if self.depth == 0: 55 | return self.sub_2conv_more(x) 56 | else: # depth > 0 57 | x_2conv_more = self.sub_2conv_more(x) 58 | x_conv_down = self.conv_down(x_2conv_more) 59 | x_bn0 = self.bn0(x_conv_down) 60 | x_relu0 = self.relu0(x_bn0) 61 | x_sub_u = self.sub_u(x_relu0) 62 | x_convt = self.convt(x_sub_u) 63 | x_bn1 = self.bn1(x_convt) 64 | x_relu1 = self.relu1(x_bn1) 65 | x_cat = torch.cat((x_2conv_more, x_relu1), 1) # concatenate 66 | x_2conv_less = self.sub_2conv_less(x_cat) 67 | return x_2conv_less 68 | 69 | 70 | class SubNet2Conv(torch.nn.Module): 71 | def __init__(self, n_in, n_out): 72 | super().__init__() 73 | self.conv1 = torch.nn.Conv2d(n_in, n_out, kernel_size=3, padding=1) 74 | self.bn1 = torch.nn.BatchNorm2d(n_out) 75 | self.relu1 = torch.nn.ReLU() 76 | self.conv2 = torch.nn.Conv2d(n_out, n_out, kernel_size=3, padding=1) 77 | self.bn2 = torch.nn.BatchNorm2d(n_out) 78 | self.relu2 = torch.nn.ReLU() 79 | 80 | def forward(self, x): 81 | x = self.conv1(x) 82 | x = self.bn1(x) 83 | x = self.relu1(x) 84 | x = self.conv2(x) 85 | x = self.bn2(x) 86 | x = self.relu2(x) 87 | return x 88 | -------------------------------------------------------------------------------- /fnet/nn_modules/fnet_nn_3d.py: -------------------------------------------------------------------------------- 1 | import fnet.nn_modules.fnet_nn_3d_params 2 | 3 | 4 | class Net(fnet.nn_modules.fnet_nn_3d_params.Net): 5 | def __init__(self): 6 | super().__init__(depth=4, mult_chan=32) 7 | -------------------------------------------------------------------------------- /fnet/nn_modules/fnet_nn_3d_params.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Net(torch.nn.Module): 5 | def __init__(self, depth=4, mult_chan=32, in_channels=1, out_channels=1): 6 | super().__init__() 7 | self.depth = depth 8 | self.mult_chan = mult_chan 9 | self.in_channels = in_channels 10 | self.out_channels = out_channels 11 | 12 | self.net_recurse = _Net_recurse( 13 | n_in_channels=self.in_channels, mult_chan=self.mult_chan, depth_parent=self.depth, depth=self.depth 14 | ) 15 | self.conv_out = torch.nn.Conv3d( 16 | self.mult_chan, self.out_channels, kernel_size=3, padding=1 17 | ) 18 | 19 | def forward(self, x): 20 | x_rec = self.net_recurse(x) 21 | return self.conv_out(x_rec) 22 | 23 | 24 | class _Net_recurse(torch.nn.Module): 25 | def __init__(self, n_in_channels, mult_chan=2, depth_parent=0, depth=0): 26 | """Class for recursive definition of U-network.p 27 | 28 | Parameters 29 | ---------- 30 | in_channels 31 | Number of channels for input. 32 | mult_chan 33 | Factor to determine number of output channels 34 | depth 35 | If 0, this subnet will only be convolutions that double the channel 36 | count. 37 | 38 | """ 39 | super().__init__() 40 | 41 | self.depth = depth 42 | 43 | if self.depth == depth_parent: 44 | n_out_channels = mult_chan 45 | else: 46 | n_out_channels = n_in_channels * mult_chan 47 | 48 | self.sub_2conv_more = SubNet2Conv(n_in_channels, n_out_channels) 49 | if depth > 0: 50 | self.sub_2conv_less = SubNet2Conv(2 * n_out_channels, n_out_channels) 51 | self.conv_down = torch.nn.Conv3d( 52 | n_out_channels, n_out_channels, 2, stride=2 53 | ) 54 | self.bn0 = torch.nn.BatchNorm3d(n_out_channels) 55 | self.relu0 = torch.nn.ReLU() 56 | self.convt = torch.nn.ConvTranspose3d( 57 | 2 * n_out_channels, n_out_channels, kernel_size=2, stride=2 58 | ) 59 | self.bn1 = torch.nn.BatchNorm3d(n_out_channels) 60 | self.relu1 = torch.nn.ReLU() 61 | self.sub_u = _Net_recurse(n_out_channels, mult_chan=2, depth_parent=depth_parent, depth=(depth - 1)) 62 | 63 | def forward(self, x): 64 | if self.depth == 0: 65 | return self.sub_2conv_more(x) 66 | else: # depth > 0 67 | x_2conv_more = self.sub_2conv_more(x) 68 | x_conv_down = self.conv_down(x_2conv_more) 69 | x_bn0 = self.bn0(x_conv_down) 70 | x_relu0 = self.relu0(x_bn0) 71 | x_sub_u = self.sub_u(x_relu0) 72 | x_convt = self.convt(x_sub_u) 73 | x_bn1 = self.bn1(x_convt) 74 | x_relu1 = self.relu1(x_bn1) 75 | x_cat = torch.cat((x_2conv_more, x_relu1), 1) # concatenate 76 | x_2conv_less = self.sub_2conv_less(x_cat) 77 | return x_2conv_less 78 | 79 | 80 | class SubNet2Conv(torch.nn.Module): 81 | def __init__(self, n_in, n_out): 82 | super().__init__() 83 | self.conv1 = torch.nn.Conv3d(n_in, n_out, kernel_size=3, padding=1) 84 | self.bn1 = torch.nn.BatchNorm3d(n_out) 85 | self.relu1 = torch.nn.ReLU() 86 | self.conv2 = torch.nn.Conv3d(n_out, n_out, kernel_size=3, padding=1) 87 | self.bn2 = torch.nn.BatchNorm3d(n_out) 88 | self.relu2 = torch.nn.ReLU() 89 | 90 | def forward(self, x): 91 | x = self.conv1(x) 92 | x = self.bn1(x) 93 | x = self.relu1(x) 94 | x = self.conv2(x) 95 | x = self.bn2(x) 96 | x = self.relu2(x) 97 | return x 98 | -------------------------------------------------------------------------------- /fnet/predict_piecewise.py: -------------------------------------------------------------------------------- 1 | from scipy.signal import triang 2 | from typing import Union, List 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def _get_weights(shape): 8 | shape_in = shape 9 | shape = shape[1:] 10 | weights = 1 11 | for idx_d in range(len(shape)): 12 | slicey = [np.newaxis] * len(shape) 13 | slicey[idx_d] = slice(None) 14 | size = shape[idx_d] 15 | weights = weights * triang(size)[tuple(slicey)] 16 | return np.broadcast_to(weights, shape_in).astype(np.float32) 17 | 18 | 19 | def _predict_piecewise_recurse( 20 | predictor, 21 | ar_in: np.ndarray, 22 | dims_max: Union[int, List[int]], 23 | overlaps: Union[int, List[int]], 24 | **predict_kwargs, 25 | ): 26 | """Performs piecewise prediction recursively.""" 27 | if tuple(ar_in.shape[1:]) == tuple(dims_max[1:]): 28 | ar_out = predictor.predict(ar_in, **predict_kwargs).numpy().astype(np.float32) 29 | ar_weight = _get_weights(ar_out.shape) 30 | return ar_out * ar_weight, ar_weight 31 | dim = None 32 | # Find first dim where input > max 33 | for idx_d in range(1, ar_in.ndim): 34 | if ar_in.shape[idx_d] > dims_max[idx_d]: 35 | dim = idx_d 36 | break 37 | # Size of channel dim is unknown until after first prediction 38 | shape_out = [None] + list(ar_in.shape[1:]) 39 | ar_out = None 40 | ar_weight = None 41 | offset = 0 42 | done = False 43 | while not done: 44 | slices = [slice(None)] * ar_in.ndim 45 | end = offset + dims_max[dim] 46 | slices[dim] = slice(offset, end) 47 | slices = tuple(slices) 48 | ar_in_sub = ar_in[slices] 49 | pred_sub, pred_weight_sub = _predict_piecewise_recurse( 50 | predictor, ar_in_sub, dims_max, overlaps, **predict_kwargs 51 | ) 52 | if ar_out is None or ar_weight is None: 53 | shape_out[0] = pred_sub.shape[0] # Set channel dim for output 54 | ar_out = np.zeros(shape_out, dtype=pred_sub.dtype) 55 | ar_weight = np.zeros(shape_out, dtype=pred_weight_sub.dtype) 56 | ar_out[slices] += pred_sub 57 | ar_weight[slices] += pred_weight_sub 58 | offset += dims_max[dim] - overlaps[dim] 59 | if end == ar_in.shape[dim]: 60 | done = True 61 | elif offset + dims_max[dim] > ar_in.shape[dim]: 62 | offset = ar_in.shape[dim] - dims_max[dim] 63 | return ar_out, ar_weight 64 | 65 | 66 | def predict_piecewise( 67 | predictor, 68 | tensor_in: torch.Tensor, 69 | dims_max: Union[int, List[int]] = 64, 70 | overlaps: Union[int, List[int]] = 0, 71 | **predict_kwargs, 72 | ) -> torch.Tensor: 73 | """Performs piecewise prediction and combines results. 74 | 75 | Parameters 76 | ---------- 77 | predictor 78 | An object with a predict() method. 79 | tensor_in 80 | Tensor to be input into predictor piecewise. Should be 3d or 4d with 81 | with the first dimension channel. 82 | dims_max 83 | Specifies dimensions of each sub prediction. 84 | overlaps 85 | Specifies overlap along each dimension for sub predictions. 86 | **predict_kwargs 87 | Kwargs to pass to predict method. 88 | 89 | Returns 90 | ------- 91 | torch.Tensor 92 | Prediction with size tensor_in.size(). 93 | 94 | """ 95 | assert isinstance(tensor_in, torch.Tensor) 96 | assert len(tensor_in.size()) > 2 97 | shape_in = tuple(tensor_in.size()) 98 | n_dim = len(shape_in) 99 | if isinstance(dims_max, int): 100 | dims_max = [dims_max] * n_dim 101 | for idx_d in range(1, n_dim): 102 | if dims_max[idx_d] > shape_in[idx_d]: 103 | dims_max[idx_d] = shape_in[idx_d] 104 | if isinstance(overlaps, int): 105 | overlaps = [overlaps] * n_dim 106 | assert len(dims_max) == len(overlaps) == n_dim 107 | # Remove restrictions on channel dimension. 108 | dims_max[0] = None 109 | overlaps[0] = None 110 | ar_in = tensor_in.numpy() 111 | ar_out, ar_weight = _predict_piecewise_recurse( 112 | predictor, ar_in, dims_max=dims_max, overlaps=overlaps, **predict_kwargs 113 | ) 114 | # tifffile.imsave('debug/ar_sum.tif', ar_out) 115 | mask = ar_weight > 0.0 116 | ar_out[mask] = ar_out[mask] / ar_weight[mask] 117 | # tifffile.imsave('debug/ar_weight.tif', ar_weight) 118 | # tifffile.imsave('debug/ar_out.tif', ar_out) 119 | return torch.tensor(ar_out) 120 | -------------------------------------------------------------------------------- /fnet/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenCellModeling/pytorch_fnet/64c53d123df644cebe5e4f7f2ab6efc5c0732f4e/fnet/tests/__init__.py -------------------------------------------------------------------------------- /fnet/tests/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenCellModeling/pytorch_fnet/64c53d123df644cebe5e4f7f2ab6efc5c0732f4e/fnet/tests/data/__init__.py -------------------------------------------------------------------------------- /fnet/tests/data/dummymodule.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import tifffile 6 | import torch 7 | 8 | from fnet.data.tiffdataset import TiffDataset 9 | from fnet.utils.general_utils import add_augmentations 10 | 11 | 12 | def dummy_fnet_dataset(train: bool = False) -> TiffDataset: 13 | """Returns a dummy Fnetdataset.""" 14 | df = pd.DataFrame( 15 | { 16 | "path_signal": [os.path.join("data", "EM_low.tif")], 17 | "path_target": [os.path.join("data", "MBP_low.tif")], 18 | } 19 | ).rename_axis("arbitrary") 20 | if not train: 21 | df = add_augmentations(df) 22 | return TiffDataset(dataframe=df) 23 | 24 | 25 | class _CustomDataset: 26 | """Custom, non-FnetDataset.""" 27 | 28 | def __init__(self, df: pd.DataFrame): 29 | self._df = df 30 | 31 | def __len__(self): 32 | return len(self._df) 33 | 34 | def __getitem__(self, idx): 35 | loc = self._df.index[idx] 36 | sig = torch.from_numpy( 37 | tifffile.imread(self._df.loc[loc, "path_signal"])[np.newaxis,] 38 | ) 39 | tar = torch.from_numpy( 40 | tifffile.imread(self._df.loc[loc, "path_target"])[np.newaxis,] 41 | ) 42 | return (sig, tar) 43 | 44 | 45 | def dummy_custom_dataset(train: bool = False) -> TiffDataset: 46 | """Returns a dummy custom dataset.""" 47 | df = pd.DataFrame( 48 | { 49 | "path_signal": [os.path.join("data", "EM_low.tif")], 50 | "path_target": [os.path.join("data", "MBP_low.tif")], 51 | } 52 | ) 53 | if not train: 54 | df = add_augmentations(df) 55 | return _CustomDataset(df) 56 | -------------------------------------------------------------------------------- /fnet/tests/data/nn_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Net(torch.nn.Module): 5 | def __init__(self, test_param=42): 6 | super().__init__() 7 | self.test_param = test_param 8 | self.conv = torch.nn.Conv3d(1, 1, 3, padding=1) 9 | 10 | def forward(self, x): 11 | return self.conv(x) 12 | -------------------------------------------------------------------------------- /fnet/tests/data/testlib.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Sequence 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | import tifffile 8 | from aicsimageio.writers import OmeTiffWriter 9 | 10 | 11 | def create_data_dir(path_root: Path): 12 | path_data = path_root / "data" 13 | path_data.mkdir(exist_ok=True) 14 | 15 | return path_data 16 | 17 | 18 | def create_tif_data( 19 | path_root: Path, shape: Sequence[int], n_items: int, weights: bool 20 | ) -> Path: 21 | path_data = create_data_dir(path_root) 22 | 23 | records = [] 24 | 25 | for idx in range(n_items): 26 | path_x = path_data / f"{idx:02}_x.tif" 27 | path_y = path_data / f"{idx:02}_y.tif" 28 | data_x = np.random.randint(128, size=shape, dtype=np.uint8) 29 | data_y = data_x + idx 30 | 31 | tifffile.imsave(path_x, data_x, compress=2) 32 | tifffile.imsave(path_y, data_y, compress=2) 33 | 34 | records.append({"dummy_id": idx, "path_signal": path_x, "path_target": path_y}) 35 | if weights: 36 | # Create map that covers half of each dim 37 | data_weight_map = np.zeros(shape, dtype=np.float32) 38 | slicey = [slice(shape[d] // 2) for d in range(len(shape))] 39 | slicey = tuple(slicey) 40 | data_weight_map[slicey] = 1 41 | path_weight_map = path_data / f"{idx:02}_weight_map.tif" 42 | tifffile.imsave(path_weight_map, data_weight_map, compress=2) 43 | records[-1]["path_weight_map"] = path_weight_map 44 | 45 | path_csv = path_root / "dummy.csv" 46 | pd.DataFrame(records).set_index("dummy_id").to_csv(path_csv) 47 | return path_csv 48 | 49 | 50 | def create_multichtiff_data( 51 | path_root: Path, dims_zyx: Sequence[int], n_ch_in: int, n_ch_out: int, n_items: int 52 | ) -> Path: 53 | 54 | assert len(dims_zyx) == 3 55 | 56 | path_data = create_data_dir(path_root) 57 | 58 | records = [] 59 | 60 | for idx in range(n_items): 61 | path_x = path_data / f"{idx:02}.tif" 62 | data_x = np.random.randint( 63 | 128, size=[n_ch_in + n_ch_out] + list(dims_zyx), dtype=np.uint8 64 | ) 65 | 66 | with OmeTiffWriter(path_x) as writer: 67 | writer.save(data_x, dimension_order="CZYX") # should be a numpy array 68 | 69 | records.append( 70 | { 71 | "dummy_id": idx, 72 | "path_tiff": path_x, 73 | "channel_signal": list(np.arange(0, n_ch_in)), 74 | "channel_target": list(np.arange(0, n_ch_out) + n_ch_in), 75 | } 76 | ) 77 | 78 | path_csv = path_root / "dummy.csv" 79 | pd.DataFrame(records).set_index("dummy_id").to_csv(path_csv) 80 | 81 | return path_csv 82 | -------------------------------------------------------------------------------- /fnet/tests/data/train_options_custom.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size": 2, 3 | "bpds_kwargs": { 4 | "buffer_size": 1, 5 | "buffer_switch_interval": -1, 6 | "patch_shape": [ 7 | 16, 8 | 16, 9 | 16 10 | ] 11 | }, 12 | "dataset_train": "fnet.data.DummyCustomFnetDataset", 13 | "dataset_train_kwargs": { 14 | "train": true 15 | }, 16 | "dataset_val": null, 17 | "dataset_val_kwargs": null, 18 | "fnet_model_class": "fnet.fnet_model.Model", 19 | "fnet_model_kwargs": { 20 | "betas": [ 21 | 0.9, 22 | 0.999 23 | ], 24 | "criterion_class": "torch.nn.MSELoss", 25 | "init_weights": false, 26 | "lr": 0.001, 27 | "nn_class": "fnet.nn_modules.fnet_nn_3d_params.Net", 28 | "nn_kwargs": { 29 | "depth": 0, 30 | "mult_chan": 2 31 | }, 32 | "scheduler": null 33 | }, 34 | "interval_checkpoint": 64, 35 | "interval_save": 64, 36 | "iter_checkpoint": [], 37 | "n_iter": 16, 38 | "path_save_dir": "test_model_custom", 39 | "seed": null 40 | } -------------------------------------------------------------------------------- /fnet/tests/data/train_options_test.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size": 2, 3 | "bpds_kwargs": { 4 | "buffer_size": 1, 5 | "buffer_switch_interval": -1, 6 | "patch_shape": [ 7 | 16, 8 | 16, 9 | 16 10 | ] 11 | }, 12 | "dataset_train": "fnet.data.DummyFnetDataset", 13 | "dataset_train_kwargs": { 14 | "train": true 15 | }, 16 | "dataset_val": "fnet.data.DummyFnetDataset", 17 | "dataset_val_kwargs": { 18 | "train": false 19 | }, 20 | "fnet_model_class": "fnet.fnet_model.Model", 21 | "fnet_model_kwargs": { 22 | "betas": [ 23 | 0.9, 24 | 0.999 25 | ], 26 | "criterion_class": "torch.nn.MSELoss", 27 | "init_weights": false, 28 | "lr": 0.001, 29 | "nn_class": "fnet.nn_modules.fnet_nn_3d_params.Net", 30 | "nn_kwargs": { 31 | "depth": 0, 32 | "mult_chan": 2 33 | }, 34 | "scheduler": null 35 | }, 36 | "interval_checkpoint": 10, 37 | "interval_save": 8, 38 | "iter_checkpoint": [], 39 | "n_iter": 16, 40 | "path_save_dir": "test_model", 41 | "seed": null 42 | } -------------------------------------------------------------------------------- /fnet/tests/test_bufferedpatchdataset.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | from typing import Tuple 3 | 4 | import numpy as np 5 | import numpy.testing as npt 6 | import pytest 7 | import torch 8 | 9 | from fnet.data import BufferedPatchDataset 10 | 11 | 12 | class _DummyDataset: 13 | """Dummy dataset class. 14 | 15 | Parameters 16 | ---------- 17 | nd 18 | Number of dimensions of dataset items. 19 | weights 20 | Set to include weight map. 21 | 22 | """ 23 | 24 | def __init__(self, nd: int = 1, weights: bool = False): 25 | self.data = [] 26 | shape = (8,) * nd 27 | for idx in range(8): 28 | x = np.arange(idx, idx + 8 ** nd).reshape(shape) 29 | y = x ** 2 30 | datum = [x, y] 31 | if weights: 32 | datum.append(-x) 33 | self.data.append(tuple(datum)) 34 | self.accessed = [] 35 | 36 | def __len__(self) -> int: 37 | return len(self.data) 38 | 39 | def __getitem__(self, index: int) -> Tuple[int, int]: 40 | self.accessed.append(index) 41 | return self.data[index] 42 | 43 | 44 | def test_bad_input(): 45 | ds = _DummyDataset() 46 | 47 | # Too many patch_shape dimensions 48 | with pytest.raises(ValueError): 49 | BufferedPatchDataset(ds) 50 | 51 | # patch_shape too big 52 | with pytest.raises(ValueError): 53 | BufferedPatchDataset(ds, patch_shape=(9,)) 54 | 55 | # Inconsistant spatial shape 56 | bad = [part for part in ds.data[0]] 57 | bad[0] = bad[0][1:] 58 | ds.data[0] = tuple(bad) 59 | with pytest.raises(ValueError): 60 | BufferedPatchDataset(ds, patch_shape=(4,), shuffle_images=False) 61 | 62 | 63 | @pytest.mark.parametrize("nd", [2, 3]) 64 | def test_nd(nd: int): 65 | """Checks shape of returned item and checks that all dataset elements were 66 | accessed. 67 | 68 | Parameters 69 | ---------- 70 | nd 71 | Number of spatial dimensions for dataset elements. 72 | 73 | """ 74 | ds = _DummyDataset(nd) 75 | patch_shape = tuple(range(2, nd + 2)) 76 | interval = 3 77 | buffer_size = 4 78 | bpds = BufferedPatchDataset( 79 | ds, 80 | patch_shape=patch_shape, 81 | buffer_size=buffer_size, 82 | buffer_switch_interval=interval, 83 | ) 84 | # Sample enough patches such that the entire dataset is used twice 85 | n_swaps = 2 * len(ds) - buffer_size 86 | for _idx in range(n_swaps * interval): 87 | x, y = next(bpds) 88 | assert x.shape == patch_shape 89 | npt.assert_array_equal(y, x ** 2) 90 | assert bpds.get_buffer_history() == ds.accessed 91 | counts = Counter(ds.accessed) 92 | assert max(counts.values()) == min(counts.values()) == 2 93 | 94 | 95 | def test_sampling(): 96 | """Verifies that samples are pulled from entire range of dataset items.""" 97 | ds = _DummyDataset(nd=3) 98 | x_low, x_hi = float("inf"), float("-inf") 99 | y_low, y_hi = float("inf"), float("-inf") 100 | bpds = BufferedPatchDataset( 101 | ds, 102 | patch_shape=(7, 7, 7), 103 | buffer_size=1, 104 | buffer_switch_interval=-1, 105 | shuffle_images=False, 106 | ) 107 | # Patch locations are randomized, so look at many patches and check that 108 | # the ends of the dataset item are sampled at least once. 109 | for _idx in range(128): 110 | x, y = next(bpds) 111 | x_low = min(x_low, x.min()) 112 | x_hi = max(x_hi, x.max()) 113 | y_low = min(y_low, y.min()) 114 | y_hi = max(y_hi, y.max()) 115 | assert x_low == y_low == 0 116 | assert x_hi == 511 117 | assert x_hi ** 2 == y_hi 118 | 119 | 120 | def test_smaller_patch(): 121 | """Verifies that patches smaller than the dataset item are pulled from the 122 | last dataset item dimensions. 123 | 124 | """ 125 | nd = 4 126 | ds = _DummyDataset(nd=nd) 127 | patch_shape = (4,) * (nd - 1) 128 | bpds = BufferedPatchDataset( 129 | ds, 130 | patch_shape=patch_shape, 131 | buffer_size=1, 132 | buffer_switch_interval=-1, 133 | shuffle_images=False, 134 | ) 135 | x, y = next(bpds) 136 | assert x.shape == y.shape == ((8,) + patch_shape) 137 | 138 | 139 | def test_weights(): 140 | """Tests bpds when underlying dataset includes target weight maps.""" 141 | ds = _DummyDataset(nd=3, weights=True) 142 | patch_shape = (7, 5) 143 | batch_size = 4 144 | bpds = BufferedPatchDataset( 145 | ds, 146 | patch_shape=patch_shape, 147 | buffer_size=len(ds), 148 | buffer_switch_interval=-1, 149 | shuffle_images=False, 150 | ) 151 | shape_exp = (batch_size, 8) + patch_shape 152 | for _ in range(4): 153 | batch = bpds.get_batch(batch_size=batch_size) 154 | assert len(batch) == 3 155 | for part in batch: 156 | assert isinstance(part, torch.Tensor) 157 | assert tuple(part.shape) == shape_exp 158 | npt.assert_array_almost_equal(batch[-1], -batch[0]) 159 | -------------------------------------------------------------------------------- /fnet/tests/test_cli.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import json 3 | import os 4 | import shutil 5 | import subprocess 6 | import tempfile 7 | 8 | import pytest 9 | 10 | from .data.testlib import create_tif_data 11 | 12 | 13 | def _update_json(path_json: Path, **kwargs): 14 | def helper(some_dict: dict, updates: dict): 15 | """Recursively updates a dictionary with another.""" 16 | for key, val in updates.items(): 17 | if not isinstance(val, dict): 18 | some_dict[key] = val 19 | else: 20 | helper(some_dict[key], val) 21 | 22 | with path_json.open("r") as fi: 23 | options = json.load(fi) 24 | helper(options, kwargs) 25 | with path_json.open("w") as fo: 26 | json.dump(options, fo) 27 | 28 | 29 | @pytest.fixture(scope="module") 30 | def project_dir(): 31 | """Creates a mock user directory in which fnet commands would be used. 32 | 33 | Copies over example tifs to be used as test data and a dummy module 34 | containing dataset definitions. 35 | 36 | """ 37 | path_pre = Path.cwd() 38 | path_tmp = Path(tempfile.mkdtemp()) 39 | path_test_dir = Path(__file__).parent 40 | path_data_dir = path_test_dir.parent.parent / "data" 41 | Path.mkdir(path_tmp / "data") 42 | for tif in ["EM_low.tif", "MBP_low.tif"]: 43 | shutil.copy(path_data_dir / tif, path_tmp / "data") 44 | shutil.copy(path_test_dir / "data" / "dummymodule.py", path_tmp) 45 | os.chdir(path_tmp) 46 | yield path_tmp 47 | os.chdir(path_pre) 48 | 49 | 50 | @pytest.mark.usefixtures("project_dir") 51 | def test_init(): 52 | subprocess.run(["fnet init"], shell=True, check=True) 53 | path_json = os.path.join("train_options_templates", "default.json") 54 | path_script_train = os.path.join("scripts", "train_model.py") 55 | path_script_predict = os.path.join("scripts", "predict.py") 56 | assert os.path.exists(path_json) 57 | assert os.path.exists(path_script_train) 58 | assert os.path.exists(path_script_predict) 59 | 60 | 61 | @pytest.mark.usefixtures("project_dir") 62 | def test_train_model_create(): 63 | """Verify that 'fnet train' creates default jsons.""" 64 | path_create = os.path.join("created", "train_options.json") 65 | subprocess.run(["fnet", "train", "--json", path_create], check=True) 66 | assert os.path.exists(path_create) 67 | 68 | 69 | @pytest.mark.usefixtures("project_dir") 70 | def test_train_model_pred(): 71 | """Verify 'fnet train', 'fnet predict' functionality on an FnetDataset.""" 72 | path_test_json = Path(__file__).parent / "data" / "train_options_test.json" 73 | 74 | subprocess.run( 75 | ["fnet", "train", "--json", path_test_json, "--gpu_ids", "-1"], check=True 76 | ) 77 | assert os.path.exists("test_model") 78 | subprocess.run( 79 | [ 80 | "fnet", 81 | "predict", 82 | "--path_model_dir", 83 | "test_model", 84 | "--dataset", 85 | "dummymodule.dummy_fnet_dataset", 86 | "--idx_sel", 87 | "0", 88 | "3", 89 | "--gpu_ids", 90 | "-1", 91 | ], 92 | check=True, 93 | ) 94 | for fname in ["tifs", "predictions.csv", "predict_options.json"]: 95 | assert os.path.exists(os.path.join("predictions", fname)) 96 | 97 | 98 | @pytest.mark.usefixtures("project_dir") 99 | def test_train_model_pred_custom(): 100 | """Verify 'fnet train', 'fnet predict' functionality on a custom dataset. 101 | 102 | """ 103 | path_test_json = Path(__file__).parent / "data" / "train_options_custom.json" 104 | subprocess.run( 105 | ["fnet", "train", "--json", str(path_test_json), "--gpu_ids", "-1"], check=True 106 | ) 107 | assert os.path.exists("test_model_custom") 108 | subprocess.run( 109 | [ 110 | "fnet", 111 | "predict", 112 | "--path_model_dir", 113 | "test_model_custom", 114 | "--dataset", 115 | "dummymodule.dummy_custom_dataset", 116 | "--idx_sel", 117 | "2", 118 | "--gpu_ids", 119 | "-1", 120 | ], 121 | check=True, 122 | ) 123 | for fname in ["tifs", "predictions.csv", "predict_options.json"]: 124 | assert os.path.exists(os.path.join("predictions", fname)) 125 | 126 | 127 | def train_pred_with_weights(tmp_path): 128 | shape = (8, 16, 32) 129 | n_items = 8 130 | path_ds = create_tif_data(tmp_path, shape=shape, n_items=n_items, weights=True) 131 | path_train_json = tmp_path / "model" / "train_options.json" 132 | subprocess.run( 133 | ["fnet", "train", str(path_train_json), "--gpu_ids", "-1"], check=True 134 | ) 135 | _update_json( 136 | path_train_json, 137 | dataset_train="fnet.data.TiffDataset", 138 | dataset_train_kwargs={"path_csv": str(path_ds)}, 139 | dataset_val="fnet.data.TiffDataset", 140 | dataset_val_kwargs={"path_csv": str(path_ds)}, 141 | bpds_kwargs={"patch_shape": [4, 8, 16]}, 142 | n_iter=16, 143 | interval_save=8, 144 | fnet_model_kwargs={"nn_class": "tests.data.nn_test.Net"}, 145 | ) 146 | subprocess.run( 147 | ["fnet", "train", str(path_train_json), "--gpu_ids", "-1"], check=True 148 | ) 149 | -------------------------------------------------------------------------------- /fnet/tests/test_fnet_model.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import numpy as np 4 | import numpy.testing as npt 5 | import pytest 6 | import tifffile 7 | import torch 8 | 9 | import fnet 10 | from fnet.fnet_model import Model 11 | 12 | SOME_PARAM_TEST_VAL = 123 13 | 14 | 15 | def get_data(device: Union[int, torch.device]) -> tuple: 16 | if isinstance(device, int): 17 | device = torch.device("cuda", device) if device >= 0 else torch.device("cpu") 18 | x = torch.rand(1, 1, 8, 16, 16, device=device) 19 | y = x * 2 + 1 20 | return x, y 21 | 22 | 23 | def train_new(path_model): 24 | gpu_id = 1 if torch.cuda.is_available() else -1 25 | x, y = get_data(gpu_id) 26 | model = Model( 27 | nn_class="fnet.nn_modules.dummy.DummyModel", 28 | nn_kwargs={"some_param": SOME_PARAM_TEST_VAL}, 29 | ) 30 | model.to_gpu(gpu_id) 31 | for idx in range(4): 32 | _ = model.train_on_batch(x, y) 33 | model.save(path_model) 34 | 35 | 36 | def train_more(path_model): 37 | gpu_id = 0 if torch.cuda.is_available() else -1 38 | x, y = get_data(gpu_id) 39 | model = fnet.models.load_model(path_model) 40 | for idx in range(2): 41 | _ = model.train_on_batch(x, y) 42 | assert model.count_iter == 6 43 | assert model.net.some_param == SOME_PARAM_TEST_VAL 44 | 45 | 46 | def test_resume(tmpdir): 47 | path_model = tmpdir.mkdir("test_model").join("model.p").strpath 48 | train_new(path_model) 49 | train_more(path_model) 50 | 51 | 52 | def test_apply_on_single_zstack(tmp_path): 53 | """Tests the apply_on_single_zstack() method in fnet_model.Model.""" 54 | model = Model(nn_class="fnet.nn_modules.dummy.DummyModel") 55 | 56 | # Test bad inputs 57 | ar_in = np.random.random(size=(3, 32, 64, 128)) 58 | with pytest.raises(ValueError): 59 | model.apply_on_single_zstack() 60 | with pytest.raises(ValueError): 61 | model.apply_on_single_zstack(ar_in) 62 | with pytest.raises(ValueError): 63 | model.apply_on_single_zstack(ar_in[0, 1]) # 2d input 64 | 65 | # Test numpy input and file path input 66 | yhat_ch1 = model.apply_on_single_zstack(ar_in, inputCh=1) 67 | ar_in = ar_in[1,] 68 | path_input_save = tmp_path / "input_save.tiff" 69 | tifffile.imsave(str(path_input_save), ar_in, compress=2) 70 | yhat = model.apply_on_single_zstack(ar_in) 71 | yhat_file = model.apply_on_single_zstack(filename=path_input_save) 72 | assert np.issubdtype(yhat.dtype, np.floating) 73 | assert yhat.shape == ar_in.shape 74 | assert np.array_equal(yhat, yhat_ch1) 75 | assert np.array_equal(yhat, yhat_file) 76 | 77 | # Test resized 78 | factors = (1, 0.5, 0.3) 79 | shape_exp = tuple(round(ar_in.shape[i] * factors[i]) for i in range(3)) 80 | yhat_resized = model.apply_on_single_zstack(ar_in, ResizeRatio=factors) 81 | assert yhat_resized.shape == shape_exp 82 | 83 | # Test cutoff 84 | cutoff = 0.1 85 | yhat_exp = (yhat >= cutoff).astype(np.uint8) * 255 86 | yhat_cutoff = model.apply_on_single_zstack(ar_in, cutoff=cutoff) 87 | assert np.issubdtype(yhat_cutoff.dtype, np.unsignedinteger) 88 | assert np.array_equal(yhat_cutoff, yhat_exp) 89 | 90 | 91 | def test_train_on_batch(): 92 | model = Model(nn_class="fnet.tests.data.nn_test.Net", lr=0.01) 93 | shape_item = (1, 2, 4, 8) 94 | batch_size = 9 95 | shape_batch = (batch_size,) + shape_item 96 | x_batch = torch.rand(shape_batch) 97 | y_batch = x_batch * 0.666 + 0.42 98 | cost_prev = float("inf") 99 | for _ in range(8): 100 | cost = model.train_on_batch(x_batch, y_batch) 101 | assert cost < cost_prev 102 | cost_prev = cost 103 | 104 | # Test target weight maps 105 | model = Model(nn_class="fnet.tests.data.nn_test.Net", lr=0.0) # disable learning 106 | cost_norm = model.train_on_batch(x_batch, y_batch) 107 | # Test uniform weight map 108 | weight_map_batch = (torch.ones(shape_item) / np.prod(shape_item)).expand( 109 | shape_batch 110 | ) 111 | cost_weighted = model.train_on_batch(x_batch, y_batch, weight_map_batch) 112 | npt.assert_approx_equal(cost_weighted, cost_norm, significant=6) 113 | # Test all-zero weight map 114 | cost_weighted = model.train_on_batch(x_batch, y_batch, torch.zeros(x_batch.size())) 115 | npt.assert_approx_equal(cost_weighted, 0.0) 116 | # Random weights with first and last examples having zero weight 117 | weight_map_batch = torch.rand(shape_batch) 118 | weight_map_batch[[0, -1]] = 0.0 119 | cost_weighted = model.train_on_batch(x_batch, y_batch, weight_map_batch) 120 | cost_exp = ( 121 | model.train_on_batch(x_batch[1:-1], y_batch[1:-1], weight_map_batch[1:-1]) 122 | * (batch_size - 2) 123 | / batch_size # account for change in batch size 124 | ) 125 | npt.assert_approx_equal(cost_weighted, cost_exp) 126 | 127 | 128 | def test_test_on_batch(): 129 | model = Model(nn_class="fnet.tests.data.nn_test.Net", lr=0.01) 130 | shape_item = (1, 2, 4, 8) 131 | batch_size = 1 132 | shape_batch = (batch_size,) + shape_item 133 | x_batch = torch.rand(shape_batch) 134 | y_batch = x_batch * 0.666 + 0.42 135 | 136 | # Model weights should remain the same so loss should not change 137 | loss_0 = model.test_on_batch(x_batch, y_batch) 138 | loss_1 = model.test_on_batch(x_batch, y_batch) 139 | npt.assert_approx_equal(loss_1 - loss_0, 0.0) 140 | 141 | # Loss should remain the same with uniform weight map 142 | loss_weight_uniform = model.test_on_batch( 143 | x_batch, y_batch, torch.ones(shape_batch) / np.prod(shape_item) 144 | ) 145 | npt.assert_almost_equal(loss_weight_uniform - loss_0, 0.0) 146 | 147 | # Loss should be zero with all-zero weight map 148 | loss_weight_zero = model.test_on_batch(x_batch, y_batch, torch.zeros(shape_batch)) 149 | npt.assert_almost_equal(loss_weight_zero, 0.0) 150 | -------------------------------------------------------------------------------- /fnet/tests/test_multichtiffdataset.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | from fnet.data import multichtiffdataset 7 | from .data.testlib import create_multichtiff_data 8 | 9 | 10 | @pytest.mark.parametrize( 11 | "n_ch_in, n_ch_out, dims_zyx", 12 | [(1, 1, (64, 128, 32)), (3, 1, (12, 13, 14)), (5, 5, (12, 13, 14))], 13 | ) 14 | def test_MultiTiffDataset(tmp_path, n_ch_in, n_ch_out, dims_zyx): 15 | """Tests TiffDataset class.""" 16 | n_items = 5 17 | path_dummy = create_multichtiff_data( 18 | tmp_path, n_ch_in=n_ch_in, n_ch_out=n_ch_out, dims_zyx=dims_zyx, n_items=n_items 19 | ) 20 | ds = multichtiffdataset.MultiChTiffDataset(path_csv=path_dummy) 21 | 22 | assert len(ds) == n_items 23 | idx = n_items // 2 24 | info = ds.get_information(n_items // 2) 25 | assert isinstance(info, dict) 26 | assert all(col in info for col in ds.df.columns) 27 | data = ds[idx] 28 | len_data = 2 29 | assert len(data) == len_data 30 | 31 | assert tuple(data[0].shape) == (n_ch_in,) + dims_zyx 32 | assert tuple(data[1].shape) == (n_ch_out,) + dims_zyx 33 | -------------------------------------------------------------------------------- /fnet/tests/test_predict_piecewise.py: -------------------------------------------------------------------------------- 1 | from fnet.predict_piecewise import predict_piecewise 2 | import numpy as np 3 | import numpy.testing as npt 4 | import torch 5 | 6 | 7 | class FakePredictor: 8 | def predict(self, x, tta=False): 9 | y_hat = x.copy() + 0.42 10 | return torch.tensor(y_hat) 11 | 12 | 13 | def test_predict_piecewise(): 14 | # Create pretty gradient image as test input 15 | shape = (1, 32, 512, 256) 16 | ar_in = 1 17 | for idx in range(1, len(shape)): 18 | slices = [None] * len(shape) 19 | slices[idx] = slice(None) 20 | ar_in = ar_in * np.linspace(0, 1, num=shape[idx], endpoint=False)[tuple(slices)] 21 | ar_in = torch.tensor(ar_in.astype(np.float32)) 22 | predictor = FakePredictor() 23 | ar_out = predict_piecewise( 24 | predictor, ar_in, dims_max=[None, 32, 128, 64], overlaps=16 25 | ) 26 | got = ar_out.numpy() 27 | expected = ar_in.numpy() + 0.42 28 | npt.assert_almost_equal(got, expected) 29 | 30 | 31 | if __name__ == "__main__": 32 | test_predict_piecewise() 33 | -------------------------------------------------------------------------------- /fnet/tests/test_tiffdataset.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | from fnet.data import tiffdataset 7 | from .data.testlib import create_tif_data 8 | 9 | 10 | @pytest.mark.parametrize( 11 | "shape,weights", [((16, 32), False), ((8, 16, 32), False), ((8, 16, 32), True)] 12 | ) 13 | def test_TiffDataset(tmp_path, shape: Sequence[int], weights: bool): 14 | """Tests TiffDataset class.""" 15 | n_items = 5 16 | path_dummy = create_tif_data( 17 | tmp_path, shape=shape, n_items=n_items, weights=weights 18 | ) 19 | ds = tiffdataset.TiffDataset(path_csv=path_dummy, col_index="dummy_id") 20 | assert len(ds) == n_items 21 | idx = n_items // 2 22 | info = ds.get_information(n_items // 2) 23 | assert isinstance(info, dict) 24 | assert all(col in info for col in ds.df.columns) 25 | data = ds[idx] 26 | len_data = 3 if weights else 2 27 | assert len(data) == len_data 28 | shape_exp = (1,) + shape 29 | for d in data: 30 | assert tuple(d.shape) == shape_exp 31 | 32 | factor = int((data[1] - data[0]).numpy().mean()) 33 | assert factor == idx 34 | 35 | if weights: 36 | weight_sum_exp = np.prod([d // 2 for d in shape]) 37 | weight_sum_got = int(data[-1].numpy().sum()) 38 | assert weight_sum_got == weight_sum_exp 39 | -------------------------------------------------------------------------------- /fnet/tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from fnet.utils.general_utils import str_to_object 4 | 5 | 6 | def _dummy(): 7 | print("Hi") 8 | 9 | 10 | def test_str_to_object(): 11 | """Test string-to-object conversion.""" 12 | exp = [_dummy, random.randrange] 13 | for idx_s, as_str in enumerate(["_dummy", "random.randrange"]): 14 | obj = str_to_object(as_str) 15 | assert obj is exp[idx_s], f"{obj} is not {exp[idx_s]}" 16 | -------------------------------------------------------------------------------- /fnet/transforms.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import logging 3 | 4 | import numpy as np 5 | import scipy 6 | 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class Normalize: 12 | def __init__(self, per_dim=None): 13 | """Class version of normalize function.""" 14 | self.per_dim = per_dim 15 | 16 | def __call__(self, x): 17 | return normalize(x, per_dim=self.per_dim) 18 | 19 | def __repr__(self): 20 | return "Normalize({})".format(self.per_dim) 21 | 22 | 23 | class ToFloat: 24 | def __call__(self, x): 25 | return x.astype(np.float32) 26 | 27 | def __repr__(self): 28 | return "ToFloat()" 29 | 30 | 31 | def normalize(img, per_dim=None): 32 | """Subtract mean, set STD to 1.0 33 | 34 | Parameters: 35 | per_dim: normalize along other axes dimensions not equal to per dim 36 | """ 37 | axis = tuple([i for i in range(img.ndim) if i != per_dim]) 38 | slices = tuple( 39 | [slice(None) if i == per_dim else np.newaxis for i in range(img.ndim)] 40 | ) # to handle broadcasting 41 | result = img.astype(np.float32) 42 | result -= np.mean(result, axis=axis)[slices] 43 | result /= np.std(result, axis=axis)[slices] 44 | return result 45 | 46 | 47 | def do_nothing(img): 48 | return img.astype(np.float) 49 | 50 | 51 | class Propper: 52 | """Padder + Cropper""" 53 | 54 | def __init__(self, action="-", **kwargs): 55 | self.action = action 56 | if self.action in ["+", "pad"]: 57 | self.transformer = Padder(**kwargs) 58 | elif self.action in ["-", "crop"]: 59 | self.transformer = Cropper(**kwargs) 60 | else: 61 | raise NotImplementedError 62 | 63 | def __repr__(self): 64 | return repr(self.transformer) 65 | 66 | def __call__(self, x_in): 67 | return self.transformer(x_in) 68 | 69 | def undo_last(self, x_in): 70 | return self.transformer.undo_last(x_in) 71 | 72 | 73 | class Padder(object): 74 | def __init__(self, padding="+", by=16, mode="constant"): 75 | """ 76 | padding: '+', int, sequence 77 | '+': pad dimensions up to multiple of "by" 78 | int: pad each dimension by this value 79 | sequence: pad each dimensions by corresponding value in sequence 80 | by: int 81 | for use with '+' padding option 82 | mode: str 83 | passed to numpy.pad function 84 | """ 85 | self.padding = padding 86 | self.by = by 87 | self.mode = mode 88 | self.pads = {} 89 | self.last_pad = None 90 | 91 | def __repr__(self): 92 | return "Padder{}".format((self.padding, self.by, self.mode)) 93 | 94 | def _calc_pad_width(self, shape_in): 95 | if isinstance(self.padding, (str, int)): 96 | paddings = (self.padding,) * len(shape_in) 97 | else: 98 | paddings = self.padding 99 | pad_width = [] 100 | for i in range(len(shape_in)): 101 | if isinstance(paddings[i], int): 102 | pad_width.append((paddings[i],) * 2) 103 | elif paddings[i] == "+": 104 | padding_total = ( 105 | int(np.ceil(1.0 * shape_in[i] / self.by) * self.by) - shape_in[i] 106 | ) 107 | pad_left = padding_total // 2 108 | pad_right = padding_total - pad_left 109 | pad_width.append((pad_left, pad_right)) 110 | assert len(pad_width) == len(shape_in) 111 | return pad_width 112 | 113 | def undo_last(self, x_in): 114 | """Crops input so its dimensions matches dimensions of last input to __call__.""" 115 | assert x_in.shape == self.last_pad["shape_out"] 116 | slices = [ 117 | slice(a, -b) if (a, b) != (0, 0) else slice(None) 118 | for a, b in self.last_pad["pad_width"] 119 | ] 120 | return x_in[slices].copy() 121 | 122 | def __call__(self, x_in): 123 | shape_in = x_in.shape 124 | pad_width = self.pads.get(shape_in, self._calc_pad_width(shape_in)) 125 | x_out = np.pad(x_in, pad_width, mode=self.mode) 126 | if shape_in not in self.pads: 127 | self.pads[shape_in] = pad_width 128 | self.last_pad = { 129 | "shape_in": shape_in, 130 | "pad_width": pad_width, 131 | "shape_out": x_out.shape, 132 | } 133 | return x_out 134 | 135 | 136 | class Cropper(object): 137 | def __init__( 138 | self, cropping="-", by=16, offset="mid", n_max_pixels=9732096, dims_no_crop=None 139 | ): 140 | """Crop input array to given shape.""" 141 | self.cropping = cropping 142 | self.offset = offset 143 | self.by = by 144 | self.n_max_pixels = n_max_pixels 145 | self.dims_no_crop = ( 146 | [dims_no_crop] if isinstance(dims_no_crop, int) else dims_no_crop 147 | ) 148 | self.crops = {} 149 | self.last_crop = None 150 | 151 | def __repr__(self): 152 | return "Cropper{}".format( 153 | (self.cropping, self.by, self.offset, self.n_max_pixels, self.dims_no_crop) 154 | ) 155 | 156 | def _adjust_shape_crop(self, shape_crop): 157 | shape_crop_new = list(shape_crop) 158 | prod_shape = np.prod(shape_crop_new) 159 | idx_dim_reduce = 0 160 | order_dim_reduce = list( 161 | range(len(shape_crop))[-2:] 162 | ) # alternate between last two dimensions 163 | while prod_shape > self.n_max_pixels: 164 | dim = order_dim_reduce[idx_dim_reduce] 165 | if not (dim == 0 and shape_crop_new[dim] <= 64): 166 | shape_crop_new[dim] -= self.by 167 | prod_shape = np.prod(shape_crop_new) 168 | idx_dim_reduce += 1 169 | if idx_dim_reduce >= len(order_dim_reduce): 170 | idx_dim_reduce = 0 171 | value = tuple(shape_crop_new) 172 | return value 173 | 174 | def _calc_shape_crop(self, shape_in): 175 | croppings = ( 176 | (self.cropping,) * len(shape_in) 177 | if isinstance(self.cropping, (str, int)) 178 | else self.cropping 179 | ) 180 | shape_crop = [] 181 | for i in range(len(shape_in)): 182 | if (croppings[i] is None) or ( 183 | self.dims_no_crop is not None and i in self.dims_no_crop 184 | ): 185 | shape_crop.append(shape_in[i]) 186 | elif isinstance(croppings[i], int): 187 | shape_crop.append(shape_in[i] - croppings[i]) 188 | elif croppings[i] == "-": 189 | shape_crop.append(shape_in[i] // self.by * self.by) 190 | else: 191 | raise NotImplementedError 192 | if self.n_max_pixels is not None: 193 | shape_crop = self._adjust_shape_crop(shape_crop) 194 | self.crops[shape_in]["shape_crop"] = shape_crop 195 | return shape_crop 196 | 197 | def _calc_offsets_crop(self, shape_in, shape_crop): 198 | offsets = ( 199 | (self.offset,) * len(shape_in) 200 | if isinstance(self.offset, (str, int)) 201 | else self.offset 202 | ) 203 | offsets_crop = [] 204 | for i in range(len(shape_in)): 205 | offset = ( 206 | (shape_in[i] - shape_crop[i]) // 2 207 | if offsets[i] == "mid" 208 | else offsets[i] 209 | ) 210 | if offset + shape_crop[i] > shape_in[i]: 211 | logger.error( 212 | f"Cannot crop outsize image dimensions ({offset}:{offset + shape_crop[i]} for dim {i})" 213 | ) 214 | raise AttributeError 215 | offsets_crop.append(offset) 216 | self.crops[shape_in]["offsets_crop"] = offsets_crop 217 | return offsets_crop 218 | 219 | def _calc_slices(self, shape_in): 220 | shape_crop = self._calc_shape_crop(shape_in) 221 | offsets_crop = self._calc_offsets_crop(shape_in, shape_crop) 222 | slices = [ 223 | slice(offsets_crop[i], offsets_crop[i] + shape_crop[i]) 224 | for i in range(len(shape_in)) 225 | ] 226 | self.crops[shape_in]["slices"] = slices 227 | return slices 228 | 229 | def __call__(self, x_in): 230 | shape_in = x_in.shape 231 | if shape_in in self.crops: 232 | slices = self.crops[shape_in]["slices"] 233 | else: 234 | self.crops[shape_in] = {} 235 | slices = self._calc_slices(shape_in) 236 | x_out = x_in[slices].copy() 237 | self.last_crop = { 238 | "shape_in": shape_in, 239 | "slices": slices, 240 | "shape_out": x_out.shape, 241 | } 242 | return x_out 243 | 244 | def undo_last(self, x_in): 245 | """Pads input with zeros so its dimensions matches dimensions of last input to __call__.""" 246 | assert x_in.shape == self.last_crop["shape_out"] 247 | shape_out = self.last_crop["shape_in"] 248 | slices = self.last_crop["slices"] 249 | x_out = np.zeros(shape_out, dtype=x_in.dtype) 250 | x_out[slices] = x_in 251 | return x_out 252 | 253 | 254 | class Resizer(object): 255 | def __init__(self, factors, per_dim=None): 256 | """ 257 | Parameters: 258 | factors: tuple of resizing factors for each dimension of the input array 259 | per_dim: normalize along other axes dimensions not equal to per dim 260 | """ 261 | self.factors = factors 262 | self.per_dim = per_dim 263 | 264 | def __call__(self, x): 265 | if self.per_dim is None: 266 | return scipy.ndimage.zoom(x, (self.factors), mode="nearest") 267 | ars_resized = list() 268 | for idx in range(x.shape[self.per_dim]): 269 | slices = tuple( 270 | [idx if i == self.per_dim else slice(None) for i in range(x.ndim)] 271 | ) 272 | ars_resized.append( 273 | scipy.ndimage.zoom(x[slices], self.factors, mode="nearest") 274 | ) 275 | return np.stack(ars_resized, axis=self.per_dim) 276 | 277 | def __repr__(self): 278 | return "Resizer({:s}, {})".format(str(self.factors), self.per_dim) 279 | 280 | 281 | class Capper(object): 282 | def __init__(self, low=None, hi=None): 283 | self._low = low 284 | self._hi = hi 285 | 286 | def __call__(self, ar): 287 | result = ar.copy() 288 | if self._hi is not None: 289 | result[result > self._hi] = self._hi 290 | if self._low is not None: 291 | result[result < self._low] = self._low 292 | return result 293 | 294 | def __repr__(self): 295 | return "Capper({}, {})".format(self._low, self._hi) 296 | 297 | 298 | def flip_y(ar: np.ndarray) -> np.ndarray: 299 | """Flip array along y axis. 300 | 301 | Array dimensions should end in YX. 302 | 303 | Parameters 304 | ---------- 305 | ar 306 | Input array to be flipped. 307 | 308 | Returns 309 | ------- 310 | np.ndarray 311 | Flipped array. 312 | 313 | """ 314 | return np.flip(ar, axis=-2) 315 | 316 | 317 | def flip_x(ar: np.ndarray) -> np.ndarray: 318 | """Flip array along x axis. 319 | 320 | Array dimensions should end in YX. 321 | 322 | Parameters 323 | ---------- 324 | ar 325 | Input array to be flipped. 326 | 327 | Returns 328 | ------- 329 | np.ndarray 330 | Flipped array. 331 | 332 | """ 333 | return np.flip(ar, axis=-1) 334 | 335 | 336 | def norm_around_center(ar: np.ndarray, z_center: Optional[int] = None): 337 | """Returns normalized version of input array. 338 | 339 | The array will be normalized with respect to the mean, std pixel intensity 340 | of the sub-array of length 32 in the z-dimension centered around the 341 | array's "z_center". 342 | 343 | Parameters 344 | ---------- 345 | ar 346 | Input 3d array to be normalized. 347 | z_center 348 | Z-index of cell centers. 349 | 350 | Returns 351 | ------- 352 | np.ndarray 353 | Nomralized array, dtype = float32 354 | 355 | """ 356 | if ar.ndim != 3: 357 | raise ValueError("Input array must be 3d") 358 | if ar.shape[0] < 32: 359 | raise ValueError("Input array must be at least length 32 in first dimension") 360 | if z_center is None: 361 | z_center = ar.shape[0] // 2 362 | chunk_zlen = 32 363 | z_start = z_center - chunk_zlen // 2 364 | if z_start < 0: 365 | z_start = 0 366 | logger.warn(f"Warning: z_start set to {z_start}") 367 | if (z_start + chunk_zlen) > ar.shape[0]: 368 | z_start = ar.shape[0] - chunk_zlen 369 | logger.warn(f"Warning: z_start set to {z_start}") 370 | chunk = ar[z_start : z_start + chunk_zlen, :, :] 371 | ar = ar - chunk.mean() 372 | ar = ar / chunk.std() 373 | return ar.astype(np.float32) 374 | -------------------------------------------------------------------------------- /fnet/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenCellModeling/pytorch_fnet/64c53d123df644cebe5e4f7f2ab6efc5c0732f4e/fnet/utils/__init__.py -------------------------------------------------------------------------------- /fnet/utils/general_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Callable, List, Optional, Sequence 3 | import importlib 4 | import inspect 5 | import logging 6 | import os 7 | import sys 8 | import time 9 | 10 | import pandas as pd 11 | 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def files_from_dir( 17 | path_dir: str, extensions: Optional[Sequence[str]] = None 18 | ) -> List[str]: 19 | """Returns sorted list of files in a directory with optional extension(s). 20 | 21 | Parameters 22 | ---------- 23 | path_dir 24 | Input directory. 25 | extensions 26 | Optional file extensions. 27 | 28 | """ 29 | if extensions is None: 30 | extensions = [""] # Allows for all extensions 31 | paths = [] 32 | for entry in os.scandir(path_dir): 33 | if any(entry.path.endswith(ext) for ext in extensions): 34 | paths.append(entry.path) 35 | return sorted(paths) 36 | 37 | 38 | def str_to_object(str_o: str): 39 | """Get object from string. 40 | 41 | Parameters 42 | ---------- 43 | str_o 44 | Fully qualified object name. 45 | 46 | """ 47 | parts = str_o.split(".") 48 | if len(parts) > 1: 49 | module = importlib.import_module(".".join(parts[:-1])) 50 | return getattr(module, parts[-1]) 51 | return inspect.currentframe().f_back.f_globals[str_o] 52 | 53 | 54 | def to_objects(slist): 55 | """Get a list of objects from list of object __repr__s.""" 56 | if slist is None: 57 | return None 58 | olist = list() 59 | for s in slist: 60 | if not isinstance(s, str): 61 | if s is None: 62 | continue 63 | olist.append(s) 64 | continue 65 | if s.lower() == "none": 66 | continue 67 | s_split = s.split(".") 68 | for idx_part, part in enumerate(s_split): 69 | if not part.isidentifier(): 70 | break 71 | importee = ".".join(s_split[:idx_part]) 72 | so = ".".join(s_split[idx_part:]) 73 | if len(importee) > 0: 74 | module = importlib.import_module(importee) # noqa: F841 75 | so = "module." + so 76 | olist.append(eval(so)) 77 | return olist 78 | 79 | 80 | def retry_if_oserror(fn: Callable): 81 | """Retries input function if an OSError is encountered.""" 82 | 83 | def wrapper(*args, **kwargs): 84 | count = 0 85 | while True: 86 | count += 1 87 | try: 88 | fn(*args, **kwargs) 89 | break 90 | except OSError as err: 91 | wait = 2 ** min(count, 5) 92 | logger.info(f"Attempt {count} failed: {err}. Waiting {wait} seconds.") 93 | time.sleep(wait) 94 | 95 | return wrapper 96 | 97 | 98 | def get_args(): 99 | """Returns the arguments passed to the calling function. 100 | 101 | Example: 102 | 103 | >>> def foo(a, b, *args, **kwargs): 104 | ... print(get_args()) 105 | ... 106 | >>> foo(1, 2, 3, 'bar', fizz='buzz') 107 | ({'b': 2, 'a': 1, 'fizz': 'buzz'}, (3, 'bar')) 108 | 109 | References: 110 | kbyanc.blogspot.com/2007/07/python-aggregating-function-arguments.html 111 | 112 | Returns 113 | ------- 114 | dict 115 | Named arguments 116 | list 117 | Unnamed positional arguments 118 | 119 | """ 120 | frame = inspect.stack()[1].frame # Look at caller 121 | _, varargs, kwargs, named_args = inspect.getargvalues(frame) 122 | named_args = dict(named_args) 123 | named_args.update(named_args.pop(kwargs, [])) 124 | pos_args = named_args.pop(varargs, []) 125 | return named_args, pos_args 126 | 127 | 128 | def str_to_class(string: str): 129 | """Return class from string representation.""" 130 | idx_dot = string.rfind(".") 131 | if idx_dot < 0: 132 | module_str = "fnet.nn_modules" 133 | class_str = string 134 | else: 135 | module_str = string[:idx_dot] 136 | class_str = string[idx_dot + 1 :] 137 | module = importlib.import_module(module_str) 138 | return getattr(module, class_str) 139 | 140 | 141 | def add_augmentations(df: pd.DataFrame) -> pd.DataFrame: 142 | """Adds augmented versions of dataframe rows. 143 | 144 | This is intended to be used on dataframes that represent datasets. Two 145 | columns will be added: flip_y, flip_x. Each dataframe row will be 146 | replicated 3 more times with flip_y, flip_x, or both columns set to 1. 147 | 148 | Parameters 149 | ---------- 150 | df 151 | Dataset dataframe to be augmented. 152 | 153 | Returns 154 | ------- 155 | pd.DataFrame 156 | Augmented dataset dataframe. 157 | 158 | """ 159 | df_flip_y = df.assign(flip_y=1) 160 | df_flip_x = df.assign(flip_x=1) 161 | df_both = df.assign(flip_y=1, flip_x=1) 162 | name_index = df.index.name 163 | df_aug = pd.concat( 164 | [df, df_flip_y, df_flip_x, df_both], ignore_index=True, sort=False 165 | ).rename_axis(name_index) 166 | return df_aug 167 | 168 | 169 | def whats_my_name(obj: object): 170 | """Returns object's name.""" 171 | return obj.__module__ + "." + obj.__qualname__ 172 | 173 | 174 | def create_formatter(): 175 | """Creates a default logging Formatter.""" 176 | return logging.Formatter("%(levelname)s:%(name)s: %(message)s") 177 | 178 | 179 | def add_logging_file_handler(path_save: Path) -> None: 180 | """Adds a file handler to fnet logger. 181 | 182 | Parameters 183 | ---------- 184 | path_save 185 | Location to save logging records. 186 | 187 | Returns 188 | ------- 189 | None 190 | 191 | """ 192 | path_save.parent.mkdir(parents=True, exist_ok=True) 193 | fh = logging.FileHandler(path_save, mode="a") 194 | fh.setFormatter(create_formatter()) 195 | logging.getLogger("fnet").addHandler(fh) 196 | 197 | 198 | def init_fnet_logging() -> None: 199 | """Initializes logging for fnet. 200 | 201 | Parameters 202 | ---------- 203 | path_save 204 | Location to save logging records. 205 | 206 | Returns 207 | ------- 208 | None 209 | 210 | """ 211 | # Remove root logger handlers potentially set by third-party packages 212 | logger_root = logging.getLogger() 213 | for handler in logger_root.handlers: 214 | logger_root.removeHandler(handler) 215 | # Init fnet logger 216 | logger_fnet = logging.getLogger("fnet") 217 | logger_fnet.setLevel(logging.INFO) 218 | if logger_fnet.hasHandlers(): # avoids redundant handlers 219 | return 220 | sh = logging.StreamHandler(sys.stdout) 221 | sh.setFormatter(create_formatter()) 222 | logger_fnet.addHandler(sh) 223 | -------------------------------------------------------------------------------- /fnet/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def move_optim(optim: torch.optim.Optimizer, device: torch.device): 5 | """Moves optimizer parameters to specified device. 6 | 7 | """ 8 | for g_state in optim.state.values(): 9 | for k, v in g_state.items(): 10 | if torch.is_tensor(v): 11 | g_state[k] = v.to(device) 12 | -------------------------------------------------------------------------------- /fnet/utils/split_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import pandas as pd 5 | 6 | 7 | def int_or_float(x): 8 | try: 9 | val = int(x) 10 | assert val >= 0 11 | except ValueError: 12 | val = float(x) 13 | assert 0.0 <= val <= 1.0 14 | return val 15 | 16 | 17 | def main(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("src_csv", help="path to dataset CSV") 20 | parser.add_argument("dst_dir", help="destination directory of dataset split") 21 | parser.add_argument( 22 | "--train_size", 23 | type=int_or_float, 24 | default=0.8, 25 | help="training set size as int or faction of total dataset size", 26 | ) 27 | parser.add_argument("--seed", type=int, default=42, help="random seed") 28 | parser.add_argument("--no_shuffle", action="store_true", help="random seed") 29 | parser.add_argument("-v", "--verbose", action="store_true", help="verbose") 30 | opts = parser.parse_args() 31 | vprint = print if opts.verbose else lambda *a, **kw: None 32 | 33 | name = os.path.basename(opts.src_csv).split(".")[0] 34 | path_store_split = os.path.join(opts.dst_dir, name) 35 | path_train_csv = os.path.join(path_store_split, "train.csv") 36 | path_test_csv = os.path.join(path_store_split, "test.csv") 37 | if os.path.exists(path_train_csv) and os.path.exists(path_test_csv): 38 | vprint("Using existing train/test split.") 39 | return 40 | rng = np.random.RandomState(opts.seed) 41 | df_all = pd.read_csv(opts.src_csv) 42 | if not opts.no_shuffle: 43 | df_all = df_all.sample(frac=1.0, random_state=rng).reset_index(drop=True) 44 | if opts.train_size == 0: 45 | df_test = df_all 46 | df_train = df_all[0:0] # empty DataFrame but with columns intact 47 | else: 48 | if isinstance(opts.train_size, int): 49 | idx_split = opts.train_size 50 | elif isinstance(opts.train_size, float): 51 | idx_split = round(len(df_all) * opts.train_size) 52 | else: 53 | raise AttributeError 54 | df_train = df_all[:idx_split] 55 | df_test = df_all[idx_split:] 56 | vprint("train/test sizes: {:d}/{:d}".format(len(df_train), len(df_test))) 57 | if not os.path.exists(path_store_split): 58 | os.makedirs(path_store_split) 59 | df_train.to_csv(path_train_csv, index=False) 60 | df_test.to_csv(path_test_csv, index=False) 61 | vprint("saved:", path_train_csv) 62 | vprint("saved:", path_test_csv) 63 | 64 | 65 | if __name__ == "__main__": 66 | main() 67 | -------------------------------------------------------------------------------- /fnet/utils/viz_utils.py: -------------------------------------------------------------------------------- 1 | """Visualization tools.""" 2 | 3 | 4 | from typing import List, Optional, Union 5 | import logging 6 | import os 7 | 8 | import matplotlib 9 | import matplotlib.pyplot as plt 10 | import pandas as pd 11 | 12 | 13 | logger = logging.getLogger(__name__) 14 | plt.style.use("seaborn") 15 | COLORS = matplotlib.rcParams["axes.prop_cycle"].by_key()["color"] 16 | 17 | 18 | def _plot_df(df, ax, model_label, colors, **kwargs): 19 | """Plot dataframe columns on axes.""" 20 | for idx_c, col in enumerate(df.columns): 21 | label = (f"{model_label}:" if model_label is not None else "") + f"{col}" 22 | key = model_label, "_".join(col.split("_")[:-1]) 23 | if key not in colors: 24 | colors[key] = COLORS[colors["idx"]] 25 | colors["idx"] = (colors["idx"] + 1) % len(COLORS) 26 | color = colors[key] 27 | ax.plot( 28 | df.index.to_numpy(), df[col].to_numpy(), color=color, label=label, **kwargs 29 | ) 30 | 31 | 32 | def plot_loss( 33 | paths_model: Union[List[str], str], 34 | path_save: Optional[str] = None, 35 | train: bool = True, 36 | val: bool = True, 37 | title: Optional[str] = None, 38 | ymin: Optional[float] = None, 39 | ymax: Optional[float] = None, 40 | ) -> None: 41 | """Plots model loss curve(s). 42 | 43 | Parameters 44 | ---------- 45 | paths_model 46 | List of paths to model directories specified as a list or as a string 47 | of paths separated by spaces. 48 | path_save 49 | If not None, specifies where to save figure and figure will not be 50 | displayed. 51 | train 52 | Set to plot training curve. 53 | val 54 | Set to plot validation curve. 55 | title 56 | Plot title. 57 | ymin 58 | Y-axis minimum value. 59 | ymax 60 | Y-axis maximum value. 61 | 62 | """ 63 | if isinstance(paths_model, str): 64 | paths_model = paths_model.split(" ") 65 | if path_save is not None: 66 | plt.switch_backend("Agg") 67 | window_train = 128 68 | window_val = 32 69 | colors = {"idx": 0} # maps model-content to colors; idx is COLORS index 70 | fig, ax = plt.subplots() 71 | for idx_m, path_model in enumerate(paths_model): 72 | name_model = os.path.basename(os.path.normpath(path_model)) 73 | model_label = None if len(paths_model) == 1 else name_model 74 | path_loss = os.path.join(path_model, "losses.csv") 75 | df = pd.read_csv(path_loss, index_col="num_iter") 76 | if train: 77 | cols_train = [col for col in df.columns if col.lower().endswith("_train")] 78 | df_train = df.loc[:, cols_train].dropna(axis=1, thresh=1).dropna() 79 | df_train_rmean = df_train.rolling(window=window_train).mean() 80 | _plot_df(df_train_rmean, ax, model_label, colors, linestyle="-") 81 | if val: 82 | cols_val = [col for col in df.columns if col.lower().endswith("_val")] 83 | df_val = df.loc[:, cols_val].dropna(axis=1, thresh=1).dropna() 84 | df_val_rmean = df_val.rolling(window=window_val).mean() 85 | _plot_df(df_val_rmean, ax, model_label, colors, linestyle="--") 86 | if title is not None: 87 | ax.set_title(title) 88 | ax.set_ylim([ymin, ymax]) 89 | ax.set_xlabel("Training iterations") 90 | ax.set_ylabel("Rolling mean squared error") 91 | ax.legend() 92 | if path_save is not None: 93 | fig.savefig(path_save, bbox_inches="tight") 94 | logger.info(f"Saved: {path_save}") 95 | return 96 | plt.show() 97 | 98 | 99 | def plot_metric( 100 | path_csv: str, 101 | metric: str, 102 | path_save: Optional[str] = None, 103 | title: Optional[str] = None, 104 | ymin: Optional[float] = None, 105 | ymax: Optional[float] = None, 106 | ) -> None: 107 | """Plots box-plot of model performance according to some metric. 108 | 109 | Parameters 110 | ---------- 111 | path_csv 112 | Path to csv where each row is a dataset item. 113 | metric 114 | Name of metric. Should be within one or more CSV column names. 115 | path_save 116 | If not None, specifies where to save figure and figure will not be 117 | displayed. 118 | title 119 | Plot title. 120 | ymin 121 | Y-axis minimum value. 122 | ymax 123 | Y-axis maximum value. 124 | 125 | """ 126 | if path_save is not None: 127 | plt.switch_backend("Agg") 128 | df = pd.read_csv(path_csv) 129 | cols = [c for c in df.columns if metric in c] 130 | cols_rename = {c: c.split(metric)[-1] for c in cols} 131 | df = df.loc[:, cols].rename(columns=cols_rename) 132 | fig, ax = plt.subplots() 133 | df.boxplot(ax=ax) 134 | if title is not None: 135 | ax.set_title(title) 136 | ax.set_ylim([ymin, ymax]) 137 | ax.set_ylabel("Pearson correlation coefficient (r)") 138 | if path_save is not None: 139 | fig.savefig(path_save, bbox_inches="tight") 140 | logger.info(f"Saved: {path_save}") 141 | return 142 | plt.show() 143 | -------------------------------------------------------------------------------- /resources/PredictingStructures-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenCellModeling/pytorch_fnet/64c53d123df644cebe5e4f7f2ab6efc5c0732f4e/resources/PredictingStructures-1.jpg -------------------------------------------------------------------------------- /resources/multi_pred_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenCellModeling/pytorch_fnet/64c53d123df644cebe5e4f7f2ab6efc5c0732f4e/resources/multi_pred_b.png -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 0.2.0 3 | commit = True 4 | tag = True 5 | 6 | [bumpversion:file:setup.py] 7 | search = {current_version} 8 | replace = {new_version} 9 | 10 | [bumpversion:file:fnet/__init__.py] 11 | search = {current_version} 12 | replace = {new_version} 13 | 14 | [bdist_wheel] 15 | universal = 1 16 | 17 | [flake8] 18 | exclude = docs 19 | max-line-length = 88 20 | ignore = 21 | E203 22 | W291 23 | W503 24 | 25 | [aliases] 26 | test = pytest 27 | 28 | [tool:pytest] 29 | collect_ignore = ['setup.py'] 30 | 31 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """The setup script.""" 5 | 6 | from setuptools import find_packages, setup 7 | 8 | with open("README.md") as readme_file: 9 | readme = readme_file.read() 10 | 11 | test_requirements = [ 12 | "codecov", 13 | "flake8", 14 | "black", 15 | "pytest", 16 | "pytest-cov", 17 | "pytest-raises", 18 | "quilt3==3.1.5", 19 | "python-dateutil==2.8.0", 20 | ] 21 | 22 | setup_requirements = [ 23 | "pytest-runner", 24 | ] 25 | 26 | examples_requirements = [ 27 | "quilt3==3.1.10", 28 | "python-dateutil==2.8.0", 29 | ] 30 | 31 | dev_requirements = [ 32 | "bumpversion>=0.5.3", 33 | "coverage>=5.0a4", 34 | "flake8>=3.7.7", 35 | "ipython>=7.5.0", 36 | "m2r>=0.2.1", 37 | "pytest>=4.3.0", 38 | "pytest-cov==2.6.1", 39 | "pytest-raises>=0.10", 40 | "pytest-runner>=4.4", 41 | "Sphinx>=2.0.0b1", 42 | "sphinx_rtd_theme>=0.1.2", 43 | "tox>=3.5.2", 44 | "twine>=1.13.0", 45 | "wheel>=0.33.1", 46 | ] 47 | 48 | interactive_requirements = [ 49 | "altair", 50 | "jupyterlab", 51 | "matplotlib", 52 | ] 53 | 54 | requirements = [ 55 | "matplotlib", 56 | "numpy", 57 | "pandas", 58 | "scipy", 59 | "tifffile==0.15.1", 60 | "torch>=1.0", 61 | "tqdm", 62 | "scikit-image>=0.15.0", 63 | "aicsimageio==3.0.7", 64 | ] 65 | 66 | extra_requirements = { 67 | "test": test_requirements, 68 | "setup": setup_requirements, 69 | "dev": dev_requirements, 70 | "interactive": interactive_requirements, 71 | "examples": examples_requirements, 72 | "all": [ 73 | *requirements, 74 | *test_requirements, 75 | *setup_requirements, 76 | *dev_requirements, 77 | *interactive_requirements 78 | ] 79 | } 80 | 81 | setup( 82 | author="Ounkomol, Chek and Fernandes, Daniel A. and Seshamani, Sharmishtaa and " 83 | "Maleckar, Mary M. and Collman, Forrest and Johnson, Gregory R.", 84 | author_email="gregj@alleninstitute.org", 85 | classifiers=[ 86 | "Development Status :: 2 - Pre-Alpha", 87 | "Intended Audience :: Developers", 88 | "License :: Free for non-commercial use", 89 | "Natural Language :: English", 90 | "Programming Language :: Python :: 3.6", 91 | "Programming Language :: Python :: 3.7", 92 | ], 93 | description="A machine learning model for transforming microsocpy images between " 94 | "modalities", 95 | entry_points={ 96 | "console_scripts": ["fnet = fnet.cli.main:main"], 97 | }, 98 | install_requires=requirements, 99 | license="Allen Institute Software License", 100 | long_description=readme, 101 | long_description_content_type="text/markdown", 102 | include_package_data=True, 103 | keywords="fnet", 104 | name="fnet", 105 | packages=find_packages(exclude=["tests", "*.tests", "*.tests.*"]), 106 | python_requires=">=3.6", 107 | setup_requires=setup_requirements, 108 | test_suite="fnet/tests", 109 | tests_require=test_requirements, 110 | extras_require=extra_requirements, 111 | url="https://github.com/AllenCellModeling/pytorch_fnet", 112 | # Do not edit this string manually, always use bumpversion 113 | # Details in CONTRIBUTING.rst 114 | version="0.2.0", 115 | zip_safe=False, 116 | ) 117 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [flake8] 2 | exclude = fnet/vendor/* 3 | ignore = E203, E266, E501, W503, F403, F401, E231 4 | 5 | [tox] 6 | skipsdist = True 7 | envlist = py36, py37, lint 8 | 9 | [pytest] 10 | markers = 11 | raises 12 | 13 | [testenv:lint] 14 | deps = 15 | .[test] 16 | commands = 17 | flake8 fnet --count --verbose --max-line-length=127 --show-source --statistics 18 | 19 | [testenv] 20 | setenv = 21 | PYTHONPATH = {toxinidir} 22 | deps = 23 | .[test] 24 | commands = 25 | pytest --basetemp={envtmpdir} --cov-report html --cov=fnet fnet/tests/ 26 | --------------------------------------------------------------------------------