├── .github └── workflows │ └── python-app.yml ├── .gitignore ├── .readthedocs.yaml ├── LICENSE ├── README.md ├── TODO.md ├── assets └── imgs │ └── fusionlab_banner.png ├── configs ├── Install on Macbook M1.md ├── requirements-m1.txt └── tf-apple-m1-conda.yaml ├── docs ├── requirements.txt └── source │ ├── _templates │ ├── custom-class-template.rst │ └── custom-module-template.rst │ ├── conf.py │ ├── datasets.rst │ ├── encoders.rst │ ├── index.rst │ ├── layers.rst │ ├── losses.rst │ ├── metrics.rst │ ├── segmentation.rst │ └── utils.rst ├── fusionlab ├── __init__.py ├── __version__.py ├── classification │ ├── __init__.py │ ├── base.py │ ├── lstm.py │ └── vgg.py ├── configs.py ├── datasets │ ├── __init__.py │ ├── a12lead.py │ ├── cinc2017.py │ ├── csv_sample.csv │ ├── csvread.py │ ├── ludb.py │ ├── muse.py │ └── utils.py ├── encoders │ ├── README.md │ ├── __init__.py │ ├── alexnet │ │ ├── __init__.py │ │ ├── alexnet.py │ │ └── tfalexnet.py │ ├── convnext │ │ ├── __init__.py │ │ └── convnext.py │ ├── efficientnet │ │ ├── __init__.py │ │ └── efficientnet.py │ ├── inceptionv1 │ │ ├── __init__.py │ │ ├── inceptionv1.py │ │ └── tfinceptionv1.py │ ├── mit │ │ ├── __init__.py │ │ └── mit.py │ ├── resnetv1 │ │ ├── __init__.py │ │ ├── resnetv1.py │ │ └── tfresnetv1.py │ ├── vgg │ │ ├── __init__.py │ │ ├── tfvgg.py │ │ └── vgg.py │ └── vit │ │ ├── __init__.py │ │ └── vit.py ├── functional │ ├── __init__.py │ ├── dice.py │ ├── iou.py │ ├── tfdice.py │ └── tfiou.py ├── layers │ ├── __init__.py │ ├── base.py │ ├── factories.py │ ├── patch_embed │ │ ├── __init__.py │ │ └── patch_embedding.py │ ├── selfattention │ │ ├── __init__.py │ │ └── selfattention.py │ └── squeeze_excitation │ │ ├── __init__.py │ │ ├── se.py │ │ └── tfse.py ├── losses │ ├── README.md │ ├── __init__.py │ ├── diceloss │ │ ├── __init__.py │ │ ├── dice.py │ │ └── tfdice.py │ ├── iouloss │ │ ├── __init__.py │ │ ├── iou.py │ │ └── tfiou.py │ └── tversky │ │ ├── __init__.py │ │ ├── tftversky.py │ │ └── tversky.py ├── metrics │ ├── __init__.py │ ├── dicescore │ │ ├── __init__.py │ │ └── dice.py │ └── iouscore │ │ ├── __init__.py │ │ └── iou.py ├── segmentation │ ├── README.md │ ├── __init__.py │ ├── base.py │ ├── resunet │ │ ├── __init__.py │ │ ├── resunet.py │ │ └── tfresunet.py │ ├── segformer │ │ ├── __init__.py │ │ └── segformer.py │ ├── tfbase.py │ ├── transunet │ │ ├── __init__.py │ │ └── transunet.py │ ├── unet │ │ ├── __init__.py │ │ ├── tfunet.py │ │ └── unet.py │ ├── unet2plus │ │ ├── __init__.py │ │ ├── tfunet2plus.py │ │ └── unet2plus.py │ └── unetr │ │ ├── __init__.py │ │ └── unetr.py ├── trainers │ ├── __init__.py │ ├── dcgan.py │ ├── test.py │ └── trainer.py └── utils │ ├── __init__.py │ ├── basic.py │ ├── labelme.py │ ├── plots.py │ ├── trace.py │ └── trunc_normal │ ├── __init__.py │ └── trunc_normal.py ├── make_init.sh ├── release_logs.md ├── requirements.txt ├── scripts └── build_pip.sh ├── setup.py └── tests ├── __init__.py ├── test_classification.py ├── test_datasets.py ├── test_encoders.py ├── test_factories.py ├── test_layers.py ├── test_losses.py ├── test_metrics.py ├── test_seg.py └── test_utils.py /.github/workflows/python-app.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: FusionLab:build 5 | 6 | on: 7 | push: 8 | branches: [ "main", "dev"] 9 | pull_request: 10 | branches: [ "main", "dev"] 11 | types: [ "opened", "synchronize", "reopened", "edited" ] 12 | 13 | permissions: 14 | contents: read 15 | 16 | jobs: 17 | build: 18 | 19 | runs-on: ubuntu-latest 20 | 21 | steps: 22 | - uses: actions/checkout@main 23 | - name: Set up Python 3.8 24 | uses: actions/setup-python@main 25 | with: 26 | python-version: "3.8" 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | pip install flake8 pytest 31 | pip install -r requirements.txt 32 | pip install torchvision 33 | pip install tensorflow 34 | pip install torchinfo 35 | - name: Lint with flake8 36 | run: | 37 | # stop the build if there are Python syntax errors or undefined names 38 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 39 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 40 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 41 | - name: Test with pytest 42 | run: | 43 | pytest 44 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .idea/ 6 | .venv* 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # Environments 87 | .env 88 | .venv 89 | env/ 90 | venv/ 91 | ENV/ 92 | env.bak/ 93 | venv.bak/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | 108 | # Dataset 109 | dataset/ 110 | data/ 111 | 112 | # VSCode 113 | .vscode 114 | 115 | # Training Files 116 | *mlruns/ 117 | *resutls/ 118 | tryout_notebooks -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the OS, Python version and other tools you might need 9 | build: 10 | os: ubuntu-22.04 11 | tools: 12 | python: "3.9" 13 | # You can also specify other tool versions: 14 | # nodejs: "19" 15 | # rust: "1.64" 16 | # golang: "1.19" 17 | 18 | # Build documentation in the "docs/" directory with Sphinx 19 | sphinx: 20 | configuration: docs/source/conf.py 21 | 22 | # Optionally build your docs in additional formats such as PDF and ePub 23 | # formats: 24 | # - pdf 25 | # - epub 26 | 27 | # Optional but recommended, declare the Python requirements required 28 | # to build your documentation 29 | # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html 30 | python: 31 | install: 32 | - requirements: docs/requirements.txt 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 SAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FusionLab 2 | 3 |

4 |
5 | 6 |
7 |

8 | 9 | [![PyPI version](https://badge.fury.io/py/fusionlab.svg)](https://badge.fury.io/py/fusionlab) ![Test](https://github.com/taipingeric/fusionlab/actions/workflows/python-app.yml/badge.svg) [![Downloads](https://static.pepy.tech/badge/fusionlab)](https://pepy.tech/project/fusionlab) 10 | 11 | [![Documentation](https://img.shields.io/badge/view-Documentation-blue?style=for-the-badge)](https://fusionlab.readthedocs.io/) 12 | 13 | FusionLab is an open-source frameworks built for Deep Learning research written in PyTorch and Tensorflow. The code is easy to read and modify 14 | especially for newbie. Feel free to send pull requests :D 15 | 16 | * [What's New](#news) 17 | * [Installation](#installation) 18 | * [How to use](#how-to-use) 19 | * [Encoders](#encoders) 20 | * [Losses](#losses) 21 | * [Segmentation](#segmentation) 22 | * [1D, 2D, 3D Model](#n-dimensional-model) 23 | * [Acknowledgements](#acknowledgements) 24 | 25 | ## Installation 26 | 27 | ### With pip 28 | 29 | ```bash 30 | pip install fusionlab 31 | ``` 32 | 33 | #### For Mac M1 chip users 34 | [Install on Macbook M1 chip](./configs/Install%20on%20Macbook%20M1.md) 35 | 36 | ## How to use 37 | 38 | ```python 39 | import fusionlab as fl 40 | 41 | # PyTorch 42 | encoder = fl.encoders.VGG16() 43 | # Tensorflow 44 | encoder = fl.encoders.TFVGG16() 45 | 46 | ``` 47 | 48 | ## Documentation 49 | 50 | [Doc](https://fusionlab.readthedocs.io/en/latest/encoders.html) 51 | 52 | ## Encoders 53 | 54 | [encoder list](fusionlab/encoders/README.md) 55 | 56 | ## Losses 57 | 58 | [Loss func list](fusionlab/losses/README.md) 59 | * Dice Loss 60 | * Tversky Loss 61 | * IoU Loss 62 | 63 | 64 | ```python 65 | # Dice Loss (Multiclass) 66 | import fusionlab as fl 67 | 68 | # PyTorch 69 | pred = torch.randn(1, 3, 4, 4) # (N, C, *) 70 | target = torch.randint(0, 3, (1, 4, 4)) # (N, *) 71 | loss_fn = fl.losses.DiceLoss() 72 | loss = loss_fn(pred, target) 73 | 74 | # Tensorflow 75 | pred = tf.random.normal((1, 4, 4, 3), 0., 1.) # (N, *, C) 76 | target = tf.random.uniform((1, 4, 4), 0, 3) # (N, *) 77 | loss_fn = fl.losses.TFDiceLoss("multiclass") 78 | loss = loss_fn(target, pred) 79 | 80 | 81 | # Dice Loss (Binary) 82 | 83 | # PyTorch 84 | pred = torch.randn(1, 1, 4, 4) # (N, 1, *) 85 | target = torch.randint(0, 3, (1, 4, 4)) # (N, *) 86 | loss_fn = fl.losses.DiceLoss("binary") 87 | loss = loss_fn(pred, target) 88 | 89 | # Tensorflow 90 | pred = tf.random.normal((1, 4, 4, 1), 0., 1.) # (N, *, 1) 91 | target = tf.random.uniform((1, 4, 4), 0, 3) # (N, *) 92 | loss_fn = fl.losses.TFDiceLoss("binary") 93 | loss = loss_fn(target, pred) 94 | 95 | 96 | ``` 97 | 98 | ## Segmentation 99 | 100 | ```python 101 | import fusionlab as fl 102 | # PyTorch UNet 103 | unet = fl.segmentation.UNet(cin=3, num_cls=10) 104 | 105 | # Tensorflow UNet 106 | # Multiclass Segmentation 107 | unet = tf.keras.Sequential([ 108 | fl.segmentation.TFUNet(num_cls=10, base_dim=64), 109 | tf.keras.layers.Activation(tf.nn.softmax), 110 | ]) 111 | 112 | # Binary Segmentation 113 | unet = tf.keras.Sequential([ 114 | fl.segmentation.TFUNet(num_cls=1, base_dim=64), 115 | tf.keras.layers.Activation(tf.nn.sigmoid), 116 | ]) 117 | ``` 118 | 119 | [Segmentation model list](fusionlab/segmentation/README.md) 120 | 121 | * UNet 122 | * ResUNet 123 | * UNet2plus 124 | 125 | ## N Dimensional Model 126 | 127 | some models can be used in 1D, 2D, 3D 128 | 129 | ```python 130 | import fusionlab as fl 131 | 132 | resnet1d = fl.encoders.ResNet50V1(cin=3, spatial_dims=1) 133 | resnet2d = fl.encoders.ResNet50V1(cin=3, spatial_dims=2) 134 | resnet3d = fl.encoders.ResNet50V1(cin=3, spatial_dims=3) 135 | 136 | unet1d = fl.segmentation.UNet(cin=3, num_cls=10, spatial_dims=1) 137 | unet2d = fl.segmentation.UNet(cin=3, num_cls=10, spatial_dims=2) 138 | unet3d = fl.segmentation.UNet(cin=3, num_cls=10, spatial_dims=3) 139 | ``` 140 | 141 | ## News 142 | 143 | [Release logs](./release_logs.md) 144 | 145 | ## Acknowledgements 146 | 147 | * [BloodAxe/pytorch-toolbelt](https://github.com/BloodAxe/pytorch-toolbelt) 148 | -------------------------------------------------------------------------------- /TODO.md: -------------------------------------------------------------------------------- 1 | ### Encoders 2 | 3 | * SENet 4 | * DenseNet 5 | 6 | ### Segmentation 7 | 8 | * SegFormer 9 | * nnUNet 10 | * Swin-Unetr 11 | * AttUNet 12 | * EnsDiff 13 | * SegDiff 14 | * FCN 15 | * DeepLab 16 | * FPN 17 | * PSPNet 18 | 19 | ### Metrics 20 | 21 | * Dice, IoU score module 22 | * HD95 23 | * Normalized Surface Distance(NSD) 24 | 25 | ### Loss 26 | 27 | * Focal loss 28 | 29 | ### Modified 30 | 31 | * Replace StochasticPath in torchvision with timm DropPath 32 | 33 | ### Dataset 34 | 35 | 36 | -------------------------------------------------------------------------------- /assets/imgs/fusionlab_banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/assets/imgs/fusionlab_banner.png -------------------------------------------------------------------------------- /configs/Install on Macbook M1.md: -------------------------------------------------------------------------------- 1 | #Install on Macbook M1 chip 2 | 3 | Ref: [https://github.com/jeffheaton/t81_558_deep_learning/install/tensorflow-install-conda-mac-metal-jan-2023.ipynb](https://github.com/jeffheaton/t81_558_deep_learning/install/tensorflow-install-conda-mac-metal-jan-2023.ipynb) 4 | 5 | [video link](https://www.youtube.com/watch?v=5DgWvU0p2bk) 6 | 7 | **NOTE: It's been tested on Macbook Air 2021 M1 chip only** 8 | 9 | Requirements: 10 | * Apple Mac with M1 chips 11 | * MacOS > 12.6 (Monterey) 12 | 13 | Following steps 14 | 1. Clone this repo 15 | ```bash 16 | git clone https://github.com/taipingeric/fusionlab.git 17 | cd fusionlab 18 | ``` 19 | 2. Uninstall Anaconda (Optional): https://docs.anaconda.com/anaconda/install/uninstall/ 20 | 3. Install [Miniconda](https://docs.conda.io/en/latest/miniconda.html) 21 | 1. Miniconda3 macOS Apple M1 64-bit pkg 22 | 2. Miniconda3 macOS Apple M1 64-bit bash 23 | 4. Install the xcode-select command-line 24 | ```bash 25 | xcode-select --install 26 | ``` 27 | 5. Deactivate the base environment 28 | ```bash 29 | conda deactivate 30 | ``` 31 | 7. Clone this repo from github and change dir 32 | ```bash 33 | git clone https://github.com/taipingeric/fusionlab.git 34 | cd fusionlab 35 | ``` 36 | 8. Create conda environment using [config](./tf-apple-m1-conda.yaml) 37 | ```bash 38 | conda env create -f ./configs/tf-apple-m1-conda.yaml -n fusionlab 39 | ``` 40 | 7. Replace [requirements.txt](../requirements.txt) with [requirements-m1.txt](requirements-m1.txt) 41 | 8. Install by pip 42 | ```bash 43 | pip install -r requirements-m1.txt 44 | ``` 45 | -------------------------------------------------------------------------------- /configs/requirements-m1.txt: -------------------------------------------------------------------------------- 1 | # Mac M1 2 | # ref: https://www.youtube.com/watch?v=5DgWvU0p2bk 3 | # ref: https://github.com/jeffheaton/t81_558_deep_learning/blob/master/install/tensorflow-install-conda-mac-metal-jul-2022.ipynb 4 | 5 | torchvision>=0.5.0 6 | tensorflow-macos>=2.9.2 7 | tensorflow-metal>=0.5.0 -------------------------------------------------------------------------------- /configs/tf-apple-m1-conda.yaml: -------------------------------------------------------------------------------- 1 | name: fusionlab 2 | channels: 3 | - apple 4 | - conda-forge 5 | dependencies: 6 | - python=3.9 7 | - pip>=19.0 8 | - jupyter 9 | - tensorflow-deps 10 | - scikit-learn 11 | - scipy 12 | - pandas 13 | - pandas-datareader 14 | - matplotlib 15 | - pillow 16 | - tqdm 17 | - requests 18 | - h5py 19 | - pyyaml 20 | - flask 21 | - boto3 22 | - ipykernel 23 | - pip: 24 | - tensorflow-macos 25 | - tensorflow-metal 26 | - bayesian-optimization 27 | - gym 28 | - kaggle 29 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | einops 2 | tqdm 3 | traceback2 4 | 5 | pandas>=1.5.3 6 | xmltodict 7 | xlwt 8 | pandas 9 | scipy 10 | wfdb 11 | # 12 | torch>=1.9 13 | torchvision 14 | tensorflow 15 | # Doc 16 | Sphinx 17 | pydata-sphinx-theme 18 | sphinxcontrib-applehelp 19 | sphinxcontrib-devhelp 20 | sphinxcontrib-htmlhelp 21 | sphinxcontrib-jsmath 22 | sphinxcontrib-qthelp 23 | sphinxcontrib-serializinghtml 24 | sphinx-autodoc-typehints==1.11.1 25 | -------------------------------------------------------------------------------- /docs/source/_templates/custom-class-template.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | :members: 7 | :show-inheritance: 8 | :inherited-members: 9 | 10 | {% block methods %} 11 | .. automethod:: __init__ 12 | 13 | {% if methods %} 14 | .. rubric:: {{ _('Methods') }} 15 | 16 | .. autosummary:: 17 | {% for item in methods %} 18 | ~{{ name }}.{{ item }} 19 | {%- endfor %} 20 | {% endif %} 21 | {% endblock %} 22 | 23 | {% block attributes %} 24 | {% if attributes %} 25 | .. rubric:: {{ _('Attributes') }} 26 | 27 | .. autosummary:: 28 | {% for item in attributes %} 29 | ~{{ name }}.{{ item }} 30 | {%- endfor %} 31 | {% endif %} 32 | {% endblock %} -------------------------------------------------------------------------------- /docs/source/_templates/custom-module-template.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. automodule:: {{ fullname }} 4 | 5 | {% block attributes %} 6 | {% if attributes %} 7 | .. rubric:: Module Attributes 8 | 9 | .. autosummary:: 10 | :toctree: 11 | {% for item in attributes %} 12 | {{ item }} 13 | {%- endfor %} 14 | {% endif %} 15 | {% endblock %} 16 | 17 | {% block functions %} 18 | {% if functions %} 19 | .. rubric:: {{ _('Functions') }} 20 | 21 | .. autosummary:: 22 | :toctree: 23 | {% for item in functions %} 24 | {{ item }} 25 | {%- endfor %} 26 | {% endif %} 27 | {% endblock %} 28 | 29 | {% block classes %} 30 | {% if classes %} 31 | .. rubric:: {{ _('Classes') }} 32 | 33 | .. autosummary:: 34 | :toctree: 35 | :template: custom-class-template.rst 36 | {% for item in classes %} 37 | {{ item }} 38 | {%- endfor %} 39 | {% endif %} 40 | {% endblock %} 41 | 42 | {% block exceptions %} 43 | {% if exceptions %} 44 | .. rubric:: {{ _('Exceptions') }} 45 | 46 | .. autosummary:: 47 | :toctree: 48 | {% for item in exceptions %} 49 | {{ item }} 50 | {%- endfor %} 51 | {% endif %} 52 | {% endblock %} 53 | 54 | {% block modules %} 55 | {% if modules %} 56 | .. rubric:: Modules 57 | 58 | .. autosummary:: 59 | :toctree: 60 | :template: custom-module-template.rst 61 | :recursive: 62 | {% for item in modules %} 63 | {{ item }} 64 | {%- endfor %} 65 | {% endif %} 66 | {% endblock %} -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import subprocess 15 | import sys 16 | 17 | sys.path.insert(0, os.path.abspath("..")) 18 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) 19 | print(sys.path) 20 | 21 | import fusionlab # noqa: E402 22 | 23 | # -- Project information ----------------------------------------------------- 24 | project = "FusionLab" 25 | copyright = "taipingeric" 26 | author = "fusionlab Contributors" 27 | 28 | # The full version, including alpha/beta/rc tags 29 | short_version = fusionlab.__version__ #.split("+")[0] 30 | release = short_version 31 | version = short_version 32 | 33 | # List of patterns, relative to source directory, that match files and 34 | # directories to ignore when looking for source files. 35 | # This pattern also affects html_static_path and html_extra_path. 36 | exclude_patterns = [ 37 | "transforms", 38 | "networks", 39 | "metrics", 40 | "engines", 41 | "data", 42 | "apps", 43 | "fl", 44 | "bundle", 45 | "config", 46 | "handlers", 47 | "losses", 48 | "visualize", 49 | "utils", 50 | "inferers", 51 | "optimizers", 52 | "auto3dseg", 53 | "'**/tf*'", 54 | ] 55 | 56 | 57 | def generate_apidocs(*args): 58 | """Generate API docs automatically by trawling the available modules""" 59 | module_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "fusionlab")) 60 | output_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "apidocs")) 61 | apidoc_command_path = "sphinx-apidoc" 62 | if hasattr(sys, "real_prefix"): # called from a virtualenv 63 | apidoc_command_path = os.path.join(sys.prefix, "bin", "sphinx-apidoc") 64 | apidoc_command_path = os.path.abspath(apidoc_command_path) 65 | print(f"output_path {output_path}") 66 | print(f"module_path {module_path}") 67 | subprocess.check_call( 68 | [apidoc_command_path, "-e"] 69 | + ["-o", output_path] 70 | + [module_path] 71 | + [os.path.join(module_path, p) for p in exclude_patterns] 72 | ) 73 | 74 | 75 | # -- General configuration --------------------------------------------------- 76 | 77 | # Add any Sphinx extension module names here, as strings. They can be 78 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 79 | # ones. 80 | source_suffix = {".rst": "restructuredtext", ".txt": "restructuredtext", ".md": "markdown"} 81 | 82 | extensions = [ 83 | "recommonmark", 84 | "sphinx.ext.intersphinx", 85 | "sphinx.ext.mathjax", 86 | "sphinx.ext.napoleon", 87 | "sphinx.ext.autodoc", 88 | "sphinx.ext.viewcode", 89 | "sphinx.ext.autosectionlabel", 90 | "sphinx.ext.autosummary", 91 | "sphinx_autodoc_typehints", 92 | ] 93 | # "sphinx.ext.autosummary" 94 | autosummary_generate = True 95 | 96 | autoclass_content = "class" 97 | add_module_names = True 98 | source_encoding = "utf-8" 99 | autosectionlabel_prefix_document = True 100 | napoleon_use_param = True 101 | napoleon_include_init_with_doc = True 102 | set_type_checking_flag = True 103 | 104 | # Add any paths that contain templates here, relative to this directory. 105 | templates_path = ["_templates"] 106 | 107 | # -- Options for HTML output ------------------------------------------------- 108 | 109 | # The theme to use for HTML and HTML Help pages. See the documentation for 110 | # a list of builtin themes. 111 | # 112 | html_theme = "pydata_sphinx_theme" 113 | # html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 114 | html_theme_options = { 115 | "external_links": [{"url": "https://github.com/taipingeric/fusionlab", "name": "Tutorials"}], 116 | # "external_links": [{"url": "https://github.com/Project-MONAI/tutorials", "name": "Tutorials"}], 117 | # "icon_links": [ 118 | # {"name": "GitHub", "url": "https://github.com/project-monai/monai", "icon": "fab fa-github-square"}, 119 | # {"name": "Twitter", "url": "https://twitter.com/projectmonai", "icon": "fab fa-twitter-square"}, 120 | # ], 121 | "collapse_navigation": True, 122 | "navigation_depth": 1, 123 | "show_toc_level": 1, 124 | "footer_start": ["copyright"], 125 | "navbar_align": "content", 126 | # "logo": {"image_light": "MONAI-logo-color.png", "image_dark": "MONAI-logo-color.png"}, 127 | } 128 | html_context = { 129 | "github_user": "taipingeric", 130 | "github_repo": "fusionlab", 131 | "github_version": "dev", 132 | "doc_path": "docs/", 133 | "conf_py_path": "/docs/", 134 | "VERSION": version, 135 | } 136 | html_scaled_image_link = False 137 | html_show_sourcelink = True 138 | # html_favicon = "../images/favicon.ico" 139 | # html_logo = "../images/MONAI-logo-color.png" 140 | html_sidebars = {"**": ["search-field", "sidebar-nav-bs"]} 141 | pygments_style = "sphinx" 142 | 143 | # Add any paths that contain custom static files (such as style sheets) here, 144 | # relative to this directory. They are copied after the builtin static files, 145 | # so a file named "default.css" will overwrite the builtin "default.css". 146 | html_static_path = ["../_static"] 147 | html_css_files = ["custom.css"] 148 | html_title = f"{project} {version} Documentation" 149 | 150 | # -- Auto-convert markdown pages to demo -------------------------------------- 151 | 152 | 153 | def setup(app): 154 | # Hook to allow for automatic generation of API docs 155 | # before doc deployment begins. 156 | app.connect("builder-inited", generate_apidocs) -------------------------------------------------------------------------------- /docs/source/datasets.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/taipingeric/fusionlab 2 | 3 | .. _datasets: 4 | 5 | Datasets 6 | ============== 7 | 8 | .. automodule:: fusionlab.datasets 9 | .. currentmodule:: fusionlab.datasets 10 | 11 | `GE Muse XML Reader` 12 | ~~~~~~~~~~ 13 | .. autoclass:: GEMuseXMLReader 14 | :members: 15 | 16 | `ECG Classification Dataset` 17 | ~~~~~~~~~~ 18 | .. autoclass:: ECGClassificationDataset 19 | :members: 20 | 21 | Cinc2017 Dataset 22 | ---------------- 23 | 24 | `ECG CSV Classification Dataset` 25 | ~~~~~~~~~~ 26 | .. autoclass:: ECGCSVClassificationDataset 27 | :members: 28 | 29 | `Validate CinC 2017 Dataset` 30 | ~~~~~~~~~~ 31 | .. autofunction:: validate_data 32 | 33 | `.mat to .csv` 34 | ~~~~~~~~~~ 35 | .. autofunction:: convert_mat_to_csv 36 | 37 | LUDB Dataset 38 | ---------------- 39 | 40 | `LUDB Dataset` 41 | ~~~~~~~~~~ 42 | .. autoclass:: LUDBDataset 43 | :members: 44 | 45 | .. autofunction:: plot 46 | 47 | Utils 48 | ---------------- 49 | 50 | `Download file` 51 | ~~~~~~~~~~ 52 | .. autofunction:: download_file 53 | 54 | `HuggingFace Dataset` 55 | ~~~~~~~~~~ 56 | .. autoclass:: HFDataset 57 | :members: 58 | 59 | `LabelStudio Time series Segmentation Dataset` 60 | ~~~~~~~~~~ 61 | .. autoclass:: LSTimeSegDataset 62 | :members: 63 | 64 | `Read csv` 65 | ~~~~~~~~~~ 66 | .. autofunction:: read_csv 67 | 68 | -------------------------------------------------------------------------------- /docs/source/encoders.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/taipingeric/fusionlab 2 | 3 | .. _encoders: 4 | 5 | Encoders 6 | ============== 7 | 8 | PyTorch Encoders 9 | ------------------- 10 | 11 | .. automodule:: fusionlab.encoders 12 | .. currentmodule:: fusionlab.encoders 13 | 14 | `AlexNet` 15 | ~~~~~~~~~~ 16 | .. autoclass:: AlexNet 17 | :members: 18 | 19 | `VGG` 20 | ~~~~~~~~~~ 21 | .. autoclass:: VGG16 22 | :members: 23 | .. autoclass:: VGG19 24 | :members: 25 | 26 | `InceptionNet` 27 | ~~~~~~~~~~ 28 | .. autoclass:: InceptionNetV1 29 | :members: 30 | 31 | `ResNet` 32 | ~~~~~~~~~~ 33 | .. autoclass:: ResNet 34 | :members: 35 | .. autoclass:: ResNet18 36 | :members: 37 | .. autoclass:: ResNet34 38 | :members: 39 | .. autoclass:: ResNet50 40 | :members: 41 | .. autoclass:: ResNet101 42 | :members: 43 | .. autoclass:: ResNet152 44 | :members: 45 | 46 | `EfficientNet` 47 | ~~~~~~~~~~ 48 | .. autoclass:: EfficientNet 49 | :members: 50 | .. autoclass:: EfficientNetB0 51 | :members: 52 | .. autoclass:: EfficientNetB1 53 | :members: 54 | .. autoclass:: EfficientNetB2 55 | :members: 56 | .. autoclass:: EfficientNetB3 57 | :members: 58 | .. autoclass:: EfficientNetB4 59 | :members: 60 | .. autoclass:: EfficientNetB5 61 | :members: 62 | .. autoclass:: EfficientNetB6 63 | :members: 64 | .. autoclass:: EfficientNetB7 65 | :members: 66 | 67 | `ConvNeXt` 68 | ~~~~~~~~~~ 69 | .. autoclass:: ConvNeXt 70 | :members: 71 | .. autoclass:: ConvNeXtTiny 72 | :members: 73 | .. autoclass:: ConvNeXtSmall 74 | :members: 75 | .. autoclass:: ConvNeXtBase 76 | :members: 77 | .. autoclass:: ConvNeXtLarge 78 | :members: 79 | .. autoclass:: ConvNeXtXLarge 80 | :members: 81 | 82 | `Vision Transformer` 83 | ~~~~~~~~~~ 84 | .. autoclass:: ViT 85 | :members: 86 | .. autoclass:: VisionTransformer 87 | :members: 88 | 89 | `Mix Transformer` 90 | ~~~~~~~~~~ 91 | .. autoclass:: MiT 92 | :members: 93 | .. autoclass:: MiTB0 94 | :members: 95 | .. autoclass:: MiTB1 96 | :members: 97 | .. autoclass:: MiTB2 98 | :members: 99 | .. autoclass:: MiTB3 100 | :members: 101 | .. autoclass:: MiTB4 102 | :members: 103 | .. autoclass:: MiTB5 104 | :members: 105 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | FusionLab API Document 2 | ============= 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | encoders 8 | losses 9 | segmentation 10 | layers 11 | metrics 12 | datasets 13 | utils -------------------------------------------------------------------------------- /docs/source/layers.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/taipingeric/fusionlab 2 | 3 | .. _layers: 4 | 5 | Layers 6 | ============== 7 | Custom layers 8 | 9 | .. automodule:: fusionlab.layers 10 | .. currentmodule:: fusionlab.layers 11 | 12 | N-dimensional layer 13 | -------- 14 | 15 | `Convolution ND` 16 | ~~~~~~~~~~ 17 | .. autoclass:: ConvND 18 | :members: 19 | 20 | `Transposed Convolution ND` 21 | ~~~~~~~~~~ 22 | .. autoclass:: ConvT 23 | :members: 24 | 25 | `Upsample ND` 26 | ~~~~~~~~~~ 27 | .. autoclass:: Upsample 28 | :members: 29 | 30 | `BatchNorm ND` 31 | ~~~~~~~~~~ 32 | .. autoclass:: BatchNorm 33 | :members: 34 | 35 | `InstanceNorm ND` 36 | ~~~~~~~~~~ 37 | .. autoclass:: InstanceNorm 38 | :members: 39 | 40 | `Maxpool ND` 41 | ~~~~~~~~~~ 42 | .. autoclass:: MaxPool 43 | :members: 44 | 45 | `AvgPool ND` 46 | ~~~~~~~~~~ 47 | .. autoclass:: AvgPool 48 | :members: 49 | 50 | `Global Max Pool (Adaptive Max Pool) ND` 51 | ~~~~~~~~~~ 52 | .. autoclass:: AdaptiveMaxPool 53 | :members: 54 | 55 | `Global Avg Pool(Adaptive Avg Pool) ND` 56 | ~~~~~~~~~~ 57 | .. autoclass:: AdaptiveAvgPool 58 | :members: 59 | 60 | `Replication Padding ND` 61 | ~~~~~~~~~~ 62 | .. autoclass:: ReplicationPad 63 | :members: 64 | 65 | `Constant Padding ND` 66 | ~~~~~~~~~~ 67 | .. autoclass:: ConstantPad 68 | :members: 69 | 70 | `Conv|Norm|Act` 71 | ~~~~~~~~~~ 72 | .. autoclass:: ConvNormAct 73 | :members: 74 | 75 | `Rearrange` 76 | ~~~~~~~~~~ 77 | .. autoclass:: Rearrange 78 | :members: 79 | 80 | `Patch Embedding` 81 | ~~~~~~~~~~ 82 | .. autoclass:: PatchEmbedding 83 | :members: 84 | 85 | `Squeeze-Excitation` 86 | ~~~~~~~~~~ 87 | .. autoclass:: SEModule 88 | :members: 89 | .. autoclass:: TFSEModule 90 | :members: -------------------------------------------------------------------------------- /docs/source/losses.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/taipingeric/fusionlab 2 | 3 | .. _losses: 4 | 5 | Loss Function 6 | ============== 7 | This module contains the implementation of loss functions for semantic segmentation. 8 | 9 | .. automodule:: fusionlab.losses 10 | .. currentmodule:: fusionlab.losses 11 | 12 | `Dice Loss` 13 | ~~~~~~~~~~ 14 | .. autoclass:: DiceLoss 15 | :members: 16 | .. autoclass:: TFDiceLoss 17 | :members: 18 | 19 | `Dice Cross Entropy Loss` 20 | ~~~~~~~~~~ 21 | .. autoclass:: DiceCELoss 22 | :members: 23 | .. autoclass:: TFDiceCE 24 | :members: 25 | 26 | `IoU Loss` 27 | ~~~~~~~~~~ 28 | .. autoclass:: IoULoss 29 | :members: 30 | .. autoclass:: TFIoULoss 31 | :members: 32 | 33 | `Tversky Loss` 34 | ~~~~~~~~~~ 35 | .. autoclass:: TverskyLoss 36 | :members: 37 | .. autoclass:: TFTverskyLoss 38 | :members: 39 | -------------------------------------------------------------------------------- /docs/source/metrics.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/taipingeric/fusionlab 2 | 3 | .. _metrics: 4 | 5 | Metrics 6 | ============== 7 | This module contains the implementation of the metrics 8 | 9 | .. automodule:: fusionlab.metrics 10 | .. currentmodule:: fusionlab.metrics 11 | 12 | `Dice` 13 | ~~~~~~~~~~ 14 | .. autoclass:: DiceScore 15 | :members: 16 | .. autoclass:: JaccardScore 17 | :members: 18 | 19 | `IoU` 20 | ~~~~~~~~~~ 21 | .. autoclass:: IoUScore 22 | :members: -------------------------------------------------------------------------------- /docs/source/segmentation.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/taipingeric/fusionlab 2 | 3 | .. _segmentation: 4 | 5 | Segmentation Model 6 | ============== 7 | 8 | 9 | .. automodule:: fusionlab.segmentation 10 | .. currentmodule:: fusionlab.segmentation 11 | 12 | `UNet` 13 | ~~~~~~~~~~ 14 | 15 | .. autoclass:: UNet 16 | :members: 17 | .. autoclass:: TFUNet 18 | :members: 19 | 20 | `ResUNet` 21 | ~~~~~~~~~~ 22 | 23 | .. autoclass:: ResUNet 24 | :members: 25 | .. autoclass:: TFResUNet 26 | :members: 27 | 28 | `UNet++` 29 | ~~~~~~~~~~ 30 | 31 | .. autoclass:: UNet2plus 32 | :members: 33 | .. autoclass:: TFUNet2plus 34 | :members: 35 | 36 | `TransUNet` 37 | ~~~~~~~~~~ 38 | 39 | .. autoclass:: TransUNet 40 | :members: 41 | 42 | `UNETR` 43 | ~~~~~~~~~~ 44 | 45 | .. autoclass:: UNETR 46 | :members: 47 | 48 | `SegFormer` 49 | ~~~~~~~~~~ 50 | 51 | .. autoclass:: SegFormer 52 | :members: 53 | 54 | `HuggingFace Segmentation Model` 55 | ~~~~~~~~~~ 56 | 57 | .. autoclass:: HFSegmentationModel 58 | :members: -------------------------------------------------------------------------------- /docs/source/utils.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/taipingeric/fusionlab 2 | 3 | .. _utils: 4 | 5 | Utils 6 | ============== 7 | This module contains the loss functions used in the project. 8 | 9 | .. automodule:: fusionlab.utils 10 | .. currentmodule:: fusionlab.utils 11 | 12 | `Auto Padding` 13 | ~~~~~~~~~~ 14 | .. autofunction:: autopad 15 | 16 | `Make N tuple` 17 | ~~~~~~~~~~ 18 | .. autofunction:: make_ntuple 19 | 20 | `Show Class Tree` 21 | ~~~~~~~~~~ 22 | .. autofunction:: show_classtree 23 | 24 | `Plot channels` 25 | ~~~~~~~~~~ 26 | .. autofunction:: plot_channels 27 | 28 | `Convert LabelMe json to Mask` 29 | ~~~~~~~~~~ 30 | .. autofunction:: convert_labelme_json2mask -------------------------------------------------------------------------------- /fusionlab/__init__.py: -------------------------------------------------------------------------------- 1 | from .__version__ import __version__ 2 | 3 | # Check if PyTorch or TensorFlow is installed 4 | try: 5 | import torch 6 | torch_installed = True 7 | except ImportError: 8 | torch_installed = False 9 | 10 | try: 11 | import tensorflow as tf 12 | tf_installed = True 13 | except ImportError: 14 | tf_installed = False 15 | 16 | print(f"PyTorch installed: {torch_installed}") 17 | print(f"TensorFlow installed: {tf_installed}") 18 | 19 | BACKEND = { 20 | "torch": torch_installed, 21 | "tf": tf_installed 22 | } 23 | 24 | # check if no backend installed 25 | if not any(BACKEND.values()): 26 | print("None of supported backend installed") 27 | 28 | from . import ( 29 | functional, 30 | encoders, 31 | utils, 32 | datasets, 33 | layers, 34 | classification, 35 | segmentation, 36 | losses, 37 | trainers, 38 | configs, 39 | metrics, 40 | ) -------------------------------------------------------------------------------- /fusionlab/__version__.py: -------------------------------------------------------------------------------- 1 | VERSION = (0, 1, 13) 2 | 3 | __version__ = ".".join(map(str, VERSION)) 4 | -------------------------------------------------------------------------------- /fusionlab/classification/__init__.py: -------------------------------------------------------------------------------- 1 | from fusionlab import BACKEND 2 | 3 | if BACKEND['torch']: 4 | from .base import ( 5 | CNNClassificationModel, 6 | RNNClassificationModel, 7 | HFClassificationModel 8 | ) 9 | from .lstm import LSTMClassifier 10 | from .vgg import ( 11 | VGG16Classifier, 12 | VGG19Classifier 13 | ) 14 | from .base import HFClassificationModel -------------------------------------------------------------------------------- /fusionlab/classification/base.py: -------------------------------------------------------------------------------- 1 | # You won't directly use this class, but it will be used by the other classes. 2 | # Unless you want to build the class yourself, you may use this 3 | 4 | import torch.nn as nn 5 | 6 | class CNNClassificationModel(nn.Module): 7 | """ 8 | Base PyTorch class of the classification model with Encoder, Head for CNN 9 | """ 10 | def forward(self, x): 11 | # 1D signal x => [BATCH, CHANNEL, TIME] 12 | # 1D spectrum x => [BATCH, FREQUENCY, TIME] (single channel) 13 | # 2D spectrum x => [BATCH, CHANNEL, FREQUENCY, TIME] (multi channel) 14 | # 2D image x => [BATCH, CHANNEL, HEIGHT, WIDTH] 15 | # 3D volumetric x => [BATCH, CHANNEL, HEIGHT, WIDTH, DEPTH] 16 | features = self.encoder(x) # => [BATCH, 512, ...] 17 | features_agg = self.globalpooling(features) # => [BATCH, 512, 1, (1, (1))] 18 | output = self.head(features_agg.view(x.shape[0],-1)) # => [BATCH, NUM_CLS] 19 | return output 20 | 21 | class RNNClassificationModel(nn.Module): 22 | """ 23 | Base PyTorch class of the classification model with Encoder, Head for RNN 24 | """ 25 | def forward(self, x): 26 | # 1D signal x => [BATCH, CHANNEL, TIME] 27 | x = x.transpose(1,2) 28 | features, _ = self.encoder(x) # RNN will output feature and states 29 | output = self.head(features[:, -1, :]) 30 | return output 31 | 32 | class HFClassificationModel(nn.Module): 33 | """ 34 | Base Hugginface-pytoch model wrapper class of the classification model 35 | """ 36 | def __init__(self, model, 37 | num_cls=None, 38 | loss_fct=nn.CrossEntropyLoss()): 39 | super().__init__() 40 | self.net = model 41 | if 'num_cls' in model.__dict__.keys(): 42 | self.num_cls = model.num_cls 43 | else: 44 | self.num_cls = num_cls 45 | assert self.num_cls is not None, "num_cls is not defined" 46 | self.loss_fct = loss_fct 47 | def forward(self, x, labels=None): 48 | logits = self.net(x) # Forward pass the model 49 | if labels is not None: 50 | # logits => [BATCH, NUM_CLS] 51 | # labels => [BATCH] 52 | loss = self.loss_fct(logits.view(-1, self.num_cls), labels.view(-1)) # Calculate loss 53 | else: 54 | loss = None 55 | # return dictionary for hugginface trainer 56 | return {'loss':loss, 'logits':logits, 'hidden_states':None} 57 | 58 | # Test the function 59 | if __name__ == '__main__': 60 | import torch 61 | from fusionlab.classification import VGG16Classifier 62 | from fusionlab.classification import LSTMClassifier 63 | 64 | H = W = 224 65 | cout = 5 66 | inputs = torch.normal(0, 1, (1, 3, W)) 67 | # Test CNNClassification 68 | model = VGG16Classifier(3, cout, spatial_dims=1) 69 | hf_model = HFClassificationModel(model, cout) 70 | output = hf_model(inputs) 71 | print(output['logits'].shape) 72 | assert list(output.keys()) == ['loss', 'logits', 'hidden_states'] 73 | 74 | inputs = torch.normal(0, 1, (1, 3, H, W)) 75 | # Test CNNClassification 76 | model = VGG16Classifier(3, cout, spatial_dims=2) 77 | hf_model = HFClassificationModel(model, cout) 78 | output = hf_model(inputs) 79 | print(output['logits'].shape) 80 | assert list(output.keys()) == ['loss', 'logits', 'hidden_states'] 81 | 82 | inputs = torch.normal(0, 1, (1, 3, H)) 83 | model = LSTMClassifier(3, cout) 84 | hf_model = HFClassificationModel(model, cout) 85 | output = hf_model(inputs) 86 | print(output['logits'].shape) 87 | assert list(output.keys()) == ['loss', 'logits', 'hidden_states'] -------------------------------------------------------------------------------- /fusionlab/classification/lstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from fusionlab.classification.base import RNNClassificationModel 4 | 5 | 6 | class LSTMClassifier(RNNClassificationModel): 7 | def __init__(self, cin, cout, hidden_size=512): 8 | super().__init__() 9 | self.encoder = nn.LSTM(input_size=cin, hidden_size=hidden_size, batch_first=True) # define LSTM layer 10 | self.head = nn.Linear(hidden_size, cout) # define output head layer 11 | 12 | if __name__ == '__main__': 13 | inputs = torch.randn(1, 5, 3) # create random input tensor 14 | model = LSTMClassifier(cin=3, hidden_size=4, cout=2) # create model instance 15 | outputs = model(inputs) # pass input through model 16 | assert list(outputs.shape) == [1, 2] # check output shape is correct -------------------------------------------------------------------------------- /fusionlab/classification/vgg.py: -------------------------------------------------------------------------------- 1 | # VGG Classifier 2 | import torch 3 | from torch import nn 4 | from fusionlab.classification.base import CNNClassificationModel 5 | from fusionlab.encoders import VGG16, VGG19 6 | from fusionlab.layers import AdaptiveAvgPool 7 | 8 | class VGG16Classifier(CNNClassificationModel): 9 | def __init__(self, cin, num_cls, spatial_dims=2): 10 | super().__init__() 11 | self.num_cls = num_cls 12 | self.encoder = VGG16(cin, spatial_dims) # Create VGG16 instance 13 | self.globalpooling = AdaptiveAvgPool(spatial_dims, 1) 14 | self.head = nn.Linear(512, num_cls) 15 | 16 | class VGG19Classifier(CNNClassificationModel): 17 | def __init__(self, cin, num_cls, spatial_dims=2): 18 | super().__init__() 19 | self.num_cls = num_cls 20 | self.encoder = VGG19(cin, spatial_dims) # Create VGG16 instance 21 | self.globalpooling = AdaptiveAvgPool(spatial_dims, 1) 22 | self.head = nn.Linear(512, num_cls) 23 | 24 | 25 | if __name__ == '__main__': 26 | inputs = torch.randn(1, 3, 224) # create random input tensor 27 | model = VGG16Classifier(cin=3, num_cls=2, spatial_dims=1) # create model instance 28 | outputs = model(inputs) # pass input through model 29 | assert list(outputs.shape) == [1, 2] # check output shape is correct 30 | 31 | inputs = torch.randn(1, 3, 224) # create random input tensor 32 | model = VGG19Classifier(cin=3, num_cls=2, spatial_dims=1) # create model instance 33 | 34 | outputs = model(inputs) # pass input through model 35 | assert list(outputs.shape) == [1, 2] # check output shape is correct -------------------------------------------------------------------------------- /fusionlab/configs.py: -------------------------------------------------------------------------------- 1 | EPS = 1e-07 # epsilon -------------------------------------------------------------------------------- /fusionlab/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .muse import GEMuseXMLReader 2 | from .csvread import read_csv 3 | 4 | from fusionlab import BACKEND 5 | if BACKEND['torch']: 6 | from .a12lead import ECGClassificationDataset 7 | from .cinc2017 import ( 8 | ECGCSVClassificationDataset, 9 | convert_mat_to_csv, 10 | validate_data 11 | ) 12 | from .ludb import ( 13 | LUDBDataset, 14 | plot 15 | ) 16 | from .utils import ( 17 | download_file, 18 | HFDataset, 19 | LSTimeSegDataset, 20 | ) -------------------------------------------------------------------------------- /fusionlab/datasets/a12lead.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from torch.utils.data import Dataset # 導入PyTorch的資料處理工具 3 | from fusionlab.datasets import csvread # 讀取ECG signal 4 | 5 | 6 | # 做一個簡單的Dataset class 7 | class ECGClassificationDataset(Dataset): 8 | def __init__(self, 9 | annotation_file, # 標註檔案路徑 10 | data_dir, # ECG資料路徑 11 | transform=None, 12 | channels=None, 13 | class_names=None): # 資料轉換函式 14 | self.signal_paths = data_dir 15 | self.ecg_signals = pd.read_csv(annotation_file) # 讀取標註檔案 16 | self.transform = transform 17 | self.channels = channels 18 | self.class_map = {n: i for i, n in enumerate(class_names)} 19 | def __len__(self): 20 | return len(self.ecg_signals) # 回傳資料筆數 21 | def __getitem__(self, idx): 22 | entry = self.ecg_signals.iloc[idx] # 取得指定索引的資料 23 | data = csvread.read_csv(self.signal_paths+entry['filename']) # 讀取ECG資料 24 | label = entry['label'] # 取得標籤 25 | if self.transform: 26 | data = self.transform(data) # 資料轉換 27 | return data, label # 回傳資料和標籤 -------------------------------------------------------------------------------- /fusionlab/datasets/cinc2017.py: -------------------------------------------------------------------------------- 1 | from typing import * # Importing all types from typing module 2 | import os # Importing os module 3 | from glob import glob # Importing glob function from glob module 4 | from tqdm.auto import tqdm 5 | from fusionlab.datasets.utils import download_file # Importing tqdm function from tqdm.auto module 6 | 7 | from scipy import io # Importing scipy module 8 | import pandas as pd # Importing pandas module and aliasing it as pd 9 | 10 | import torch # Importing torch module 11 | from torch.utils.data import Dataset # Importing Dataset class from torch.utils.data module 12 | 13 | URL_ECG = "https://physionet.org/files/challenge-2017/1.0.0/training2017.zip" 14 | # URL_LABEL = "https://physionet.org/content/challenge-2017/1.0.0/REFERENCE-v3.csv" 15 | URL_LABEL = "https://physionet.org/files/challenge-2017/1.0.0/REFERENCE-v3.csv" 16 | 17 | 18 | class ECGCSVClassificationDataset(Dataset): 19 | 20 | def __init__(self, 21 | data_root, 22 | label_filename="REFERENCE-v3.csv", 23 | channels=["lead"], 24 | class_names=["N", "O", "A", "~"]): 25 | """ 26 | Args: 27 | data_root (str): root directory of the dataset 28 | label_filename (str): filename of the label file 29 | channels (list): list of target lead names 30 | class_names (list): list of class names for mapping class name to class id 31 | """ 32 | # read the label file and store it in a pandas dataframe 33 | self.df_label = pd.read_csv(os.path.join(data_root, label_filename), 34 | header=None, 35 | names=["pat", "label"]) 36 | # set the directory where the signal files are stored 37 | self.signal_dir = os.path.join(data_root, "csv") 38 | # get the paths of all the signal files and sort them 39 | self.signal_paths = sorted(glob(os.path.join(self.signal_dir, "*.csv"))) 40 | # create a dictionary to map class names to class ids 41 | self.class_map = {n: i for i, n in enumerate(class_names)} 42 | # set the target leads for the dataset 43 | self.channels = channels 44 | print("dataset class map: ", self.class_map) 45 | 46 | def __len__(self): 47 | return len(self.signal_paths) 48 | 49 | def __getitem__(self, idx): 50 | row = self.df_label.iloc[idx] 51 | signal_filename = row["pat"] + ".csv" 52 | signal_path = os.path.join(self.signal_dir, signal_filename) 53 | df_csv = pd.read_csv(signal_path) 54 | signal = df_csv["lead"].values 55 | 56 | class_name = row["label"] 57 | class_id = self.class_map[class_name] 58 | 59 | signal = torch.tensor(signal, dtype=torch.float) 60 | class_id = torch.tensor(class_id, dtype=torch.long) 61 | 62 | # preprocess 63 | signal = signal.unsqueeze(0) 64 | return signal, class_id 65 | 66 | def _check_validate(self): 67 | assert len(self.df_label) == len( 68 | self.signal_paths), "csv files and label files are not matched" 69 | 70 | 71 | def convert_mat_to_csv(root, target_dir="csv"): 72 | paths = glob(os.path.join(root, "training2017", 73 | "*.mat")) # get all paths of .mat files in the training2017 folder 74 | os.makedirs(os.path.join(root, target_dir), 75 | exist_ok=True) # create a new directory named target_dir in the root directory 76 | print("mat files: ", len(paths)) # print the number of .mat files found 77 | print("start to convert mat files to csv files" 78 | ) # print a message indicating the start of the conversion process 79 | for path in tqdm(paths): # iterate through each path in paths 80 | filename = os.path.basename(path) # get the filename from the path 81 | file_id = filename.split(".")[ 82 | 0] # get the file ID by splitting the filename at the "." and taking the first part 83 | target_filename = file_id + ".csv" # create the target filename by appending ".csv" to the file ID 84 | signal = io.loadmat(path)["val"][0] # load the .mat file and extract the "val" array 85 | df = pd.DataFrame(columns=["lead" 86 | ]) # create a new DataFrame with a single column named "lead" 87 | df["lead"] = signal # set the "lead" column to the "val" array 88 | df.to_csv( 89 | os.path.join(root, target_dir, target_filename) 90 | ) # save the DataFrame as a CSV file in the target directory with the target filename 91 | 92 | 93 | def validate_data(csv_dir, label_path): 94 | """ 95 | check if the number of csv files and label files are matched 96 | """ 97 | csv_paths = glob(os.path.join(csv_dir, "*.csv")) # get all csv files in the directory 98 | df_label = pd.read_csv(label_path, header=None, 99 | names=["pat", "label"]) # read the label file as a dataframe 100 | print("csv files: ", len(csv_paths)) # print the number of csv files 101 | print("label files: ", len(df_label)) # print the number of label files 102 | assert len(csv_paths) == len( 103 | df_label 104 | ), "csv files and label files are not matched" # check if the number of csv files and label files are equal 105 | return # return nothing 106 | 107 | 108 | if __name__ == "__main__": 109 | root = "data" 110 | try: 111 | validate_data("./data/csv", "./data/REFERENCE-v3.csv") 112 | except: 113 | print("validation failed, start to donwload and convert data") 114 | download_file(URL_ECG, root, extract=True) 115 | download_file(URL_LABEL, root, extract=False) 116 | convert_mat_to_csv(root) 117 | 118 | ds = ECGCSVClassificationDataset("./data", label_filename="REFERENCE-v3.csv") 119 | 120 | sig, label = ds[0] 121 | print(sig.shape, label) 122 | -------------------------------------------------------------------------------- /fusionlab/datasets/csv_sample.csv: -------------------------------------------------------------------------------- 1 | time, I, II, III 2 | 0.0, 1, 1, 1 3 | 0.01, 2, 2, 2 4 | 0.02, 3, 3, 3 -------------------------------------------------------------------------------- /fusionlab/datasets/csvread.py: -------------------------------------------------------------------------------- 1 | # Reader for csv files 2 | import numpy as np 3 | 4 | def read_csv(fname): 5 | readout = np.genfromtxt(fname, dtype=np.float32, skip_header=1, delimiter=",") 6 | return readout[:, 1:] 7 | 8 | if __name__ == "__main__": 9 | signal = read_csv('csv_sample.csv') 10 | assert list(signal.shape) == [3, 3] 11 | assert list(signal[:,0]) == [1., 2., 3.] -------------------------------------------------------------------------------- /fusionlab/encoders/README.md: -------------------------------------------------------------------------------- 1 | # Encoders 2 | 3 | | Name | PyTorch | Tensorflow | Paper| 4 | | :----------- | :----------------------------------------------- | :---------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------- | 5 | | ConvNeXt | ConvNeXtTiny \| Small \| Base \| Large \| XLarge | X | [A ConvNet for the 2020s (2020)](https://arxiv.org/abs/2201.03545)| 6 | | EfficientNet | EfficientNetB0~7 | X | [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks (2019)](https://arxiv.org/abs/1905.11946)| 7 | | ResNet | ResNet50 | TFResNet50 | [Deep Residual Learning for Image Recognition (2015)](https://arxiv.org/abs/1512.03385)| 8 | | InceptionV1 | InceptionNetV1 | TFInceptionNetV1 | [Going Deeper with Convolutions (2014)](https://arxiv.org/abs/1409.4842)| 9 | | VGG | VGG16, VGG19 | TFVGG16, TFVGG19 | [Very Deep Convolutional Networks for Large-Scale Image Recognition (2014)](https://arxiv.org/abs/1409.1556)| 10 | | AlexNet | AlexNet | TFAlexNet | [ImageNet Classification with Deep Convolutional Neural Networks (2012)](https://proceedings.neurips.cc/paper/2012/file/c399862d3b9d6b76c8436e924a68c45b-Paper.pdf)| 11 | -------------------------------------------------------------------------------- /fusionlab/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | from fusionlab import BACKEND 2 | if BACKEND['torch']: 3 | from .alexnet.alexnet import AlexNet 4 | from .vgg.vgg import VGG16, VGG19 5 | from .inceptionv1.inceptionv1 import InceptionNetV1 6 | from .resnetv1.resnetv1 import * 7 | from .efficientnet.efficientnet import ( 8 | EfficientNet, 9 | EfficientNetB0, 10 | EfficientNetB1, 11 | EfficientNetB2, 12 | EfficientNetB3, 13 | EfficientNetB4, 14 | EfficientNetB5, 15 | EfficientNetB6, 16 | EfficientNetB7 17 | ) 18 | from .convnext.convnext import ( 19 | ConvNeXt, 20 | ConvNeXtTiny, 21 | ConvNeXtSmall, 22 | ConvNeXtBase, 23 | ConvNeXtLarge, 24 | ConvNeXtXLarge 25 | ) 26 | from .vit.vit import ( 27 | ViT, 28 | VisionTransformer 29 | ) 30 | from .mit.mit import ( 31 | MiT, 32 | MiTB0, 33 | MiTB1, 34 | MiTB2, 35 | MiTB3, 36 | MiTB4, 37 | MiTB5, 38 | ) 39 | if BACKEND['tf']: 40 | from .alexnet.tfalexnet import TFAlexNet 41 | from .vgg.tfvgg import TFVGG16, TFVGG19 42 | from .inceptionv1.tfinceptionv1 import TFInceptionNetV1 43 | from .resnetv1.tfresnetv1 import TFResNet50V1 44 | -------------------------------------------------------------------------------- /fusionlab/encoders/alexnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/encoders/alexnet/__init__.py -------------------------------------------------------------------------------- /fusionlab/encoders/alexnet/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from fusionlab.layers.factories import ConvND, MaxPool 4 | 5 | # Official pytorch ref: https://github.com/pytorch/vision/blob/main/torchvision/models/alexnet.py 6 | class AlexNet(nn.Module): 7 | def __init__(self, cin=3, spatial_dims=2): 8 | super().__init__() 9 | self.features = nn.Sequential( 10 | ConvND(spatial_dims, cin, 64, kernel_size=11, stride=4, padding=2), 11 | nn.ReLU(inplace=True), 12 | MaxPool(spatial_dims, kernel_size=3, stride=2), 13 | ConvND(spatial_dims, 64, 192, kernel_size=5, padding=2), 14 | nn.ReLU(inplace=True), 15 | MaxPool(spatial_dims, kernel_size=3, stride=2), 16 | ConvND(spatial_dims, 192, 384, kernel_size=3, padding=1), 17 | nn.ReLU(inplace=True), 18 | ConvND(spatial_dims, 384, 256, kernel_size=3, padding=1), 19 | nn.ReLU(inplace=True), 20 | ConvND(spatial_dims, 256, 256, kernel_size=3, padding=1), 21 | nn.ReLU(inplace=True), 22 | MaxPool(spatial_dims, kernel_size=3, stride=2), 23 | ) 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.features(x) 27 | 28 | 29 | 30 | if __name__ == '__main__': 31 | img_size = 224 32 | inputs = torch.normal(0, 1, (1, 3, img_size, img_size)) 33 | output = AlexNet(3, spatial_dims=2)(inputs) 34 | print(output.shape) 35 | assert list(output.shape) == [1, 256, 6, 6] 36 | 37 | inputs = torch.normal(0, 1, (1, 3, img_size)) 38 | output = AlexNet(3, spatial_dims=1)(inputs) 39 | print(output.shape) 40 | assert list(output.shape) == [1, 256, 6] 41 | 42 | img_size = 128 43 | inputs = torch.normal(0, 1, (1, 3, img_size, img_size, img_size)) 44 | output = AlexNet(3, spatial_dims=3)(inputs) 45 | print(output.shape) 46 | assert list(output.shape) == [1, 256, 3, 3, 3] 47 | -------------------------------------------------------------------------------- /fusionlab/encoders/alexnet/tfalexnet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class TFAlexNet(tf.keras.Model): 5 | def __init__(self): 6 | super().__init__() 7 | self.features = tf.keras.Sequential([ 8 | tf.keras.layers.ZeroPadding2D(2), 9 | tf.keras.layers.Conv2D(64, kernel_size=11, strides=4), 10 | tf.keras.layers.ReLU(), 11 | tf.keras.layers.MaxPool2D(pool_size=3, strides=2), 12 | tf.keras.layers.Conv2D(192, kernel_size=5, padding='same'), 13 | tf.keras.layers.ReLU(), 14 | tf.keras.layers.MaxPool2D(pool_size=3, strides=2), 15 | tf.keras.layers.Conv2D(384, kernel_size=3, padding='same'), 16 | tf.keras.layers.ReLU(), 17 | tf.keras.layers.Conv2D(256, kernel_size=3, padding='same'), 18 | tf.keras.layers.ReLU(), 19 | tf.keras.layers.Conv2D(256, kernel_size=3, padding='same'), 20 | tf.keras.layers.ReLU(), 21 | tf.keras.layers.MaxPool2D(pool_size=3, strides=2) 22 | ]) 23 | 24 | def call(self, inputs): 25 | return self.features(inputs) 26 | 27 | 28 | if __name__ == '__main__': 29 | inputs = tf.random.normal((1, 224, 224, 3)) 30 | output = TFAlexNet()(inputs) 31 | shape = output.shape 32 | print(shape) 33 | assert shape[1:3] == [6, 6] 34 | -------------------------------------------------------------------------------- /fusionlab/encoders/convnext/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/encoders/convnext/__init__.py -------------------------------------------------------------------------------- /fusionlab/encoders/efficientnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/encoders/efficientnet/__init__.py -------------------------------------------------------------------------------- /fusionlab/encoders/inceptionv1/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/encoders/inceptionv1/__init__.py -------------------------------------------------------------------------------- /fusionlab/encoders/inceptionv1/inceptionv1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from fusionlab.layers.factories import ConvND, MaxPool 5 | from fusionlab.utils import autopad 6 | 7 | # ref: https://arxiv.org/abs/1409.4842 8 | # Going Deeper with Convolutions 9 | class ConvBlock(nn.Module): 10 | def __init__(self, cin, cout, kernel_size=3, spatial_dims=2, stride=1): 11 | super().__init__() 12 | self.conv = ConvND(spatial_dims, cin, cout, kernel_size, stride, padding=autopad(kernel_size)) 13 | self.act = nn.ReLU(inplace=True) 14 | 15 | def forward(self, x): 16 | x = self.conv(x) 17 | x = self.act(x) 18 | return x 19 | 20 | 21 | class InceptionBlock(nn.Module): 22 | def __init__(self, cin, dim0, dim1, dim2, dim3, spatial_dims=2): 23 | super().__init__() 24 | self.branch1 = ConvBlock(cin, dim0, 3, spatial_dims) 25 | self.branch3 = nn.Sequential(ConvBlock(cin, dim1[0], 1, spatial_dims), 26 | ConvBlock(dim1[0], dim1[1], 3, spatial_dims)) 27 | self.branch5 = nn.Sequential(ConvBlock(cin, dim2[0], 1, spatial_dims), 28 | ConvBlock(dim2[0], dim2[1], 5, spatial_dims)) 29 | self.pool = nn.Sequential(MaxPool(spatial_dims, 3, 1, autopad(3)), 30 | ConvBlock(cin, dim3, 3,spatial_dims)) 31 | 32 | def forward(self, x): 33 | x0 = self.branch1(x) 34 | x1 = self.branch3(x) 35 | x2 = self.branch5(x) 36 | x3 = self.pool(x) 37 | x = torch.cat((x0, x1, x2, x3), 1) 38 | return x 39 | 40 | 41 | class InceptionNetV1(nn.Module): 42 | def __init__(self, cin=3, spatial_dims=2): 43 | super().__init__() 44 | self.stem = nn.Sequential( 45 | ConvBlock(cin, 64, 7, spatial_dims, stride=2), 46 | MaxPool(spatial_dims, 3, 2, padding=autopad(3)), 47 | ConvBlock(64, 192, 3, spatial_dims), 48 | MaxPool(spatial_dims, 3, 2, padding=autopad(3)), 49 | ) 50 | self.incept3a = InceptionBlock(192, 64, (96, 128), (16, 32), 32, spatial_dims) 51 | self.incept3b = InceptionBlock(256, 128, (128, 192), (32, 96), 64, spatial_dims) 52 | self.pool3 = MaxPool(spatial_dims, 3, 2, padding=autopad(3)) 53 | self.incept4a = InceptionBlock(480, 192, (96, 208), (16, 48), 64, spatial_dims) 54 | self.incept4b = InceptionBlock(512, 160, (112, 224), (24, 64), 64, spatial_dims) 55 | self.incept4c = InceptionBlock(512, 128, (128, 256), (24, 64), 64, spatial_dims) 56 | self.incept4d = InceptionBlock(512, 112, (144, 288), (32, 64), 64, spatial_dims) 57 | self.incept4e = InceptionBlock(528, 256, (160, 320), (32, 128), 128, spatial_dims) 58 | self.pool4 = MaxPool(spatial_dims, 3, 2, padding=autopad(3)) 59 | self.incept5a = InceptionBlock(832, 256, (160, 320), (32, 128), 128, spatial_dims) 60 | self.incept5b = InceptionBlock(832, 384, (192, 384), (48, 128), 128, spatial_dims) 61 | 62 | def forward(self, x): 63 | x = self.stem(x) 64 | x = self.incept3a(x) 65 | x = self.incept3b(x) 66 | x = self.pool3(x) 67 | x = self.incept4a(x) 68 | x = self.incept4b(x) 69 | x = self.incept4c(x) 70 | x = self.incept4d(x) 71 | x = self.incept4e(x) 72 | x = self.pool4(x) 73 | x = self.incept5a(x) 74 | x = self.incept5b(x) 75 | return x 76 | 77 | if __name__ == "__main__": 78 | inputs = torch.normal(0, 1, (1, 3, 224, 224)) 79 | outputs = InceptionBlock(3, 64, (96, 128), (16, 32), 32)(inputs) 80 | print(outputs.shape) 81 | assert list(outputs.shape) == [1, 256, 224, 224] 82 | 83 | outputs = InceptionNetV1()(inputs) 84 | print(outputs.shape) 85 | assert list(outputs.shape) == [1, 1024, 7, 7] 86 | 87 | inputs = torch.normal(0, 1, (1, 3, 224)) 88 | outputs = InceptionNetV1(spatial_dims=1)(inputs) 89 | print(outputs.shape) 90 | assert list(outputs.shape) == [1, 1024, 7] 91 | -------------------------------------------------------------------------------- /fusionlab/encoders/inceptionv1/tfinceptionv1.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import Model, layers, Sequential 3 | 4 | # ref: https://arxiv.org/abs/1409.4842 5 | # Going Deeper with Convolutions 6 | 7 | 8 | class ConvBlock(Model): 9 | def __init__(self, cout, kernel_size=3, stride=1): 10 | super().__init__() 11 | self.conv = layers.Conv2D(cout, kernel_size, stride, padding="same") 12 | self.act = layers.ReLU() 13 | 14 | def call(self, inputs): 15 | x = self.conv(inputs) 16 | x = self.act(x) 17 | return x 18 | 19 | 20 | class InceptionBlock(Model): 21 | def __init__(self, dim0, dim1, dim2, dim3): 22 | super().__init__() 23 | self.branch1 = ConvBlock(dim0, 1) 24 | self.branch3 = Sequential([ 25 | ConvBlock(dim1[0], 1), 26 | ConvBlock(dim1[1], 3) 27 | ]) 28 | self.branch5 = Sequential([ 29 | ConvBlock(dim2[0], 1), 30 | ConvBlock(dim2[1], 5)] 31 | ) 32 | self.pool = Sequential([ 33 | layers.MaxPool2D(3, 1, padding='same'), 34 | ConvBlock(dim3, 1) 35 | ]) 36 | 37 | def call(self, inputs): 38 | x = inputs 39 | x0 = self.branch1(x) 40 | x1 = self.branch3(x) 41 | x2 = self.branch5(x) 42 | x3 = self.pool(x) 43 | x = tf.concat([x0, x1, x2, x3], axis=-1) 44 | return x 45 | 46 | 47 | class TFInceptionNetV1(Model): 48 | def __init__(self): 49 | super().__init__() 50 | self.stem = Sequential([ 51 | ConvBlock(64, 7, stride=2), 52 | layers.MaxPool2D(3, 2, padding='same'), 53 | ConvBlock(192, 3), 54 | layers.MaxPool2D(3, 2, padding='same'), 55 | ]) 56 | self.incept3a = InceptionBlock(64, (96, 128), (16, 32), 32) 57 | self.incept3b = InceptionBlock(128, (128, 192), (32, 96), 64) 58 | self.pool3 = layers.MaxPool2D(3, 2, padding='same') 59 | self.incept4a = InceptionBlock(192, (96, 208), (16, 48), 64) 60 | self.incept4b = InceptionBlock(160, (112, 224), (24, 64), 64) 61 | self.incept4c = InceptionBlock(128, (128, 256), (24, 64), 64) 62 | self.incept4d = InceptionBlock(112, (144, 288), (32, 64), 64) 63 | self.incept4e = InceptionBlock(256, (160, 320), (32, 128), 128) 64 | self.pool4 = layers.MaxPool2D(3, 2, padding='same') 65 | self.incept5a = InceptionBlock(256, (160, 320), (32, 128), 128) 66 | self.incept5b = InceptionBlock(384, (192, 384), (48, 128), 128) 67 | 68 | def call(self, inputs): 69 | x = self.stem(inputs) 70 | x = self.incept3a(x) 71 | x = self.incept3b(x) 72 | x = self.pool3(x) 73 | x = self.incept4a(x) 74 | x = self.incept4b(x) 75 | x = self.incept4c(x) 76 | x = self.incept4d(x) 77 | x = self.incept4e(x) 78 | x = self.pool4(x) 79 | x = self.incept5a(x) 80 | x = self.incept5b(x) 81 | return x 82 | 83 | 84 | if __name__ == "__main__": 85 | inputs = tf.random.normal((1, 224, 224, 3)) 86 | output = InceptionBlock(64, (96, 128), (16, 32), 32)(inputs) 87 | shape = output.shape 88 | print("InceptionBlock", shape) 89 | assert shape == (1, 224, 224, 256) 90 | 91 | output = TFInceptionNetV1()(inputs) 92 | shape = output.shape 93 | print("TFInceptionNetV1", shape) 94 | assert shape == (1, 7, 7, 1024) 95 | -------------------------------------------------------------------------------- /fusionlab/encoders/mit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/encoders/mit/__init__.py -------------------------------------------------------------------------------- /fusionlab/encoders/mit/mit.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | import torch 3 | import torch.nn as nn 4 | from fusionlab.layers import ( 5 | SRAttention, 6 | DropPath, 7 | ) 8 | 9 | class PatchEmbed(nn.Module): 10 | def __init__( 11 | self, 12 | in_channels=3, 13 | out_channels=32, 14 | patch_size=7, 15 | stride=4 16 | ): 17 | super().__init__() 18 | self.proj = nn.Conv2d( 19 | in_channels, 20 | out_channels, 21 | patch_size, stride, 22 | patch_size//2) 23 | self.norm = nn.LayerNorm(out_channels) 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | x = self.proj(x) 27 | _, _, h, w = x.shape 28 | x = x.flatten(2).transpose(1, 2) 29 | x = self.norm(x) 30 | return x, h, w 31 | 32 | class MiTBlock(nn.Module): 33 | def __init__(self, dim, head, spatio_reduction_ratio=1, drop_path_rate=0.): 34 | super().__init__() 35 | self.norm1 = nn.LayerNorm(dim) 36 | self.attn = SRAttention(dim, head, spatio_reduction_ratio) 37 | self.drop_path = DropPath(drop_path_rate, "row") if drop_path_rate > 0. else nn.Identity() 38 | self.norm2 = nn.LayerNorm(dim) 39 | self.mlp = MLP(dim, int(dim*4)) 40 | 41 | def forward(self, x: torch.Tensor, h, w) -> torch.Tensor: 42 | x = x + self.drop_path(self.attn(self.norm1(x), h, w)) 43 | x = x + self.drop_path(self.mlp(self.norm2(x), h, w)) 44 | return x 45 | 46 | class MLP(nn.Module): 47 | def __init__(self, c1, c2): 48 | super().__init__() 49 | self.fc1 = nn.Linear(c1, c2) 50 | self.dwconv = nn.Conv2d(c2, c2, 3, 1, 1, groups=c2) 51 | self.fc2 = nn.Linear(c2, c1) 52 | self.act = nn.GELU() 53 | 54 | def forward(self, x: torch.Tensor, h, w) -> torch.Tensor: 55 | x = self.fc1(x) 56 | B, _, C = x.shape 57 | x = x.transpose(1, 2).view(B, C, h, w) 58 | x = self.dwconv(x) 59 | x = x.flatten(2).transpose(1, 2) 60 | x = self.act(x) 61 | x = self.fc2(x) 62 | return x 63 | 64 | class MiT(nn.Module): 65 | """ 66 | Mix Transformer 67 | 68 | source code: https://github.com/sithu31296/semantic-segmentation/blob/main/semseg/models/backbones/mit.py 69 | """ 70 | def __init__( 71 | self, 72 | in_channels: int = 3, 73 | embed_dims: Sequence[int] = [32, 64, 160, 256], 74 | depths: Sequence[int] = [2, 2, 2, 2] 75 | ): 76 | super().__init__() 77 | drop_path_rate = 0.1 78 | self.channels = embed_dims 79 | 80 | # patch_embed 81 | self.patch_embed1 = PatchEmbed(in_channels, embed_dims[0], 7, 4) 82 | self.patch_embed2 = PatchEmbed(embed_dims[0], embed_dims[1], 3, 2) 83 | self.patch_embed3 = PatchEmbed(embed_dims[1], embed_dims[2], 3, 2) 84 | self.patch_embed4 = PatchEmbed(embed_dims[2], embed_dims[3], 3, 2) 85 | 86 | drop_path_rate = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 87 | 88 | cur = 0 89 | self.block1 = nn.ModuleList([MiTBlock(embed_dims[0], 1, 8, drop_path_rate[cur+i]) for i in range(depths[0])]) 90 | self.norm1 = nn.LayerNorm(embed_dims[0]) 91 | 92 | cur += depths[0] 93 | self.block2 = nn.ModuleList([MiTBlock(embed_dims[1], 2, 4, drop_path_rate[cur+i]) for i in range(depths[1])]) 94 | self.norm2 = nn.LayerNorm(embed_dims[1]) 95 | 96 | cur += depths[1] 97 | self.block3 = nn.ModuleList([MiTBlock(embed_dims[2], 5, 2, drop_path_rate[cur+i]) for i in range(depths[2])]) 98 | self.norm3 = nn.LayerNorm(embed_dims[2]) 99 | 100 | cur += depths[2] 101 | self.block4 = nn.ModuleList([MiTBlock(embed_dims[3], 8, 1, drop_path_rate[cur+i]) for i in range(depths[3])]) 102 | self.norm4 = nn.LayerNorm(embed_dims[3]) 103 | 104 | 105 | def forward(self, x: torch.Tensor, return_features=False) -> torch.Tensor: 106 | bs = x.shape[0] 107 | 108 | # stage 1 109 | x, h, w = self.patch_embed1(x) 110 | for blk in self.block1: 111 | x = blk(x, h, w) 112 | x1 = self.norm1(x).reshape(bs, h, w, -1).permute(0, 3, 1, 2) 113 | 114 | # stage 2 115 | x, h, w = self.patch_embed2(x1) 116 | for blk in self.block2: 117 | x = blk(x, h, w) 118 | x2 = self.norm2(x).reshape(bs, h, w, -1).permute(0, 3, 1, 2) 119 | 120 | # stage 3 121 | x, h, w = self.patch_embed3(x2) 122 | for blk in self.block3: 123 | x = blk(x, h, w) 124 | x3 = self.norm3(x).reshape(bs, h, w, -1).permute(0, 3, 1, 2) 125 | 126 | # stage 4 127 | x, h, w = self.patch_embed4(x3) 128 | for blk in self.block4: 129 | x = blk(x, h, w) 130 | x4 = self.norm4(x).reshape(bs, h, w, -1).permute(0, 3, 1, 2) 131 | 132 | if return_features: 133 | return x4, [x1, x2, x3, x4] 134 | else: 135 | return x4 136 | 137 | class MiTB0(MiT): 138 | def __init__(self, in_channels: int = 3): 139 | super().__init__(in_channels, [32, 64, 160, 256], [2, 2, 2, 2]) 140 | 141 | class MiTB1(MiT): 142 | def __init__(self, in_channels: int = 3): 143 | super().__init__(in_channels, [64, 128, 320, 512], [2, 2, 2, 2]) 144 | 145 | class MiTB2(MiT): 146 | def __init__(self, in_channels: int = 3): 147 | super().__init__(in_channels, [64, 128, 320, 512], [3, 4, 6, 3]) 148 | 149 | class MiTB3(MiT): 150 | def __init__(self, in_channels: int = 3): 151 | super().__init__(in_channels, [64, 128, 320, 512], [3, 4, 18, 3]) 152 | 153 | class MiTB4(MiT): 154 | def __init__(self, in_channels: int = 3): 155 | super().__init__(in_channels, [64, 128, 320, 512], [3, 8, 27, 3]) 156 | 157 | class MiTB5(MiT): 158 | def __init__(self, in_channels: int = 3): 159 | super().__init__(in_channels, [64, 128, 320, 512], [3, 6, 40, 3]) 160 | 161 | if __name__ == '__main__': 162 | inputs = torch.randn(1, 3, 128, 128) 163 | for i in range(6): 164 | # model = MiT(in_channels=3) 165 | model = eval(f'MiTB{i}')(in_channels=3) 166 | outputs = model(inputs) 167 | print(outputs.shape) -------------------------------------------------------------------------------- /fusionlab/encoders/resnetv1/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/encoders/resnetv1/__init__.py -------------------------------------------------------------------------------- /fusionlab/encoders/resnetv1/tfresnetv1.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import Model, Sequential, layers 3 | 4 | # ResNet50 5 | # Ref: 6 | # https://github.com/keras-team/keras-applications/blob/master/keras_applications/resnet50.py 7 | # https://github.com/raghakot/keras-resnet/blob/master/README.md 8 | 9 | 10 | class Identity(layers.Layer): 11 | def __init__(self): 12 | super(Identity, self).__init__() 13 | 14 | def call(self, inputs, training=None): 15 | return inputs 16 | 17 | 18 | class ConvBlock(Model): 19 | def __init__(self, cout, kernel_size=3, stride=1, activation=True, padding=True): 20 | super().__init__() 21 | self.conv = layers.Conv2D(cout, kernel_size, stride, 22 | padding='same' if padding else 'valid') 23 | self.bn = layers.BatchNormalization() 24 | self.act = layers.ReLU() if activation else Identity() 25 | 26 | def call(self, inputs, training=None): 27 | x = self.conv(inputs) 28 | x = self.bn(x, training) 29 | x = self.act(x) 30 | return x 31 | 32 | 33 | class Bottleneck(Model): 34 | def __init__(self, dims, kernel_size=3, stride=None): 35 | super().__init__() 36 | dim1, dim2, dim3 = dims 37 | self.conv1 = ConvBlock(dim1, kernel_size=1) 38 | self.conv2 = ConvBlock(dim2, kernel_size=kernel_size, 39 | stride=stride if stride else 1) 40 | self.conv3 = ConvBlock(dim3, kernel_size=1, activation=False) 41 | self.act = layers.ReLU() 42 | self.skip = Identity() if not stride else ConvBlock(dim3, 43 | kernel_size=1, 44 | stride=stride, 45 | activation=False) 46 | 47 | def call(self, inputs, training=None): 48 | identity = self.skip(inputs, training) 49 | 50 | x = self.conv1(inputs, training) 51 | x = self.conv2(x, training) 52 | x = self.conv3(x, training) 53 | 54 | x += identity 55 | x = self.act(x) 56 | return x 57 | 58 | 59 | class TFResNet50V1(Model): 60 | def __init__(self): 61 | super(TFResNet50V1, self).__init__() 62 | self.conv1 = Sequential([ 63 | ConvBlock(64, 7, stride=2), 64 | layers.MaxPool2D(3, strides=2, padding='same'), 65 | ]) 66 | self.conv2 = Sequential([ 67 | Bottleneck([64, 64, 256], 3, stride=1), 68 | Bottleneck([64, 64, 256], 3), 69 | Bottleneck([64, 64, 256], 3), 70 | ]) 71 | self.conv3 = Sequential([ 72 | Bottleneck([128, 128, 512], 3, stride=2), 73 | Bottleneck([128, 128, 512], 3), 74 | Bottleneck([128, 128, 512], 3), 75 | Bottleneck([128, 128, 512], 3), 76 | ]) 77 | self.conv4 = Sequential([ 78 | Bottleneck([256, 256, 1024], 3, stride=2), 79 | Bottleneck([256, 256, 1024], 3), 80 | Bottleneck([256, 256, 1024], 3), 81 | Bottleneck([256, 256, 1024], 3), 82 | Bottleneck([256, 256, 1024], 3), 83 | Bottleneck([256, 256, 1024], 3), 84 | ]) 85 | self.conv5 = Sequential([ 86 | Bottleneck([512, 512, 2048], 3, stride=2), 87 | Bottleneck([512, 512, 2048], 3), 88 | Bottleneck([512, 512, 2048], 3), 89 | ]) 90 | 91 | def call(self, inputs, training=None): 92 | x = self.conv1(inputs, training) 93 | x = self.conv2(x, training) 94 | x = self.conv3(x, training) 95 | x = self.conv4(x, training) 96 | x = self.conv5(x, training) 97 | return x 98 | 99 | 100 | if __name__ == '__main__': 101 | inputs = tf.random.normal((1, 224, 224, 128)) 102 | output = Bottleneck([64, 64, 128])(inputs) 103 | shape = output.shape 104 | print("Bottleneck", shape) 105 | assert shape == (1, 224, 224, 128) 106 | 107 | output = Bottleneck([128, 128, 256], stride=1)(inputs) 108 | shape = output.shape 109 | print("Bottleneck first conv for aligh dims", shape) 110 | assert shape == (1, 224, 224, 256) 111 | 112 | output = Bottleneck([64, 64, 128], stride=2)(inputs) 113 | shape = output.shape 114 | print("Bottleneck downsample", shape) 115 | assert shape == (1, 112, 112, 128) 116 | 117 | output = Identity()(inputs) 118 | shape = output.shape 119 | print("Identity", shape) 120 | assert shape == (1, 224, 224, 128) 121 | 122 | 123 | output = TFResNet50V1()(inputs) 124 | shape = output.shape 125 | print("TFResNet50V1", shape) 126 | assert shape == (1, 7, 7, 2048) 127 | 128 | 129 | 130 | 131 | -------------------------------------------------------------------------------- /fusionlab/encoders/vgg/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/encoders/vgg/__init__.py -------------------------------------------------------------------------------- /fusionlab/encoders/vgg/tfvgg.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | # Official pytorch ref: https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py 4 | class TFVGG16(tf.keras.Model): 5 | def __init__(self): 6 | super().__init__() 7 | ksize = 3 8 | self.features = tf.keras.Sequential([ 9 | tf.keras.layers.Conv2D(64, ksize, padding='same'), 10 | tf.keras.layers.ReLU(), 11 | tf.keras.layers.Conv2D(64, ksize, padding='same'), 12 | tf.keras.layers.ReLU(), 13 | tf.keras.layers.MaxPool2D(), 14 | 15 | tf.keras.layers.Conv2D(128, ksize, padding='same'), 16 | tf.keras.layers.ReLU(), 17 | tf.keras.layers.Conv2D(128, ksize, padding='same'), 18 | tf.keras.layers.ReLU(), 19 | tf.keras.layers.MaxPool2D(), 20 | 21 | tf.keras.layers.Conv2D(256, ksize, padding='same'), 22 | tf.keras.layers.ReLU(), 23 | tf.keras.layers.Conv2D(256, ksize, padding='same'), 24 | tf.keras.layers.ReLU(), 25 | tf.keras.layers.Conv2D(256, ksize, padding='same'), 26 | tf.keras.layers.ReLU(), 27 | tf.keras.layers.MaxPool2D(), 28 | 29 | tf.keras.layers.Conv2D(512, ksize, padding='same'), 30 | tf.keras.layers.ReLU(), 31 | tf.keras.layers.Conv2D(512, ksize, padding='same'), 32 | tf.keras.layers.ReLU(), 33 | tf.keras.layers.Conv2D(512, ksize, padding='same'), 34 | tf.keras.layers.ReLU(), 35 | tf.keras.layers.MaxPool2D(), 36 | 37 | tf.keras.layers.Conv2D(512, ksize, padding='same'), 38 | tf.keras.layers.ReLU(), 39 | tf.keras.layers.Conv2D(512, ksize, padding='same'), 40 | tf.keras.layers.ReLU(), 41 | tf.keras.layers.Conv2D(512, ksize, padding='same'), 42 | tf.keras.layers.ReLU(), 43 | tf.keras.layers.MaxPool2D(), 44 | ]) 45 | 46 | def call(self, inputs): 47 | return self.features(inputs) 48 | 49 | 50 | class TFVGG19(tf.keras.Model): 51 | def __init__(self): 52 | super().__init__() 53 | ksize = 3 54 | self.features = tf.keras.Sequential([ 55 | tf.keras.layers.Conv2D(64, ksize, padding='same'), 56 | tf.keras.layers.ReLU(), 57 | tf.keras.layers.Conv2D(64, ksize, padding='same'), 58 | tf.keras.layers.ReLU(), 59 | tf.keras.layers.MaxPool2D(), 60 | 61 | tf.keras.layers.Conv2D(128, ksize, padding='same'), 62 | tf.keras.layers.ReLU(), 63 | tf.keras.layers.Conv2D(128, ksize, padding='same'), 64 | tf.keras.layers.ReLU(), 65 | tf.keras.layers.MaxPool2D(), 66 | 67 | tf.keras.layers.Conv2D(256, ksize, padding='same'), 68 | tf.keras.layers.ReLU(), 69 | tf.keras.layers.Conv2D(256, ksize, padding='same'), 70 | tf.keras.layers.ReLU(), 71 | tf.keras.layers.Conv2D(256, ksize, padding='same'), 72 | tf.keras.layers.ReLU(), 73 | tf.keras.layers.Conv2D(256, ksize, padding='same'), 74 | tf.keras.layers.ReLU(), 75 | tf.keras.layers.MaxPool2D(), 76 | 77 | tf.keras.layers.Conv2D(512, ksize, padding='same'), 78 | tf.keras.layers.ReLU(), 79 | tf.keras.layers.Conv2D(512, ksize, padding='same'), 80 | tf.keras.layers.ReLU(), 81 | tf.keras.layers.Conv2D(512, ksize, padding='same'), 82 | tf.keras.layers.ReLU(), 83 | tf.keras.layers.Conv2D(512, ksize, padding='same'), 84 | tf.keras.layers.ReLU(), 85 | tf.keras.layers.MaxPool2D(), 86 | 87 | tf.keras.layers.Conv2D(512, ksize, padding='same'), 88 | tf.keras.layers.ReLU(), 89 | tf.keras.layers.Conv2D(512, ksize, padding='same'), 90 | tf.keras.layers.ReLU(), 91 | tf.keras.layers.Conv2D(512, ksize, padding='same'), 92 | tf.keras.layers.ReLU(), 93 | tf.keras.layers.Conv2D(512, ksize, padding='same'), 94 | tf.keras.layers.ReLU(), 95 | tf.keras.layers.MaxPool2D(), 96 | ]) 97 | 98 | def call(self, inputs): 99 | return self.features(inputs) 100 | 101 | 102 | if __name__ == '__main__': 103 | # VGG16 104 | inputs = tf.random.normal((1, 224, 224, 3)) 105 | output = TFVGG16()(inputs) 106 | shape = output.shape 107 | assert shape[1:3] == [7, 7] 108 | 109 | # VGG19 110 | inputs = tf.random.normal((1, 224, 224, 3)) 111 | output = TFVGG19()(inputs) 112 | shape = output.shape 113 | assert shape[1:3] == [7, 7] -------------------------------------------------------------------------------- /fusionlab/encoders/vgg/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from fusionlab.layers import ConvND, MaxPool 4 | 5 | # Official pytorch ref: https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py 6 | class VGG16(nn.Module): 7 | def __init__(self, cin=3, spatial_dims=2): 8 | super().__init__() 9 | ksize = 3 10 | self.features = nn.Sequential( 11 | ConvND(spatial_dims, cin, 64, ksize, padding=1), 12 | nn.ReLU(inplace=True), 13 | ConvND(spatial_dims, 64, 64, ksize, padding=1), 14 | nn.ReLU(inplace=True), 15 | MaxPool(spatial_dims, kernel_size=2, stride=2), 16 | 17 | ConvND(spatial_dims, 64, 128, ksize, padding=1), 18 | nn.ReLU(inplace=True), 19 | ConvND(spatial_dims, 128, 128, ksize, padding=1), 20 | nn.ReLU(inplace=True), 21 | MaxPool(spatial_dims, kernel_size=2, stride=2), 22 | 23 | ConvND(spatial_dims, 128, 256, ksize, padding=1), 24 | nn.ReLU(inplace=True), 25 | ConvND(spatial_dims, 256, 256, ksize, padding=1), 26 | nn.ReLU(inplace=True), 27 | ConvND(spatial_dims, 256, 256, ksize, padding=1), 28 | nn.ReLU(inplace=True), 29 | MaxPool(spatial_dims, kernel_size=2, stride=2), 30 | 31 | ConvND(spatial_dims, 256, 512, ksize, padding=1), 32 | nn.ReLU(inplace=True), 33 | ConvND(spatial_dims, 512, 512, ksize, padding=1), 34 | nn.ReLU(inplace=True), 35 | ConvND(spatial_dims, 512, 512, ksize, padding=1), 36 | nn.ReLU(inplace=True), 37 | MaxPool(spatial_dims, kernel_size=2, stride=2), 38 | 39 | ConvND(spatial_dims, 512, 512, ksize, padding=1), 40 | nn.ReLU(inplace=True), 41 | ConvND(spatial_dims, 512, 512, ksize, padding=1), 42 | nn.ReLU(inplace=True), 43 | ConvND(spatial_dims, 512, 512, ksize, padding=1), 44 | nn.ReLU(inplace=True), 45 | MaxPool(spatial_dims, kernel_size=2, stride=2), 46 | ) 47 | 48 | def forward(self, x): 49 | return self.features(x) 50 | 51 | 52 | class VGG19(nn.Module): 53 | def __init__(self, cin=3, spatial_dims=2): 54 | super().__init__() 55 | ksize = 3 56 | self.features = nn.Sequential( 57 | ConvND(spatial_dims, cin, 64, ksize, padding=1), 58 | nn.ReLU(inplace=True), 59 | ConvND(spatial_dims, 64, 64, ksize, padding=1), 60 | nn.ReLU(inplace=True), 61 | MaxPool(spatial_dims, kernel_size=2, stride=2), 62 | 63 | ConvND(spatial_dims, 64, 128, ksize, padding=1), 64 | nn.ReLU(inplace=True), 65 | ConvND(spatial_dims, 128, 128, ksize, padding=1), 66 | nn.ReLU(inplace=True), 67 | MaxPool(spatial_dims, kernel_size=2, stride=2), 68 | 69 | ConvND(spatial_dims, 128, 256, ksize, padding=1), 70 | nn.ReLU(inplace=True), 71 | ConvND(spatial_dims, 256, 256, ksize, padding=1), 72 | nn.ReLU(inplace=True), 73 | ConvND(spatial_dims, 256, 256, ksize, padding=1), 74 | nn.ReLU(inplace=True), 75 | ConvND(spatial_dims, 256, 256, ksize, padding=1), 76 | nn.ReLU(inplace=True), 77 | MaxPool(spatial_dims, kernel_size=2, stride=2), 78 | 79 | ConvND(spatial_dims, 256, 512, ksize, padding=1), 80 | nn.ReLU(inplace=True), 81 | ConvND(spatial_dims, 512, 512, ksize, padding=1), 82 | nn.ReLU(inplace=True), 83 | ConvND(spatial_dims, 512, 512, ksize, padding=1), 84 | nn.ReLU(inplace=True), 85 | ConvND(spatial_dims, 512, 512, ksize, padding=1), 86 | nn.ReLU(inplace=True), 87 | MaxPool(spatial_dims, kernel_size=2, stride=2), 88 | 89 | ConvND(spatial_dims, 512, 512, ksize, padding=1), 90 | nn.ReLU(inplace=True), 91 | ConvND(spatial_dims, 512, 512, ksize, padding=1), 92 | nn.ReLU(inplace=True), 93 | ConvND(spatial_dims, 512, 512, ksize, padding=1), 94 | nn.ReLU(inplace=True), 95 | ConvND(spatial_dims, 512, 512, ksize, padding=1), 96 | nn.ReLU(inplace=True), 97 | MaxPool(spatial_dims, kernel_size=2, stride=2), 98 | ) 99 | 100 | def forward(self, x): 101 | return self.features(x) 102 | 103 | 104 | 105 | 106 | if __name__ == '__main__': 107 | # VGG16 108 | inputs = torch.normal(0, 1, (1, 3, 224)) 109 | output = VGG16(cin=3, spatial_dims=1)(inputs) 110 | shape = list(output.shape) 111 | assert shape[2:] == [7] 112 | 113 | # VGG19 114 | inputs = torch.normal(0, 1, (1, 3, 224)) 115 | output = VGG19(cin=3, spatial_dims=1)(inputs) 116 | shape = list(output.shape) 117 | assert shape[2:] == [7] 118 | 119 | # VGG16 120 | inputs = torch.normal(0, 1, (1, 3, 224, 224)) 121 | output = VGG16(cin=3, spatial_dims=2)(inputs) 122 | shape = list(output.shape) 123 | assert shape[2:] == [7, 7] 124 | 125 | # VGG19 126 | inputs = torch.normal(0, 1, (1, 3, 224, 224)) 127 | output = VGG19(cin=3, spatial_dims=2)(inputs) 128 | shape = list(output.shape) 129 | assert shape[2:] == [7, 7] -------------------------------------------------------------------------------- /fusionlab/encoders/vit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/encoders/vit/__init__.py -------------------------------------------------------------------------------- /fusionlab/functional/__init__.py: -------------------------------------------------------------------------------- 1 | from fusionlab import BACKEND 2 | if BACKEND['torch']: 3 | from .dice import dice_score 4 | from .iou import ( 5 | iou_score, 6 | jaccard_score, 7 | ) 8 | if BACKEND['tf']: 9 | from .tfdice import tf_dice_score 10 | from .tfiou import ( 11 | tf_iou_score, 12 | tf_jaccard_score, 13 | ) -------------------------------------------------------------------------------- /fusionlab/functional/dice.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import torch 3 | from fusionlab.configs import EPS 4 | 5 | 6 | def dice_score(pred: torch.Tensor, 7 | target: torch.Tensor, 8 | dims: Tuple[int, ...]=None) -> torch.Tensor: 9 | """ 10 | Computes the dice score 11 | 12 | Args: 13 | pred: (N, C, *) 14 | target: (N, C, *) 15 | dims: dimensions to sum over 16 | 17 | """ 18 | assert pred.size() == target.size() 19 | intersection = torch.sum(pred * target, dim=dims) 20 | cardinality = torch.sum(pred + target, dim=dims) 21 | return (2.0 * intersection) / cardinality.clamp(min=EPS) -------------------------------------------------------------------------------- /fusionlab/functional/iou.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from fusionlab.configs import EPS 3 | 4 | 5 | def iou_score(pred, target, dims=None): 6 | """ 7 | Shape: 8 | - pred: :math:`(N, C, *)` 9 | - target: :math:`(N, C, *)` 10 | - Output: scalar. 11 | """ 12 | assert pred.size() == target.size() 13 | 14 | intersection = torch.sum(pred * target, dim=dims) 15 | cardinality = torch.sum(pred + target, dim=dims) 16 | union = cardinality - intersection 17 | iou = intersection / union.clamp_min(EPS) 18 | return iou 19 | 20 | jaccard_score = iou_score -------------------------------------------------------------------------------- /fusionlab/functional/tfdice.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tensorflow as tf 3 | from fusionlab.configs import EPS 4 | 5 | def tf_dice_score(pred, target, axis=None): 6 | """ 7 | Shape: 8 | - pred: :math:`(N, *, C)` where :math:`*` means any number of additional dimensions 9 | - target: :math:`(N, *, C)`, same shape as the input 10 | - Output: scalar. 11 | """ 12 | intersection = tf.reduce_sum(pred * target, axis=axis) 13 | cardinality = tf.reduce_sum(pred + target, axis=axis) 14 | cardinality = tf.clip_by_value(cardinality, 15 | clip_value_min=EPS, 16 | clip_value_max=cardinality.dtype.max) 17 | return (2.0 * intersection) / cardinality 18 | -------------------------------------------------------------------------------- /fusionlab/functional/tfiou.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tensorflow as tf 3 | from fusionlab.configs import EPS 4 | 5 | def tf_iou_score(pred, target, axis=None): 6 | """ 7 | Shape: 8 | - pred: :math:`(N, *, C)` where :math:`*` means any number of additional dimensions 9 | - target: :math:`(N, *, C)`, same shape as the input 10 | - Output: scalar. 11 | """ 12 | intersection = tf.reduce_sum(pred * target, axis=axis) 13 | cardinality = tf.reduce_sum(pred + target, axis=axis) 14 | cardinality = tf.clip_by_value(cardinality, 15 | clip_value_min=EPS, 16 | clip_value_max=cardinality.dtype.max) 17 | union = cardinality - intersection 18 | return intersection / union 19 | 20 | tf_jaccard_score = tf_iou_score 21 | -------------------------------------------------------------------------------- /fusionlab/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from fusionlab import BACKEND 2 | if BACKEND['torch']: 3 | from .factories import ( 4 | ConvND, 5 | ConvT, 6 | Upsample, 7 | BatchNorm, 8 | InstanceNorm, 9 | MaxPool, 10 | AvgPool, 11 | AdaptiveMaxPool, 12 | AdaptiveAvgPool, 13 | ReplicationPad, 14 | ConstantPad 15 | ) 16 | from .squeeze_excitation.se import SEModule 17 | from .base import ( 18 | ConvNormAct, 19 | Rearrange, 20 | DropPath, 21 | ) 22 | from .patch_embed.patch_embedding import PatchEmbedding 23 | from .selfattention.selfattention import ( 24 | SelfAttention, 25 | SRAttention, 26 | ) 27 | if BACKEND['tf']: 28 | from .squeeze_excitation.tfse import TFSEModule -------------------------------------------------------------------------------- /fusionlab/layers/base.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from typing import Union, Sequence, Optional, Callable 3 | from einops import rearrange 4 | from torchvision.ops import StochasticDepth 5 | 6 | from fusionlab.layers import ConvND, BatchNorm 7 | from fusionlab.utils import make_ntuple 8 | 9 | class ConvNormAct(nn.Module): 10 | ''' 11 | ref: 12 | https://pytorch.org/vision/main/generated/torchvision.ops.Conv2dNormActivation.html 13 | https://github.com/pytorch/vision/blob/main/torchvision/ops/misc.py#L68 14 | 15 | Convolution + Normalization + Activation 16 | 17 | Args: 18 | spatial_dims (int): number of spatial dimensions of the input image. 19 | in_channels (int): number of channels of the input image. 20 | out_channels (int): number of channels of the output image. 21 | kernel_size (Union[Sequence[int], int]): size of the convolving kernel. 22 | stride (Union[Sequence[int], int], optional): stride of the convolution. Default: 1 23 | padding (Union[Sequence[int], str], optional): Padding added to all four sides of the input. Default: None, 24 | in which case it will be calculated as padding = (kernel_size - 1) // 2 * dilation 25 | dilation (Union[Sequence[int], int], optional): spacing between kernel elements. Default: 1 26 | groups (int, optional): number of blocked connections from input channels to output channels. Default: 1 27 | bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if norm_layer is None. 28 | norm_layer (Optional[Callable[..., nn.Module]], optional): normalization layer. Default: BatchNorm 29 | act_layer (Optional[Callable[..., nn.Module]], optional): activation layer. Default: nn.ReLU 30 | padding_mode (str, optional): mode of padding. Default: 'zeros' 31 | inplace (Optional[bool], optional): Parameter for the activation layer, 32 | which can optionally do the operation in-place. Default True 33 | 34 | ''' 35 | def __init__(self, 36 | spatial_dims: int, 37 | in_channels: int, 38 | out_channels: int, 39 | kernel_size: Union[Sequence[int], int], 40 | stride: Union[Sequence[int], int] = 1, 41 | padding: Union[Sequence[int], str] = None, 42 | dilation: Union[Sequence[int], int] = 1, 43 | groups: int = 1, 44 | bias: Optional[bool] = None, 45 | norm_layer: Optional[Callable[..., nn.Module]] = BatchNorm, 46 | act_layer: Optional[Callable[..., nn.Module]] = nn.ReLU, 47 | padding_mode: str = 'zeros', 48 | inplace: Optional[bool] = bool, 49 | ): 50 | super().__init__() 51 | # padding 52 | if padding is None: 53 | if isinstance(kernel_size, int) and isinstance(dilation, int): 54 | padding = (kernel_size - 1) // 2 * dilation 55 | else: 56 | _conv_dim = spatial_dims 57 | kernel_size = make_ntuple(kernel_size, _conv_dim) 58 | dilation = make_ntuple(dilation, _conv_dim) 59 | padding = tuple((kernel_size[i] - 1) // 2 * dilation[i] for i in range(_conv_dim)) 60 | # bias 61 | if bias is None: 62 | bias = norm_layer is None 63 | 64 | self.conv = ConvND( 65 | spatial_dims, 66 | in_channels, 67 | out_channels, 68 | kernel_size, 69 | stride, 70 | padding, 71 | dilation, 72 | groups, 73 | bias, 74 | padding_mode 75 | ) 76 | self.norm = norm_layer(spatial_dims, out_channels) 77 | params = {} if inplace is None else {"inplace": inplace} 78 | if act_layer is None: 79 | act_layer = nn.Identity 80 | self.act = act_layer(**params) 81 | 82 | def forward(self, x): 83 | x = self.conv(x) 84 | x = self.norm(x) 85 | x = self.act(x) 86 | return x 87 | 88 | class Rearrange(nn.Module): 89 | ''' 90 | nn.Module wrapper for eion's rearrange function 91 | ''' 92 | 93 | def __init__(self, pattern: str, **kwargs): 94 | super().__init__() 95 | self.pattern = pattern 96 | self.kwargs = kwargs 97 | 98 | def forward(self, x): 99 | return rearrange(x, self.pattern, **self.kwargs) 100 | 101 | DropPath = StochasticDepth 102 | 103 | if __name__ == '__main__': 104 | import torch 105 | inputs = torch.randn(1, 3, 128, 128) 106 | l = Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=16, p2=16) 107 | outputs = l(inputs) 108 | print(outputs.shape) 109 | 110 | inputs = torch.randn(1, 3, 128, 128) 111 | l = Rearrange('b c h w -> b (c h w)') 112 | outputs = l(inputs) 113 | print(outputs.shape) -------------------------------------------------------------------------------- /fusionlab/layers/patch_embed/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/layers/patch_embed/__init__.py -------------------------------------------------------------------------------- /fusionlab/layers/patch_embed/patch_embedding.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence, Union 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | from fusionlab.layers import Rearrange, ConvND 7 | from fusionlab.utils import make_ntuple, trunc_normal_ 8 | 9 | EMBEDDING_TYPES = ["conv", "fc"] 10 | 11 | class PatchEmbedding(nn.Module): 12 | """ 13 | A patch embedding block, based on: "Dosovitskiy et al., 14 | An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " 15 | 16 | """ 17 | 18 | def __init__( 19 | self, 20 | in_channels: int, 21 | img_size: Union[int, Sequence[int]], 22 | patch_size: Union[int, Sequence[int]], 23 | hidden_size: int, 24 | pos_embed_type: str = 'conv', 25 | dropout_rate: float = 0.0, 26 | spatial_dims: int = 2, 27 | ) -> None: 28 | """ 29 | Args: 30 | in_channels: dimension of input channels. 31 | img_size: dimension of input image. 32 | patch_size: dimension of patch size. 33 | hidden_size: dimension of hidden layer. 34 | num_heads: number of attention heads. 35 | pos_embed_type: position embedding layer type. 36 | dropout_rate: faction of the input units to drop. 37 | spatial_dims: number of spatial dimensions. 38 | """ 39 | 40 | super().__init__() 41 | assert pos_embed_type in EMBEDDING_TYPES, f"pos_embed_type must be in {EMBEDDING_TYPES}" 42 | self.pos_embed_type = pos_embed_type 43 | 44 | img_sizes = make_ntuple(img_size, spatial_dims) 45 | patch_sizes = make_ntuple(patch_size, spatial_dims) 46 | for m, p in zip(img_sizes, patch_sizes): 47 | if self.pos_embed_type == "fc" and m % p != 0: 48 | raise ValueError("patch_size should be divisible by img_size for fc embedding type.") 49 | self.n_patches = np.prod([im_d // p_d for im_d, p_d in zip(img_sizes, patch_sizes)]) 50 | self.patch_dim = int(in_channels * np.prod(patch_sizes)) 51 | 52 | self.patch_embeddings: nn.Module 53 | if self.pos_embed_type == "conv": 54 | self.patch_embeddings = nn.Sequential( 55 | ConvND( 56 | spatial_dims, 57 | in_channels, 58 | hidden_size, 59 | kernel_size=patch_size, 60 | stride=patch_size 61 | ), 62 | nn.Flatten(2), 63 | Rearrange('b d n -> b n d'), 64 | ) 65 | # self.patch_embeddings = Conv[Conv.CONV, spatial_dims]( 66 | # in_channels=in_channels, out_channels=hidden_size, kernel_size=patch_size, stride=patch_size 67 | # ) 68 | elif self.pos_embed_type == "fc": 69 | # for 3d: "b c (h p1) (w p2) (d p3) -> b (h w d) (p1 p2 p3 c)" 70 | chars = (("h", "p1"), ("w", "p2"), ("d", "p3"))[:spatial_dims] 71 | from_chars = "b c " + " ".join(f"({k} {v})" for k, v in chars) 72 | to_chars = f"b ({' '.join([c[0] for c in chars])}) ({' '.join([c[1] for c in chars])} c)" 73 | axes_len = {f"p{i+1}": p for i, p in enumerate(patch_sizes)} 74 | self.patch_embeddings = nn.Sequential( 75 | Rearrange(f"{from_chars} -> {to_chars}", **axes_len), 76 | nn.Linear(self.patch_dim, hidden_size), 77 | ) 78 | self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size)) 79 | self.dropout = nn.Dropout(dropout_rate) 80 | trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0) 81 | self.apply(self._init_weights) 82 | 83 | def forward(self, x): 84 | x = self.patch_embeddings(x) 85 | # if self.pos_embed_type == "conv": 86 | # x = x.flatten(2).transpose(-1, -2) # (b c w h) -> (b c wh) -> (b wh c) 87 | embeddings = x + self.position_embeddings 88 | embeddings = self.dropout(embeddings) 89 | return embeddings 90 | 91 | def _init_weights(self, m): 92 | if isinstance(m, nn.Linear): 93 | trunc_normal_(m.weight, mean=0.0, std=0.02, a=-2.0, b=2.0) 94 | if isinstance(m, nn.Linear) and m.bias is not None: 95 | nn.init.constant_(m.bias, 0) 96 | elif isinstance(m, nn.LayerNorm): 97 | nn.init.constant_(m.bias, 0) 98 | nn.init.constant_(m.weight, 1.0) 99 | 100 | 101 | if __name__ == '__main__': 102 | # 2D 103 | inputs = torch.randn(1, 3, 224, 224) 104 | l = PatchEmbedding(3, 224, 16, 768, pos_embed_type='conv') 105 | outputs = l(inputs) 106 | print(outputs.shape) 107 | 108 | inputs = torch.randn(1, 3, 224, 224) 109 | l = PatchEmbedding(3, 224, 16, 768, pos_embed_type='fc') 110 | print(l) 111 | outputs = l(inputs) 112 | print(outputs.shape) 113 | 114 | # 1D 115 | inputs = torch.randn(1, 3, 224) 116 | l = PatchEmbedding(3, 224, 16, 768, pos_embed_type='conv', spatial_dims=1) 117 | outputs = l(inputs) 118 | print(outputs.shape) 119 | 120 | inputs = torch.randn(1, 3, 224) 121 | l = PatchEmbedding(3, 224, 16, 768, pos_embed_type='fc', spatial_dims=1) 122 | outputs = l(inputs) 123 | print(outputs.shape) 124 | 125 | # 3D 126 | inputs = torch.randn(1, 3, 112, 112, 112) 127 | l = PatchEmbedding(3, 112, 16, 768, pos_embed_type='conv', spatial_dims=3) 128 | outputs = l(inputs) 129 | print(outputs.shape) 130 | 131 | inputs = torch.randn(1, 3, 112, 112, 112) 132 | l = PatchEmbedding(3, 112, 16, 768, pos_embed_type='fc', spatial_dims=3) 133 | outputs = l(inputs) 134 | print(outputs.shape) -------------------------------------------------------------------------------- /fusionlab/layers/selfattention/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/layers/selfattention/__init__.py -------------------------------------------------------------------------------- /fusionlab/layers/selfattention/selfattention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from fusionlab.layers import Rearrange 4 | 5 | 6 | class SelfAttention(nn.Module): 7 | """ 8 | A self-attention block, based on: "Dosovitskiy et al., 9 | An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " 10 | 11 | source code: https://github.com/Project-MONAI/MONAI/blob/main/monai/networks/blocks/selfattention.py#L22 12 | """ 13 | 14 | def __init__( 15 | self, 16 | hidden_size: int, 17 | num_heads: int, 18 | dropout_rate: float = 0.0, 19 | qkv_bias: bool = False, 20 | save_attn: bool = False, 21 | ) -> None: 22 | """ 23 | Args: 24 | hidden_size (int): dimension of hidden layer. 25 | num_heads (int): number of attention heads. 26 | dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0. 27 | qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False. 28 | save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. 29 | 30 | """ 31 | 32 | super().__init__() 33 | 34 | if not (0 <= dropout_rate <= 1): 35 | raise ValueError("dropout_rate should be between 0 and 1.") 36 | 37 | if hidden_size % num_heads != 0: 38 | raise ValueError("hidden size should be divisible by num_heads.") 39 | 40 | self.num_heads = num_heads 41 | self.out_proj = nn.Linear(hidden_size, hidden_size) 42 | self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) 43 | # b: batch size, h: num_patches, l: num_heads, d: head_dim 44 | self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads) 45 | self.out_rearrange = Rearrange("b h l d -> b l (h d)") 46 | self.drop_output = nn.Dropout(dropout_rate) 47 | self.drop_weights = nn.Dropout(dropout_rate) 48 | self.head_dim = hidden_size // num_heads 49 | self.scale = self.head_dim**-0.5 50 | self.save_attn = save_attn 51 | self.att_mat = torch.Tensor() 52 | 53 | def forward(self, x): 54 | qkv = self.input_rearrange(self.qkv(x)) 55 | q, k, v = qkv[0], qkv[1], qkv[2] # (b, l, h, d) 56 | att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1) # (b, l, h, h) 57 | if self.save_attn: 58 | self.att_mat = att_mat.detach() 59 | 60 | att_mat = self.drop_weights(att_mat) 61 | x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) # (b, l, h, d) 62 | x = self.out_rearrange(x) 63 | x = self.out_proj(x) 64 | x = self.drop_output(x) 65 | return x 66 | 67 | class SRAttention(nn.Module): 68 | """ 69 | Spatial Reduction Attention (SR-Attention) block, based on "Wang et al., 70 | Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions " 71 | 72 | source code: https://github.com/sithu31296/semantic-segmentation/blob/main/semseg/models/backbones/mit.py 73 | """ 74 | def __init__(self, dim, head, sr_ratio): 75 | """ 76 | Args: 77 | dim (int): input dimension 78 | head (int): number of attention heads 79 | sr_ratio (int): spatial reduction ratio 80 | 81 | """ 82 | super().__init__() 83 | self.head = head 84 | self.sr_ratio = sr_ratio 85 | self.scale = (dim // head) ** -0.5 86 | self.q = nn.Linear(dim, dim) 87 | self.kv = nn.Linear(dim, dim*2) 88 | self.proj = nn.Linear(dim, dim) 89 | 90 | if sr_ratio > 1: 91 | self.sr = nn.Conv2d(dim, dim, sr_ratio, sr_ratio) 92 | self.norm = nn.LayerNorm(dim) 93 | 94 | def forward(self, x: torch.Tensor, H, W) -> torch.Tensor: 95 | B, N, C = x.shape 96 | q = self.q(x).reshape(B, N, self.head, C // self.head).permute(0, 2, 1, 3) 97 | 98 | if self.sr_ratio > 1: 99 | x = x.permute(0, 2, 1).reshape(B, C, H, W) 100 | x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1) 101 | x = self.norm(x) 102 | 103 | k, v = self.kv(x).reshape(B, -1, 2, self.head, C // self.head).permute(2, 0, 3, 1, 4) 104 | 105 | attn = (q @ k.transpose(-2, -1)) * self.scale 106 | attn = attn.softmax(dim=-1) 107 | 108 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 109 | x = self.proj(x) 110 | return x -------------------------------------------------------------------------------- /fusionlab/layers/squeeze_excitation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/layers/squeeze_excitation/__init__.py -------------------------------------------------------------------------------- /fusionlab/layers/squeeze_excitation/se.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | import torch 3 | from torch import Tensor 4 | import torch.nn as nn 5 | from fusionlab.layers import ConvND, AdaptiveAvgPool 6 | 7 | class SEModule(nn.Module): 8 | """ 9 | source: https://github.com/pytorch/vision/blob/main/torchvision/ops/misc.py#L224 10 | This block implements the Squeeze-and-Excitation block from https://arxiv.org/abs/1709.01507 (see Fig. 1). 11 | Parameters ``activation``, and ``scale_activation`` correspond to ``delta`` and ``sigma`` in eq. 3. 12 | 13 | Args: 14 | input_channels (int): Number of channels in the input image 15 | squeeze_channels (int): Number of squeeze channels 16 | activation (Callable[..., torch.nn.Module], optional): ``delta`` activation. Default: ``torch.nn.ReLU`` 17 | scale_activation (Callable[..., torch.nn.Module]): ``sigma`` activation. Default: ``torch.nn.Sigmoid`` 18 | """ 19 | 20 | def __init__( 21 | self, 22 | input_channels: int, 23 | squeeze_channels: int, 24 | act_layer: Callable[..., torch.nn.Module] = torch.nn.ReLU, 25 | scale_layer: Callable[..., torch.nn.Module] = torch.nn.Sigmoid, 26 | spatial_dims: int = 2, 27 | ) -> None: 28 | super().__init__() 29 | self.avgpool = AdaptiveAvgPool(spatial_dims, 1) 30 | self.fc1 = ConvND(spatial_dims, input_channels, squeeze_channels, kernel_size=1) 31 | self.fc2 = ConvND(spatial_dims, squeeze_channels, input_channels, kernel_size=1) 32 | self.act_layer = act_layer() 33 | self.scale_layer = scale_layer() 34 | 35 | def _scale(self, input: Tensor) -> Tensor: 36 | scale = self.avgpool(input) 37 | scale = self.fc1(scale) 38 | scale = self.act_layer(scale) 39 | scale = self.fc2(scale) 40 | return self.scale_layer(scale) 41 | 42 | def forward(self, input: Tensor) -> Tensor: 43 | scale = self._scale(input) 44 | return scale * input 45 | 46 | 47 | if __name__ == '__main__': 48 | print('SEModule') 49 | inputs = torch.normal(0, 1, (1, 256, 16, 16)) 50 | layer = SEModule(256) 51 | outputs = layer(inputs) 52 | assert list(outputs.shape) == [1, 256, 16, 16] 53 | 54 | inputs = torch.normal(0, 1, (1, 256, 16, 16, 16)) 55 | layer = SEModule(256, spatial_dims=3) 56 | outputs = layer(inputs) 57 | assert list(outputs.shape) == [1, 256, 16, 16, 16] 58 | -------------------------------------------------------------------------------- /fusionlab/layers/squeeze_excitation/tfse.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import layers, Sequential 3 | 4 | 5 | class TFSEModule(layers.Layer): 6 | def __init__(self, cin, ratio=16): 7 | super().__init__() 8 | cout = int(cin / ratio) 9 | self.gate = Sequential([ 10 | layers.Conv2D(cout, kernel_size=1), 11 | layers.ReLU(), 12 | layers.Conv2D(cin, kernel_size=1), 13 | layers.Activation(tf.nn.sigmoid), 14 | ]) 15 | 16 | def call(self, inputs): 17 | x = tf.reduce_mean(inputs, (1, 2), keepdims=True) 18 | x = self.gate(x) 19 | return inputs * x 20 | 21 | 22 | if __name__ == '__main__': 23 | inputs = tf.random.normal((1, 224, 224, 256), 0, 1) 24 | layer = TFSEModule(256) 25 | outputs = layer(inputs) 26 | assert list(outputs.shape) == [1, 224, 224, 256] 27 | -------------------------------------------------------------------------------- /fusionlab/losses/README.md: -------------------------------------------------------------------------------- 1 | # Loss function 2 | 3 | | Name | PyTorch | Tensorflow | 4 | |:------------------------------|:------------|---------------| 5 | | Dice Loss | DiceLoss | TFDiceLoss | 6 | | Tversky Loss | TverskyLoss | TFTverskyLoss | 7 | | IoU Loss | IoULoss | TFIoULoss | 8 | | Dice Loss + CrossEntropy Loss | DiceCELoss | | 9 | 10 | -------------------------------------------------------------------------------- /fusionlab/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from fusionlab import BACKEND 2 | if BACKEND['torch']: 3 | from .diceloss.dice import DiceLoss, DiceCELoss 4 | from .iouloss.iou import IoULoss 5 | from .tversky.tversky import TverskyLoss 6 | if BACKEND['tf']: 7 | from .diceloss.tfdice import TFDiceLoss, TFDiceCE 8 | from .iouloss.tfiou import TFIoULoss 9 | from .tversky.tftversky import TFTverskyLoss -------------------------------------------------------------------------------- /fusionlab/losses/diceloss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/losses/diceloss/__init__.py -------------------------------------------------------------------------------- /fusionlab/losses/diceloss/dice.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from einops import rearrange 5 | from fusionlab.functional import dice_score 6 | from fusionlab.configs import EPS 7 | 8 | __all__ = ["DiceLoss", "DiceCELoss"] 9 | 10 | BINARY_MODE = "binary" 11 | MULTICLASS_MODE = "multiclass" 12 | 13 | 14 | class DiceCELoss(nn.Module): 15 | def __init__(self, w_dice=0.5, w_ce=0.5, cls_weight=None): 16 | """ 17 | Dice Loss + Cross Entropy Loss 18 | Args: 19 | w_dice: weight of Dice Loss 20 | w_ce: weight of CrossEntropy loss 21 | cls_weight: 22 | """ 23 | super().__init__() 24 | self.w_dice = w_dice 25 | self.w_ce = w_ce 26 | self.cls_weight = cls_weight 27 | self.dice = DiceLoss() 28 | self.ce = nn.CrossEntropyLoss(weight=cls_weight) 29 | 30 | def forward(self, y_pred, y_true): 31 | loss_dice = self.dice(y_pred, y_true) 32 | loss_ce = self.ce(y_pred, y_true) 33 | return self.w_dice * loss_dice + self.w_ce * loss_ce 34 | 35 | 36 | class DiceLoss(nn.Module): 37 | def __init__( 38 | self, 39 | mode="multiclass", # binary, multiclass 40 | log_loss=False, 41 | from_logits=True, 42 | ): 43 | """ 44 | Implementation of Dice loss for image segmentation task. 45 | It supports "binary", "multiclass" 46 | ref: https://github.com/BloodAxe/pytorch-toolbelt/blob/develop/pytorch_toolbelt/losses/dice.py 47 | Args: 48 | mode: Metric mode {'binary', 'multiclass'} 49 | log_loss: If True, loss computed as `-log(dice)`; otherwise `1 - dice` 50 | from_logits: If True assumes input is raw logits 51 | """ 52 | super().__init__() 53 | self.mode = mode 54 | self.from_logits = from_logits 55 | self.log_loss = log_loss 56 | 57 | def forward(self, y_pred, y_true) -> torch.Tensor: 58 | """ 59 | :param y_pred: (N, C, *) 60 | :param y_true: (N, *) 61 | :return: scalar 62 | """ 63 | assert y_true.size(0) == y_pred.size(0) 64 | num_classes = y_pred.size(1) 65 | dims = (0, 2) # (N, C, HW) 66 | 67 | if self.from_logits: 68 | # get [0..1] class probabilities 69 | if self.mode == MULTICLASS_MODE: 70 | y_pred = F.softmax(y_pred, dim=1) 71 | else: 72 | y_pred = torch.sigmoid(y_pred) 73 | 74 | if self.mode == BINARY_MODE: 75 | y_true = rearrange(y_true, "N ... -> N 1 (...)") 76 | y_pred = rearrange(y_pred, "N 1 ... -> N 1 (...)") 77 | elif self.mode == MULTICLASS_MODE: 78 | y_pred = rearrange(y_pred, "N C ... -> N C (...)") 79 | y_true = F.one_hot(y_true, num_classes) # (N, *) -> (N, *, C) 80 | y_true = rearrange(y_true, "N ... C -> N C (...)") 81 | else: 82 | AssertionError("Not implemented") 83 | 84 | scores = dice_score(y_pred, y_true.type_as(y_pred), dims=dims) 85 | if self.log_loss: 86 | loss = -torch.log(scores.clamp_min(EPS)) 87 | else: 88 | loss = 1.0 - scores 89 | return loss.mean() 90 | 91 | 92 | if __name__ == "__main__": 93 | 94 | print("multiclass") 95 | pred = torch.tensor([[ 96 | [1., 2., 3., 4.], 97 | [2., 6., 4., 4.], 98 | [9., 6., 3., 4.] 99 | ]]).view(1, 3, 4) 100 | true = torch.tensor([[2, 1, 0, 2]]).view(1, 4) 101 | 102 | dice = DiceLoss("multiclass", from_logits=True) 103 | loss = dice(pred, true) 104 | 105 | print("Binary") 106 | pred = torch.tensor([0.4, 0.2, 0.3, 0.5]).reshape(1, 1, 2, 2) 107 | true = torch.tensor([0, 1, 0, 1]).reshape(1, 2, 2) 108 | dice = DiceLoss("binary", from_logits=True) 109 | loss = dice(pred, true) 110 | 111 | print("Binary Logloss") 112 | pred = torch.tensor([0.4, 0.2, 0.3, 0.5]).reshape(1, 1, 2, 2) 113 | true = torch.tensor([0, 1, 0, 1]).reshape(1, 2, 2) 114 | dice = DiceLoss("binary", from_logits=True, log_loss=True) 115 | loss = dice(pred, true) 116 | 117 | -------------------------------------------------------------------------------- /fusionlab/losses/diceloss/tfdice.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from einops import rearrange 3 | from fusionlab.functional.tfdice import tf_dice_score 4 | 5 | __all__ = ["TFDiceLoss", "TFDiceCE"] 6 | 7 | BINARY_MODE = "binary" 8 | MULTICLASS_MODE = "multiclass" 9 | 10 | # TODO: Test code 11 | class TFDiceCE(tf.keras.losses.Loss): 12 | def __init__(self, mode="binary", from_logits=False, w_dice=0.5, w_ce=0.5): 13 | """ 14 | Dice Loss + Cross Entropy Loss 15 | Args: 16 | w_dice: weight of Dice Loss 17 | w_ce: weight of CrossEntropy loss 18 | mode: Metric mode {'binary', 'multiclass'} 19 | """ 20 | super().__init__() 21 | self.w_dice = w_dice 22 | self.w_ce = w_ce 23 | self.dice = TFDiceLoss(mode, from_logits) 24 | if mode == BINARY_MODE: 25 | self.ce = tf.keras.losses.BinaryCrossentropy(from_logits) 26 | elif mode == MULTICLASS_MODE: 27 | self.ce = tf.keras.losses.SparseCategoricalCrossentropy(from_logits) 28 | 29 | def call(self, y_true, y_pred): 30 | loss_dice = self.dice(y_true, y_pred) 31 | loss_ce = self.ce(y_true, y_pred) 32 | return self.w_dice * loss_dice + self.w_ce * loss_ce 33 | 34 | 35 | class TFDiceLoss(tf.keras.losses.Loss): 36 | def __init__( 37 | self, 38 | mode="multiclass", # binary, multiclass 39 | log_loss=False, 40 | from_logits=False, 41 | ): 42 | """ 43 | Implementation of Dice loss for image segmentation task. 44 | It supports "binary", "multiclass" 45 | https://github.com/BloodAxe/pytorch-toolbelt/blob/develop/pytorch_toolbelt/losses/dice.py 46 | Args: 47 | mode: Metric mode {'binary', 'multiclass'} 48 | log_loss: If True, loss computed as `-log(dice)`; otherwise `1 - dice` 49 | from_logits: If True assumes input is raw logits 50 | """ 51 | super().__init__() 52 | self.mode = mode 53 | self.from_logits = from_logits 54 | self.log_loss = log_loss 55 | 56 | def call(self, y_true, y_pred): 57 | """ 58 | :param y_true: (N, *) 59 | :param y_pred: (N, *, C) 60 | :return: scalar 61 | """ 62 | y_true_shape = y_true.shape.as_list() 63 | y_pred_shape = y_pred.shape.as_list() 64 | assert y_true_shape[0] == y_pred_shape[0] 65 | num_classes = y_pred_shape[-1] 66 | axis = [0] 67 | 68 | if self.from_logits: 69 | # get [0..1] class probabilities 70 | if self.mode == MULTICLASS_MODE: 71 | y_pred = tf.nn.softmax(y_pred, axis=-1) 72 | else: 73 | y_pred = tf.nn.sigmoid(y_pred) 74 | 75 | if self.mode == BINARY_MODE: 76 | y_true = rearrange(y_true, "... -> (...) 1") 77 | y_pred = rearrange(y_pred, "... -> (...) 1") 78 | elif self.mode == MULTICLASS_MODE: 79 | y_true = tf.cast(y_true, tf.int32) 80 | y_true = tf.one_hot(y_true, num_classes) 81 | y_true = rearrange(y_true, "... C -> (...) C") 82 | y_pred = rearrange(y_pred, "... C -> (...) C") 83 | else: 84 | AssertionError("Not implemented") 85 | 86 | scores = tf_dice_score(y_pred, tf.cast(y_true, y_pred.dtype), axis=axis) 87 | if self.log_loss: 88 | loss = -tf.math.log(tf.clip_by_value(scores, clip_value_min=1e-7, clip_value_max=scores.dtype.max)) 89 | else: 90 | loss = 1.0 - scores 91 | return tf.math.reduce_mean(loss) 92 | 93 | 94 | if __name__ == '__main__': 95 | print("Multiclass") 96 | pred = tf.convert_to_tensor([[ 97 | [1., 2., 3., 4.], 98 | [2., 6., 4., 4.], 99 | [9., 6., 3., 4.] 100 | ]]) 101 | pred = rearrange(pred, "N C H -> N H C") 102 | true = tf.convert_to_tensor([[2, 1, 0, 2]]) 103 | 104 | dice = TFDiceLoss("multiclass", from_logits=True) 105 | loss = dice(true, pred) 106 | 107 | print("Binary") 108 | pred = tf.convert_to_tensor([0.4, 0.2, 0.3, 0.5]) 109 | pred = tf.reshape(pred, [1, 2, 2, 1]) 110 | true = tf.convert_to_tensor([0, 1, 0, 1]) 111 | true = tf.reshape(true, [1, 2, 2]) 112 | dice = TFDiceLoss("binary", from_logits=True) 113 | loss = dice(true, pred) 114 | 115 | print("Binary Log loss") 116 | pred = tf.convert_to_tensor([0.4, 0.2, 0.3, 0.5]) 117 | pred = tf.reshape(pred, [1, 2, 2, 1]) 118 | true = tf.convert_to_tensor([0, 1, 0, 1]) 119 | true = tf.reshape(true, [1, 2, 2]) 120 | dice = TFDiceLoss("binary", from_logits=True, log_loss=True) 121 | loss = dice(true, pred) -------------------------------------------------------------------------------- /fusionlab/losses/iouloss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/losses/iouloss/__init__.py -------------------------------------------------------------------------------- /fusionlab/losses/iouloss/iou.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from einops import rearrange 5 | from fusionlab.functional import iou_score 6 | 7 | __all__ = ["IoULoss"] 8 | 9 | BINARY_MODE = "binary" 10 | MULTICLASS_MODE = "multiclass" 11 | 12 | 13 | class IoULoss(nn.Module): 14 | def __init__( 15 | self, 16 | mode="multiclass", # binary, multiclass 17 | log_loss=False, 18 | from_logits=True, 19 | ): 20 | """ 21 | Implementation of Iou loss for image segmentation task. 22 | It supports "binary", "multiclass" 23 | Args: 24 | mode: Metric mode {'binary', 'multiclass'} 25 | log_loss: If True, loss computed as `-log(jaccard)`; otherwise `1 - jaccard` 26 | from_logits: If True assumes input is raw logits 27 | """ 28 | super().__init__() 29 | self.mode = mode 30 | self.from_logits = from_logits 31 | self.log_loss = log_loss 32 | 33 | def forward(self, y_pred, y_true): 34 | """ 35 | :param y_pred: (N, C, *) 36 | :param y_true: (N, *) 37 | :return: scalar 38 | """ 39 | assert y_true.size(0) == y_pred.size(0) 40 | num_classes = y_pred.size(1) 41 | dims = (0, 2) # (N, C, *) 42 | 43 | if self.from_logits: 44 | # get [0..1] class probabilities 45 | if self.mode == MULTICLASS_MODE: 46 | y_pred = F.softmax(y_pred, dim=1) 47 | else: 48 | y_pred = torch.sigmoid(y_pred) 49 | 50 | if self.mode == BINARY_MODE: 51 | y_true = rearrange(y_true, "N ... -> N 1 (...)") 52 | y_pred = rearrange(y_pred, "N 1 ... -> N 1 (...)") 53 | elif self.mode == MULTICLASS_MODE: 54 | y_pred = rearrange(y_pred, "N C ... -> N C (...)") 55 | y_true = F.one_hot(y_true, num_classes) # (N, *) -> (N, *, C) 56 | y_true = rearrange(y_true, "N ... C -> N C (...)") 57 | else: 58 | AssertionError("Not implemented") 59 | 60 | scores = iou_score(y_pred, y_true.type_as(y_pred), dims=dims) 61 | if self.log_loss: 62 | loss = -torch.log(scores.clamp_min(1e-7)) 63 | else: 64 | loss = 1.0 - scores 65 | return loss.mean() 66 | 67 | 68 | if __name__ == "__main__": 69 | 70 | print("multiclass") 71 | pred = torch.tensor([[ 72 | [1., 2., 3., 4.], 73 | [2., 6., 4., 4.], 74 | [9., 6., 3., 4.] 75 | ]]).view(1, 3, 4) 76 | true = torch.tensor([[2, 1, 0, 2]]).view(1, 4) 77 | 78 | loss = IoULoss("multiclass", from_logits=True)(pred, true) 79 | 80 | print("Binary") 81 | pred = torch.tensor([0.4, 0.2, 0.3, 0.5]).reshape(1, 1, 2, 2) 82 | true = torch.tensor([0, 1, 0, 1]).reshape(1, 2, 2) 83 | loss = IoULoss("binary", from_logits=True)(pred, true) 84 | 85 | print("Binary Logloss") 86 | pred = torch.tensor([0.4, 0.2, 0.3, 0.5]).reshape(1, 1, 2, 2) 87 | true = torch.tensor([0, 1, 0, 1]).reshape(1, 2, 2) 88 | loss = IoULoss("binary", from_logits=True, log_loss=True)(pred, true) 89 | 90 | -------------------------------------------------------------------------------- /fusionlab/losses/iouloss/tfiou.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from einops import rearrange 3 | from fusionlab.functional.tfiou import tf_iou_score 4 | 5 | __all__ = ["TFIoULoss"] 6 | 7 | BINARY_MODE = "binary" 8 | MULTICLASS_MODE = "multiclass" 9 | 10 | 11 | class TFIoULoss(tf.keras.losses.Loss): 12 | def __init__( 13 | self, 14 | mode="binary", # binary, multiclass 15 | log_loss=False, 16 | from_logits=False, 17 | ): 18 | """ 19 | Implementation of IoU loss for image segmentation task. 20 | It supports "binary", "multiclass" 21 | Args: 22 | mode: Metric mode {'binary', 'multiclass'} 23 | log_loss: If True, loss computed as `-log(dice)`; otherwise `1 - dice` 24 | from_logits: If True assumes input is raw logits 25 | """ 26 | super().__init__() 27 | self.mode = mode 28 | self.from_logits = from_logits 29 | self.log_loss = log_loss 30 | 31 | def call(self, y_true, y_pred): 32 | """ 33 | :param y_true: (N, *) 34 | :param y_pred: (N, *, C) 35 | :return: scalar 36 | """ 37 | y_true_shape = y_true.shape.as_list() 38 | y_pred_shape = y_pred.shape.as_list() 39 | assert y_true_shape[0] == y_pred_shape[0] 40 | num_classes = y_pred_shape[-1] 41 | axis = [0] 42 | 43 | if self.from_logits: 44 | # get [0..1] class probabilities 45 | if self.mode == MULTICLASS_MODE: 46 | y_pred = tf.nn.softmax(y_pred, axis=-1) 47 | else: 48 | y_pred = tf.nn.sigmoid(y_pred) 49 | 50 | if self.mode == BINARY_MODE: 51 | y_true = rearrange(y_true, "... -> (...) 1") 52 | y_pred = rearrange(y_pred, "... -> (...) 1") 53 | elif self.mode == MULTICLASS_MODE: 54 | y_true = tf.cast(y_true, tf.int32) 55 | y_true = tf.one_hot(y_true, num_classes) 56 | y_true = rearrange(y_true, "... C -> (...) C") 57 | y_pred = rearrange(y_pred, "... C -> (...) C") 58 | else: 59 | AssertionError("Not implemented") 60 | 61 | scores = tf_iou_score(y_pred, tf.cast(y_true, y_pred.dtype), axis=axis) 62 | scores = tf.clip_by_value(scores, clip_value_min=1e-7, clip_value_max=scores.dtype.max) 63 | if self.log_loss: 64 | loss = -tf.math.log(scores) 65 | else: 66 | loss = 1.0 - scores 67 | return tf.math.reduce_mean(loss) 68 | 69 | 70 | if __name__ == '__main__': 71 | print("Multiclass") 72 | pred = tf.convert_to_tensor([[ 73 | [1., 2., 3., 4.], 74 | [2., 6., 4., 4.], 75 | [9., 6., 3., 4.] 76 | ]]) 77 | pred = rearrange(pred, "N C H -> N H C") 78 | true = tf.convert_to_tensor([[2, 1, 0, 2]]) 79 | 80 | loss = TFIoULoss("multiclass", from_logits=True)(true, pred) 81 | 82 | print("Binary") 83 | pred = tf.convert_to_tensor([0.4, 0.2, 0.3, 0.5]) 84 | pred = tf.reshape(pred, [1, 2, 2, 1]) 85 | true = tf.convert_to_tensor([0, 1, 0, 1]) 86 | true = tf.reshape(true, [1, 2, 2]) 87 | loss = TFIoULoss("binary", from_logits=True)(true, pred) 88 | 89 | print("Binary Log loss") 90 | pred = tf.convert_to_tensor([0.4, 0.2, 0.3, 0.5]) 91 | pred = tf.reshape(pred, [1, 2, 2, 1]) 92 | true = tf.convert_to_tensor([0, 1, 0, 1]) 93 | true = tf.reshape(true, [1, 2, 2]) 94 | loss = TFIoULoss("binary", from_logits=True, log_loss=True)(true, pred) -------------------------------------------------------------------------------- /fusionlab/losses/tversky/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/losses/tversky/__init__.py -------------------------------------------------------------------------------- /fusionlab/losses/tversky/tftversky.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from einops import rearrange 3 | from fusionlab.configs import EPS 4 | 5 | __all__ = ["TFTverskyLoss"] 6 | 7 | BINARY_MODE = "binary" 8 | MULTICLASS_MODE = "multiclass" 9 | 10 | 11 | class TFTverskyLoss(tf.keras.losses.Loss): 12 | def __init__(self, 13 | alpha, 14 | beta, 15 | mode="binary", # binary, multiclass 16 | log_loss=False, 17 | from_logits=False, 18 | ): 19 | """ 20 | Implementation of Dice loss for image segmentation task. 21 | It supports "binary", "multiclass" 22 | ref: https://github.com/kornia/kornia/blob/master/kornia/losses/tversky.py 23 | ref: https://github.com/kevinzakka/pytorch-goodies/blob/master/losses.py 24 | Args: 25 | alpha: controls the penalty for false positives(FP). 26 | beta: controls the penalty for false negatives(FN). 27 | mode: Metric mode {'binary', 'multiclass'} 28 | log_loss: If True, loss computed as `-log(dice)`; otherwise `1 - dice` 29 | from_logits: If True assumes input is raw logits 30 | """ 31 | super().__init__() 32 | self.alpha = alpha 33 | self.beta = beta 34 | self.mode = mode 35 | self.from_logits = from_logits 36 | self.log_loss = log_loss 37 | 38 | def call(self, y_true, y_pred): 39 | """ 40 | :param y_true: (N, *) 41 | :param y_pred: (N, *, C) 42 | :return: scalar 43 | """ 44 | y_true_shape = y_true.shape.as_list() 45 | y_pred_shape = y_pred.shape.as_list() 46 | assert y_true_shape[0] == y_pred_shape[0] 47 | num_classes = y_pred_shape[-1] 48 | axis = [0] 49 | 50 | if self.from_logits: 51 | # get [0..1] class probabilities 52 | if self.mode == MULTICLASS_MODE: 53 | y_pred = tf.nn.softmax(y_pred, axis=-1) 54 | else: 55 | y_pred = tf.nn.sigmoid(y_pred) 56 | 57 | if self.mode == BINARY_MODE: 58 | y_true = rearrange(y_true, "... -> (...) 1") 59 | y_pred = rearrange(y_pred, "... -> (...) 1") 60 | elif self.mode == MULTICLASS_MODE: 61 | y_true = tf.cast(y_true, tf.int32) 62 | y_true = tf.one_hot(y_true, num_classes) 63 | y_true = rearrange(y_true, "... C -> (...) C") 64 | y_pred = rearrange(y_pred, "... C -> (...) C") 65 | else: 66 | AssertionError("Not implemented") 67 | 68 | scores = tf_tversky_score(y_pred, tf.cast(y_true, y_pred.dtype), 69 | self.alpha, 70 | self.beta, 71 | axis=axis) 72 | if self.log_loss: 73 | loss = -tf.math.log(tf.clip_by_value(scores, clip_value_min=1e-7, clip_value_max=scores.dtype.max)) 74 | else: 75 | loss = 1.0 - scores 76 | return tf.math.reduce_mean(loss) 77 | 78 | 79 | def tf_tversky_score(pred, target, alpha, beta, axis=None): 80 | """ 81 | Shape: 82 | - pred: :math:`(N, *, C)` where :math:`*` means any number of additional dimensions 83 | - target: :math:`(N, *, C)`, same shape as the input 84 | - Output: scalar. 85 | """ 86 | intersection = tf.reduce_sum(pred * target, axis=axis) 87 | fp = tf.reduce_sum(pred * (1. - target), axis) 88 | fn = tf.reduce_sum((1. - pred) * target, axis) 89 | denominator = intersection + alpha * fp + beta * fn 90 | denominator = tf.clip_by_value(denominator, 91 | clip_value_min=EPS, 92 | clip_value_max=denominator.dtype.max) 93 | return intersection / denominator 94 | 95 | 96 | if __name__ == "__main__": 97 | print("Multiclass") 98 | pred = tf.convert_to_tensor([[ 99 | [1., 2., 3., 4.], 100 | [2., 6., 4., 4.], 101 | [9., 6., 3., 4.] 102 | ]]) 103 | pred = rearrange(pred, "N C H -> N H C") 104 | true = tf.convert_to_tensor([[2, 1, 0, 2]]) 105 | 106 | loss_fn = TFTverskyLoss(0.5, 0.5, "multiclass", from_logits=True) 107 | loss = loss_fn(true, pred) 108 | print(float(loss)) 109 | 110 | print("Binary") 111 | pred = tf.convert_to_tensor([0.4, 0.2, 0.3, 0.5]) 112 | pred = tf.reshape(pred, [1, 2, 2, 1]) 113 | true = tf.convert_to_tensor([0, 1, 0, 1]) 114 | true = tf.reshape(true, [1, 2, 2]) 115 | loss_fn = TFTverskyLoss(0.5, 0.5, "binary", from_logits=True) 116 | loss = loss_fn(true, pred) 117 | print(float(loss)) 118 | 119 | print("Binary Log loss") 120 | pred = tf.convert_to_tensor([0.4, 0.2, 0.3, 0.5]) 121 | pred = tf.reshape(pred, [1, 2, 2, 1]) 122 | true = tf.convert_to_tensor([0, 1, 0, 1]) 123 | true = tf.reshape(true, [1, 2, 2]) 124 | loss_fn = TFTverskyLoss(0.5, 0.5, "binary", from_logits=True, log_loss=True) 125 | loss = loss_fn(true, pred) 126 | print(float(loss)) 127 | -------------------------------------------------------------------------------- /fusionlab/losses/tversky/tversky.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from einops import rearrange 5 | from fusionlab.configs import EPS 6 | 7 | __all__ = ["TverskyLoss"] 8 | 9 | BINARY_MODE = "binary" 10 | MULTICLASS_MODE = "multiclass" 11 | 12 | 13 | class TverskyLoss(nn.Module): 14 | def __init__( 15 | self, 16 | alpha, 17 | beta, 18 | mode="multiclass", # binary, multiclass 19 | log_loss=False, 20 | from_logits=True, 21 | ): 22 | """ 23 | Implementation of Tversky loss for image segmentation task. 24 | It supports "binary", "multiclass" 25 | ref: https://github.com/kornia/kornia/blob/master/kornia/losses/tversky.py 26 | ref: https://github.com/kevinzakka/pytorch-goodies/blob/master/losses.py 27 | Args: 28 | alpha: controls the penalty for false positives(FP). 29 | beta: controls the penalty for false negatives(FN). 30 | mode: Metric mode {'binary', 'multiclass'} 31 | log_loss: If True, loss computed as `-log(dice)`; otherwise `1 - dice` 32 | from_logits: If True assumes input is raw logits 33 | """ 34 | super().__init__() 35 | self.alpha = alpha 36 | self.beta = beta 37 | self.mode = mode 38 | self.from_logits = from_logits 39 | self.log_loss = log_loss 40 | 41 | def forward(self, y_pred, y_true): 42 | """ 43 | :param y_pred: (N, C, *) 44 | :param y_true: (N, *) 45 | :return: scalar 46 | """ 47 | assert y_true.size(0) == y_pred.size(0) 48 | num_classes = y_pred.size(1) 49 | dims = (0, 2) # (N, C, HW) 50 | 51 | if self.from_logits: 52 | # get [0..1] class probabilities 53 | if self.mode == MULTICLASS_MODE: 54 | y_pred = F.softmax(y_pred, dim=1) 55 | else: 56 | y_pred = torch.sigmoid(y_pred) 57 | 58 | if self.mode == BINARY_MODE: 59 | y_true = rearrange(y_true, "N ... -> N 1 (...)") 60 | y_pred = rearrange(y_pred, "N 1 ... -> N 1 (...)") 61 | elif self.mode == MULTICLASS_MODE: 62 | y_pred = rearrange(y_pred, "N C ... -> N C (...)") 63 | y_true = F.one_hot(y_true, num_classes) # (N, *) -> (N, *, C) 64 | y_true = rearrange(y_true, "N ... C -> N C (...)") 65 | else: 66 | AssertionError("Not implemented") 67 | 68 | scores = tversky_score(y_pred, y_true.type_as(y_pred), 69 | self.alpha, self.beta, 70 | dims=dims) 71 | if self.log_loss: 72 | loss = -torch.log(scores.clamp_min(EPS)) 73 | else: 74 | loss = 1.0 - scores 75 | return loss.mean() 76 | 77 | 78 | def tversky_score(pred, target, alpha, beta, dims): 79 | """ 80 | Shape: 81 | - pred: :math:`(N, C, *)` 82 | - target: :math:`(N, C, *)` 83 | - Output: scalar. 84 | """ 85 | assert pred.size() == target.size() 86 | 87 | intersection = torch.sum(pred * target, dim=dims) 88 | fp = torch.sum(pred * (1. - target), dims) 89 | fn = torch.sum((1. - pred) * target, dims) 90 | 91 | denominator = intersection + alpha * fp + beta * fn 92 | return intersection / denominator.clamp(min=EPS) 93 | 94 | if __name__ == "__main__": 95 | print("multiclass") 96 | pred = torch.tensor([[ 97 | [1., 2., 3., 4.], 98 | [2., 6., 4., 4.], 99 | [9., 6., 3., 4.] 100 | ]]).unsqueeze(-1) 101 | true = torch.tensor([[2, 1, 0, 2]]).view(1, 4).unsqueeze(-1) 102 | 103 | loss_fn = TverskyLoss(0.5, 0.5, "multiclass", from_logits=True) 104 | loss = loss_fn(pred, true) 105 | print(loss.item()) 106 | 107 | print("Binary") 108 | pred = torch.tensor([0.4, 0.2, 0.3, 0.5]).reshape(1, 1, 2, 2) 109 | true = torch.tensor([0, 1, 0, 1]).reshape(1, 2, 2) 110 | loss_fn = TverskyLoss(0.5, 0.5, "binary", from_logits=True) 111 | loss = loss_fn(pred, true) 112 | print(loss.item()) 113 | 114 | print("Binary Logloss") 115 | pred = torch.tensor([0.4, 0.2, 0.3, 0.5]).reshape(1, 1, 2, 2) 116 | true = torch.tensor([0, 1, 0, 1]).reshape(1, 2, 2) 117 | loss_fn = TverskyLoss(0.5, 0.5, "binary", from_logits=True, log_loss=True) 118 | loss = loss_fn(pred, true) 119 | print(loss.item()) 120 | -------------------------------------------------------------------------------- /fusionlab/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from fusionlab import BACKEND 2 | if BACKEND['torch']: 3 | from .dicescore.dice import DiceScore, JaccardScore 4 | from .iouscore.iou import IoUScore -------------------------------------------------------------------------------- /fusionlab/metrics/dicescore/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/metrics/dicescore/__init__.py -------------------------------------------------------------------------------- /fusionlab/metrics/dicescore/dice.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from einops import rearrange 5 | from fusionlab.functional import dice_score 6 | 7 | BINARY_MODE = "binary" 8 | MULTICLASS_MODE = "multiclass" 9 | 10 | class DiceScore(nn.Module): 11 | def __init__( 12 | self, 13 | mode="multiclass", # binary, multiclass 14 | from_logits=True, 15 | reduction="none", # mean, none 16 | ): 17 | """ 18 | Computer dice score for binary or multiclass input 19 | 20 | Args: 21 | mode: "binary" or "multiclass" 22 | from_logits: if True, assumes input is raw logits 23 | reduction: "mean" or "none", if "none" returns dice score for each channels, else returns mean 24 | """ 25 | super().__init__() 26 | self.mode = mode 27 | self.from_logits = from_logits 28 | self.reduction = reduction 29 | 30 | def forward(self, y_pred, y_true) -> torch.Tensor: 31 | """ 32 | :param y_pred: (N, C, *) 33 | :param y_true: (N, *) 34 | :return: scalar 35 | """ 36 | assert y_true.size(0) == y_pred.size(0) 37 | num_classes = y_pred.size(1) 38 | dims = (0, 2) # dimensions to sum over (N, C, *) 39 | 40 | if self.from_logits: 41 | # get [0..1] class probabilities 42 | if self.mode == MULTICLASS_MODE: 43 | y_pred = F.softmax(y_pred, dim=1) 44 | else: 45 | y_pred = torch.sigmoid(y_pred) 46 | 47 | if self.mode == BINARY_MODE: 48 | y_true = rearrange(y_true, "N ... -> N 1 (...)") 49 | y_pred = rearrange(y_pred, "N 1 ... -> N 1 (...)") 50 | elif self.mode == MULTICLASS_MODE: 51 | y_pred = rearrange(y_pred, "N C ... -> N C (...)") 52 | y_true = F.one_hot(y_true, num_classes) # (N, *) -> (N, *, C) 53 | y_true = rearrange(y_true, "N ... C -> N C (...)") 54 | else: 55 | AssertionError("Not implemented") 56 | 57 | scores = dice_score(y_pred, y_true.type_as(y_pred), dims=dims) 58 | if self.reduction == "none": 59 | return scores 60 | else: 61 | return scores.mean() 62 | 63 | JaccardScore = DiceScore -------------------------------------------------------------------------------- /fusionlab/metrics/iouscore/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/metrics/iouscore/__init__.py -------------------------------------------------------------------------------- /fusionlab/metrics/iouscore/iou.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from einops import rearrange 5 | from fusionlab.functional import iou_score 6 | 7 | __all__ = ["IoUScore"] 8 | 9 | BINARY_MODE = "binary" 10 | MULTICLASS_MODE = "multiclass" 11 | 12 | class IoUScore(nn.Module): 13 | def __init__( 14 | self, 15 | mode="multiclass", # binary, multiclass 16 | from_logits=True, 17 | reduction="none", # mean, none 18 | ): 19 | """ 20 | Implementation of Iou score for segmentation task. 21 | It supports "binary", "multiclass" 22 | Args: 23 | mode: Metric mode {'binary', 'multiclass'} 24 | from_logits: If True assumes input is raw logits 25 | reduction: "mean" or "none", if "none" returns dice score for each channels, else returns mean 26 | """ 27 | super().__init__() 28 | self.mode = mode 29 | self.from_logits = from_logits 30 | self.reduction = reduction 31 | 32 | def forward(self, y_pred, y_true): 33 | """ 34 | :param y_pred: (N, C, *) 35 | :param y_true: (N, *) 36 | :return: (C, ) if mode is 'multiclass' else (1, ) 37 | """ 38 | assert y_true.size(0) == y_pred.size(0) 39 | num_classes = y_pred.size(1) 40 | dims = (0, 2) # (N, C, *) 41 | 42 | if self.from_logits: 43 | # get [0..1] class probabilities 44 | if self.mode == MULTICLASS_MODE: 45 | y_pred = F.softmax(y_pred, dim=1) 46 | else: 47 | y_pred = torch.sigmoid(y_pred) 48 | 49 | if self.mode == BINARY_MODE: 50 | y_true = rearrange(y_true, "N ... -> N 1 (...)") 51 | y_pred = rearrange(y_pred, "N 1 ... -> N 1 (...)") 52 | elif self.mode == MULTICLASS_MODE: 53 | y_pred = rearrange(y_pred, "N C ... -> N C (...)") 54 | y_true = F.one_hot(y_true, num_classes) # (N, *) -> (N, *, C) 55 | y_true = rearrange(y_true, "N ... C -> N C (...)") 56 | else: 57 | AssertionError("Not implemented") 58 | 59 | scores = iou_score(y_pred, y_true.type_as(y_pred), dims=dims) 60 | if self.reduction == "none": 61 | return scores 62 | else: 63 | return scores.mean() -------------------------------------------------------------------------------- /fusionlab/segmentation/README.md: -------------------------------------------------------------------------------- 1 | # Segmentation 2 | 3 | | Name | PyTorch | Tensorflow | Paper | 4 | | :-------- | :-------- | ----------- | ---------------------------------------------------------------------------------------------------------------------- | 5 | | UNet | UNet | TFUNet | [U-Net: Convolutional Networks for Biomedical Image Segmentation (2015)](https://arxiv.org/abs/1505.04597) | 6 | | ResUNet | ResUNet | TFResUNet | [Road Extraction by Deep Residual U-Net (2017)](https://arxiv.org/abs/1711.10684) | 7 | | UNet++ | Unet2plus | TFUnet2plus | [UNet++: A Nested U-Net Architecture for Medical Image Segmentation (2018)](https://arxiv.org/abs/1807.10165) | 8 | | TransUNet | TransUNet | X | [TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation (2021)](https://arxiv.org/abs/2102.04306) | 9 | | UNETR | UNETR | X | [UNETR: Transformers for 3D Medical Image Segmentation (2021)](https://arxiv.org/abs/2103.10504) | 10 | | SegFormer | SegFormer | X | [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers (2021)](https://arxiv.org/abs/2105.15203) | 11 | -------------------------------------------------------------------------------- /fusionlab/segmentation/__init__.py: -------------------------------------------------------------------------------- 1 | from fusionlab import BACKEND 2 | if BACKEND['torch']: 3 | from .unet.unet import UNet 4 | from .resunet.resunet import ResUNet 5 | from .unet2plus.unet2plus import UNet2plus 6 | from .transunet.transunet import TransUNet 7 | from .unetr.unetr import UNETR 8 | from .segformer.segformer import SegFormer 9 | from .base import HFSegmentationModel 10 | if BACKEND['tf']: 11 | from .unet.tfunet import TFUNet 12 | from .resunet.tfresunet import TFResUNet 13 | from .unet2plus.tfunet2plus import TFUNet2plus -------------------------------------------------------------------------------- /fusionlab/segmentation/base.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class SegmentationModel(nn.Module): 4 | """ 5 | Base PyTorch class of the segmentation model with Encoder, Bridger, Decoder, Head 6 | """ 7 | def forward(self, x): 8 | features = self.encoder(x) 9 | feature_fusion = self.bridger(features) 10 | decoder_output = self.decoder(feature_fusion) 11 | output = self.head(decoder_output) 12 | return output 13 | 14 | class HFSegmentationModel(nn.Module): 15 | """ 16 | Base Hugginface-pytoch model wrapper class of the segmentation model 17 | """ 18 | def __init__(self, model, num_cls=None, 19 | loss_fct=nn.CrossEntropyLoss()): 20 | super().__init__() 21 | self.net = model 22 | if 'num_cls' in model.__dict__.keys(): 23 | self.num_cls = model.num_cls 24 | else: 25 | self.num_cls = num_cls 26 | assert self.num_cls is not None, "num_cls is not defined" 27 | self.loss_fct = loss_fct 28 | def forward(self, x, labels=None): 29 | logits = self.net(x) # Forward pass the model 30 | if labels is not None: 31 | # logits => [BATCH, NUM_CLS] 32 | # labels => [BATCH] 33 | loss = self.loss_fct(logits, labels) # Calculate loss 34 | else: 35 | loss = None 36 | # return dictionary for hugginface trainer 37 | return {'loss':loss, 'logits':logits, 'hidden_states':None} 38 | 39 | 40 | if __name__ == '__main__': 41 | import torch 42 | from fusionlab.segmentation import ResUNet 43 | H = W = 224 44 | cout = 5 45 | inputs = torch.normal(0, 1, (1, 3, H, W)) 46 | 47 | model = ResUNet(3, cout, 64) 48 | hf_model = HFSegmentationModel(model, cout) 49 | output = hf_model(inputs) 50 | assert list(output.keys()) == ['loss', 'logits', 'hidden_states'] 51 | print(output['logits'].shape) 52 | assert list(output['logits'].shape) == [1, cout, H, W] -------------------------------------------------------------------------------- /fusionlab/segmentation/resunet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/segmentation/resunet/__init__.py -------------------------------------------------------------------------------- /fusionlab/segmentation/resunet/resunet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from fusionlab.segmentation.base import SegmentationModel 4 | from fusionlab.utils import autopad 5 | from fusionlab.layers.factories import ConvND, ConvT, BatchNorm 6 | 7 | 8 | 9 | class ResUNet(SegmentationModel): 10 | def __init__( 11 | self, 12 | cin, 13 | num_cls, 14 | base_dim=64, 15 | spatial_dims=2 16 | ): 17 | super().__init__() 18 | self.num_cls = num_cls 19 | self.encoder = Encoder(cin, base_dim, spatial_dims) 20 | self.bridger = Bridger() 21 | self.decoder = Decoder(base_dim, spatial_dims) 22 | self.head = Head(base_dim, num_cls, spatial_dims) 23 | 24 | 25 | class Encoder(nn.Module): 26 | def __init__(self, cin, base_dim, spatial_dims=2): 27 | super().__init__() 28 | dims = [base_dim * (2 ** i) for i in range(4)] 29 | self.stem = Stem(cin, dims[0], spatial_dims) 30 | self.stage1 = ResConv(dims[0], dims[1], spatial_dims, stride=2) 31 | self.stage2 = ResConv(dims[1], dims[2], spatial_dims, stride=2) 32 | self.stage3 = ResConv(dims[2], dims[3], spatial_dims, stride=2) 33 | 34 | def forward(self, x): 35 | s0 = self.stem(x) 36 | s1 = self.stage1(s0) 37 | s2 = self.stage2(s1) 38 | s3 = self.stage3(s2) 39 | return [s0, s1, s2, s3] 40 | 41 | 42 | class Decoder(nn.Module): 43 | def __init__(self, base_dim, spatial_dims=2): 44 | """ 45 | Base UNet decoder 46 | Args: 47 | base_dim (int): output dim of deepest stage output or input channels 48 | """ 49 | super().__init__() 50 | dims = [base_dim*(2**i) for i in range(4)] 51 | self.d3 = DecoderBlock(dims[3], dims[2], spatial_dims) 52 | self.d2 = DecoderBlock(dims[2], dims[1], spatial_dims) 53 | self.d1 = DecoderBlock(dims[1], dims[0], spatial_dims) 54 | 55 | def forward(self, x): 56 | s0, s1, s2, s3 = x 57 | 58 | x = self.d3(s3, s2) 59 | x = self.d2(x, s1) 60 | x = self.d1(x, s0) 61 | return x 62 | 63 | 64 | class DecoderBlock(nn.Module): 65 | def __init__(self, cin, cout, spatial_dims=2): 66 | super().__init__() 67 | self.upsample = ConvT(spatial_dims, cin, cout, 2, stride=2) 68 | self.conv = ResConv(cout*2, cout, spatial_dims, 1) 69 | 70 | def forward(self, x1, x2): 71 | x1 = self.upsample(x1) 72 | x = torch.cat([x1, x2], dim=1) 73 | return self.conv(x) 74 | 75 | 76 | class Bridger(nn.Module): 77 | def __init__(self): 78 | super().__init__() 79 | 80 | def forward(self, x): 81 | outputs = [nn.Identity()(i) for i in x] 82 | return outputs 83 | 84 | 85 | class Stem(nn.Module): 86 | def __init__(self, cin, cout, spatial_dims=2): 87 | super().__init__() 88 | self.conv = nn.Sequential( 89 | ConvND(spatial_dims, cin, cout, 3, padding=autopad(3)), 90 | BatchNorm(spatial_dims, cout), 91 | nn.ReLU(), 92 | ConvND(spatial_dims, cout, cout, 3, padding=autopad(3)), 93 | ) 94 | self.skip = nn.Sequential( 95 | ConvND(spatial_dims, cin, cout, 3, padding=autopad(3)), 96 | ) 97 | 98 | def forward(self, x): 99 | return self.conv(x) + self.skip(x) 100 | 101 | 102 | class ResConv(nn.Module): 103 | def __init__(self, cin, cout, spatial_dims=2, stride=1): 104 | super().__init__() 105 | 106 | self.conv = nn.Sequential( 107 | BatchNorm(spatial_dims, cin), 108 | nn.ReLU(), 109 | ConvND(spatial_dims, cin, cout, 3, stride, padding=autopad(3)), 110 | BatchNorm(spatial_dims, cout), 111 | nn.ReLU(), 112 | ConvND(spatial_dims, cout, cout, 3, padding=autopad(3)), 113 | ) 114 | self.skip = nn.Sequential( 115 | ConvND(spatial_dims, cin, cout, 3, stride=stride, padding=autopad(3)), 116 | BatchNorm(spatial_dims, cout), 117 | ) 118 | 119 | def forward(self, x): 120 | return self.conv(x) + self.skip(x) 121 | 122 | 123 | class Head(nn.Sequential): 124 | def __init__(self, cin, cout, spatial_dims): 125 | """ 126 | Basic conv head 127 | :param int cin: input channel 128 | :param int cout: output channel 129 | """ 130 | conv = ConvND(spatial_dims, cin, cout, 1) 131 | super().__init__(conv) 132 | 133 | if __name__ == '__main__': 134 | H = W = 224 135 | cout = 64 136 | inputs = torch.normal(0, 1, (1, 3, H, W)) 137 | 138 | model = ResUNet(3, 100, cout) 139 | output = model(inputs) 140 | print(output.shape) 141 | 142 | dblock = DecoderBlock(64, 128) 143 | inputs2 = torch.normal(0, 1, (1, 128, H, W)) 144 | inputs1 = torch.normal(0, 1, (1, 64, H//2, W//2)) 145 | outputs = dblock(inputs1, inputs2) 146 | print(outputs.shape) 147 | 148 | encoder = Encoder(3, cout) 149 | outputs = encoder(inputs) 150 | for o in outputs: 151 | print(o.shape) 152 | 153 | decoder = Decoder(cout) 154 | outputs = decoder(outputs) 155 | print("Encoder + Decoder ", outputs.shape) 156 | 157 | stem = Stem(3, cout) 158 | outputs = stem(inputs) 159 | print(outputs.shape) 160 | assert list(outputs.shape) == [1, cout, H, W] 161 | 162 | resconv = ResConv(3, cout, stride=1) 163 | outputs = resconv(inputs) 164 | print(outputs.shape) 165 | assert list(outputs.shape) == [1, cout, H, W] 166 | 167 | resconv = ResConv(3, cout, stride=2) 168 | outputs = resconv(inputs) 169 | print(outputs.shape) 170 | assert list(outputs.shape) == [1, cout, H//2, W//2] 171 | 172 | 173 | print("3D ResUNet") 174 | D = H = W = 64 175 | cout = 32 176 | inputs = torch.rand(1, 3, D, H, W) 177 | 178 | model = ResUNet(3, 100, cout, spatial_dims=3) 179 | output = model(inputs) 180 | print(output.shape) 181 | 182 | print("1D ResUNet") 183 | L = 64 184 | cout = 32 185 | inputs = torch.rand(1, 3, L) 186 | 187 | model = ResUNet(3, 100, cout, spatial_dims=1) 188 | output = model(inputs) 189 | print(output.shape) 190 | -------------------------------------------------------------------------------- /fusionlab/segmentation/resunet/tfresunet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import layers, Model, Sequential 3 | from fusionlab.segmentation.tfbase import TFSegmentationModel 4 | 5 | 6 | class TFResUNet(TFSegmentationModel): 7 | def __init__(self, num_cls, base_dim): 8 | super().__init__() 9 | self.encoder = Encoder(base_dim) 10 | self.bridger = Bridger() 11 | self.decoder = Decoder(base_dim) 12 | self.head = Head(num_cls) 13 | 14 | 15 | class Encoder(Model): 16 | def __init__(self, base_dim): 17 | super().__init__() 18 | dims = [base_dim * (2 ** i) for i in range(4)] 19 | self.stem = Stem(dims[0]) 20 | self.stage1 = ResConv(dims[1], stride=2) 21 | self.stage2 = ResConv(dims[2], stride=2) 22 | self.stage3 = ResConv(dims[3], stride=2) 23 | 24 | def call(self, x, training): 25 | s0 = self.stem(x, training) 26 | s1 = self.stage1(s0, training) 27 | s2 = self.stage2(s1, training) 28 | s3 = self.stage3(s2, training) 29 | return [s0, s1, s2, s3] 30 | 31 | 32 | class Decoder(Model): 33 | def __init__(self, base_dim): 34 | """ 35 | Base UNet decoder 36 | Args: 37 | base_dim (int): output dim of deepest stage output or input channels 38 | """ 39 | super().__init__() 40 | dims = [base_dim*(2**i) for i in range(4)] 41 | self.d3 = DecoderBlock(dims[2]) 42 | self.d2 = DecoderBlock(dims[1]) 43 | self.d1 = DecoderBlock(dims[0]) 44 | 45 | def call(self, x, training): 46 | s0, s1, s2, s3 = x 47 | 48 | x = self.d3(s3, s2, training) 49 | x = self.d2(x, s1, training) 50 | x = self.d1(x, s0, training) 51 | return x 52 | 53 | 54 | class DecoderBlock(Model): 55 | def __init__(self, cout): 56 | super().__init__() 57 | self.upsample = layers.Conv2DTranspose(cout, 2, strides=2) 58 | self.conv = ResConv(cout, stride=1) 59 | 60 | def call(self, x1, x2, training): 61 | x1 = self.upsample(x1) 62 | x = tf.concat([x1, x2], axis=-1) 63 | return self.conv(x) 64 | 65 | 66 | class Bridger(Model): 67 | def __init__(self): 68 | super().__init__() 69 | 70 | def call(self, x, training): 71 | outputs = [tf.identity(i) for i in x] 72 | return outputs 73 | 74 | 75 | class Stem(Model): 76 | def __init__(self, cout): 77 | super().__init__() 78 | self.conv = Sequential([ 79 | layers.Conv2D(cout, 3, padding='same'), 80 | layers.BatchNormalization(), 81 | layers.ReLU(), 82 | layers.Conv2D(cout, 3, padding='same'), 83 | ]) 84 | self.skip = Sequential( 85 | layers.Conv2D(cout, 3, padding='same'), 86 | ) 87 | 88 | def call(self, x, training): 89 | return self.conv(x) + self.skip(x) 90 | 91 | 92 | class ResConv(Model): 93 | def __init__(self, cout, stride=1): 94 | super().__init__() 95 | 96 | self.conv = Sequential([ 97 | layers.BatchNormalization(), 98 | layers.ReLU(), 99 | layers.Conv2D(cout, 3, stride, padding='same'), 100 | layers.BatchNormalization(), 101 | layers.ReLU(), 102 | layers.Conv2D(cout, 3, padding='same'), 103 | ]) 104 | self.skip = Sequential([ 105 | layers.Conv2D(cout, 3, strides=stride, padding='same'), 106 | layers.BatchNormalization(), 107 | ]) 108 | 109 | def call(self, x, training=None): 110 | return self.conv(x, training) + self.skip(x, training) 111 | 112 | 113 | class Head(Sequential): 114 | def __init__(self, cout): 115 | """ 116 | Basic conv head 117 | :param int cout: number of classes 118 | """ 119 | conv = layers.Conv2D(cout, 1) 120 | super().__init__(conv) 121 | 122 | 123 | if __name__ == '__main__': 124 | H = W = 224 125 | cout = 64 126 | inputs = tf.random.normal((1, H, W, 3)) 127 | 128 | model = TFResUNet(100, cout) 129 | output = model(inputs, training=True) 130 | print(output.shape) 131 | 132 | dblock = DecoderBlock(128) 133 | inputs2 = tf.random.normal((1, H, W, 128)) 134 | inputs1 = tf.random.normal((1, H//2, W//2, 64)) 135 | outputs = dblock(inputs1, inputs2, training=True) 136 | print(outputs.shape) 137 | 138 | encoder = Encoder(cout) 139 | outputs = encoder(inputs, training=True) 140 | for o in outputs: 141 | print(o.shape) 142 | 143 | decoder = Decoder(cout) 144 | outputs = decoder(outputs, training=True) 145 | print("Encoder + Decoder ", outputs.shape) 146 | 147 | stem = Stem(cout) 148 | outputs = stem(inputs, training=True) 149 | print(outputs.shape) 150 | assert list(outputs.shape) == [1, H, W, cout] 151 | 152 | resconv = ResConv(cout, 1) 153 | outputs = resconv(inputs, training=True) 154 | print(outputs.shape) 155 | assert list(outputs.shape) == [1, H, W, cout] 156 | 157 | resconv = ResConv(cout, stride=2) 158 | outputs = resconv(inputs, training=True) 159 | print(outputs.shape) 160 | assert list(outputs.shape) == [1, H//2, W//2, cout] -------------------------------------------------------------------------------- /fusionlab/segmentation/segformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/segmentation/segformer/__init__.py -------------------------------------------------------------------------------- /fusionlab/segmentation/segformer/segformer.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from fusionlab.encoders import ( 6 | MiT, 7 | MiTB0, 8 | MiTB1, 9 | MiTB2, 10 | MiTB3, 11 | MiTB4, 12 | MiTB5, 13 | ) 14 | 15 | class MLP(nn.Module): 16 | def __init__(self, dim, embed_dim): 17 | super().__init__() 18 | self.proj = nn.Linear(dim, embed_dim) 19 | 20 | def forward(self, x: torch.Tensor) -> torch.Tensor: 21 | x = x.flatten(2).transpose(1, 2) 22 | x = self.proj(x) 23 | return x 24 | 25 | 26 | class ConvModule(nn.Module): 27 | def __init__(self, c1, c2): 28 | super().__init__() 29 | self.conv = nn.Conv2d(c1, c2, 1, bias=False) 30 | self.bn = nn.BatchNorm2d(c2) # use SyncBN in original 31 | self.activate = nn.ReLU(True) 32 | 33 | def forward(self, x: torch.Tensor) -> torch.Tensor: 34 | return self.activate(self.bn(self.conv(x))) 35 | 36 | class SegFormerHead(nn.Module): 37 | def __init__(self, dims: list, embed_dim: int = 256, num_classes: int = 19): 38 | super().__init__() 39 | for i, dim in enumerate(dims): 40 | self.add_module(f"linear_c{i+1}", MLP(dim, embed_dim)) 41 | 42 | self.linear_fuse = ConvModule(embed_dim*4, embed_dim) 43 | self.linear_pred = nn.Conv2d(embed_dim, num_classes, 1) 44 | self.dropout = nn.Dropout2d(0.1) 45 | 46 | def forward(self, features: Sequence[torch.Tensor]) -> torch.Tensor: 47 | B, _, H, W = features[0].shape 48 | outs = [self.linear_c1(features[0]).permute(0, 2, 1).reshape(B, -1, *features[0].shape[-2:])] 49 | 50 | for i, feature in enumerate(features[1:]): 51 | cf = eval(f"self.linear_c{i+2}")(feature).permute(0, 2, 1).reshape(B, -1, *feature.shape[-2:]) 52 | outs.append(F.interpolate(cf, size=(H, W), mode='bilinear', align_corners=False)) 53 | 54 | seg = self.linear_fuse(torch.cat(outs[::-1], dim=1)) 55 | seg = self.linear_pred(self.dropout(seg)) 56 | return seg 57 | 58 | class SegFormer(nn.Module): 59 | """ 60 | SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers 61 | 62 | source code: https://github.com/sithu31296/semantic-segmentation/blob/main/semseg/models/segformer.py 63 | 64 | Args: 65 | 66 | num_classes (int): number of classes to segment 67 | mit_encoder_type (str): type of MiT encoder, one of ['B0', 'B1', 'B2', 'B3', 'B4', 'B5'] 68 | """ 69 | def __init__( 70 | self, 71 | num_classes: int = 6, 72 | mit_encoder_type: str = 'B0' 73 | ): 74 | super().__init__() 75 | self.encoder: MiT = eval(f'MiT{mit_encoder_type}')() 76 | embed_dim = self.encoder.channels[-1] 77 | self.decode_head = SegFormerHead( 78 | self.encoder.channels, 79 | embed_dim, 80 | num_classes, 81 | ) 82 | 83 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 84 | _, features = self.encoder(inputs, return_features=True) 85 | x = self.decode_head(features) # 4x reduction in image size 86 | x = F.interpolate(x, size=inputs.shape[2:], mode='bilinear', align_corners=False) 87 | return x 88 | 89 | if __name__ == '__main__': 90 | model = SegFormer(num_classes=6) 91 | x = torch.randn(1, 3, 128, 128) 92 | outputs = model(x) 93 | print(outputs.shape) 94 | 95 | -------------------------------------------------------------------------------- /fusionlab/segmentation/tfbase.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras import Model 2 | 3 | class TFSegmentationModel(Model): 4 | """ 5 | Base PyTorch class of the segmentation model with Encoder, Bridger, Decoder, Head 6 | """ 7 | def call(self, x, training=None): 8 | """ 9 | 10 | Args: 11 | x: input tensor 12 | training: flag for BatchNormalization and Dropout, whether the layer should behave in training mode or in 13 | inference mode 14 | 15 | Returns: 16 | 17 | """ 18 | features = self.encoder(x, training) 19 | feature_fusion = self.bridger(features, training) 20 | decoder_output = self.decoder(feature_fusion, training) 21 | output = self.head(decoder_output, training) 22 | return output 23 | -------------------------------------------------------------------------------- /fusionlab/segmentation/transunet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/segmentation/transunet/__init__.py -------------------------------------------------------------------------------- /fusionlab/segmentation/unet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/segmentation/unet/__init__.py -------------------------------------------------------------------------------- /fusionlab/segmentation/unet/tfunet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import layers, Model, Sequential 3 | from fusionlab.segmentation.tfbase import TFSegmentationModel 4 | 5 | 6 | class TFUNet(TFSegmentationModel): 7 | def __init__(self, num_cls, base_dim=64): 8 | """ 9 | Base Unet 10 | Args: 11 | num_cls (int): number of classes 12 | base_dim (int): 1st stage dim of conv output 13 | """ 14 | super().__init__() 15 | stage = 5 16 | self.encoder = Encoder(base_dim) 17 | self.bridger = Bridger() 18 | self.decoder = Decoder(base_dim*(2**(stage-2))) # 512 19 | self.head = Head(num_cls) 20 | 21 | 22 | class Encoder(Model): 23 | def __init__(self, base_dim): 24 | """ 25 | UNet Encoder 26 | Args: 27 | base_dim (int): 1st stage dim of conv output 28 | """ 29 | super().__init__() 30 | self.pool = layers.MaxPool2D() 31 | self.stage1 = BasicBlock(base_dim) 32 | self.stage2 = BasicBlock(base_dim * 2) 33 | self.stage3 = BasicBlock(base_dim * 4) 34 | self.stage4 = BasicBlock(base_dim * 8) 35 | self.stage5 = BasicBlock(base_dim * 16) 36 | 37 | def call(self, x, training=None): 38 | s1 = self.stage1(x, training) 39 | x = self.pool(s1) 40 | s2 = self.stage2(x, training) 41 | x = self.pool(s2) 42 | s3 = self.stage3(x, training) 43 | x = self.pool(s3) 44 | s4 = self.stage4(x, training) 45 | x = self.pool(s4) 46 | s5 = self.stage5(x, training) 47 | 48 | return [s1, s2, s3, s4, s5] 49 | 50 | 51 | class Decoder(Model): 52 | def __init__(self, base_dim): 53 | """ 54 | Base UNet decoder 55 | Args: 56 | base_dim (int): output dim of deepest stage output 57 | """ 58 | super().__init__() 59 | self.d4 = DecoderBlock(base_dim) 60 | self.d3 = DecoderBlock(base_dim//2) 61 | self.d2 = DecoderBlock(base_dim//4) 62 | self.d1 = DecoderBlock(base_dim//8) 63 | 64 | def call(self, x, training=None): 65 | f1, f2, f3, f4, f5 = x 66 | x = self.d4(f5, f4, training) 67 | x = self.d3(x, f3, training) 68 | x = self.d2(x, f2, training) 69 | x = self.d1(x, f1, training) 70 | return x 71 | 72 | 73 | class Bridger(Model): 74 | def __init__(self): 75 | super().__init__() 76 | 77 | def call(self, x, training=None): 78 | outputs = [tf.identity(i) for i in x] 79 | return outputs 80 | 81 | 82 | class Head(Sequential): 83 | def __init__(self, cout): 84 | """ 85 | Basic Identity 86 | :param int cout: output channel 87 | """ 88 | conv = layers.Conv2D(cout, 1) 89 | super().__init__(conv) 90 | 91 | 92 | class BasicBlock(Sequential): 93 | def __init__(self, cout): 94 | conv1 = Sequential([ 95 | layers.Conv2D(cout, 3, 1, padding='same'), 96 | layers.ReLU(), 97 | ]) 98 | conv2 = Sequential([ 99 | layers.Conv2D(cout, 3, 1, padding='same'), 100 | layers.ReLU(), 101 | ]) 102 | super().__init__([conv1, conv2]) 103 | 104 | 105 | class DecoderBlock(Model): 106 | def __init__(self, cout): 107 | """ 108 | Base Unet decoder block for merging the outputs from 2 stages 109 | Args: 110 | cout: output dim of the block 111 | """ 112 | super().__init__() 113 | self.up = layers.UpSampling2D() 114 | self.conv = BasicBlock(cout) 115 | 116 | def call(self, x1, x2, training=None): 117 | x1 = self.up(x1) 118 | x = tf.concat([x1, x2], axis=-1) 119 | x = self.conv(x, training) 120 | return x 121 | 122 | 123 | if __name__ == '__main__': 124 | H = W = 224 125 | dim = 64 126 | inputs = tf.random.normal((1, H, W, 3)) 127 | 128 | encoder = Encoder(dim) 129 | encoder.build((1, H, W, 3)) 130 | outputs = encoder(inputs) 131 | for i, o in enumerate(outputs): 132 | assert list(o.shape) == [1, H // (2 ** i), W // (2 ** i), dim * (2 ** i)] 133 | 134 | bridger = Bridger() 135 | outputs = bridger(outputs) 136 | for i, o in enumerate(outputs): 137 | assert list(o.shape) == [1, H // (2 ** i), W // (2 ** i), dim * (2 ** i)] 138 | 139 | features = [tf.random.normal((1, H // (2 ** i), W // (2 ** i), dim * (2 ** i))) for i in range(5)] 140 | decoder = Decoder(512) 141 | outputs = decoder(features) 142 | assert list(outputs.shape) == [1, H, W, 64] 143 | 144 | head = Head(10) 145 | outputs = head(outputs) 146 | assert list(outputs.shape) == [1, H, W, 10] 147 | 148 | unet = TFUNet(10) 149 | outputs = unet(inputs) 150 | assert list(outputs.shape) == [1, H, W, 10] 151 | -------------------------------------------------------------------------------- /fusionlab/segmentation/unet/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from fusionlab.segmentation.base import SegmentationModel 4 | from fusionlab.utils import autopad 5 | from fusionlab.layers import ConvND, MaxPool 6 | 7 | 8 | class UNet(SegmentationModel): 9 | def __init__(self, cin, num_cls, base_dim=64, spatial_dims=2): 10 | """ 11 | Base Unet 12 | Args: 13 | cin (int): input channels 14 | num_cls (int): number of classes 15 | base_dim (int): 1st stage dim of conv output 16 | """ 17 | super().__init__() 18 | stage = 5 19 | self.num_cls = num_cls 20 | self.encoder = Encoder(cin, base_dim=base_dim, spatial_dims=spatial_dims) 21 | self.bridger = Bridger() 22 | self.decoder = Decoder(cin=base_dim*(2**(stage-1)), 23 | base_dim=base_dim*(2**(stage-2)), 24 | spatial_dims=spatial_dims) # 1024, 512 25 | self.head = Head(base_dim, num_cls, spatial_dims=spatial_dims) 26 | 27 | 28 | class Encoder(nn.Module): 29 | def __init__(self, cin, base_dim, spatial_dims=2): 30 | """ 31 | UNet Encoder 32 | Args: 33 | cin (int): input channels 34 | base_dim (int): 1st stage dim of conv output 35 | """ 36 | super().__init__() 37 | self.pool = MaxPool(spatial_dims, 2, 2) 38 | self.stage1 = BasicBlock(cin, base_dim, spatial_dims) 39 | self.stage2 = BasicBlock(base_dim, base_dim * 2, spatial_dims) 40 | self.stage3 = BasicBlock(base_dim * 2, base_dim * 4, spatial_dims) 41 | self.stage4 = BasicBlock(base_dim * 4, base_dim * 8, spatial_dims) 42 | self.stage5 = BasicBlock(base_dim * 8, base_dim * 16, spatial_dims) 43 | 44 | def forward(self, x): 45 | s1 = self.stage1(x) 46 | x = self.pool(s1) 47 | s2 = self.stage2(x) 48 | x = self.pool(s2) 49 | s3 = self.stage3(x) 50 | x = self.pool(s3) 51 | s4 = self.stage4(x) 52 | x = self.pool(s4) 53 | s5 = self.stage5(x) 54 | 55 | return [s1, s2, s3, s4, s5] 56 | 57 | 58 | class Decoder(nn.Module): 59 | def __init__(self, cin, base_dim, spatial_dims=2): 60 | """ 61 | Base UNet decoder 62 | Args: 63 | cin (int): input channels 64 | base_dim (int): output dim of deepest stage output 65 | """ 66 | super().__init__() 67 | self.d4 = DecoderBlock(cin, cin//2, base_dim, spatial_dims) 68 | self.d3 = DecoderBlock(base_dim, cin//4, base_dim//2, spatial_dims) 69 | self.d2 = DecoderBlock(base_dim//2, cin//8, base_dim//4, spatial_dims) 70 | self.d1 = DecoderBlock(base_dim//4, cin//16, base_dim//8, spatial_dims) 71 | 72 | def forward(self, x): 73 | f1, f2, f3, f4, f5 = x 74 | x = self.d4(f5, f4) 75 | x = self.d3(x, f3) 76 | x = self.d2(x, f2) 77 | x = self.d1(x, f1) 78 | return x 79 | 80 | 81 | class Bridger(nn.Module): 82 | def __init__(self): 83 | super().__init__() 84 | 85 | def forward(self, x): 86 | outputs = [nn.Identity()(i) for i in x] 87 | return outputs 88 | 89 | 90 | class Head(nn.Sequential): 91 | def __init__(self, cin, cout, spatial_dims=2): 92 | """ 93 | Basic conv head 94 | :param int cin: input channel 95 | :param int cout: output channel 96 | """ 97 | conv = ConvND(spatial_dims, cin, cout, 1) 98 | super().__init__(conv) 99 | 100 | 101 | class BasicBlock(nn.Sequential): 102 | def __init__(self, cin, cout, spatial_dims=2): 103 | conv1 = nn.Sequential( 104 | ConvND(spatial_dims, cin, cout, 3, 1, autopad(3)), 105 | nn.ReLU(), 106 | ) 107 | conv2 = nn.Sequential( 108 | ConvND(spatial_dims, cout, cout, 3, 1, autopad(3)), 109 | nn.ReLU(), 110 | ) 111 | super().__init__(conv1, conv2) 112 | 113 | 114 | class DecoderBlock(nn.Module): 115 | def __init__(self, c1, c2, cout, spatial_dims=2): 116 | """ 117 | Base Unet decoder block for merging the outputs from 2 stages 118 | Args: 119 | c1: input dim of the deeper stage 120 | c2: input dim of the shallower stage 121 | cout: output dim of the block 122 | """ 123 | super().__init__() 124 | self.up = nn.Upsample(scale_factor=2) 125 | self.conv = BasicBlock(c1 + c2, cout, spatial_dims) 126 | 127 | def forward(self, x1, x2): 128 | x1 = self.up(x1) 129 | x = torch.concat([x1, x2], dim=1) 130 | x = self.conv(x) 131 | return x 132 | 133 | 134 | if __name__ == '__main__': 135 | print("2D UNet") 136 | H = W = 224 137 | dim = 64 138 | inputs = torch.normal(0, 1, (1, 3, H, W)) 139 | 140 | encoder = Encoder(3, base_dim=dim) 141 | outputs = encoder(inputs) 142 | for i, o in enumerate(outputs): 143 | assert list(o.shape) == [1, dim*(2**i), H//(2**i), W//(2**i)] 144 | 145 | bridger = Bridger() 146 | outputs = bridger(outputs) 147 | for i, o in enumerate(outputs): 148 | assert list(o.shape) == [1, dim * (2 ** i), H // (2 ** i), W // (2 ** i)] 149 | 150 | features = [torch.normal(0, 1, (1, dim * (2 ** i), H // (2 ** i), W // (2 ** i))) for i in range(5)] 151 | decoder = Decoder(1024, 512) 152 | outputs = decoder(features) 153 | assert list(outputs.shape) == [1, 64, H, W] 154 | 155 | head = Head(64, 10) 156 | outputs = head(outputs) 157 | assert list(outputs.shape) == [1, 10, H, W] 158 | 159 | unet = UNet(3, 10) 160 | outputs = unet(inputs) 161 | assert list(outputs.shape) == [1, 10, H, W] 162 | 163 | print("1D UNet") 164 | L = 224 165 | dim = 64 166 | inputs = torch.rand(1, 3, L) 167 | unet = UNet(3, 10, spatial_dims=1) 168 | outputs = unet(inputs) 169 | assert list(outputs.shape) == [1, 10, L] 170 | 171 | print("3D UNet") 172 | D = H = W = 64 173 | dim = 32 174 | inputs = torch.rand(1, 3, D, H, W) 175 | unet = UNet(3, 10, spatial_dims=3) 176 | outputs = unet(inputs) 177 | assert list(outputs.shape) == [1, 10, D, H, W] -------------------------------------------------------------------------------- /fusionlab/segmentation/unet2plus/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/segmentation/unet2plus/__init__.py -------------------------------------------------------------------------------- /fusionlab/segmentation/unet2plus/tfunet2plus.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import layers, Model, Sequential 3 | from fusionlab.segmentation.tfbase import TFSegmentationModel 4 | 5 | 6 | class TFUNet2plus(TFSegmentationModel): 7 | def __init__(self, num_cls, base_dim): 8 | super().__init__() 9 | self.encoder = Encoder(base_dim) 10 | self.bridger = Bridger() 11 | self.decoder = Decoder(base_dim) 12 | self.head = Head(num_cls) 13 | 14 | 15 | class BasicBlock(Sequential): 16 | def __init__(self, cout): 17 | conv1 = Sequential([ 18 | layers.Conv2D(cout, 3, 1, padding='same'), 19 | layers.BatchNormalization(), 20 | layers.Activation(tf.nn.relu), 21 | ]) 22 | conv2 = Sequential([ 23 | layers.Conv2D(cout, 3, 1, padding='same'), 24 | layers.BatchNormalization(), 25 | layers.Activation(tf.nn.relu), 26 | ]) 27 | super().__init__([conv1, conv2]) 28 | 29 | 30 | class Encoder(Model): 31 | def __init__(self, base_dim): 32 | """ 33 | UNet Encoder 34 | Args: 35 | base_dim (int): 1st stage dim of conv output 36 | """ 37 | super().__init__() 38 | self.pool = layers.MaxPool2D() 39 | self.conv0_0 = BasicBlock(base_dim) 40 | self.conv1_0 = BasicBlock(base_dim * 2) 41 | self.conv2_0 = BasicBlock(base_dim * 4) 42 | self.conv3_0 = BasicBlock(base_dim * 8) 43 | self.conv4_0 = BasicBlock(base_dim * 16) 44 | 45 | def call(self, x, training): 46 | x0_0 = self.conv0_0(x, training) 47 | x1_0 = self.conv1_0(self.pool(x0_0), training) 48 | x2_0 = self.conv2_0(self.pool(x1_0), training) 49 | x3_0 = self.conv3_0(self.pool(x2_0), training) 50 | x4_0 = self.conv4_0(self.pool(x3_0), training) 51 | return [x0_0, x1_0, x2_0, x3_0, x4_0] 52 | 53 | 54 | class Bridger(Model): 55 | def call(self, x, training=None): 56 | return [tf.identity(i) for i in x] 57 | 58 | 59 | class Decoder(Model): 60 | def __init__(self, base_dim): 61 | super().__init__() 62 | dims = [base_dim*(2**i) for i in range(5)] # [base_dim, base_dim*2, base_dim*4, base_dim*8, base_dim*16] 63 | self.conv0_1 = BasicBlock(dims[0]) 64 | self.conv1_1 = BasicBlock(dims[1]) 65 | self.conv2_1 = BasicBlock(dims[2]) 66 | self.conv3_1 = BasicBlock(dims[3]) 67 | 68 | self.conv0_2 = BasicBlock(dims[0]) 69 | self.conv1_2 = BasicBlock(dims[1]) 70 | self.conv2_2 = BasicBlock(dims[2]) 71 | 72 | self.conv0_3 = BasicBlock(dims[0]) 73 | self.conv1_3 = BasicBlock(dims[1]) 74 | 75 | self.conv0_4 = BasicBlock(dims[0]) 76 | self.up = layers.UpSampling2D() 77 | 78 | def call(self, x, training=None): 79 | x0_0, x1_0, x2_0, x3_0, x4_0 = x 80 | 81 | x0_1 = self.conv0_1(layers.concatenate([x0_0, self.up(x1_0)], -1)) 82 | x1_1 = self.conv1_1(layers.concatenate([x1_0, self.up(x2_0)], -1)) 83 | x0_2 = self.conv0_2(layers.concatenate([x0_0, x0_1, self.up(x1_1)], -1)) 84 | 85 | x2_1 = self.conv2_1(layers.concatenate([x2_0, self.up(x3_0)], -1)) 86 | x1_2 = self.conv1_2(layers.concatenate([x1_0, x1_1, self.up(x2_1)], -1)) 87 | x0_3 = self.conv0_3(layers.concatenate([x0_0, x0_1, x0_2, self.up(x1_2)], -1)) 88 | 89 | x3_1 = self.conv3_1(layers.concatenate([x3_0, self.up(x4_0)], -1)) 90 | x2_2 = self.conv2_2(layers.concatenate([x2_0, x2_1, self.up(x3_1)], -1)) 91 | x1_3 = self.conv1_3(layers.concatenate([x1_0, x1_1, x1_2, self.up(x2_2)], -1)) 92 | x0_4 = self.conv0_4(layers.concatenate([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], -1)) 93 | 94 | return x0_4 95 | 96 | 97 | class Head(Sequential): 98 | def __init__(self, cout): 99 | """ 100 | Basic Identity 101 | :param int cout: output channel 102 | """ 103 | conv = layers.Conv2D(cout, 1) 104 | super().__init__(conv) 105 | 106 | 107 | if __name__ == '__main__': 108 | H = W = 224 109 | dim = 32 110 | num_cls = 10 111 | inputs = tf.random.normal((1, H, W, 3)) 112 | 113 | encoder = Encoder(base_dim=dim) 114 | outputs = encoder(inputs, training=True) 115 | for i, o in enumerate(outputs): 116 | assert list(o.shape) == [1, H // (2 ** i), W // (2 ** i), dim * (2 ** i)] 117 | 118 | bridger = Bridger() 119 | outputs = bridger(outputs, training=True) 120 | for i, o in enumerate(outputs): 121 | assert list(o.shape) == [1, H // (2 ** i), W // (2 ** i), dim * (2 ** i)] 122 | 123 | features = [tf.random.normal((1, H // (2 ** i), W // (2 ** i), dim * (2 ** i))) for i in range(5)] 124 | decoder = Decoder(dim) 125 | decoder.build([f.shape for f in features]) 126 | outputs = decoder(features, training=True) 127 | assert list(outputs.shape) == [1, H, W, dim] 128 | 129 | head = Head(num_cls) 130 | outputs = head(outputs, training=True) 131 | assert list(outputs.shape) == [1, H, W, num_cls] 132 | 133 | unet = TFUNet2plus(num_cls, dim) 134 | outputs = unet(inputs, training=True) 135 | assert list(outputs.shape) == [1, H, W, num_cls] 136 | 137 | -------------------------------------------------------------------------------- /fusionlab/segmentation/unet2plus/unet2plus.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from fusionlab.segmentation.base import SegmentationModel 4 | from fusionlab.utils import autopad 5 | from fusionlab.layers import ConvND, BatchNorm, MaxPool, Upsample 6 | 7 | # ref: https://github.com/4uiiurz1/pytorch-nested-unet 8 | 9 | 10 | class UNet2plus(SegmentationModel): 11 | def __init__(self, cin, num_cls, base_dim, spatial_dims=2): 12 | super().__init__() 13 | self.num_cls = num_cls 14 | self.encoder = Encoder(cin, base_dim, spatial_dims) 15 | self.bridger = Bridger() 16 | self.decoder = Decoder(base_dim, spatial_dims) 17 | self.head = Head(base_dim, num_cls, spatial_dims) 18 | 19 | 20 | class BasicBlock(nn.Sequential): 21 | def __init__(self, cin, cout, spatial_dims=2): 22 | conv1 = nn.Sequential( 23 | ConvND(spatial_dims, cin, cout, 3, 1, autopad(3)), 24 | BatchNorm(spatial_dims, cout), 25 | nn.ReLU(), 26 | ) 27 | conv2 = nn.Sequential( 28 | ConvND(spatial_dims, cout, cout, 3, 1, autopad(3)), 29 | BatchNorm(spatial_dims, cout), 30 | nn.ReLU(), 31 | ) 32 | super().__init__(conv1, conv2) 33 | 34 | 35 | class Encoder(nn.Module): 36 | def __init__(self, cin, base_dim, spatial_dims=2): 37 | """ 38 | UNet Encoder 39 | Args: 40 | cin (int): input channels 41 | base_dim (int): 1st stage dim of conv output 42 | """ 43 | super().__init__() 44 | self.pool = MaxPool(spatial_dims, 2, 2) 45 | self.conv0_0 = BasicBlock(cin, base_dim, spatial_dims) 46 | self.conv1_0 = BasicBlock(base_dim, base_dim * 2, spatial_dims) 47 | self.conv2_0 = BasicBlock(base_dim * 2, base_dim * 4, spatial_dims) 48 | self.conv3_0 = BasicBlock(base_dim * 4, base_dim * 8, spatial_dims) 49 | self.conv4_0 = BasicBlock(base_dim * 8, base_dim * 16, spatial_dims) 50 | 51 | def forward(self, x): 52 | x0_0 = self.conv0_0(x) 53 | x1_0 = self.conv1_0(self.pool(x0_0)) 54 | x2_0 = self.conv2_0(self.pool(x1_0)) 55 | x3_0 = self.conv3_0(self.pool(x2_0)) 56 | x4_0 = self.conv4_0(self.pool(x3_0)) 57 | return [x0_0, x1_0, x2_0, x3_0, x4_0] 58 | 59 | 60 | class Bridger(nn.Module): 61 | def __init__(self): 62 | super().__init__() 63 | 64 | def forward(self, x): 65 | return [nn.Identity()(i) for i in x] 66 | 67 | 68 | class Decoder(nn.Module): 69 | def __init__(self, base_dim, spatial_dims=2): 70 | super().__init__() 71 | dims = [base_dim*(2**i) for i in range(5)] # [base_dim, base_dim*2, base_dim*4, base_dim*8, base_dim*16] 72 | self.conv0_1 = BasicBlock(dims[0] + dims[1], dims[0], spatial_dims) 73 | self.conv1_1 = BasicBlock(dims[1] + dims[2], dims[1], spatial_dims) 74 | self.conv2_1 = BasicBlock(dims[2] + dims[3], dims[2], spatial_dims) 75 | self.conv3_1 = BasicBlock(dims[3] + dims[4], dims[3], spatial_dims) 76 | 77 | self.conv0_2 = BasicBlock(dims[0] * 2 + dims[1], dims[0], spatial_dims) 78 | self.conv1_2 = BasicBlock(dims[1] * 2 + dims[2], dims[1], spatial_dims) 79 | self.conv2_2 = BasicBlock(dims[2] * 2 + dims[3], dims[2], spatial_dims) 80 | 81 | self.conv0_3 = BasicBlock(dims[0] * 3 + dims[1], dims[0], spatial_dims) 82 | self.conv1_3 = BasicBlock(dims[1] * 3 + dims[2], dims[1], spatial_dims) 83 | 84 | self.conv0_4 = BasicBlock(dims[0] * 4 + dims[1], dims[0], spatial_dims) 85 | self.up = Upsample(spatial_dims, scale_factor=2) 86 | 87 | def forward(self, x): 88 | x0_0, x1_0, x2_0, x3_0, x4_0 = x 89 | 90 | x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1)) 91 | x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1)) 92 | x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1)) 93 | 94 | x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1)) 95 | x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1)) 96 | x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1)) 97 | 98 | x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1)) 99 | x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1)) 100 | x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1)) 101 | x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1)) 102 | 103 | return x0_4 104 | 105 | 106 | class Head(nn.Sequential): 107 | def __init__(self, cin, cout, spatial_dims=2): 108 | """ 109 | Basic Identity 110 | :param int cin: input channel 111 | :param int cout: output channel 112 | """ 113 | conv = ConvND(spatial_dims, cin, cout, 1) 114 | super().__init__(conv) 115 | 116 | 117 | if __name__ == '__main__': 118 | H = W = 224 119 | dim = 32 120 | inputs = torch.normal(0, 1, (1, 3, H, W)) 121 | 122 | encoder = Encoder(3, base_dim=dim) 123 | outputs = encoder(inputs) 124 | for i, o in enumerate(outputs): 125 | assert list(o.shape) == [1, dim * (2 ** i), H // (2 ** i), W // (2 ** i)] 126 | 127 | bridger = Bridger() 128 | outputs = bridger(outputs) 129 | for i, o in enumerate(outputs): 130 | assert list(o.shape) == [1, dim * (2 ** i), H // (2 ** i), W // (2 ** i)] 131 | 132 | features = [torch.normal(0, 1, (1, dim * (2 ** i), H // (2 ** i), W // (2 ** i))) for i in range(5)] 133 | decoder = Decoder(dim) 134 | outputs = decoder(features) 135 | assert list(outputs.shape) == [1, dim, H, W] 136 | 137 | head = Head(dim, 10) 138 | outputs = head(outputs) 139 | assert list(outputs.shape) == [1, 10, H, W] 140 | 141 | unet = UNet2plus(3, 10, dim) 142 | outputs = unet(inputs) 143 | assert list(outputs.shape) == [1, 10, H, W] 144 | 145 | print("1D UNet++") 146 | L = 128 147 | dim = 32 148 | inputs = torch.rand(1, 3, L) 149 | unet = UNet2plus(3, 10, dim, spatial_dims=1) 150 | 151 | outputs = unet(inputs) 152 | assert list(outputs.shape) == [1, 10, L] 153 | 154 | print("3D UNet++") 155 | D = H = W = 32 156 | dim = 32 157 | inputs = torch.rand(1, 3, D, H, W) 158 | unet = UNet2plus(3, 10, dim, spatial_dims=3) 159 | 160 | outputs = unet(inputs) 161 | assert list(outputs.shape) == [1, 10, D, H, W] 162 | -------------------------------------------------------------------------------- /fusionlab/segmentation/unetr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/segmentation/unetr/__init__.py -------------------------------------------------------------------------------- /fusionlab/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from fusionlab import BACKEND 2 | if BACKEND['torch']: 3 | from .dcgan import * 4 | from .trainer import * -------------------------------------------------------------------------------- /fusionlab/trainers/test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | m = tf.keras.metrics.Accuracy() 4 | m.reset_state() -------------------------------------------------------------------------------- /fusionlab/trainers/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm.auto import tqdm 3 | import numpy as np 4 | 5 | # ref: https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/Callback 6 | 7 | class Trainer: 8 | def __init__(self, device): 9 | self.device = device 10 | 11 | def train_step(self, data): 12 | data = self._data_to_device(data) 13 | inputs, target = data 14 | pred = self.model(inputs) 15 | loss = self.loss_fn(pred, target) 16 | loss.backward() 17 | self.optimizer.step() 18 | self.optimizer.zero_grad() 19 | return loss.item() 20 | 21 | def val_step(self, data): 22 | data = self._data_to_device(data) 23 | inputs, target = data 24 | with torch.no_grad(): 25 | pred = self.model(inputs) 26 | loss = self.loss_fn(pred, target) 27 | return loss.item() 28 | 29 | def train_epoch(self): 30 | self.model.train() 31 | epoch_loss = [] 32 | for _, data in enumerate(tqdm(self.train_dataloader, leave=False)): 33 | batch_loss = self.train_step(data) 34 | epoch_loss.append(batch_loss) 35 | return np.mean(epoch_loss) 36 | 37 | def val_epoch(self): 38 | self.model.eval() 39 | epoch_loss = [] 40 | for _, data in enumerate(tqdm(self.val_dataloader, leave=False)): 41 | batch_loss = self.val_step(data) 42 | epoch_loss.append(batch_loss) 43 | return np.mean(epoch_loss) 44 | 45 | def on_fit_begin(self): 46 | pass 47 | def on_fit_end(self): 48 | pass 49 | 50 | def on_epoch_begin(self): 51 | pass 52 | def on_epoch_end(self): 53 | pass 54 | 55 | def _data_to_device(self, data): 56 | if isinstance(data, torch.Tensor): 57 | return data.to(self.device) 58 | elif isinstance(data, dict): 59 | return {k: v.to(self.device) for k, v in data.items()} 60 | elif isinstance(data, list): 61 | return [v.to(self.device) for v in data] 62 | else: 63 | raise NotImplementedError 64 | 65 | def fit(self, model, train_dataloader, val_dataloader, epochs, optimizer, loss_fn): 66 | self.model = model.to(self.device) 67 | self.train_dataloader = train_dataloader 68 | self.val_dataloader = val_dataloader 69 | self.epochs = epochs 70 | self.optimizer = optimizer 71 | self.loss_fn = loss_fn 72 | 73 | self.train_log = {'loss': []} 74 | self.val_log = {'loss': []} 75 | 76 | self.on_fit_begin() 77 | for epoch in tqdm(range(epochs)): 78 | self.on_epoch_begin() 79 | train_epoch_loss = self.train_epoch() 80 | self.train_log['loss'].append(train_epoch_loss) 81 | 82 | if self.val_dataloader: 83 | val_epoch_loss = self.val_epoch() 84 | self.val_log['loss'].append(val_epoch_loss) 85 | 86 | print(f'''[{epoch}/{epochs}] train_loss: {self.train_log['loss'][-1]:.4f} \ 87 | val_loss: {self.val_log['loss'][-1]:.4f}''') 88 | self.on_epoch_end() 89 | self.on_fit_end() 90 | return 91 | 92 | if __name__ == "__main__": 93 | class FakeModel(torch.nn.Module): 94 | def __init__(self): 95 | super().__init__() 96 | self.conv = torch.nn.Conv2d(1, 3, 3) 97 | self.pool = torch.nn.Sequential( 98 | torch.nn.AdaptiveAvgPool2d(1), 99 | torch.nn.Flatten(), 100 | ) 101 | self.cls = torch.nn.Linear(3, 10) 102 | def forward(self, x): 103 | x = self.conv(x) 104 | x = self.pool(x) 105 | x = self.cls(x) 106 | return x 107 | 108 | from abc import ABC, abstractmethod 109 | class Metric(ABC): 110 | def __init__(self): 111 | pass 112 | 113 | @abstractmethod 114 | def reset(): 115 | raise NotImplementedError("reset method is not implemented!") 116 | 117 | @abstractmethod 118 | def update(): 119 | raise NotImplementedError("update method is not implemented!") 120 | 121 | @abstractmethod 122 | def compute(): 123 | raise NotImplementedError("compute method is not implemented!") 124 | 125 | # class Accuracy(Metric): 126 | 127 | 128 | 129 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 130 | print(device) 131 | 132 | # mnist 133 | from torchvision.datasets import MNIST 134 | from torchvision.transforms import ToTensor 135 | from torch.utils.data import DataLoader 136 | train_dataset = MNIST(root='data/', train=True, transform=ToTensor(), download=True) 137 | val_dataset = MNIST(root='data/', train=False, transform=ToTensor(), download=True) 138 | train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True) 139 | val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False) 140 | model = FakeModel() 141 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) 142 | loss_fn = torch.nn.CrossEntropyLoss() 143 | 144 | trainer = Trainer(device) 145 | trainer.fit(model, 146 | train_dataloader, 147 | val_dataloader, 148 | 10, 149 | optimizer, 150 | loss_fn) 151 | 152 | 153 | -------------------------------------------------------------------------------- /fusionlab/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .basic import autopad, make_ntuple 2 | from .trace import show_classtree 3 | from .plots import plot_channels 4 | from .labelme import convert_labelme_json2mask 5 | 6 | from fusionlab import BACKEND 7 | if BACKEND['torch']: 8 | from .trunc_normal.trunc_normal import trunc_normal_ -------------------------------------------------------------------------------- /fusionlab/utils/basic.py: -------------------------------------------------------------------------------- 1 | import collections 2 | from itertools import repeat 3 | from typing import Any, Tuple 4 | 5 | def autopad(kernel_size, padding=None, dilation=1, spatial_dims=2): 6 | ''' 7 | Auto padding for convolutional layers 8 | ''' 9 | if padding is None: 10 | if isinstance(kernel_size, int) and isinstance(dilation, int): 11 | padding = (kernel_size - 1) // 2 * dilation 12 | else: 13 | kernel_size = make_ntuple(kernel_size, spatial_dims) 14 | dilation = make_ntuple(dilation, spatial_dims) 15 | padding = tuple((kernel_size[i] - 1) // 2 * dilation[i] for i in range(spatial_dims)) 16 | return padding 17 | 18 | def make_ntuple(x: Any, n: int) -> Tuple[Any, ...]: 19 | """ 20 | Make n-tuple from input x. If x is an iterable, then we just convert it to tuple. 21 | Otherwise, we will make a tuple of length n, all with value of x. 22 | reference: https://github.com/pytorch/vision/blob/main/torchvision/utils.py#L585C1-L597C31 23 | 24 | Args: 25 | x (Any): input value 26 | n (int): length of the resulting tuple 27 | """ 28 | if isinstance(x, collections.abc.Iterable): 29 | return tuple(x) 30 | return tuple(repeat(x, n)) -------------------------------------------------------------------------------- /fusionlab/utils/labelme.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | import json 3 | import numpy as np 4 | import os 5 | from glob import glob 6 | from tqdm.auto import tqdm 7 | 8 | def convert_labelme_json2mask( 9 | class_names: Sequence[str], 10 | json_dir: str, 11 | output_dir: str, 12 | single_mask: bool = True 13 | ): 14 | """ 15 | Convert labelme json files to mask files(.png) 16 | 17 | Args: 18 | class_names (list): list of class names, background class must be included at first 19 | json_dir (str): path to json files directory 20 | output_dir (str): path to output directory 21 | single_mask (bool): if True, save single mask file with class index(uint8), otherwise save multiple mask files with class index(uint8) 22 | """ 23 | try: 24 | import cv2 25 | except ImportError: 26 | raise ImportError("opencv-python package is not installed") 27 | 28 | num_classes = len(class_names) 29 | if num_classes > 255: 30 | raise ValueError("Maximum number of classes is 255") 31 | 32 | cls_map = {name: i for i, name in enumerate(class_names)} 33 | print(f"Number of classes: {num_classes}") 34 | print("class name to index: ", cls_map) 35 | json_paths = glob(os.path.join(json_dir, '*.json')) 36 | 37 | for path in tqdm(json_paths): 38 | json_data = json.load(open(path)) 39 | h = json_data['imageHeight'] 40 | w = json_data['imageWidth'] 41 | 42 | # Draw Object mask 43 | mask = np.zeros((len(class_names), h, w), dtype=np.uint8) 44 | for shape in json_data['shapes']: 45 | if shape["shape_type"] != "polygon": 46 | continue 47 | cls_name = shape['label'] 48 | cls_idx = cls_map[cls_name] 49 | points = shape['points'] 50 | cv2.fillPoly( 51 | mask[cls_idx], 52 | np.array([points], dtype=np.int32), 53 | 255 54 | ) 55 | # update backgroud mask 56 | mask[0] = 255-np.max(mask[1:], axis=0) 57 | # Save Mask File 58 | filename = ".".join(os.path.split(path)[-1].split('.')[:-1]) 59 | if single_mask: 60 | mask_single = np.argmax(mask, axis=0).astype(np.uint8) 61 | cv2.imwrite(os.path.join(output_dir, f"{filename}.png"), mask_single) 62 | else: 63 | for i, m in enumerate(mask): 64 | cv2.imwrite(os.path.join(output_dir, f'{filename}_{i:03d}.png'), m.astype(np.uint8)) 65 | 66 | if __name__ == '__main__': 67 | convert_labelme_json2mask( 68 | ['bg', 'dog', 'cat'], 69 | "json", 70 | "mask", 71 | single_mask=True, 72 | ) -------------------------------------------------------------------------------- /fusionlab/utils/plots.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | def plot_channels(signals, show=True): 4 | ''' 5 | plot signals by channels 6 | 7 | Args: 8 | signals: numpy array, shape (num_samples, num_channels) 9 | ''' 10 | num_channels = signals.shape[1] 11 | fig, axes = plt.subplots(num_channels, 1, figsize=(10, 10)) 12 | for i in range(num_channels): 13 | axes[i].plot(signals[:, i]) 14 | 15 | if show: plt.show() 16 | return fig 17 | -------------------------------------------------------------------------------- /fusionlab/utils/trace.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | # Define the show_classtree function 4 | def show_classtree(clss, indent=0): 5 | # Get the full argument spec for the class 6 | argspec = inspect.getfullargspec(clss) 7 | # Get the arguments for the class 8 | args = argspec.args 9 | # If the class has a varargs argument, append it to args 10 | if argspec.varargs: 11 | args.append('*' + argspec.varargs) 12 | # If the class has a varkw argument, append it to args 13 | if argspec.varkw: 14 | args.append('**' + argspec.varkw) 15 | # Print the class name and arguments, indented by indent spaces 16 | print(' ' * indent + f'{clss} | input: {args}') 17 | # For each base class of the class 18 | for supercls in clss.__bases__: 19 | # Recursively call show_classtree on the base class, with an indent of 3 more spaces 20 | show_classtree(supercls, indent + 3) -------------------------------------------------------------------------------- /fusionlab/utils/trunc_normal/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/utils/trunc_normal/__init__.py -------------------------------------------------------------------------------- /fusionlab/utils/trunc_normal/trunc_normal.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 5 | """Tensor initialization with truncated normal distribution. 6 | Based on: 7 | https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 8 | https://github.com/rwightman/pytorch-image-models 9 | 10 | Args: 11 | tensor: an n-dimensional `torch.Tensor`. 12 | mean: the mean of the normal distribution. 13 | std: the standard deviation of the normal distribution. 14 | a: the minimum cutoff value. 15 | b: the maximum cutoff value. 16 | """ 17 | 18 | def norm_cdf(x): 19 | return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 20 | 21 | with torch.no_grad(): 22 | l = norm_cdf((a - mean) / std) 23 | u = norm_cdf((b - mean) / std) 24 | tensor.uniform_(2 * l - 1, 2 * u - 1) 25 | tensor.erfinv_() 26 | tensor.mul_(std * math.sqrt(2.0)) 27 | tensor.add_(mean) 28 | tensor.clamp_(min=a, max=b) 29 | return tensor 30 | 31 | 32 | def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): 33 | """Tensor initialization with truncated normal distribution. 34 | Based on: 35 | https://github.com/rwightman/pytorch-image-models 36 | 37 | Args: 38 | tensor: an n-dimensional `torch.Tensor` 39 | mean: the mean of the normal distribution 40 | std: the standard deviation of the normal distribution 41 | a: the minimum cutoff value 42 | b: the maximum cutoff value 43 | """ 44 | 45 | if std <= 0: 46 | raise ValueError("the standard deviation should be greater than zero.") 47 | 48 | if a >= b: 49 | raise ValueError("minimum cutoff value (a) should be smaller than maximum cutoff value (b).") 50 | 51 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) -------------------------------------------------------------------------------- /make_init.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Set the package directory 4 | PACKAGE_DIR="fusionlab" 5 | 6 | # Find all directories in the package 7 | DIRS=$(find $PACKAGE_DIR -type d) 8 | 9 | # Create __init__.py files in all directories 10 | for DIR in $DIRS; do 11 | touch $DIR/__init__.py 12 | done 13 | 14 | # Add import statements to __init__.py files 15 | for DIR in $DIRS; do 16 | # Get the relative path of the directory 17 | RELATIVE_PATH=${DIR#$PACKAGE_DIR/} 18 | # Get the list of subdirectories and files in the directory 19 | SUBDIRS=$(find $DIR -maxdepth 1 -type d ! -path $DIR) 20 | FILES=$(find $DIR -maxdepth 1 -type f -name "*.py" ! -name "__init__.py") 21 | # Add import statements for subdirectories and files 22 | for SUBDIR in $SUBDIRS; do 23 | MODULE_NAME=${SUBDIR#$DIR/} 24 | echo "from .$MODULE_NAME import *" >> $DIR/__init__.py 25 | done 26 | for FILE in $FILES; do 27 | MODULE_NAME=$(basename $FILE .py) 28 | echo "from .$MODULE_NAME import *" >> $DIR/__init__.py 29 | done 30 | done -------------------------------------------------------------------------------- /release_logs.md: -------------------------------------------------------------------------------- 1 | 0.1.12 2 | 3 | * Fix: remove print in SegFormer 4 | 5 | 0.1.11 6 | 7 | * Add: convert labelme json to mask files 8 | 9 | 0.1.10 10 | 11 | * Add: TransUNet, ViT, UNETER, SegFormer 12 | * Add layer: Rearrange, PatchEmbedding, InstanceNorm ND, DropPath 13 | * Add utils: trunc_normal_ 14 | 15 | 0.1.9 16 | 17 | * Add LSTimeClassificationDataset for time series classification 18 | 19 | 0.1.7 20 | 21 | * Add Dice, IoU score 22 | 23 | 0.1.6 24 | 25 | * Add API Documentation 26 | * Add utils.count_parameters 27 | * Update backend support 28 | 29 | 0.1.5 30 | 31 | * Remove numpy from requirements.txt 32 | * extract make_ntuple from torchvision to fusionlab.utils 33 | 34 | 0.1.4 35 | 36 | * Add EfficientNetB0 ~ B7 37 | * Add ConvNeXt Tiny ~ XLarge 38 | 39 | 0.1.3 40 | 41 | * Add 42 | * ECGCSVClassificationDataset(cinc2017) 43 | * LUDBDataset 44 | * LSTimeSegDataset 45 | * HFDataset 46 | 47 | 0.1.2 48 | 49 | * Add LUDB dataset 50 | 51 | 52 | 0.0.52 53 | 54 | * Tversky Loss for Torch and TF 55 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops 2 | tqdm 3 | traceback2 4 | 5 | pandas>=1.5.3 6 | xmltodict 7 | xlwt 8 | scipy 9 | wfdb -------------------------------------------------------------------------------- /scripts/build_pip.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # ref: https://datainpoint.substack.com/p/c0a 3 | 4 | python ../setup.py sdist bdist_wheel 5 | python -m twine upload dist/* 6 | 7 | # pip install -e . 8 | 9 | # build 10 | python setup.py sdist bdist_wheel 11 | # pypi 12 | twine upload dist/* 13 | # testpypi 14 | twine upload --repository testpypi dist/* 15 | 16 | # testpypi install 17 | !pip install --upgrade fusionlab==? 18 | !pip install -i https://test.pypi.org/simple/ --upgrade fusionlab==? 19 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Ref: https://github.com/qubvel/segmentation_models.pytorch 5 | 6 | # Note: To use the 'upload' functionality of this file, you must: 7 | # $ pip install twine 8 | 9 | import io 10 | import os 11 | import sys 12 | from shutil import rmtree 13 | 14 | from setuptools import find_packages, setup, Command 15 | 16 | # Package meta-data. 17 | NAME = "fusionlab" 18 | DESCRIPTION = "Useful packages for DL" 19 | URL = "https://github.com/taipingeric/fusionlab" 20 | EMAIL = "taipingeric@gmail.com" 21 | AUTHOR = "Chih-Yang Li" 22 | REQUIRES_PYTHON = ">=3.8.0" 23 | VERSION = None 24 | 25 | # The rest you shouldn't have to touch too much :) 26 | # ------------------------------------------------ 27 | # Except, perhaps the License and Trove Classifiers! 28 | # If you do change the License, remember to change the Trove Classifier for that! 29 | 30 | here = os.path.abspath(os.path.dirname(__file__)) 31 | 32 | # What packages are required for this module to be executed? 33 | try: 34 | with open(os.path.join(here, "requirements.txt"), encoding="utf-8") as f: 35 | REQUIRED = f.read().split("\n") 36 | except: 37 | REQUIRED = [] 38 | 39 | # What packages are optional? 40 | EXTRAS = { 41 | "test": [ 42 | "pytest", 43 | "mock", 44 | "pre-commit", 45 | "black==22.3.0", 46 | "flake8==4.0.1", 47 | "flake8-docstrings==1.6.0", 48 | "torchvision", 49 | "tensorflow", 50 | ], 51 | } 52 | 53 | # Import the README and use it as the long-description. 54 | # Note: this will only work if 'README.md' is present in your MANIFEST.in file! 55 | try: 56 | with io.open(os.path.join(here, "README.md"), encoding="utf-8") as f: 57 | long_description = "\n" + f.read() 58 | except FileNotFoundError: 59 | long_description = DESCRIPTION 60 | 61 | # Load the package's __version__.py module as a dictionary. 62 | about = {} 63 | if not VERSION: 64 | with open(os.path.join(here, NAME, "__version__.py")) as f: 65 | exec(f.read(), about) 66 | else: 67 | about["__version__"] = VERSION 68 | 69 | 70 | class UploadCommand(Command): 71 | """Support setup.py upload.""" 72 | 73 | description = "Build and publish the package." 74 | user_options = [] 75 | 76 | @staticmethod 77 | def status(s): 78 | """Prints things in bold.""" 79 | print(s) 80 | 81 | def initialize_options(self): 82 | pass 83 | 84 | def finalize_options(self): 85 | pass 86 | 87 | def run(self): 88 | try: 89 | self.status("Removing previous builds...") 90 | rmtree(os.path.join(here, "dist")) 91 | except OSError: 92 | pass 93 | 94 | self.status("Building Source and Wheel (universal) distribution...") 95 | os.system("{0} setup.py sdist bdist_wheel --universal".format(sys.executable)) 96 | 97 | self.status("Uploading the package to PyPI via Twine...") 98 | os.system("twine upload dist/*") 99 | 100 | self.status("Pushing git tags...") 101 | os.system("git tag v{0}".format(about["__version__"])) 102 | os.system("git push --tags") 103 | 104 | sys.exit() 105 | 106 | 107 | # Where the magic happens: 108 | setup( 109 | name=NAME, 110 | version=about["__version__"], 111 | description=DESCRIPTION, 112 | long_description=long_description, 113 | long_description_content_type="text/markdown", 114 | author=AUTHOR, 115 | author_email=EMAIL, 116 | python_requires=REQUIRES_PYTHON, 117 | url=URL, 118 | packages=find_packages(exclude=("tests", "docs", "images")), 119 | # If your package is a single module, use this instead of 'packages': 120 | # py_modules=['mypackage'], 121 | # entry_points={ 122 | # 'console_scripts': ['mycli=mymodule:cli'], 123 | # }, 124 | install_requires=REQUIRED, 125 | extras_require=EXTRAS, 126 | include_package_data=True, 127 | license="MIT", 128 | classifiers=[ 129 | # Trove classifiers 130 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers 131 | "License :: OSI Approved :: MIT License", 132 | # "Programming Language :: Python", 133 | "Programming Language :: Python :: 3", 134 | "Operating System :: OS Independent", 135 | # "Programming Language :: Python :: Implementation :: CPython", 136 | # "Programming Language :: Python :: Implementation :: PyPy", 137 | ], 138 | # $ setup.py publish support. 139 | # cmdclass={ 140 | # "upload": UploadCommand, 141 | # }, 142 | ) -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_classification.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from fusionlab.classification import HFClassificationModel 3 | from fusionlab.classification import VGG16Classifier 4 | from fusionlab.classification import LSTMClassifier 5 | 6 | H = W = 224 7 | cout = 5 8 | inputs = torch.normal(0, 1, (1, 3, W)) 9 | # Test CNNClassification 10 | model = VGG16Classifier(3, cout, spatial_dims=1) 11 | hf_model = HFClassificationModel(model, cout) 12 | output = hf_model(inputs) 13 | print(output['logits'].shape) 14 | assert list(output.keys()) == ['loss', 'logits', 'hidden_states'] 15 | 16 | inputs = torch.normal(0, 1, (1, 3, H, W)) 17 | # Test CNNClassification 18 | model = VGG16Classifier(3, cout, spatial_dims=2) 19 | hf_model = HFClassificationModel(model, cout) 20 | output = hf_model(inputs) 21 | print(output['logits'].shape) 22 | assert list(output.keys()) == ['loss', 'logits', 'hidden_states'] 23 | 24 | inputs = torch.normal(0, 1, (1, 3, H)) 25 | model = LSTMClassifier(3, cout) 26 | hf_model = HFClassificationModel(model, cout) 27 | output = hf_model(inputs) 28 | print(output['logits'].shape) 29 | assert list(output.keys()) == ['loss', 'logits', 'hidden_states'] -------------------------------------------------------------------------------- /tests/test_datasets.py: -------------------------------------------------------------------------------- 1 | from fusionlab import datasets 2 | 3 | class TestHFDataset: 4 | def test_HFDataset(self): 5 | import torch 6 | print("Test HFDataset") 7 | 8 | NUM_DATA = 20 9 | NUM_FEATURES = 16 10 | ds = torch.utils.data.TensorDataset( 11 | torch.zeros(NUM_DATA, NUM_FEATURES), 12 | torch.zeros(NUM_DATA)) 13 | for x, y in ds: 14 | assert list(x.shape) == [NUM_FEATURES] 15 | pass 16 | hf_ds = datasets.HFDataset(ds) 17 | assert hf_ds is not None 18 | for i in range(len(hf_ds)): 19 | data_dict = hf_ds[i] 20 | assert data_dict.keys() == set(['x', 'labels']) 21 | assert list(data_dict['x'].shape) == [NUM_FEATURES] 22 | pass 23 | 24 | class TestLSTimeSegDataset: 25 | def test_LSTimeSegDataset(self, tmpdir): 26 | import numpy as np 27 | import os 28 | import pandas as pd 29 | import json 30 | import torch 31 | filename = "29be6360-12lead.csv" 32 | annotaion_path = os.path.join(tmpdir, "12.json") 33 | annotation = [ 34 | { 35 | "csv": f"/data/upload/12/{filename}", 36 | "label": [ 37 | { 38 | "start": 0.004, 39 | "end": 0.764, 40 | "instant": False, 41 | "timeserieslabels": ["N"] 42 | }, 43 | { 44 | "start": 0.762, 45 | "end": 1.468, 46 | "instant": False, 47 | "timeserieslabels": ["p"] 48 | }, 49 | { 50 | "start": 1.466, 51 | "end": 2.5, 52 | "instant": False, 53 | "timeserieslabels": ["t"] 54 | } 55 | ], 56 | "number": [{"number": 500}], 57 | } 58 | ] 59 | with open(annotaion_path, "w") as f: 60 | json.dump(annotation, f) 61 | col_names = ['i', 'ii', 'iii', 'avr', 'avl', 'avf', 'v1', 'v2', 'v3', 'v4', 'v5', 'v6'] 62 | df = pd.DataFrame() 63 | sample_rate = 500 64 | num_samples = sample_rate * 10 65 | df['time'] = np.arange(num_samples) / sample_rate 66 | for col_name in col_names: 67 | df[col_name] = np.random.randn(num_samples) 68 | df.to_csv(os.path.join(tmpdir, filename), index=False) 69 | ds = datasets.LSTimeSegDataset(data_dir=tmpdir, 70 | annotation_path=annotaion_path, 71 | class_map={"N": 1, "p": 2, "t": 3}, 72 | column_names=col_names) 73 | signals, mask = ds[0] 74 | assert signals.shape == (len(col_names), num_samples) 75 | assert mask.shape == (num_samples, ) 76 | assert type(signals) == torch.Tensor 77 | assert type(mask) == torch.Tensor 78 | assert len(ds) == 1 -------------------------------------------------------------------------------- /tests/test_factories.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from fusionlab.layers.factories import ( 3 | ConvND, 4 | ConvT, 5 | Upsample, 6 | BatchNorm, 7 | MaxPool, 8 | AvgPool, 9 | AdaptiveMaxPool, 10 | AdaptiveAvgPool, 11 | ReplicationPad, 12 | ConstantPad 13 | ) 14 | 15 | 16 | 17 | # Test Code for ConvND 18 | inputs = torch.randn(1, 3, 16) # create random input tensor 19 | layer = ConvND(spatial_dims=1, in_channels=3, out_channels=2, kernel_size=5) # create model instance 20 | outputs = layer(inputs) # pass input through model 21 | print(outputs.shape) 22 | assert list(outputs.shape) == [1, 2, 16] # check output shape is correct 23 | 24 | # Test code for ConvT 25 | inputs = torch.randn(1, 3, 16) # create random input tensor 26 | layer = ConvT(spatial_dims=1, in_channels=3, out_channels=2, kernel_size=5) # create model instance 27 | outputs = layer(inputs) # pass input through model 28 | print(outputs.shape) 29 | assert list(outputs.shape) == [1, 2, 16] # check output shape is correct 30 | 31 | # Test code for Upsample 32 | inputs = torch.randn(1, 3, 16) # create random input tensor 33 | layer = Upsample(spatial_dims=1, scale_factor=2) # create model instance 34 | outputs = layer(inputs) # pass input through model 35 | print(outputs.shape) 36 | assert list(outputs.shape) == [1, 3, 32] # check output shape is correct 37 | 38 | # Test code for BatchNormND 39 | inputs = torch.randn(1, 3, 16) # create random input tensor 40 | layer = BatchNorm(spatial_dims=1, num_features=3) # create model instance 41 | outputs = layer(inputs) # pass input through model 42 | 43 | # Test code for MaxPool 44 | for Module in [MaxPool, AvgPool]: 45 | inputs = torch.randn(1, 3, 16) # create random input tensor 46 | layer = Module(spatial_dims=1, kernel_size=2) # create model instance 47 | outputs = layer(inputs) # pass input through model 48 | print(outputs.shape) 49 | assert list(outputs.shape) == [1, 3, 8] # check output shape is correct 50 | 51 | # Test code for Pool 52 | for Module in [AdaptiveMaxPool, AdaptiveAvgPool]: 53 | inputs = torch.randn(1, 3, 16) # create random input tensor 54 | layer = Module(spatial_dims=1, output_size=8) # create model instance 55 | outputs = layer(inputs) # pass input through model 56 | print(outputs.shape) 57 | assert list(outputs.shape) == [1, 3, 8] # check output shape is correct 58 | 59 | # Test code for Padding 60 | inputs = torch.randn(1, 3, 16) # create random input tensor 61 | layer = ReplicationPad(spatial_dims=1, padding=2) # create model instance 62 | outputs = layer(inputs) # pass input through model 63 | print(outputs.shape) 64 | assert list(outputs.shape) == [1, 3, 20] # check output shape is correct 65 | 66 | layer = ConstantPad(spatial_dims=1, padding=2, value=0) # create model instance 67 | outputs = layer(inputs) # pass input through model 68 | print(outputs.shape) 69 | assert list(outputs.shape) == [1, 3, 20] # check output shape is correct 70 | -------------------------------------------------------------------------------- /tests/test_layers.py: -------------------------------------------------------------------------------- 1 | from fusionlab.layers import ConvNormAct 2 | 3 | class TestConvNormAct: 4 | def test_ConvNormAct(self): 5 | import torch 6 | inputs = torch.randn(1, 3, 16) 7 | l = ConvNormAct(1, 3, 4, 3, 1, 1) 8 | outputs = l(inputs) 9 | assert outputs.shape == torch.Size([1, 4, 16]) 10 | 11 | inputs = torch.randn(1, 3, 16, 16) 12 | l = ConvNormAct(2, 3, 4, 3, 1, 1) 13 | outputs = l(inputs) 14 | assert outputs.shape == torch.Size([1, 4, 16, 16]) 15 | 16 | inputs = torch.randn(1, 3, 16, 16, 16) 17 | l = ConvNormAct(3, 3, 4, 3, 1, 1) 18 | outputs = l(inputs) 19 | assert outputs.shape == torch.Size([1, 4, 16, 16, 16]) 20 | 21 | class TestSqueezeExcitation: 22 | def test_se(self): 23 | import torch 24 | from fusionlab.layers import SEModule 25 | cin = 64 26 | squeeze_channels = 16 27 | 28 | for i in range(1, 4): 29 | size = tuple([1, cin] + [16] * i) 30 | inputs = torch.randn(size) 31 | layer = SEModule(cin, squeeze_channels, spatial_dims=i) 32 | outputs = layer(inputs) 33 | assert outputs.shape == size -------------------------------------------------------------------------------- /tests/test_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tensorflow as tf 3 | from einops import rearrange 4 | from fusionlab.losses import * 5 | from pytest import approx 6 | 7 | EPS = 1e-6 8 | 9 | 10 | class Data: 11 | def __init__(self): 12 | self.pred = [[ 13 | [1., 2., 3., 4.], 14 | [2., 6., 4., 4.], 15 | [9., 6., 3., 4.] 16 | ]] 17 | self.target = [[2, 1, 0, 2]] 18 | 19 | 20 | class BinaryData: 21 | def __init__(self): 22 | self.pred = [0.4, 0.2, 0.3, 0.5] 23 | self.target = [0, 1, 0, 1] 24 | 25 | 26 | class TestSegLoss: 27 | def test_dice_loss(self): 28 | from fusionlab.losses.diceloss.tfdice import TFDiceLoss 29 | # Multi class 30 | data = Data() 31 | true_loss = 0.5519775748252869 32 | # PT 33 | pred = torch.tensor(data.pred).view(1, 3, 4) 34 | true = torch.tensor(data.target).view(1, 4) 35 | loss = DiceLoss("multiclass", from_logits=True)(pred, true) 36 | assert loss == approx(true_loss, EPS) 37 | # TF 38 | pred = tf.convert_to_tensor(data.pred) 39 | pred = rearrange(pred, "N C H -> N H C") 40 | true = tf.convert_to_tensor(data.target) 41 | loss = TFDiceLoss("multiclass", from_logits=True)(true, pred) 42 | assert float(loss) == approx(true_loss, EPS) 43 | 44 | # Binary Loss 45 | data = BinaryData() 46 | true_loss = 0.46044695377349854 47 | 48 | pred = torch.tensor(data.pred).reshape(1, 1, 2, 2) 49 | true = torch.tensor(data.target).reshape(1, 2, 2) 50 | # PT 51 | loss = DiceLoss("binary", from_logits=True)(pred, true) 52 | assert loss == approx(true_loss, EPS) 53 | # TF 54 | pred = tf.convert_to_tensor(data.pred) 55 | pred = tf.reshape(pred, [1, 2, 2, 1]) 56 | true = tf.convert_to_tensor(data.target) 57 | true = tf.reshape(true, [1, 2, 2]) 58 | loss = TFDiceLoss("binary", from_logits=True)(true, pred) 59 | assert float(loss) == approx(true_loss, EPS) 60 | 61 | # Binary Log loss 62 | true_loss = 0.6170141696929932 63 | 64 | # PT 65 | pred = torch.tensor(data.pred).reshape(1, 1, 2, 2) 66 | true = torch.tensor(data.target).reshape(1, 2, 2) 67 | loss = DiceLoss("binary", from_logits=True, log_loss=True)(pred, true) 68 | assert loss == approx(true_loss, EPS) 69 | # TF 70 | pred = tf.convert_to_tensor(data.pred) 71 | pred = tf.reshape(pred, [1, 2, 2, 1]) 72 | true = tf.convert_to_tensor(data.target) 73 | true = tf.reshape(true, [1, 2, 2]) 74 | loss = TFDiceLoss("binary", from_logits=True, log_loss=True)(true, pred) 75 | assert float(loss) == approx(true_loss, EPS) 76 | 77 | def test_iou_loss(self): 78 | from fusionlab.losses.iouloss.tfiou import TFIoULoss 79 | # multi class 80 | true_loss = 0.6969285607337952 81 | data = Data() 82 | 83 | # PT 84 | pred = torch.tensor(data.pred).view(1, 3, 4) 85 | true = torch.tensor(data.target).view(1, 4) 86 | loss = IoULoss("multiclass", from_logits=True)(pred, true) 87 | assert loss == approx(true_loss, EPS) 88 | # TF 89 | pred = tf.convert_to_tensor(data.pred) 90 | pred = rearrange(pred, "N C H -> N H C") 91 | true = tf.convert_to_tensor(data.target) 92 | loss = TFIoULoss("multiclass", from_logits=True)(true, pred) 93 | assert float(loss) == approx(true_loss, EPS) 94 | 95 | # Binary 96 | data = BinaryData() 97 | true_loss = 0.6305561661720276 98 | # PT 99 | pred = torch.tensor(data.pred).reshape(1, 1, 2, 2) 100 | true = torch.tensor(data.target).reshape(1, 2, 2) 101 | loss = IoULoss("binary", from_logits=True)(pred, true) 102 | assert loss == approx(true_loss, EPS) 103 | # TF 104 | pred = tf.convert_to_tensor(data.pred) 105 | pred = tf.reshape(pred, [1, 2, 2, 1]) 106 | true = tf.convert_to_tensor(data.target) 107 | true = tf.reshape(true, [1, 2, 2]) 108 | loss = TFIoULoss("binary", from_logits=True)(true, pred) 109 | assert float(loss) == approx(true_loss, EPS) 110 | 111 | # Binary Log loss 112 | data = BinaryData() 113 | true_loss = 0.9957565665245056 114 | # PT 115 | pred = torch.tensor(data.pred).reshape(1, 1, 2, 2) 116 | true = torch.tensor(data.target).reshape(1, 2, 2) 117 | loss = IoULoss("binary", from_logits=True, log_loss=True)(pred, true) 118 | assert loss == approx(true_loss, EPS) 119 | # TF 120 | pred = tf.convert_to_tensor(data.pred) 121 | pred = tf.reshape(pred, [1, 2, 2, 1]) 122 | true = tf.convert_to_tensor(data.target) 123 | true = tf.reshape(true, [1, 2, 2]) 124 | loss = TFIoULoss("binary", from_logits=True, log_loss=True)(true, pred) 125 | assert float(loss) == approx(true_loss, EPS) -------------------------------------------------------------------------------- /tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tensorflow as tf 3 | from einops import rearrange 4 | from fusionlab.metrics import ( 5 | DiceScore, 6 | IoUScore, 7 | ) 8 | from pytest import approx 9 | 10 | EPS = 1e-6 11 | 12 | class Data: 13 | def __init__(self): 14 | self.pred = [[ 15 | [1., 2., 3., 4.], 16 | [2., 6., 4., 4.], 17 | [9., 6., 3., 4.] 18 | ]] 19 | self.target = [[2, 1, 0, 2]] 20 | 21 | 22 | class BinaryData: 23 | def __init__(self): 24 | self.pred = [0.4, 0.2, 0.3, 0.5] 25 | self.target = [0, 1, 0, 1] 26 | 27 | 28 | class TestMetrics: 29 | def test_dice_score(self): 30 | # Multi class 31 | data = Data() 32 | true_scores = [0.27264893, 0.41188607, 0.65953231] 33 | true_score_mean = 0.44802245 34 | # PT 35 | pred = torch.tensor(data.pred).view(1, 3, 4) 36 | true = torch.tensor(data.target).view(1, 4) 37 | scores = DiceScore("multiclass")(pred, true) 38 | assert scores.mean() == approx(true_score_mean, EPS) 39 | assert scores == approx(true_scores, EPS) 40 | 41 | # TODO: Add binary class test 42 | data = BinaryData() 43 | true_loss = 0.46044695377349854 44 | 45 | def test_iou_score(self): 46 | # multi class 47 | true_loss = 0.6969285607337952 48 | true_scores = [0.15784223, 0.25935552, 0.49201655] 49 | true_score_mean = 0.30307144 50 | data = Data() 51 | 52 | # PT 53 | pred = torch.tensor(data.pred).view(1, 3, 4) 54 | true = torch.tensor(data.target).view(1, 4) 55 | scores = IoUScore("multiclass", from_logits=True)(pred, true) 56 | assert scores.mean() == approx(true_score_mean, EPS) 57 | assert scores == approx(true_scores, EPS) 58 | 59 | # TODO: Add binary class test 60 | 61 | -------------------------------------------------------------------------------- /tests/test_seg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def generate_inputs(img_size, spatial_dims): 4 | size = tuple([1, 3] + [img_size] * spatial_dims) 5 | return torch.randn(size) 6 | 7 | class TestSeg: 8 | def test_unet(self): 9 | from fusionlab.segmentation import UNet 10 | for i in range(1, 4): 11 | inputs = generate_inputs(64, i) 12 | model = UNet(3, 2, spatial_dims=i) 13 | outputs = model(inputs) 14 | assert outputs.shape == tuple([1, 2] + [64] * i) 15 | 16 | def test_resunet(self): 17 | from fusionlab.segmentation import ResUNet 18 | for i in range(1, 4): 19 | inputs = generate_inputs(64, i) 20 | model = ResUNet(3, 2, spatial_dims=i) 21 | outputs = model(inputs) 22 | assert outputs.shape == tuple([1, 2] + [64] * i) 23 | 24 | def test_unet2plus(self): 25 | from fusionlab.segmentation import UNet2plus 26 | for i in range(1, 4): 27 | inputs = generate_inputs(64, i) 28 | model = UNet2plus(3, 2, 16, spatial_dims=i) 29 | outputs = model(inputs) 30 | assert outputs.shape == tuple([1, 2] + [64] * i) 31 | 32 | def test_transunet(self): 33 | from fusionlab.segmentation import TransUNet 34 | inputs = generate_inputs(64, 2) 35 | model = TransUNet( 36 | in_channels=3, 37 | img_size=64, 38 | num_classes=2, 39 | ) 40 | outputs = model(inputs) 41 | assert outputs.shape == tuple([1, 2] + [64] * 2) 42 | 43 | def test_unetr(self): 44 | from fusionlab.segmentation import UNETR 45 | for i in range(1, 4): 46 | inputs = generate_inputs(64, i) 47 | model = UNETR(3, 2, 64, spatial_dims=i) 48 | outputs = model(inputs) 49 | assert outputs.shape == tuple([1, 2] + [64] * i) 50 | 51 | def test_segformer(self): 52 | from fusionlab.segmentation import SegFormer 53 | inputs = torch.rand(1, 3, 64, 64) 54 | for i in range(6): 55 | mit_type = f'B{i}' 56 | model = SegFormer(num_classes=2, mit_encoder_type=mit_type) 57 | outputs = model(inputs) 58 | assert outputs.shape == (1, 2, 64, 64) -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | class TestPlotChannels: 2 | def test_plot_channels(self): 3 | from fusionlab.utils import plot_channels 4 | import numpy as np 5 | import matplotlib 6 | """Test that plot_channels() returns a figure.""" 7 | signals = np.random.randn(500, 12) 8 | fig = plot_channels(signals, show=False) 9 | assert isinstance(fig, matplotlib.figure.Figure) --------------------------------------------------------------------------------