├── .github
└── workflows
│ └── python-app.yml
├── .gitignore
├── .readthedocs.yaml
├── LICENSE
├── README.md
├── TODO.md
├── assets
└── imgs
│ └── fusionlab_banner.png
├── configs
├── Install on Macbook M1.md
├── requirements-m1.txt
└── tf-apple-m1-conda.yaml
├── docs
├── requirements.txt
└── source
│ ├── _templates
│ ├── custom-class-template.rst
│ └── custom-module-template.rst
│ ├── conf.py
│ ├── datasets.rst
│ ├── encoders.rst
│ ├── index.rst
│ ├── layers.rst
│ ├── losses.rst
│ ├── metrics.rst
│ ├── segmentation.rst
│ └── utils.rst
├── fusionlab
├── __init__.py
├── __version__.py
├── classification
│ ├── __init__.py
│ ├── base.py
│ ├── lstm.py
│ └── vgg.py
├── configs.py
├── datasets
│ ├── __init__.py
│ ├── a12lead.py
│ ├── cinc2017.py
│ ├── csv_sample.csv
│ ├── csvread.py
│ ├── ludb.py
│ ├── muse.py
│ └── utils.py
├── encoders
│ ├── README.md
│ ├── __init__.py
│ ├── alexnet
│ │ ├── __init__.py
│ │ ├── alexnet.py
│ │ └── tfalexnet.py
│ ├── convnext
│ │ ├── __init__.py
│ │ └── convnext.py
│ ├── efficientnet
│ │ ├── __init__.py
│ │ └── efficientnet.py
│ ├── inceptionv1
│ │ ├── __init__.py
│ │ ├── inceptionv1.py
│ │ └── tfinceptionv1.py
│ ├── mit
│ │ ├── __init__.py
│ │ └── mit.py
│ ├── resnetv1
│ │ ├── __init__.py
│ │ ├── resnetv1.py
│ │ └── tfresnetv1.py
│ ├── vgg
│ │ ├── __init__.py
│ │ ├── tfvgg.py
│ │ └── vgg.py
│ └── vit
│ │ ├── __init__.py
│ │ └── vit.py
├── functional
│ ├── __init__.py
│ ├── dice.py
│ ├── iou.py
│ ├── tfdice.py
│ └── tfiou.py
├── layers
│ ├── __init__.py
│ ├── base.py
│ ├── factories.py
│ ├── patch_embed
│ │ ├── __init__.py
│ │ └── patch_embedding.py
│ ├── selfattention
│ │ ├── __init__.py
│ │ └── selfattention.py
│ └── squeeze_excitation
│ │ ├── __init__.py
│ │ ├── se.py
│ │ └── tfse.py
├── losses
│ ├── README.md
│ ├── __init__.py
│ ├── diceloss
│ │ ├── __init__.py
│ │ ├── dice.py
│ │ └── tfdice.py
│ ├── iouloss
│ │ ├── __init__.py
│ │ ├── iou.py
│ │ └── tfiou.py
│ └── tversky
│ │ ├── __init__.py
│ │ ├── tftversky.py
│ │ └── tversky.py
├── metrics
│ ├── __init__.py
│ ├── dicescore
│ │ ├── __init__.py
│ │ └── dice.py
│ └── iouscore
│ │ ├── __init__.py
│ │ └── iou.py
├── segmentation
│ ├── README.md
│ ├── __init__.py
│ ├── base.py
│ ├── resunet
│ │ ├── __init__.py
│ │ ├── resunet.py
│ │ └── tfresunet.py
│ ├── segformer
│ │ ├── __init__.py
│ │ └── segformer.py
│ ├── tfbase.py
│ ├── transunet
│ │ ├── __init__.py
│ │ └── transunet.py
│ ├── unet
│ │ ├── __init__.py
│ │ ├── tfunet.py
│ │ └── unet.py
│ ├── unet2plus
│ │ ├── __init__.py
│ │ ├── tfunet2plus.py
│ │ └── unet2plus.py
│ └── unetr
│ │ ├── __init__.py
│ │ └── unetr.py
├── trainers
│ ├── __init__.py
│ ├── dcgan.py
│ ├── test.py
│ └── trainer.py
└── utils
│ ├── __init__.py
│ ├── basic.py
│ ├── labelme.py
│ ├── plots.py
│ ├── trace.py
│ └── trunc_normal
│ ├── __init__.py
│ └── trunc_normal.py
├── make_init.sh
├── release_logs.md
├── requirements.txt
├── scripts
└── build_pip.sh
├── setup.py
└── tests
├── __init__.py
├── test_classification.py
├── test_datasets.py
├── test_encoders.py
├── test_factories.py
├── test_layers.py
├── test_losses.py
├── test_metrics.py
├── test_seg.py
└── test_utils.py
/.github/workflows/python-app.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
3 |
4 | name: FusionLab:build
5 |
6 | on:
7 | push:
8 | branches: [ "main", "dev"]
9 | pull_request:
10 | branches: [ "main", "dev"]
11 | types: [ "opened", "synchronize", "reopened", "edited" ]
12 |
13 | permissions:
14 | contents: read
15 |
16 | jobs:
17 | build:
18 |
19 | runs-on: ubuntu-latest
20 |
21 | steps:
22 | - uses: actions/checkout@main
23 | - name: Set up Python 3.8
24 | uses: actions/setup-python@main
25 | with:
26 | python-version: "3.8"
27 | - name: Install dependencies
28 | run: |
29 | python -m pip install --upgrade pip
30 | pip install flake8 pytest
31 | pip install -r requirements.txt
32 | pip install torchvision
33 | pip install tensorflow
34 | pip install torchinfo
35 | - name: Lint with flake8
36 | run: |
37 | # stop the build if there are Python syntax errors or undefined names
38 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
39 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
40 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
41 | - name: Test with pytest
42 | run: |
43 | pytest
44 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | .idea/
6 | .venv*
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | .hypothesis/
50 | .pytest_cache/
51 |
52 | # Translations
53 | *.mo
54 | *.pot
55 |
56 | # Django stuff:
57 | *.log
58 | local_settings.py
59 | db.sqlite3
60 |
61 | # Flask stuff:
62 | instance/
63 | .webassets-cache
64 |
65 | # Scrapy stuff:
66 | .scrapy
67 |
68 | # Sphinx documentation
69 | docs/_build/
70 |
71 | # PyBuilder
72 | target/
73 |
74 | # Jupyter Notebook
75 | .ipynb_checkpoints
76 |
77 | # pyenv
78 | .python-version
79 |
80 | # celery beat schedule file
81 | celerybeat-schedule
82 |
83 | # SageMath parsed files
84 | *.sage.py
85 |
86 | # Environments
87 | .env
88 | .venv
89 | env/
90 | venv/
91 | ENV/
92 | env.bak/
93 | venv.bak/
94 |
95 | # Spyder project settings
96 | .spyderproject
97 | .spyproject
98 |
99 | # Rope project settings
100 | .ropeproject
101 |
102 | # mkdocs documentation
103 | /site
104 |
105 | # mypy
106 | .mypy_cache/
107 |
108 | # Dataset
109 | dataset/
110 | data/
111 |
112 | # VSCode
113 | .vscode
114 |
115 | # Training Files
116 | *mlruns/
117 | *resutls/
118 | tryout_notebooks
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # .readthedocs.yaml
2 | # Read the Docs configuration file
3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
4 |
5 | # Required
6 | version: 2
7 |
8 | # Set the OS, Python version and other tools you might need
9 | build:
10 | os: ubuntu-22.04
11 | tools:
12 | python: "3.9"
13 | # You can also specify other tool versions:
14 | # nodejs: "19"
15 | # rust: "1.64"
16 | # golang: "1.19"
17 |
18 | # Build documentation in the "docs/" directory with Sphinx
19 | sphinx:
20 | configuration: docs/source/conf.py
21 |
22 | # Optionally build your docs in additional formats such as PDF and ePub
23 | # formats:
24 | # - pdf
25 | # - epub
26 |
27 | # Optional but recommended, declare the Python requirements required
28 | # to build your documentation
29 | # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
30 | python:
31 | install:
32 | - requirements: docs/requirements.txt
33 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 SAI
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FusionLab
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 | [](https://badge.fury.io/py/fusionlab)  [](https://pepy.tech/project/fusionlab)
10 |
11 | [](https://fusionlab.readthedocs.io/)
12 |
13 | FusionLab is an open-source frameworks built for Deep Learning research written in PyTorch and Tensorflow. The code is easy to read and modify
14 | especially for newbie. Feel free to send pull requests :D
15 |
16 | * [What's New](#news)
17 | * [Installation](#installation)
18 | * [How to use](#how-to-use)
19 | * [Encoders](#encoders)
20 | * [Losses](#losses)
21 | * [Segmentation](#segmentation)
22 | * [1D, 2D, 3D Model](#n-dimensional-model)
23 | * [Acknowledgements](#acknowledgements)
24 |
25 | ## Installation
26 |
27 | ### With pip
28 |
29 | ```bash
30 | pip install fusionlab
31 | ```
32 |
33 | #### For Mac M1 chip users
34 | [Install on Macbook M1 chip](./configs/Install%20on%20Macbook%20M1.md)
35 |
36 | ## How to use
37 |
38 | ```python
39 | import fusionlab as fl
40 |
41 | # PyTorch
42 | encoder = fl.encoders.VGG16()
43 | # Tensorflow
44 | encoder = fl.encoders.TFVGG16()
45 |
46 | ```
47 |
48 | ## Documentation
49 |
50 | [Doc](https://fusionlab.readthedocs.io/en/latest/encoders.html)
51 |
52 | ## Encoders
53 |
54 | [encoder list](fusionlab/encoders/README.md)
55 |
56 | ## Losses
57 |
58 | [Loss func list](fusionlab/losses/README.md)
59 | * Dice Loss
60 | * Tversky Loss
61 | * IoU Loss
62 |
63 |
64 | ```python
65 | # Dice Loss (Multiclass)
66 | import fusionlab as fl
67 |
68 | # PyTorch
69 | pred = torch.randn(1, 3, 4, 4) # (N, C, *)
70 | target = torch.randint(0, 3, (1, 4, 4)) # (N, *)
71 | loss_fn = fl.losses.DiceLoss()
72 | loss = loss_fn(pred, target)
73 |
74 | # Tensorflow
75 | pred = tf.random.normal((1, 4, 4, 3), 0., 1.) # (N, *, C)
76 | target = tf.random.uniform((1, 4, 4), 0, 3) # (N, *)
77 | loss_fn = fl.losses.TFDiceLoss("multiclass")
78 | loss = loss_fn(target, pred)
79 |
80 |
81 | # Dice Loss (Binary)
82 |
83 | # PyTorch
84 | pred = torch.randn(1, 1, 4, 4) # (N, 1, *)
85 | target = torch.randint(0, 3, (1, 4, 4)) # (N, *)
86 | loss_fn = fl.losses.DiceLoss("binary")
87 | loss = loss_fn(pred, target)
88 |
89 | # Tensorflow
90 | pred = tf.random.normal((1, 4, 4, 1), 0., 1.) # (N, *, 1)
91 | target = tf.random.uniform((1, 4, 4), 0, 3) # (N, *)
92 | loss_fn = fl.losses.TFDiceLoss("binary")
93 | loss = loss_fn(target, pred)
94 |
95 |
96 | ```
97 |
98 | ## Segmentation
99 |
100 | ```python
101 | import fusionlab as fl
102 | # PyTorch UNet
103 | unet = fl.segmentation.UNet(cin=3, num_cls=10)
104 |
105 | # Tensorflow UNet
106 | # Multiclass Segmentation
107 | unet = tf.keras.Sequential([
108 | fl.segmentation.TFUNet(num_cls=10, base_dim=64),
109 | tf.keras.layers.Activation(tf.nn.softmax),
110 | ])
111 |
112 | # Binary Segmentation
113 | unet = tf.keras.Sequential([
114 | fl.segmentation.TFUNet(num_cls=1, base_dim=64),
115 | tf.keras.layers.Activation(tf.nn.sigmoid),
116 | ])
117 | ```
118 |
119 | [Segmentation model list](fusionlab/segmentation/README.md)
120 |
121 | * UNet
122 | * ResUNet
123 | * UNet2plus
124 |
125 | ## N Dimensional Model
126 |
127 | some models can be used in 1D, 2D, 3D
128 |
129 | ```python
130 | import fusionlab as fl
131 |
132 | resnet1d = fl.encoders.ResNet50V1(cin=3, spatial_dims=1)
133 | resnet2d = fl.encoders.ResNet50V1(cin=3, spatial_dims=2)
134 | resnet3d = fl.encoders.ResNet50V1(cin=3, spatial_dims=3)
135 |
136 | unet1d = fl.segmentation.UNet(cin=3, num_cls=10, spatial_dims=1)
137 | unet2d = fl.segmentation.UNet(cin=3, num_cls=10, spatial_dims=2)
138 | unet3d = fl.segmentation.UNet(cin=3, num_cls=10, spatial_dims=3)
139 | ```
140 |
141 | ## News
142 |
143 | [Release logs](./release_logs.md)
144 |
145 | ## Acknowledgements
146 |
147 | * [BloodAxe/pytorch-toolbelt](https://github.com/BloodAxe/pytorch-toolbelt)
148 |
--------------------------------------------------------------------------------
/TODO.md:
--------------------------------------------------------------------------------
1 | ### Encoders
2 |
3 | * SENet
4 | * DenseNet
5 |
6 | ### Segmentation
7 |
8 | * SegFormer
9 | * nnUNet
10 | * Swin-Unetr
11 | * AttUNet
12 | * EnsDiff
13 | * SegDiff
14 | * FCN
15 | * DeepLab
16 | * FPN
17 | * PSPNet
18 |
19 | ### Metrics
20 |
21 | * Dice, IoU score module
22 | * HD95
23 | * Normalized Surface Distance(NSD)
24 |
25 | ### Loss
26 |
27 | * Focal loss
28 |
29 | ### Modified
30 |
31 | * Replace StochasticPath in torchvision with timm DropPath
32 |
33 | ### Dataset
34 |
35 |
36 |
--------------------------------------------------------------------------------
/assets/imgs/fusionlab_banner.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/assets/imgs/fusionlab_banner.png
--------------------------------------------------------------------------------
/configs/Install on Macbook M1.md:
--------------------------------------------------------------------------------
1 | #Install on Macbook M1 chip
2 |
3 | Ref: [https://github.com/jeffheaton/t81_558_deep_learning/install/tensorflow-install-conda-mac-metal-jan-2023.ipynb](https://github.com/jeffheaton/t81_558_deep_learning/install/tensorflow-install-conda-mac-metal-jan-2023.ipynb)
4 |
5 | [video link](https://www.youtube.com/watch?v=5DgWvU0p2bk)
6 |
7 | **NOTE: It's been tested on Macbook Air 2021 M1 chip only**
8 |
9 | Requirements:
10 | * Apple Mac with M1 chips
11 | * MacOS > 12.6 (Monterey)
12 |
13 | Following steps
14 | 1. Clone this repo
15 | ```bash
16 | git clone https://github.com/taipingeric/fusionlab.git
17 | cd fusionlab
18 | ```
19 | 2. Uninstall Anaconda (Optional): https://docs.anaconda.com/anaconda/install/uninstall/
20 | 3. Install [Miniconda](https://docs.conda.io/en/latest/miniconda.html)
21 | 1. Miniconda3 macOS Apple M1 64-bit pkg
22 | 2. Miniconda3 macOS Apple M1 64-bit bash
23 | 4. Install the xcode-select command-line
24 | ```bash
25 | xcode-select --install
26 | ```
27 | 5. Deactivate the base environment
28 | ```bash
29 | conda deactivate
30 | ```
31 | 7. Clone this repo from github and change dir
32 | ```bash
33 | git clone https://github.com/taipingeric/fusionlab.git
34 | cd fusionlab
35 | ```
36 | 8. Create conda environment using [config](./tf-apple-m1-conda.yaml)
37 | ```bash
38 | conda env create -f ./configs/tf-apple-m1-conda.yaml -n fusionlab
39 | ```
40 | 7. Replace [requirements.txt](../requirements.txt) with [requirements-m1.txt](requirements-m1.txt)
41 | 8. Install by pip
42 | ```bash
43 | pip install -r requirements-m1.txt
44 | ```
45 |
--------------------------------------------------------------------------------
/configs/requirements-m1.txt:
--------------------------------------------------------------------------------
1 | # Mac M1
2 | # ref: https://www.youtube.com/watch?v=5DgWvU0p2bk
3 | # ref: https://github.com/jeffheaton/t81_558_deep_learning/blob/master/install/tensorflow-install-conda-mac-metal-jul-2022.ipynb
4 |
5 | torchvision>=0.5.0
6 | tensorflow-macos>=2.9.2
7 | tensorflow-metal>=0.5.0
--------------------------------------------------------------------------------
/configs/tf-apple-m1-conda.yaml:
--------------------------------------------------------------------------------
1 | name: fusionlab
2 | channels:
3 | - apple
4 | - conda-forge
5 | dependencies:
6 | - python=3.9
7 | - pip>=19.0
8 | - jupyter
9 | - tensorflow-deps
10 | - scikit-learn
11 | - scipy
12 | - pandas
13 | - pandas-datareader
14 | - matplotlib
15 | - pillow
16 | - tqdm
17 | - requests
18 | - h5py
19 | - pyyaml
20 | - flask
21 | - boto3
22 | - ipykernel
23 | - pip:
24 | - tensorflow-macos
25 | - tensorflow-metal
26 | - bayesian-optimization
27 | - gym
28 | - kaggle
29 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | einops
2 | tqdm
3 | traceback2
4 |
5 | pandas>=1.5.3
6 | xmltodict
7 | xlwt
8 | pandas
9 | scipy
10 | wfdb
11 | #
12 | torch>=1.9
13 | torchvision
14 | tensorflow
15 | # Doc
16 | Sphinx
17 | pydata-sphinx-theme
18 | sphinxcontrib-applehelp
19 | sphinxcontrib-devhelp
20 | sphinxcontrib-htmlhelp
21 | sphinxcontrib-jsmath
22 | sphinxcontrib-qthelp
23 | sphinxcontrib-serializinghtml
24 | sphinx-autodoc-typehints==1.11.1
25 |
--------------------------------------------------------------------------------
/docs/source/_templates/custom-class-template.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline}}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. autoclass:: {{ objname }}
6 | :members:
7 | :show-inheritance:
8 | :inherited-members:
9 |
10 | {% block methods %}
11 | .. automethod:: __init__
12 |
13 | {% if methods %}
14 | .. rubric:: {{ _('Methods') }}
15 |
16 | .. autosummary::
17 | {% for item in methods %}
18 | ~{{ name }}.{{ item }}
19 | {%- endfor %}
20 | {% endif %}
21 | {% endblock %}
22 |
23 | {% block attributes %}
24 | {% if attributes %}
25 | .. rubric:: {{ _('Attributes') }}
26 |
27 | .. autosummary::
28 | {% for item in attributes %}
29 | ~{{ name }}.{{ item }}
30 | {%- endfor %}
31 | {% endif %}
32 | {% endblock %}
--------------------------------------------------------------------------------
/docs/source/_templates/custom-module-template.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline}}
2 |
3 | .. automodule:: {{ fullname }}
4 |
5 | {% block attributes %}
6 | {% if attributes %}
7 | .. rubric:: Module Attributes
8 |
9 | .. autosummary::
10 | :toctree:
11 | {% for item in attributes %}
12 | {{ item }}
13 | {%- endfor %}
14 | {% endif %}
15 | {% endblock %}
16 |
17 | {% block functions %}
18 | {% if functions %}
19 | .. rubric:: {{ _('Functions') }}
20 |
21 | .. autosummary::
22 | :toctree:
23 | {% for item in functions %}
24 | {{ item }}
25 | {%- endfor %}
26 | {% endif %}
27 | {% endblock %}
28 |
29 | {% block classes %}
30 | {% if classes %}
31 | .. rubric:: {{ _('Classes') }}
32 |
33 | .. autosummary::
34 | :toctree:
35 | :template: custom-class-template.rst
36 | {% for item in classes %}
37 | {{ item }}
38 | {%- endfor %}
39 | {% endif %}
40 | {% endblock %}
41 |
42 | {% block exceptions %}
43 | {% if exceptions %}
44 | .. rubric:: {{ _('Exceptions') }}
45 |
46 | .. autosummary::
47 | :toctree:
48 | {% for item in exceptions %}
49 | {{ item }}
50 | {%- endfor %}
51 | {% endif %}
52 | {% endblock %}
53 |
54 | {% block modules %}
55 | {% if modules %}
56 | .. rubric:: Modules
57 |
58 | .. autosummary::
59 | :toctree:
60 | :template: custom-module-template.rst
61 | :recursive:
62 | {% for item in modules %}
63 | {{ item }}
64 | {%- endfor %}
65 | {% endif %}
66 | {% endblock %}
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 |
9 | # If extensions (or modules to document with autodoc) are in another directory,
10 | # add these directories to sys.path here. If the directory is relative to the
11 | # documentation root, use os.path.abspath to make it absolute, like shown here.
12 | #
13 | import os
14 | import subprocess
15 | import sys
16 |
17 | sys.path.insert(0, os.path.abspath(".."))
18 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
19 | print(sys.path)
20 |
21 | import fusionlab # noqa: E402
22 |
23 | # -- Project information -----------------------------------------------------
24 | project = "FusionLab"
25 | copyright = "taipingeric"
26 | author = "fusionlab Contributors"
27 |
28 | # The full version, including alpha/beta/rc tags
29 | short_version = fusionlab.__version__ #.split("+")[0]
30 | release = short_version
31 | version = short_version
32 |
33 | # List of patterns, relative to source directory, that match files and
34 | # directories to ignore when looking for source files.
35 | # This pattern also affects html_static_path and html_extra_path.
36 | exclude_patterns = [
37 | "transforms",
38 | "networks",
39 | "metrics",
40 | "engines",
41 | "data",
42 | "apps",
43 | "fl",
44 | "bundle",
45 | "config",
46 | "handlers",
47 | "losses",
48 | "visualize",
49 | "utils",
50 | "inferers",
51 | "optimizers",
52 | "auto3dseg",
53 | "'**/tf*'",
54 | ]
55 |
56 |
57 | def generate_apidocs(*args):
58 | """Generate API docs automatically by trawling the available modules"""
59 | module_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "fusionlab"))
60 | output_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "apidocs"))
61 | apidoc_command_path = "sphinx-apidoc"
62 | if hasattr(sys, "real_prefix"): # called from a virtualenv
63 | apidoc_command_path = os.path.join(sys.prefix, "bin", "sphinx-apidoc")
64 | apidoc_command_path = os.path.abspath(apidoc_command_path)
65 | print(f"output_path {output_path}")
66 | print(f"module_path {module_path}")
67 | subprocess.check_call(
68 | [apidoc_command_path, "-e"]
69 | + ["-o", output_path]
70 | + [module_path]
71 | + [os.path.join(module_path, p) for p in exclude_patterns]
72 | )
73 |
74 |
75 | # -- General configuration ---------------------------------------------------
76 |
77 | # Add any Sphinx extension module names here, as strings. They can be
78 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
79 | # ones.
80 | source_suffix = {".rst": "restructuredtext", ".txt": "restructuredtext", ".md": "markdown"}
81 |
82 | extensions = [
83 | "recommonmark",
84 | "sphinx.ext.intersphinx",
85 | "sphinx.ext.mathjax",
86 | "sphinx.ext.napoleon",
87 | "sphinx.ext.autodoc",
88 | "sphinx.ext.viewcode",
89 | "sphinx.ext.autosectionlabel",
90 | "sphinx.ext.autosummary",
91 | "sphinx_autodoc_typehints",
92 | ]
93 | # "sphinx.ext.autosummary"
94 | autosummary_generate = True
95 |
96 | autoclass_content = "class"
97 | add_module_names = True
98 | source_encoding = "utf-8"
99 | autosectionlabel_prefix_document = True
100 | napoleon_use_param = True
101 | napoleon_include_init_with_doc = True
102 | set_type_checking_flag = True
103 |
104 | # Add any paths that contain templates here, relative to this directory.
105 | templates_path = ["_templates"]
106 |
107 | # -- Options for HTML output -------------------------------------------------
108 |
109 | # The theme to use for HTML and HTML Help pages. See the documentation for
110 | # a list of builtin themes.
111 | #
112 | html_theme = "pydata_sphinx_theme"
113 | # html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
114 | html_theme_options = {
115 | "external_links": [{"url": "https://github.com/taipingeric/fusionlab", "name": "Tutorials"}],
116 | # "external_links": [{"url": "https://github.com/Project-MONAI/tutorials", "name": "Tutorials"}],
117 | # "icon_links": [
118 | # {"name": "GitHub", "url": "https://github.com/project-monai/monai", "icon": "fab fa-github-square"},
119 | # {"name": "Twitter", "url": "https://twitter.com/projectmonai", "icon": "fab fa-twitter-square"},
120 | # ],
121 | "collapse_navigation": True,
122 | "navigation_depth": 1,
123 | "show_toc_level": 1,
124 | "footer_start": ["copyright"],
125 | "navbar_align": "content",
126 | # "logo": {"image_light": "MONAI-logo-color.png", "image_dark": "MONAI-logo-color.png"},
127 | }
128 | html_context = {
129 | "github_user": "taipingeric",
130 | "github_repo": "fusionlab",
131 | "github_version": "dev",
132 | "doc_path": "docs/",
133 | "conf_py_path": "/docs/",
134 | "VERSION": version,
135 | }
136 | html_scaled_image_link = False
137 | html_show_sourcelink = True
138 | # html_favicon = "../images/favicon.ico"
139 | # html_logo = "../images/MONAI-logo-color.png"
140 | html_sidebars = {"**": ["search-field", "sidebar-nav-bs"]}
141 | pygments_style = "sphinx"
142 |
143 | # Add any paths that contain custom static files (such as style sheets) here,
144 | # relative to this directory. They are copied after the builtin static files,
145 | # so a file named "default.css" will overwrite the builtin "default.css".
146 | html_static_path = ["../_static"]
147 | html_css_files = ["custom.css"]
148 | html_title = f"{project} {version} Documentation"
149 |
150 | # -- Auto-convert markdown pages to demo --------------------------------------
151 |
152 |
153 | def setup(app):
154 | # Hook to allow for automatic generation of API docs
155 | # before doc deployment begins.
156 | app.connect("builder-inited", generate_apidocs)
--------------------------------------------------------------------------------
/docs/source/datasets.rst:
--------------------------------------------------------------------------------
1 | :github_url: https://github.com/taipingeric/fusionlab
2 |
3 | .. _datasets:
4 |
5 | Datasets
6 | ==============
7 |
8 | .. automodule:: fusionlab.datasets
9 | .. currentmodule:: fusionlab.datasets
10 |
11 | `GE Muse XML Reader`
12 | ~~~~~~~~~~
13 | .. autoclass:: GEMuseXMLReader
14 | :members:
15 |
16 | `ECG Classification Dataset`
17 | ~~~~~~~~~~
18 | .. autoclass:: ECGClassificationDataset
19 | :members:
20 |
21 | Cinc2017 Dataset
22 | ----------------
23 |
24 | `ECG CSV Classification Dataset`
25 | ~~~~~~~~~~
26 | .. autoclass:: ECGCSVClassificationDataset
27 | :members:
28 |
29 | `Validate CinC 2017 Dataset`
30 | ~~~~~~~~~~
31 | .. autofunction:: validate_data
32 |
33 | `.mat to .csv`
34 | ~~~~~~~~~~
35 | .. autofunction:: convert_mat_to_csv
36 |
37 | LUDB Dataset
38 | ----------------
39 |
40 | `LUDB Dataset`
41 | ~~~~~~~~~~
42 | .. autoclass:: LUDBDataset
43 | :members:
44 |
45 | .. autofunction:: plot
46 |
47 | Utils
48 | ----------------
49 |
50 | `Download file`
51 | ~~~~~~~~~~
52 | .. autofunction:: download_file
53 |
54 | `HuggingFace Dataset`
55 | ~~~~~~~~~~
56 | .. autoclass:: HFDataset
57 | :members:
58 |
59 | `LabelStudio Time series Segmentation Dataset`
60 | ~~~~~~~~~~
61 | .. autoclass:: LSTimeSegDataset
62 | :members:
63 |
64 | `Read csv`
65 | ~~~~~~~~~~
66 | .. autofunction:: read_csv
67 |
68 |
--------------------------------------------------------------------------------
/docs/source/encoders.rst:
--------------------------------------------------------------------------------
1 | :github_url: https://github.com/taipingeric/fusionlab
2 |
3 | .. _encoders:
4 |
5 | Encoders
6 | ==============
7 |
8 | PyTorch Encoders
9 | -------------------
10 |
11 | .. automodule:: fusionlab.encoders
12 | .. currentmodule:: fusionlab.encoders
13 |
14 | `AlexNet`
15 | ~~~~~~~~~~
16 | .. autoclass:: AlexNet
17 | :members:
18 |
19 | `VGG`
20 | ~~~~~~~~~~
21 | .. autoclass:: VGG16
22 | :members:
23 | .. autoclass:: VGG19
24 | :members:
25 |
26 | `InceptionNet`
27 | ~~~~~~~~~~
28 | .. autoclass:: InceptionNetV1
29 | :members:
30 |
31 | `ResNet`
32 | ~~~~~~~~~~
33 | .. autoclass:: ResNet
34 | :members:
35 | .. autoclass:: ResNet18
36 | :members:
37 | .. autoclass:: ResNet34
38 | :members:
39 | .. autoclass:: ResNet50
40 | :members:
41 | .. autoclass:: ResNet101
42 | :members:
43 | .. autoclass:: ResNet152
44 | :members:
45 |
46 | `EfficientNet`
47 | ~~~~~~~~~~
48 | .. autoclass:: EfficientNet
49 | :members:
50 | .. autoclass:: EfficientNetB0
51 | :members:
52 | .. autoclass:: EfficientNetB1
53 | :members:
54 | .. autoclass:: EfficientNetB2
55 | :members:
56 | .. autoclass:: EfficientNetB3
57 | :members:
58 | .. autoclass:: EfficientNetB4
59 | :members:
60 | .. autoclass:: EfficientNetB5
61 | :members:
62 | .. autoclass:: EfficientNetB6
63 | :members:
64 | .. autoclass:: EfficientNetB7
65 | :members:
66 |
67 | `ConvNeXt`
68 | ~~~~~~~~~~
69 | .. autoclass:: ConvNeXt
70 | :members:
71 | .. autoclass:: ConvNeXtTiny
72 | :members:
73 | .. autoclass:: ConvNeXtSmall
74 | :members:
75 | .. autoclass:: ConvNeXtBase
76 | :members:
77 | .. autoclass:: ConvNeXtLarge
78 | :members:
79 | .. autoclass:: ConvNeXtXLarge
80 | :members:
81 |
82 | `Vision Transformer`
83 | ~~~~~~~~~~
84 | .. autoclass:: ViT
85 | :members:
86 | .. autoclass:: VisionTransformer
87 | :members:
88 |
89 | `Mix Transformer`
90 | ~~~~~~~~~~
91 | .. autoclass:: MiT
92 | :members:
93 | .. autoclass:: MiTB0
94 | :members:
95 | .. autoclass:: MiTB1
96 | :members:
97 | .. autoclass:: MiTB2
98 | :members:
99 | .. autoclass:: MiTB3
100 | :members:
101 | .. autoclass:: MiTB4
102 | :members:
103 | .. autoclass:: MiTB5
104 | :members:
105 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | FusionLab API Document
2 | =============
3 |
4 | .. toctree::
5 | :maxdepth: 1
6 |
7 | encoders
8 | losses
9 | segmentation
10 | layers
11 | metrics
12 | datasets
13 | utils
--------------------------------------------------------------------------------
/docs/source/layers.rst:
--------------------------------------------------------------------------------
1 | :github_url: https://github.com/taipingeric/fusionlab
2 |
3 | .. _layers:
4 |
5 | Layers
6 | ==============
7 | Custom layers
8 |
9 | .. automodule:: fusionlab.layers
10 | .. currentmodule:: fusionlab.layers
11 |
12 | N-dimensional layer
13 | --------
14 |
15 | `Convolution ND`
16 | ~~~~~~~~~~
17 | .. autoclass:: ConvND
18 | :members:
19 |
20 | `Transposed Convolution ND`
21 | ~~~~~~~~~~
22 | .. autoclass:: ConvT
23 | :members:
24 |
25 | `Upsample ND`
26 | ~~~~~~~~~~
27 | .. autoclass:: Upsample
28 | :members:
29 |
30 | `BatchNorm ND`
31 | ~~~~~~~~~~
32 | .. autoclass:: BatchNorm
33 | :members:
34 |
35 | `InstanceNorm ND`
36 | ~~~~~~~~~~
37 | .. autoclass:: InstanceNorm
38 | :members:
39 |
40 | `Maxpool ND`
41 | ~~~~~~~~~~
42 | .. autoclass:: MaxPool
43 | :members:
44 |
45 | `AvgPool ND`
46 | ~~~~~~~~~~
47 | .. autoclass:: AvgPool
48 | :members:
49 |
50 | `Global Max Pool (Adaptive Max Pool) ND`
51 | ~~~~~~~~~~
52 | .. autoclass:: AdaptiveMaxPool
53 | :members:
54 |
55 | `Global Avg Pool(Adaptive Avg Pool) ND`
56 | ~~~~~~~~~~
57 | .. autoclass:: AdaptiveAvgPool
58 | :members:
59 |
60 | `Replication Padding ND`
61 | ~~~~~~~~~~
62 | .. autoclass:: ReplicationPad
63 | :members:
64 |
65 | `Constant Padding ND`
66 | ~~~~~~~~~~
67 | .. autoclass:: ConstantPad
68 | :members:
69 |
70 | `Conv|Norm|Act`
71 | ~~~~~~~~~~
72 | .. autoclass:: ConvNormAct
73 | :members:
74 |
75 | `Rearrange`
76 | ~~~~~~~~~~
77 | .. autoclass:: Rearrange
78 | :members:
79 |
80 | `Patch Embedding`
81 | ~~~~~~~~~~
82 | .. autoclass:: PatchEmbedding
83 | :members:
84 |
85 | `Squeeze-Excitation`
86 | ~~~~~~~~~~
87 | .. autoclass:: SEModule
88 | :members:
89 | .. autoclass:: TFSEModule
90 | :members:
--------------------------------------------------------------------------------
/docs/source/losses.rst:
--------------------------------------------------------------------------------
1 | :github_url: https://github.com/taipingeric/fusionlab
2 |
3 | .. _losses:
4 |
5 | Loss Function
6 | ==============
7 | This module contains the implementation of loss functions for semantic segmentation.
8 |
9 | .. automodule:: fusionlab.losses
10 | .. currentmodule:: fusionlab.losses
11 |
12 | `Dice Loss`
13 | ~~~~~~~~~~
14 | .. autoclass:: DiceLoss
15 | :members:
16 | .. autoclass:: TFDiceLoss
17 | :members:
18 |
19 | `Dice Cross Entropy Loss`
20 | ~~~~~~~~~~
21 | .. autoclass:: DiceCELoss
22 | :members:
23 | .. autoclass:: TFDiceCE
24 | :members:
25 |
26 | `IoU Loss`
27 | ~~~~~~~~~~
28 | .. autoclass:: IoULoss
29 | :members:
30 | .. autoclass:: TFIoULoss
31 | :members:
32 |
33 | `Tversky Loss`
34 | ~~~~~~~~~~
35 | .. autoclass:: TverskyLoss
36 | :members:
37 | .. autoclass:: TFTverskyLoss
38 | :members:
39 |
--------------------------------------------------------------------------------
/docs/source/metrics.rst:
--------------------------------------------------------------------------------
1 | :github_url: https://github.com/taipingeric/fusionlab
2 |
3 | .. _metrics:
4 |
5 | Metrics
6 | ==============
7 | This module contains the implementation of the metrics
8 |
9 | .. automodule:: fusionlab.metrics
10 | .. currentmodule:: fusionlab.metrics
11 |
12 | `Dice`
13 | ~~~~~~~~~~
14 | .. autoclass:: DiceScore
15 | :members:
16 | .. autoclass:: JaccardScore
17 | :members:
18 |
19 | `IoU`
20 | ~~~~~~~~~~
21 | .. autoclass:: IoUScore
22 | :members:
--------------------------------------------------------------------------------
/docs/source/segmentation.rst:
--------------------------------------------------------------------------------
1 | :github_url: https://github.com/taipingeric/fusionlab
2 |
3 | .. _segmentation:
4 |
5 | Segmentation Model
6 | ==============
7 |
8 |
9 | .. automodule:: fusionlab.segmentation
10 | .. currentmodule:: fusionlab.segmentation
11 |
12 | `UNet`
13 | ~~~~~~~~~~
14 |
15 | .. autoclass:: UNet
16 | :members:
17 | .. autoclass:: TFUNet
18 | :members:
19 |
20 | `ResUNet`
21 | ~~~~~~~~~~
22 |
23 | .. autoclass:: ResUNet
24 | :members:
25 | .. autoclass:: TFResUNet
26 | :members:
27 |
28 | `UNet++`
29 | ~~~~~~~~~~
30 |
31 | .. autoclass:: UNet2plus
32 | :members:
33 | .. autoclass:: TFUNet2plus
34 | :members:
35 |
36 | `TransUNet`
37 | ~~~~~~~~~~
38 |
39 | .. autoclass:: TransUNet
40 | :members:
41 |
42 | `UNETR`
43 | ~~~~~~~~~~
44 |
45 | .. autoclass:: UNETR
46 | :members:
47 |
48 | `SegFormer`
49 | ~~~~~~~~~~
50 |
51 | .. autoclass:: SegFormer
52 | :members:
53 |
54 | `HuggingFace Segmentation Model`
55 | ~~~~~~~~~~
56 |
57 | .. autoclass:: HFSegmentationModel
58 | :members:
--------------------------------------------------------------------------------
/docs/source/utils.rst:
--------------------------------------------------------------------------------
1 | :github_url: https://github.com/taipingeric/fusionlab
2 |
3 | .. _utils:
4 |
5 | Utils
6 | ==============
7 | This module contains the loss functions used in the project.
8 |
9 | .. automodule:: fusionlab.utils
10 | .. currentmodule:: fusionlab.utils
11 |
12 | `Auto Padding`
13 | ~~~~~~~~~~
14 | .. autofunction:: autopad
15 |
16 | `Make N tuple`
17 | ~~~~~~~~~~
18 | .. autofunction:: make_ntuple
19 |
20 | `Show Class Tree`
21 | ~~~~~~~~~~
22 | .. autofunction:: show_classtree
23 |
24 | `Plot channels`
25 | ~~~~~~~~~~
26 | .. autofunction:: plot_channels
27 |
28 | `Convert LabelMe json to Mask`
29 | ~~~~~~~~~~
30 | .. autofunction:: convert_labelme_json2mask
--------------------------------------------------------------------------------
/fusionlab/__init__.py:
--------------------------------------------------------------------------------
1 | from .__version__ import __version__
2 |
3 | # Check if PyTorch or TensorFlow is installed
4 | try:
5 | import torch
6 | torch_installed = True
7 | except ImportError:
8 | torch_installed = False
9 |
10 | try:
11 | import tensorflow as tf
12 | tf_installed = True
13 | except ImportError:
14 | tf_installed = False
15 |
16 | print(f"PyTorch installed: {torch_installed}")
17 | print(f"TensorFlow installed: {tf_installed}")
18 |
19 | BACKEND = {
20 | "torch": torch_installed,
21 | "tf": tf_installed
22 | }
23 |
24 | # check if no backend installed
25 | if not any(BACKEND.values()):
26 | print("None of supported backend installed")
27 |
28 | from . import (
29 | functional,
30 | encoders,
31 | utils,
32 | datasets,
33 | layers,
34 | classification,
35 | segmentation,
36 | losses,
37 | trainers,
38 | configs,
39 | metrics,
40 | )
--------------------------------------------------------------------------------
/fusionlab/__version__.py:
--------------------------------------------------------------------------------
1 | VERSION = (0, 1, 13)
2 |
3 | __version__ = ".".join(map(str, VERSION))
4 |
--------------------------------------------------------------------------------
/fusionlab/classification/__init__.py:
--------------------------------------------------------------------------------
1 | from fusionlab import BACKEND
2 |
3 | if BACKEND['torch']:
4 | from .base import (
5 | CNNClassificationModel,
6 | RNNClassificationModel,
7 | HFClassificationModel
8 | )
9 | from .lstm import LSTMClassifier
10 | from .vgg import (
11 | VGG16Classifier,
12 | VGG19Classifier
13 | )
14 | from .base import HFClassificationModel
--------------------------------------------------------------------------------
/fusionlab/classification/base.py:
--------------------------------------------------------------------------------
1 | # You won't directly use this class, but it will be used by the other classes.
2 | # Unless you want to build the class yourself, you may use this
3 |
4 | import torch.nn as nn
5 |
6 | class CNNClassificationModel(nn.Module):
7 | """
8 | Base PyTorch class of the classification model with Encoder, Head for CNN
9 | """
10 | def forward(self, x):
11 | # 1D signal x => [BATCH, CHANNEL, TIME]
12 | # 1D spectrum x => [BATCH, FREQUENCY, TIME] (single channel)
13 | # 2D spectrum x => [BATCH, CHANNEL, FREQUENCY, TIME] (multi channel)
14 | # 2D image x => [BATCH, CHANNEL, HEIGHT, WIDTH]
15 | # 3D volumetric x => [BATCH, CHANNEL, HEIGHT, WIDTH, DEPTH]
16 | features = self.encoder(x) # => [BATCH, 512, ...]
17 | features_agg = self.globalpooling(features) # => [BATCH, 512, 1, (1, (1))]
18 | output = self.head(features_agg.view(x.shape[0],-1)) # => [BATCH, NUM_CLS]
19 | return output
20 |
21 | class RNNClassificationModel(nn.Module):
22 | """
23 | Base PyTorch class of the classification model with Encoder, Head for RNN
24 | """
25 | def forward(self, x):
26 | # 1D signal x => [BATCH, CHANNEL, TIME]
27 | x = x.transpose(1,2)
28 | features, _ = self.encoder(x) # RNN will output feature and states
29 | output = self.head(features[:, -1, :])
30 | return output
31 |
32 | class HFClassificationModel(nn.Module):
33 | """
34 | Base Hugginface-pytoch model wrapper class of the classification model
35 | """
36 | def __init__(self, model,
37 | num_cls=None,
38 | loss_fct=nn.CrossEntropyLoss()):
39 | super().__init__()
40 | self.net = model
41 | if 'num_cls' in model.__dict__.keys():
42 | self.num_cls = model.num_cls
43 | else:
44 | self.num_cls = num_cls
45 | assert self.num_cls is not None, "num_cls is not defined"
46 | self.loss_fct = loss_fct
47 | def forward(self, x, labels=None):
48 | logits = self.net(x) # Forward pass the model
49 | if labels is not None:
50 | # logits => [BATCH, NUM_CLS]
51 | # labels => [BATCH]
52 | loss = self.loss_fct(logits.view(-1, self.num_cls), labels.view(-1)) # Calculate loss
53 | else:
54 | loss = None
55 | # return dictionary for hugginface trainer
56 | return {'loss':loss, 'logits':logits, 'hidden_states':None}
57 |
58 | # Test the function
59 | if __name__ == '__main__':
60 | import torch
61 | from fusionlab.classification import VGG16Classifier
62 | from fusionlab.classification import LSTMClassifier
63 |
64 | H = W = 224
65 | cout = 5
66 | inputs = torch.normal(0, 1, (1, 3, W))
67 | # Test CNNClassification
68 | model = VGG16Classifier(3, cout, spatial_dims=1)
69 | hf_model = HFClassificationModel(model, cout)
70 | output = hf_model(inputs)
71 | print(output['logits'].shape)
72 | assert list(output.keys()) == ['loss', 'logits', 'hidden_states']
73 |
74 | inputs = torch.normal(0, 1, (1, 3, H, W))
75 | # Test CNNClassification
76 | model = VGG16Classifier(3, cout, spatial_dims=2)
77 | hf_model = HFClassificationModel(model, cout)
78 | output = hf_model(inputs)
79 | print(output['logits'].shape)
80 | assert list(output.keys()) == ['loss', 'logits', 'hidden_states']
81 |
82 | inputs = torch.normal(0, 1, (1, 3, H))
83 | model = LSTMClassifier(3, cout)
84 | hf_model = HFClassificationModel(model, cout)
85 | output = hf_model(inputs)
86 | print(output['logits'].shape)
87 | assert list(output.keys()) == ['loss', 'logits', 'hidden_states']
--------------------------------------------------------------------------------
/fusionlab/classification/lstm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from fusionlab.classification.base import RNNClassificationModel
4 |
5 |
6 | class LSTMClassifier(RNNClassificationModel):
7 | def __init__(self, cin, cout, hidden_size=512):
8 | super().__init__()
9 | self.encoder = nn.LSTM(input_size=cin, hidden_size=hidden_size, batch_first=True) # define LSTM layer
10 | self.head = nn.Linear(hidden_size, cout) # define output head layer
11 |
12 | if __name__ == '__main__':
13 | inputs = torch.randn(1, 5, 3) # create random input tensor
14 | model = LSTMClassifier(cin=3, hidden_size=4, cout=2) # create model instance
15 | outputs = model(inputs) # pass input through model
16 | assert list(outputs.shape) == [1, 2] # check output shape is correct
--------------------------------------------------------------------------------
/fusionlab/classification/vgg.py:
--------------------------------------------------------------------------------
1 | # VGG Classifier
2 | import torch
3 | from torch import nn
4 | from fusionlab.classification.base import CNNClassificationModel
5 | from fusionlab.encoders import VGG16, VGG19
6 | from fusionlab.layers import AdaptiveAvgPool
7 |
8 | class VGG16Classifier(CNNClassificationModel):
9 | def __init__(self, cin, num_cls, spatial_dims=2):
10 | super().__init__()
11 | self.num_cls = num_cls
12 | self.encoder = VGG16(cin, spatial_dims) # Create VGG16 instance
13 | self.globalpooling = AdaptiveAvgPool(spatial_dims, 1)
14 | self.head = nn.Linear(512, num_cls)
15 |
16 | class VGG19Classifier(CNNClassificationModel):
17 | def __init__(self, cin, num_cls, spatial_dims=2):
18 | super().__init__()
19 | self.num_cls = num_cls
20 | self.encoder = VGG19(cin, spatial_dims) # Create VGG16 instance
21 | self.globalpooling = AdaptiveAvgPool(spatial_dims, 1)
22 | self.head = nn.Linear(512, num_cls)
23 |
24 |
25 | if __name__ == '__main__':
26 | inputs = torch.randn(1, 3, 224) # create random input tensor
27 | model = VGG16Classifier(cin=3, num_cls=2, spatial_dims=1) # create model instance
28 | outputs = model(inputs) # pass input through model
29 | assert list(outputs.shape) == [1, 2] # check output shape is correct
30 |
31 | inputs = torch.randn(1, 3, 224) # create random input tensor
32 | model = VGG19Classifier(cin=3, num_cls=2, spatial_dims=1) # create model instance
33 |
34 | outputs = model(inputs) # pass input through model
35 | assert list(outputs.shape) == [1, 2] # check output shape is correct
--------------------------------------------------------------------------------
/fusionlab/configs.py:
--------------------------------------------------------------------------------
1 | EPS = 1e-07 # epsilon
--------------------------------------------------------------------------------
/fusionlab/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .muse import GEMuseXMLReader
2 | from .csvread import read_csv
3 |
4 | from fusionlab import BACKEND
5 | if BACKEND['torch']:
6 | from .a12lead import ECGClassificationDataset
7 | from .cinc2017 import (
8 | ECGCSVClassificationDataset,
9 | convert_mat_to_csv,
10 | validate_data
11 | )
12 | from .ludb import (
13 | LUDBDataset,
14 | plot
15 | )
16 | from .utils import (
17 | download_file,
18 | HFDataset,
19 | LSTimeSegDataset,
20 | )
--------------------------------------------------------------------------------
/fusionlab/datasets/a12lead.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from torch.utils.data import Dataset # 導入PyTorch的資料處理工具
3 | from fusionlab.datasets import csvread # 讀取ECG signal
4 |
5 |
6 | # 做一個簡單的Dataset class
7 | class ECGClassificationDataset(Dataset):
8 | def __init__(self,
9 | annotation_file, # 標註檔案路徑
10 | data_dir, # ECG資料路徑
11 | transform=None,
12 | channels=None,
13 | class_names=None): # 資料轉換函式
14 | self.signal_paths = data_dir
15 | self.ecg_signals = pd.read_csv(annotation_file) # 讀取標註檔案
16 | self.transform = transform
17 | self.channels = channels
18 | self.class_map = {n: i for i, n in enumerate(class_names)}
19 | def __len__(self):
20 | return len(self.ecg_signals) # 回傳資料筆數
21 | def __getitem__(self, idx):
22 | entry = self.ecg_signals.iloc[idx] # 取得指定索引的資料
23 | data = csvread.read_csv(self.signal_paths+entry['filename']) # 讀取ECG資料
24 | label = entry['label'] # 取得標籤
25 | if self.transform:
26 | data = self.transform(data) # 資料轉換
27 | return data, label # 回傳資料和標籤
--------------------------------------------------------------------------------
/fusionlab/datasets/cinc2017.py:
--------------------------------------------------------------------------------
1 | from typing import * # Importing all types from typing module
2 | import os # Importing os module
3 | from glob import glob # Importing glob function from glob module
4 | from tqdm.auto import tqdm
5 | from fusionlab.datasets.utils import download_file # Importing tqdm function from tqdm.auto module
6 |
7 | from scipy import io # Importing scipy module
8 | import pandas as pd # Importing pandas module and aliasing it as pd
9 |
10 | import torch # Importing torch module
11 | from torch.utils.data import Dataset # Importing Dataset class from torch.utils.data module
12 |
13 | URL_ECG = "https://physionet.org/files/challenge-2017/1.0.0/training2017.zip"
14 | # URL_LABEL = "https://physionet.org/content/challenge-2017/1.0.0/REFERENCE-v3.csv"
15 | URL_LABEL = "https://physionet.org/files/challenge-2017/1.0.0/REFERENCE-v3.csv"
16 |
17 |
18 | class ECGCSVClassificationDataset(Dataset):
19 |
20 | def __init__(self,
21 | data_root,
22 | label_filename="REFERENCE-v3.csv",
23 | channels=["lead"],
24 | class_names=["N", "O", "A", "~"]):
25 | """
26 | Args:
27 | data_root (str): root directory of the dataset
28 | label_filename (str): filename of the label file
29 | channels (list): list of target lead names
30 | class_names (list): list of class names for mapping class name to class id
31 | """
32 | # read the label file and store it in a pandas dataframe
33 | self.df_label = pd.read_csv(os.path.join(data_root, label_filename),
34 | header=None,
35 | names=["pat", "label"])
36 | # set the directory where the signal files are stored
37 | self.signal_dir = os.path.join(data_root, "csv")
38 | # get the paths of all the signal files and sort them
39 | self.signal_paths = sorted(glob(os.path.join(self.signal_dir, "*.csv")))
40 | # create a dictionary to map class names to class ids
41 | self.class_map = {n: i for i, n in enumerate(class_names)}
42 | # set the target leads for the dataset
43 | self.channels = channels
44 | print("dataset class map: ", self.class_map)
45 |
46 | def __len__(self):
47 | return len(self.signal_paths)
48 |
49 | def __getitem__(self, idx):
50 | row = self.df_label.iloc[idx]
51 | signal_filename = row["pat"] + ".csv"
52 | signal_path = os.path.join(self.signal_dir, signal_filename)
53 | df_csv = pd.read_csv(signal_path)
54 | signal = df_csv["lead"].values
55 |
56 | class_name = row["label"]
57 | class_id = self.class_map[class_name]
58 |
59 | signal = torch.tensor(signal, dtype=torch.float)
60 | class_id = torch.tensor(class_id, dtype=torch.long)
61 |
62 | # preprocess
63 | signal = signal.unsqueeze(0)
64 | return signal, class_id
65 |
66 | def _check_validate(self):
67 | assert len(self.df_label) == len(
68 | self.signal_paths), "csv files and label files are not matched"
69 |
70 |
71 | def convert_mat_to_csv(root, target_dir="csv"):
72 | paths = glob(os.path.join(root, "training2017",
73 | "*.mat")) # get all paths of .mat files in the training2017 folder
74 | os.makedirs(os.path.join(root, target_dir),
75 | exist_ok=True) # create a new directory named target_dir in the root directory
76 | print("mat files: ", len(paths)) # print the number of .mat files found
77 | print("start to convert mat files to csv files"
78 | ) # print a message indicating the start of the conversion process
79 | for path in tqdm(paths): # iterate through each path in paths
80 | filename = os.path.basename(path) # get the filename from the path
81 | file_id = filename.split(".")[
82 | 0] # get the file ID by splitting the filename at the "." and taking the first part
83 | target_filename = file_id + ".csv" # create the target filename by appending ".csv" to the file ID
84 | signal = io.loadmat(path)["val"][0] # load the .mat file and extract the "val" array
85 | df = pd.DataFrame(columns=["lead"
86 | ]) # create a new DataFrame with a single column named "lead"
87 | df["lead"] = signal # set the "lead" column to the "val" array
88 | df.to_csv(
89 | os.path.join(root, target_dir, target_filename)
90 | ) # save the DataFrame as a CSV file in the target directory with the target filename
91 |
92 |
93 | def validate_data(csv_dir, label_path):
94 | """
95 | check if the number of csv files and label files are matched
96 | """
97 | csv_paths = glob(os.path.join(csv_dir, "*.csv")) # get all csv files in the directory
98 | df_label = pd.read_csv(label_path, header=None,
99 | names=["pat", "label"]) # read the label file as a dataframe
100 | print("csv files: ", len(csv_paths)) # print the number of csv files
101 | print("label files: ", len(df_label)) # print the number of label files
102 | assert len(csv_paths) == len(
103 | df_label
104 | ), "csv files and label files are not matched" # check if the number of csv files and label files are equal
105 | return # return nothing
106 |
107 |
108 | if __name__ == "__main__":
109 | root = "data"
110 | try:
111 | validate_data("./data/csv", "./data/REFERENCE-v3.csv")
112 | except:
113 | print("validation failed, start to donwload and convert data")
114 | download_file(URL_ECG, root, extract=True)
115 | download_file(URL_LABEL, root, extract=False)
116 | convert_mat_to_csv(root)
117 |
118 | ds = ECGCSVClassificationDataset("./data", label_filename="REFERENCE-v3.csv")
119 |
120 | sig, label = ds[0]
121 | print(sig.shape, label)
122 |
--------------------------------------------------------------------------------
/fusionlab/datasets/csv_sample.csv:
--------------------------------------------------------------------------------
1 | time, I, II, III
2 | 0.0, 1, 1, 1
3 | 0.01, 2, 2, 2
4 | 0.02, 3, 3, 3
--------------------------------------------------------------------------------
/fusionlab/datasets/csvread.py:
--------------------------------------------------------------------------------
1 | # Reader for csv files
2 | import numpy as np
3 |
4 | def read_csv(fname):
5 | readout = np.genfromtxt(fname, dtype=np.float32, skip_header=1, delimiter=",")
6 | return readout[:, 1:]
7 |
8 | if __name__ == "__main__":
9 | signal = read_csv('csv_sample.csv')
10 | assert list(signal.shape) == [3, 3]
11 | assert list(signal[:,0]) == [1., 2., 3.]
--------------------------------------------------------------------------------
/fusionlab/encoders/README.md:
--------------------------------------------------------------------------------
1 | # Encoders
2 |
3 | | Name | PyTorch | Tensorflow | Paper|
4 | | :----------- | :----------------------------------------------- | :---------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
5 | | ConvNeXt | ConvNeXtTiny \| Small \| Base \| Large \| XLarge | X | [A ConvNet for the 2020s (2020)](https://arxiv.org/abs/2201.03545)|
6 | | EfficientNet | EfficientNetB0~7 | X | [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks (2019)](https://arxiv.org/abs/1905.11946)|
7 | | ResNet | ResNet50 | TFResNet50 | [Deep Residual Learning for Image Recognition (2015)](https://arxiv.org/abs/1512.03385)|
8 | | InceptionV1 | InceptionNetV1 | TFInceptionNetV1 | [Going Deeper with Convolutions (2014)](https://arxiv.org/abs/1409.4842)|
9 | | VGG | VGG16, VGG19 | TFVGG16, TFVGG19 | [Very Deep Convolutional Networks for Large-Scale Image Recognition (2014)](https://arxiv.org/abs/1409.1556)|
10 | | AlexNet | AlexNet | TFAlexNet | [ImageNet Classification with Deep Convolutional Neural Networks (2012)](https://proceedings.neurips.cc/paper/2012/file/c399862d3b9d6b76c8436e924a68c45b-Paper.pdf)|
11 |
--------------------------------------------------------------------------------
/fusionlab/encoders/__init__.py:
--------------------------------------------------------------------------------
1 | from fusionlab import BACKEND
2 | if BACKEND['torch']:
3 | from .alexnet.alexnet import AlexNet
4 | from .vgg.vgg import VGG16, VGG19
5 | from .inceptionv1.inceptionv1 import InceptionNetV1
6 | from .resnetv1.resnetv1 import *
7 | from .efficientnet.efficientnet import (
8 | EfficientNet,
9 | EfficientNetB0,
10 | EfficientNetB1,
11 | EfficientNetB2,
12 | EfficientNetB3,
13 | EfficientNetB4,
14 | EfficientNetB5,
15 | EfficientNetB6,
16 | EfficientNetB7
17 | )
18 | from .convnext.convnext import (
19 | ConvNeXt,
20 | ConvNeXtTiny,
21 | ConvNeXtSmall,
22 | ConvNeXtBase,
23 | ConvNeXtLarge,
24 | ConvNeXtXLarge
25 | )
26 | from .vit.vit import (
27 | ViT,
28 | VisionTransformer
29 | )
30 | from .mit.mit import (
31 | MiT,
32 | MiTB0,
33 | MiTB1,
34 | MiTB2,
35 | MiTB3,
36 | MiTB4,
37 | MiTB5,
38 | )
39 | if BACKEND['tf']:
40 | from .alexnet.tfalexnet import TFAlexNet
41 | from .vgg.tfvgg import TFVGG16, TFVGG19
42 | from .inceptionv1.tfinceptionv1 import TFInceptionNetV1
43 | from .resnetv1.tfresnetv1 import TFResNet50V1
44 |
--------------------------------------------------------------------------------
/fusionlab/encoders/alexnet/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/encoders/alexnet/__init__.py
--------------------------------------------------------------------------------
/fusionlab/encoders/alexnet/alexnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from fusionlab.layers.factories import ConvND, MaxPool
4 |
5 | # Official pytorch ref: https://github.com/pytorch/vision/blob/main/torchvision/models/alexnet.py
6 | class AlexNet(nn.Module):
7 | def __init__(self, cin=3, spatial_dims=2):
8 | super().__init__()
9 | self.features = nn.Sequential(
10 | ConvND(spatial_dims, cin, 64, kernel_size=11, stride=4, padding=2),
11 | nn.ReLU(inplace=True),
12 | MaxPool(spatial_dims, kernel_size=3, stride=2),
13 | ConvND(spatial_dims, 64, 192, kernel_size=5, padding=2),
14 | nn.ReLU(inplace=True),
15 | MaxPool(spatial_dims, kernel_size=3, stride=2),
16 | ConvND(spatial_dims, 192, 384, kernel_size=3, padding=1),
17 | nn.ReLU(inplace=True),
18 | ConvND(spatial_dims, 384, 256, kernel_size=3, padding=1),
19 | nn.ReLU(inplace=True),
20 | ConvND(spatial_dims, 256, 256, kernel_size=3, padding=1),
21 | nn.ReLU(inplace=True),
22 | MaxPool(spatial_dims, kernel_size=3, stride=2),
23 | )
24 |
25 | def forward(self, x: torch.Tensor) -> torch.Tensor:
26 | return self.features(x)
27 |
28 |
29 |
30 | if __name__ == '__main__':
31 | img_size = 224
32 | inputs = torch.normal(0, 1, (1, 3, img_size, img_size))
33 | output = AlexNet(3, spatial_dims=2)(inputs)
34 | print(output.shape)
35 | assert list(output.shape) == [1, 256, 6, 6]
36 |
37 | inputs = torch.normal(0, 1, (1, 3, img_size))
38 | output = AlexNet(3, spatial_dims=1)(inputs)
39 | print(output.shape)
40 | assert list(output.shape) == [1, 256, 6]
41 |
42 | img_size = 128
43 | inputs = torch.normal(0, 1, (1, 3, img_size, img_size, img_size))
44 | output = AlexNet(3, spatial_dims=3)(inputs)
45 | print(output.shape)
46 | assert list(output.shape) == [1, 256, 3, 3, 3]
47 |
--------------------------------------------------------------------------------
/fusionlab/encoders/alexnet/tfalexnet.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | class TFAlexNet(tf.keras.Model):
5 | def __init__(self):
6 | super().__init__()
7 | self.features = tf.keras.Sequential([
8 | tf.keras.layers.ZeroPadding2D(2),
9 | tf.keras.layers.Conv2D(64, kernel_size=11, strides=4),
10 | tf.keras.layers.ReLU(),
11 | tf.keras.layers.MaxPool2D(pool_size=3, strides=2),
12 | tf.keras.layers.Conv2D(192, kernel_size=5, padding='same'),
13 | tf.keras.layers.ReLU(),
14 | tf.keras.layers.MaxPool2D(pool_size=3, strides=2),
15 | tf.keras.layers.Conv2D(384, kernel_size=3, padding='same'),
16 | tf.keras.layers.ReLU(),
17 | tf.keras.layers.Conv2D(256, kernel_size=3, padding='same'),
18 | tf.keras.layers.ReLU(),
19 | tf.keras.layers.Conv2D(256, kernel_size=3, padding='same'),
20 | tf.keras.layers.ReLU(),
21 | tf.keras.layers.MaxPool2D(pool_size=3, strides=2)
22 | ])
23 |
24 | def call(self, inputs):
25 | return self.features(inputs)
26 |
27 |
28 | if __name__ == '__main__':
29 | inputs = tf.random.normal((1, 224, 224, 3))
30 | output = TFAlexNet()(inputs)
31 | shape = output.shape
32 | print(shape)
33 | assert shape[1:3] == [6, 6]
34 |
--------------------------------------------------------------------------------
/fusionlab/encoders/convnext/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/encoders/convnext/__init__.py
--------------------------------------------------------------------------------
/fusionlab/encoders/efficientnet/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/encoders/efficientnet/__init__.py
--------------------------------------------------------------------------------
/fusionlab/encoders/inceptionv1/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/encoders/inceptionv1/__init__.py
--------------------------------------------------------------------------------
/fusionlab/encoders/inceptionv1/inceptionv1.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from fusionlab.layers.factories import ConvND, MaxPool
5 | from fusionlab.utils import autopad
6 |
7 | # ref: https://arxiv.org/abs/1409.4842
8 | # Going Deeper with Convolutions
9 | class ConvBlock(nn.Module):
10 | def __init__(self, cin, cout, kernel_size=3, spatial_dims=2, stride=1):
11 | super().__init__()
12 | self.conv = ConvND(spatial_dims, cin, cout, kernel_size, stride, padding=autopad(kernel_size))
13 | self.act = nn.ReLU(inplace=True)
14 |
15 | def forward(self, x):
16 | x = self.conv(x)
17 | x = self.act(x)
18 | return x
19 |
20 |
21 | class InceptionBlock(nn.Module):
22 | def __init__(self, cin, dim0, dim1, dim2, dim3, spatial_dims=2):
23 | super().__init__()
24 | self.branch1 = ConvBlock(cin, dim0, 3, spatial_dims)
25 | self.branch3 = nn.Sequential(ConvBlock(cin, dim1[0], 1, spatial_dims),
26 | ConvBlock(dim1[0], dim1[1], 3, spatial_dims))
27 | self.branch5 = nn.Sequential(ConvBlock(cin, dim2[0], 1, spatial_dims),
28 | ConvBlock(dim2[0], dim2[1], 5, spatial_dims))
29 | self.pool = nn.Sequential(MaxPool(spatial_dims, 3, 1, autopad(3)),
30 | ConvBlock(cin, dim3, 3,spatial_dims))
31 |
32 | def forward(self, x):
33 | x0 = self.branch1(x)
34 | x1 = self.branch3(x)
35 | x2 = self.branch5(x)
36 | x3 = self.pool(x)
37 | x = torch.cat((x0, x1, x2, x3), 1)
38 | return x
39 |
40 |
41 | class InceptionNetV1(nn.Module):
42 | def __init__(self, cin=3, spatial_dims=2):
43 | super().__init__()
44 | self.stem = nn.Sequential(
45 | ConvBlock(cin, 64, 7, spatial_dims, stride=2),
46 | MaxPool(spatial_dims, 3, 2, padding=autopad(3)),
47 | ConvBlock(64, 192, 3, spatial_dims),
48 | MaxPool(spatial_dims, 3, 2, padding=autopad(3)),
49 | )
50 | self.incept3a = InceptionBlock(192, 64, (96, 128), (16, 32), 32, spatial_dims)
51 | self.incept3b = InceptionBlock(256, 128, (128, 192), (32, 96), 64, spatial_dims)
52 | self.pool3 = MaxPool(spatial_dims, 3, 2, padding=autopad(3))
53 | self.incept4a = InceptionBlock(480, 192, (96, 208), (16, 48), 64, spatial_dims)
54 | self.incept4b = InceptionBlock(512, 160, (112, 224), (24, 64), 64, spatial_dims)
55 | self.incept4c = InceptionBlock(512, 128, (128, 256), (24, 64), 64, spatial_dims)
56 | self.incept4d = InceptionBlock(512, 112, (144, 288), (32, 64), 64, spatial_dims)
57 | self.incept4e = InceptionBlock(528, 256, (160, 320), (32, 128), 128, spatial_dims)
58 | self.pool4 = MaxPool(spatial_dims, 3, 2, padding=autopad(3))
59 | self.incept5a = InceptionBlock(832, 256, (160, 320), (32, 128), 128, spatial_dims)
60 | self.incept5b = InceptionBlock(832, 384, (192, 384), (48, 128), 128, spatial_dims)
61 |
62 | def forward(self, x):
63 | x = self.stem(x)
64 | x = self.incept3a(x)
65 | x = self.incept3b(x)
66 | x = self.pool3(x)
67 | x = self.incept4a(x)
68 | x = self.incept4b(x)
69 | x = self.incept4c(x)
70 | x = self.incept4d(x)
71 | x = self.incept4e(x)
72 | x = self.pool4(x)
73 | x = self.incept5a(x)
74 | x = self.incept5b(x)
75 | return x
76 |
77 | if __name__ == "__main__":
78 | inputs = torch.normal(0, 1, (1, 3, 224, 224))
79 | outputs = InceptionBlock(3, 64, (96, 128), (16, 32), 32)(inputs)
80 | print(outputs.shape)
81 | assert list(outputs.shape) == [1, 256, 224, 224]
82 |
83 | outputs = InceptionNetV1()(inputs)
84 | print(outputs.shape)
85 | assert list(outputs.shape) == [1, 1024, 7, 7]
86 |
87 | inputs = torch.normal(0, 1, (1, 3, 224))
88 | outputs = InceptionNetV1(spatial_dims=1)(inputs)
89 | print(outputs.shape)
90 | assert list(outputs.shape) == [1, 1024, 7]
91 |
--------------------------------------------------------------------------------
/fusionlab/encoders/inceptionv1/tfinceptionv1.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.keras import Model, layers, Sequential
3 |
4 | # ref: https://arxiv.org/abs/1409.4842
5 | # Going Deeper with Convolutions
6 |
7 |
8 | class ConvBlock(Model):
9 | def __init__(self, cout, kernel_size=3, stride=1):
10 | super().__init__()
11 | self.conv = layers.Conv2D(cout, kernel_size, stride, padding="same")
12 | self.act = layers.ReLU()
13 |
14 | def call(self, inputs):
15 | x = self.conv(inputs)
16 | x = self.act(x)
17 | return x
18 |
19 |
20 | class InceptionBlock(Model):
21 | def __init__(self, dim0, dim1, dim2, dim3):
22 | super().__init__()
23 | self.branch1 = ConvBlock(dim0, 1)
24 | self.branch3 = Sequential([
25 | ConvBlock(dim1[0], 1),
26 | ConvBlock(dim1[1], 3)
27 | ])
28 | self.branch5 = Sequential([
29 | ConvBlock(dim2[0], 1),
30 | ConvBlock(dim2[1], 5)]
31 | )
32 | self.pool = Sequential([
33 | layers.MaxPool2D(3, 1, padding='same'),
34 | ConvBlock(dim3, 1)
35 | ])
36 |
37 | def call(self, inputs):
38 | x = inputs
39 | x0 = self.branch1(x)
40 | x1 = self.branch3(x)
41 | x2 = self.branch5(x)
42 | x3 = self.pool(x)
43 | x = tf.concat([x0, x1, x2, x3], axis=-1)
44 | return x
45 |
46 |
47 | class TFInceptionNetV1(Model):
48 | def __init__(self):
49 | super().__init__()
50 | self.stem = Sequential([
51 | ConvBlock(64, 7, stride=2),
52 | layers.MaxPool2D(3, 2, padding='same'),
53 | ConvBlock(192, 3),
54 | layers.MaxPool2D(3, 2, padding='same'),
55 | ])
56 | self.incept3a = InceptionBlock(64, (96, 128), (16, 32), 32)
57 | self.incept3b = InceptionBlock(128, (128, 192), (32, 96), 64)
58 | self.pool3 = layers.MaxPool2D(3, 2, padding='same')
59 | self.incept4a = InceptionBlock(192, (96, 208), (16, 48), 64)
60 | self.incept4b = InceptionBlock(160, (112, 224), (24, 64), 64)
61 | self.incept4c = InceptionBlock(128, (128, 256), (24, 64), 64)
62 | self.incept4d = InceptionBlock(112, (144, 288), (32, 64), 64)
63 | self.incept4e = InceptionBlock(256, (160, 320), (32, 128), 128)
64 | self.pool4 = layers.MaxPool2D(3, 2, padding='same')
65 | self.incept5a = InceptionBlock(256, (160, 320), (32, 128), 128)
66 | self.incept5b = InceptionBlock(384, (192, 384), (48, 128), 128)
67 |
68 | def call(self, inputs):
69 | x = self.stem(inputs)
70 | x = self.incept3a(x)
71 | x = self.incept3b(x)
72 | x = self.pool3(x)
73 | x = self.incept4a(x)
74 | x = self.incept4b(x)
75 | x = self.incept4c(x)
76 | x = self.incept4d(x)
77 | x = self.incept4e(x)
78 | x = self.pool4(x)
79 | x = self.incept5a(x)
80 | x = self.incept5b(x)
81 | return x
82 |
83 |
84 | if __name__ == "__main__":
85 | inputs = tf.random.normal((1, 224, 224, 3))
86 | output = InceptionBlock(64, (96, 128), (16, 32), 32)(inputs)
87 | shape = output.shape
88 | print("InceptionBlock", shape)
89 | assert shape == (1, 224, 224, 256)
90 |
91 | output = TFInceptionNetV1()(inputs)
92 | shape = output.shape
93 | print("TFInceptionNetV1", shape)
94 | assert shape == (1, 7, 7, 1024)
95 |
--------------------------------------------------------------------------------
/fusionlab/encoders/mit/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/encoders/mit/__init__.py
--------------------------------------------------------------------------------
/fusionlab/encoders/mit/mit.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence
2 | import torch
3 | import torch.nn as nn
4 | from fusionlab.layers import (
5 | SRAttention,
6 | DropPath,
7 | )
8 |
9 | class PatchEmbed(nn.Module):
10 | def __init__(
11 | self,
12 | in_channels=3,
13 | out_channels=32,
14 | patch_size=7,
15 | stride=4
16 | ):
17 | super().__init__()
18 | self.proj = nn.Conv2d(
19 | in_channels,
20 | out_channels,
21 | patch_size, stride,
22 | patch_size//2)
23 | self.norm = nn.LayerNorm(out_channels)
24 |
25 | def forward(self, x: torch.Tensor) -> torch.Tensor:
26 | x = self.proj(x)
27 | _, _, h, w = x.shape
28 | x = x.flatten(2).transpose(1, 2)
29 | x = self.norm(x)
30 | return x, h, w
31 |
32 | class MiTBlock(nn.Module):
33 | def __init__(self, dim, head, spatio_reduction_ratio=1, drop_path_rate=0.):
34 | super().__init__()
35 | self.norm1 = nn.LayerNorm(dim)
36 | self.attn = SRAttention(dim, head, spatio_reduction_ratio)
37 | self.drop_path = DropPath(drop_path_rate, "row") if drop_path_rate > 0. else nn.Identity()
38 | self.norm2 = nn.LayerNorm(dim)
39 | self.mlp = MLP(dim, int(dim*4))
40 |
41 | def forward(self, x: torch.Tensor, h, w) -> torch.Tensor:
42 | x = x + self.drop_path(self.attn(self.norm1(x), h, w))
43 | x = x + self.drop_path(self.mlp(self.norm2(x), h, w))
44 | return x
45 |
46 | class MLP(nn.Module):
47 | def __init__(self, c1, c2):
48 | super().__init__()
49 | self.fc1 = nn.Linear(c1, c2)
50 | self.dwconv = nn.Conv2d(c2, c2, 3, 1, 1, groups=c2)
51 | self.fc2 = nn.Linear(c2, c1)
52 | self.act = nn.GELU()
53 |
54 | def forward(self, x: torch.Tensor, h, w) -> torch.Tensor:
55 | x = self.fc1(x)
56 | B, _, C = x.shape
57 | x = x.transpose(1, 2).view(B, C, h, w)
58 | x = self.dwconv(x)
59 | x = x.flatten(2).transpose(1, 2)
60 | x = self.act(x)
61 | x = self.fc2(x)
62 | return x
63 |
64 | class MiT(nn.Module):
65 | """
66 | Mix Transformer
67 |
68 | source code: https://github.com/sithu31296/semantic-segmentation/blob/main/semseg/models/backbones/mit.py
69 | """
70 | def __init__(
71 | self,
72 | in_channels: int = 3,
73 | embed_dims: Sequence[int] = [32, 64, 160, 256],
74 | depths: Sequence[int] = [2, 2, 2, 2]
75 | ):
76 | super().__init__()
77 | drop_path_rate = 0.1
78 | self.channels = embed_dims
79 |
80 | # patch_embed
81 | self.patch_embed1 = PatchEmbed(in_channels, embed_dims[0], 7, 4)
82 | self.patch_embed2 = PatchEmbed(embed_dims[0], embed_dims[1], 3, 2)
83 | self.patch_embed3 = PatchEmbed(embed_dims[1], embed_dims[2], 3, 2)
84 | self.patch_embed4 = PatchEmbed(embed_dims[2], embed_dims[3], 3, 2)
85 |
86 | drop_path_rate = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
87 |
88 | cur = 0
89 | self.block1 = nn.ModuleList([MiTBlock(embed_dims[0], 1, 8, drop_path_rate[cur+i]) for i in range(depths[0])])
90 | self.norm1 = nn.LayerNorm(embed_dims[0])
91 |
92 | cur += depths[0]
93 | self.block2 = nn.ModuleList([MiTBlock(embed_dims[1], 2, 4, drop_path_rate[cur+i]) for i in range(depths[1])])
94 | self.norm2 = nn.LayerNorm(embed_dims[1])
95 |
96 | cur += depths[1]
97 | self.block3 = nn.ModuleList([MiTBlock(embed_dims[2], 5, 2, drop_path_rate[cur+i]) for i in range(depths[2])])
98 | self.norm3 = nn.LayerNorm(embed_dims[2])
99 |
100 | cur += depths[2]
101 | self.block4 = nn.ModuleList([MiTBlock(embed_dims[3], 8, 1, drop_path_rate[cur+i]) for i in range(depths[3])])
102 | self.norm4 = nn.LayerNorm(embed_dims[3])
103 |
104 |
105 | def forward(self, x: torch.Tensor, return_features=False) -> torch.Tensor:
106 | bs = x.shape[0]
107 |
108 | # stage 1
109 | x, h, w = self.patch_embed1(x)
110 | for blk in self.block1:
111 | x = blk(x, h, w)
112 | x1 = self.norm1(x).reshape(bs, h, w, -1).permute(0, 3, 1, 2)
113 |
114 | # stage 2
115 | x, h, w = self.patch_embed2(x1)
116 | for blk in self.block2:
117 | x = blk(x, h, w)
118 | x2 = self.norm2(x).reshape(bs, h, w, -1).permute(0, 3, 1, 2)
119 |
120 | # stage 3
121 | x, h, w = self.patch_embed3(x2)
122 | for blk in self.block3:
123 | x = blk(x, h, w)
124 | x3 = self.norm3(x).reshape(bs, h, w, -1).permute(0, 3, 1, 2)
125 |
126 | # stage 4
127 | x, h, w = self.patch_embed4(x3)
128 | for blk in self.block4:
129 | x = blk(x, h, w)
130 | x4 = self.norm4(x).reshape(bs, h, w, -1).permute(0, 3, 1, 2)
131 |
132 | if return_features:
133 | return x4, [x1, x2, x3, x4]
134 | else:
135 | return x4
136 |
137 | class MiTB0(MiT):
138 | def __init__(self, in_channels: int = 3):
139 | super().__init__(in_channels, [32, 64, 160, 256], [2, 2, 2, 2])
140 |
141 | class MiTB1(MiT):
142 | def __init__(self, in_channels: int = 3):
143 | super().__init__(in_channels, [64, 128, 320, 512], [2, 2, 2, 2])
144 |
145 | class MiTB2(MiT):
146 | def __init__(self, in_channels: int = 3):
147 | super().__init__(in_channels, [64, 128, 320, 512], [3, 4, 6, 3])
148 |
149 | class MiTB3(MiT):
150 | def __init__(self, in_channels: int = 3):
151 | super().__init__(in_channels, [64, 128, 320, 512], [3, 4, 18, 3])
152 |
153 | class MiTB4(MiT):
154 | def __init__(self, in_channels: int = 3):
155 | super().__init__(in_channels, [64, 128, 320, 512], [3, 8, 27, 3])
156 |
157 | class MiTB5(MiT):
158 | def __init__(self, in_channels: int = 3):
159 | super().__init__(in_channels, [64, 128, 320, 512], [3, 6, 40, 3])
160 |
161 | if __name__ == '__main__':
162 | inputs = torch.randn(1, 3, 128, 128)
163 | for i in range(6):
164 | # model = MiT(in_channels=3)
165 | model = eval(f'MiTB{i}')(in_channels=3)
166 | outputs = model(inputs)
167 | print(outputs.shape)
--------------------------------------------------------------------------------
/fusionlab/encoders/resnetv1/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/encoders/resnetv1/__init__.py
--------------------------------------------------------------------------------
/fusionlab/encoders/resnetv1/tfresnetv1.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.keras import Model, Sequential, layers
3 |
4 | # ResNet50
5 | # Ref:
6 | # https://github.com/keras-team/keras-applications/blob/master/keras_applications/resnet50.py
7 | # https://github.com/raghakot/keras-resnet/blob/master/README.md
8 |
9 |
10 | class Identity(layers.Layer):
11 | def __init__(self):
12 | super(Identity, self).__init__()
13 |
14 | def call(self, inputs, training=None):
15 | return inputs
16 |
17 |
18 | class ConvBlock(Model):
19 | def __init__(self, cout, kernel_size=3, stride=1, activation=True, padding=True):
20 | super().__init__()
21 | self.conv = layers.Conv2D(cout, kernel_size, stride,
22 | padding='same' if padding else 'valid')
23 | self.bn = layers.BatchNormalization()
24 | self.act = layers.ReLU() if activation else Identity()
25 |
26 | def call(self, inputs, training=None):
27 | x = self.conv(inputs)
28 | x = self.bn(x, training)
29 | x = self.act(x)
30 | return x
31 |
32 |
33 | class Bottleneck(Model):
34 | def __init__(self, dims, kernel_size=3, stride=None):
35 | super().__init__()
36 | dim1, dim2, dim3 = dims
37 | self.conv1 = ConvBlock(dim1, kernel_size=1)
38 | self.conv2 = ConvBlock(dim2, kernel_size=kernel_size,
39 | stride=stride if stride else 1)
40 | self.conv3 = ConvBlock(dim3, kernel_size=1, activation=False)
41 | self.act = layers.ReLU()
42 | self.skip = Identity() if not stride else ConvBlock(dim3,
43 | kernel_size=1,
44 | stride=stride,
45 | activation=False)
46 |
47 | def call(self, inputs, training=None):
48 | identity = self.skip(inputs, training)
49 |
50 | x = self.conv1(inputs, training)
51 | x = self.conv2(x, training)
52 | x = self.conv3(x, training)
53 |
54 | x += identity
55 | x = self.act(x)
56 | return x
57 |
58 |
59 | class TFResNet50V1(Model):
60 | def __init__(self):
61 | super(TFResNet50V1, self).__init__()
62 | self.conv1 = Sequential([
63 | ConvBlock(64, 7, stride=2),
64 | layers.MaxPool2D(3, strides=2, padding='same'),
65 | ])
66 | self.conv2 = Sequential([
67 | Bottleneck([64, 64, 256], 3, stride=1),
68 | Bottleneck([64, 64, 256], 3),
69 | Bottleneck([64, 64, 256], 3),
70 | ])
71 | self.conv3 = Sequential([
72 | Bottleneck([128, 128, 512], 3, stride=2),
73 | Bottleneck([128, 128, 512], 3),
74 | Bottleneck([128, 128, 512], 3),
75 | Bottleneck([128, 128, 512], 3),
76 | ])
77 | self.conv4 = Sequential([
78 | Bottleneck([256, 256, 1024], 3, stride=2),
79 | Bottleneck([256, 256, 1024], 3),
80 | Bottleneck([256, 256, 1024], 3),
81 | Bottleneck([256, 256, 1024], 3),
82 | Bottleneck([256, 256, 1024], 3),
83 | Bottleneck([256, 256, 1024], 3),
84 | ])
85 | self.conv5 = Sequential([
86 | Bottleneck([512, 512, 2048], 3, stride=2),
87 | Bottleneck([512, 512, 2048], 3),
88 | Bottleneck([512, 512, 2048], 3),
89 | ])
90 |
91 | def call(self, inputs, training=None):
92 | x = self.conv1(inputs, training)
93 | x = self.conv2(x, training)
94 | x = self.conv3(x, training)
95 | x = self.conv4(x, training)
96 | x = self.conv5(x, training)
97 | return x
98 |
99 |
100 | if __name__ == '__main__':
101 | inputs = tf.random.normal((1, 224, 224, 128))
102 | output = Bottleneck([64, 64, 128])(inputs)
103 | shape = output.shape
104 | print("Bottleneck", shape)
105 | assert shape == (1, 224, 224, 128)
106 |
107 | output = Bottleneck([128, 128, 256], stride=1)(inputs)
108 | shape = output.shape
109 | print("Bottleneck first conv for aligh dims", shape)
110 | assert shape == (1, 224, 224, 256)
111 |
112 | output = Bottleneck([64, 64, 128], stride=2)(inputs)
113 | shape = output.shape
114 | print("Bottleneck downsample", shape)
115 | assert shape == (1, 112, 112, 128)
116 |
117 | output = Identity()(inputs)
118 | shape = output.shape
119 | print("Identity", shape)
120 | assert shape == (1, 224, 224, 128)
121 |
122 |
123 | output = TFResNet50V1()(inputs)
124 | shape = output.shape
125 | print("TFResNet50V1", shape)
126 | assert shape == (1, 7, 7, 2048)
127 |
128 |
129 |
130 |
131 |
--------------------------------------------------------------------------------
/fusionlab/encoders/vgg/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/encoders/vgg/__init__.py
--------------------------------------------------------------------------------
/fusionlab/encoders/vgg/tfvgg.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | # Official pytorch ref: https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py
4 | class TFVGG16(tf.keras.Model):
5 | def __init__(self):
6 | super().__init__()
7 | ksize = 3
8 | self.features = tf.keras.Sequential([
9 | tf.keras.layers.Conv2D(64, ksize, padding='same'),
10 | tf.keras.layers.ReLU(),
11 | tf.keras.layers.Conv2D(64, ksize, padding='same'),
12 | tf.keras.layers.ReLU(),
13 | tf.keras.layers.MaxPool2D(),
14 |
15 | tf.keras.layers.Conv2D(128, ksize, padding='same'),
16 | tf.keras.layers.ReLU(),
17 | tf.keras.layers.Conv2D(128, ksize, padding='same'),
18 | tf.keras.layers.ReLU(),
19 | tf.keras.layers.MaxPool2D(),
20 |
21 | tf.keras.layers.Conv2D(256, ksize, padding='same'),
22 | tf.keras.layers.ReLU(),
23 | tf.keras.layers.Conv2D(256, ksize, padding='same'),
24 | tf.keras.layers.ReLU(),
25 | tf.keras.layers.Conv2D(256, ksize, padding='same'),
26 | tf.keras.layers.ReLU(),
27 | tf.keras.layers.MaxPool2D(),
28 |
29 | tf.keras.layers.Conv2D(512, ksize, padding='same'),
30 | tf.keras.layers.ReLU(),
31 | tf.keras.layers.Conv2D(512, ksize, padding='same'),
32 | tf.keras.layers.ReLU(),
33 | tf.keras.layers.Conv2D(512, ksize, padding='same'),
34 | tf.keras.layers.ReLU(),
35 | tf.keras.layers.MaxPool2D(),
36 |
37 | tf.keras.layers.Conv2D(512, ksize, padding='same'),
38 | tf.keras.layers.ReLU(),
39 | tf.keras.layers.Conv2D(512, ksize, padding='same'),
40 | tf.keras.layers.ReLU(),
41 | tf.keras.layers.Conv2D(512, ksize, padding='same'),
42 | tf.keras.layers.ReLU(),
43 | tf.keras.layers.MaxPool2D(),
44 | ])
45 |
46 | def call(self, inputs):
47 | return self.features(inputs)
48 |
49 |
50 | class TFVGG19(tf.keras.Model):
51 | def __init__(self):
52 | super().__init__()
53 | ksize = 3
54 | self.features = tf.keras.Sequential([
55 | tf.keras.layers.Conv2D(64, ksize, padding='same'),
56 | tf.keras.layers.ReLU(),
57 | tf.keras.layers.Conv2D(64, ksize, padding='same'),
58 | tf.keras.layers.ReLU(),
59 | tf.keras.layers.MaxPool2D(),
60 |
61 | tf.keras.layers.Conv2D(128, ksize, padding='same'),
62 | tf.keras.layers.ReLU(),
63 | tf.keras.layers.Conv2D(128, ksize, padding='same'),
64 | tf.keras.layers.ReLU(),
65 | tf.keras.layers.MaxPool2D(),
66 |
67 | tf.keras.layers.Conv2D(256, ksize, padding='same'),
68 | tf.keras.layers.ReLU(),
69 | tf.keras.layers.Conv2D(256, ksize, padding='same'),
70 | tf.keras.layers.ReLU(),
71 | tf.keras.layers.Conv2D(256, ksize, padding='same'),
72 | tf.keras.layers.ReLU(),
73 | tf.keras.layers.Conv2D(256, ksize, padding='same'),
74 | tf.keras.layers.ReLU(),
75 | tf.keras.layers.MaxPool2D(),
76 |
77 | tf.keras.layers.Conv2D(512, ksize, padding='same'),
78 | tf.keras.layers.ReLU(),
79 | tf.keras.layers.Conv2D(512, ksize, padding='same'),
80 | tf.keras.layers.ReLU(),
81 | tf.keras.layers.Conv2D(512, ksize, padding='same'),
82 | tf.keras.layers.ReLU(),
83 | tf.keras.layers.Conv2D(512, ksize, padding='same'),
84 | tf.keras.layers.ReLU(),
85 | tf.keras.layers.MaxPool2D(),
86 |
87 | tf.keras.layers.Conv2D(512, ksize, padding='same'),
88 | tf.keras.layers.ReLU(),
89 | tf.keras.layers.Conv2D(512, ksize, padding='same'),
90 | tf.keras.layers.ReLU(),
91 | tf.keras.layers.Conv2D(512, ksize, padding='same'),
92 | tf.keras.layers.ReLU(),
93 | tf.keras.layers.Conv2D(512, ksize, padding='same'),
94 | tf.keras.layers.ReLU(),
95 | tf.keras.layers.MaxPool2D(),
96 | ])
97 |
98 | def call(self, inputs):
99 | return self.features(inputs)
100 |
101 |
102 | if __name__ == '__main__':
103 | # VGG16
104 | inputs = tf.random.normal((1, 224, 224, 3))
105 | output = TFVGG16()(inputs)
106 | shape = output.shape
107 | assert shape[1:3] == [7, 7]
108 |
109 | # VGG19
110 | inputs = tf.random.normal((1, 224, 224, 3))
111 | output = TFVGG19()(inputs)
112 | shape = output.shape
113 | assert shape[1:3] == [7, 7]
--------------------------------------------------------------------------------
/fusionlab/encoders/vgg/vgg.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from fusionlab.layers import ConvND, MaxPool
4 |
5 | # Official pytorch ref: https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py
6 | class VGG16(nn.Module):
7 | def __init__(self, cin=3, spatial_dims=2):
8 | super().__init__()
9 | ksize = 3
10 | self.features = nn.Sequential(
11 | ConvND(spatial_dims, cin, 64, ksize, padding=1),
12 | nn.ReLU(inplace=True),
13 | ConvND(spatial_dims, 64, 64, ksize, padding=1),
14 | nn.ReLU(inplace=True),
15 | MaxPool(spatial_dims, kernel_size=2, stride=2),
16 |
17 | ConvND(spatial_dims, 64, 128, ksize, padding=1),
18 | nn.ReLU(inplace=True),
19 | ConvND(spatial_dims, 128, 128, ksize, padding=1),
20 | nn.ReLU(inplace=True),
21 | MaxPool(spatial_dims, kernel_size=2, stride=2),
22 |
23 | ConvND(spatial_dims, 128, 256, ksize, padding=1),
24 | nn.ReLU(inplace=True),
25 | ConvND(spatial_dims, 256, 256, ksize, padding=1),
26 | nn.ReLU(inplace=True),
27 | ConvND(spatial_dims, 256, 256, ksize, padding=1),
28 | nn.ReLU(inplace=True),
29 | MaxPool(spatial_dims, kernel_size=2, stride=2),
30 |
31 | ConvND(spatial_dims, 256, 512, ksize, padding=1),
32 | nn.ReLU(inplace=True),
33 | ConvND(spatial_dims, 512, 512, ksize, padding=1),
34 | nn.ReLU(inplace=True),
35 | ConvND(spatial_dims, 512, 512, ksize, padding=1),
36 | nn.ReLU(inplace=True),
37 | MaxPool(spatial_dims, kernel_size=2, stride=2),
38 |
39 | ConvND(spatial_dims, 512, 512, ksize, padding=1),
40 | nn.ReLU(inplace=True),
41 | ConvND(spatial_dims, 512, 512, ksize, padding=1),
42 | nn.ReLU(inplace=True),
43 | ConvND(spatial_dims, 512, 512, ksize, padding=1),
44 | nn.ReLU(inplace=True),
45 | MaxPool(spatial_dims, kernel_size=2, stride=2),
46 | )
47 |
48 | def forward(self, x):
49 | return self.features(x)
50 |
51 |
52 | class VGG19(nn.Module):
53 | def __init__(self, cin=3, spatial_dims=2):
54 | super().__init__()
55 | ksize = 3
56 | self.features = nn.Sequential(
57 | ConvND(spatial_dims, cin, 64, ksize, padding=1),
58 | nn.ReLU(inplace=True),
59 | ConvND(spatial_dims, 64, 64, ksize, padding=1),
60 | nn.ReLU(inplace=True),
61 | MaxPool(spatial_dims, kernel_size=2, stride=2),
62 |
63 | ConvND(spatial_dims, 64, 128, ksize, padding=1),
64 | nn.ReLU(inplace=True),
65 | ConvND(spatial_dims, 128, 128, ksize, padding=1),
66 | nn.ReLU(inplace=True),
67 | MaxPool(spatial_dims, kernel_size=2, stride=2),
68 |
69 | ConvND(spatial_dims, 128, 256, ksize, padding=1),
70 | nn.ReLU(inplace=True),
71 | ConvND(spatial_dims, 256, 256, ksize, padding=1),
72 | nn.ReLU(inplace=True),
73 | ConvND(spatial_dims, 256, 256, ksize, padding=1),
74 | nn.ReLU(inplace=True),
75 | ConvND(spatial_dims, 256, 256, ksize, padding=1),
76 | nn.ReLU(inplace=True),
77 | MaxPool(spatial_dims, kernel_size=2, stride=2),
78 |
79 | ConvND(spatial_dims, 256, 512, ksize, padding=1),
80 | nn.ReLU(inplace=True),
81 | ConvND(spatial_dims, 512, 512, ksize, padding=1),
82 | nn.ReLU(inplace=True),
83 | ConvND(spatial_dims, 512, 512, ksize, padding=1),
84 | nn.ReLU(inplace=True),
85 | ConvND(spatial_dims, 512, 512, ksize, padding=1),
86 | nn.ReLU(inplace=True),
87 | MaxPool(spatial_dims, kernel_size=2, stride=2),
88 |
89 | ConvND(spatial_dims, 512, 512, ksize, padding=1),
90 | nn.ReLU(inplace=True),
91 | ConvND(spatial_dims, 512, 512, ksize, padding=1),
92 | nn.ReLU(inplace=True),
93 | ConvND(spatial_dims, 512, 512, ksize, padding=1),
94 | nn.ReLU(inplace=True),
95 | ConvND(spatial_dims, 512, 512, ksize, padding=1),
96 | nn.ReLU(inplace=True),
97 | MaxPool(spatial_dims, kernel_size=2, stride=2),
98 | )
99 |
100 | def forward(self, x):
101 | return self.features(x)
102 |
103 |
104 |
105 |
106 | if __name__ == '__main__':
107 | # VGG16
108 | inputs = torch.normal(0, 1, (1, 3, 224))
109 | output = VGG16(cin=3, spatial_dims=1)(inputs)
110 | shape = list(output.shape)
111 | assert shape[2:] == [7]
112 |
113 | # VGG19
114 | inputs = torch.normal(0, 1, (1, 3, 224))
115 | output = VGG19(cin=3, spatial_dims=1)(inputs)
116 | shape = list(output.shape)
117 | assert shape[2:] == [7]
118 |
119 | # VGG16
120 | inputs = torch.normal(0, 1, (1, 3, 224, 224))
121 | output = VGG16(cin=3, spatial_dims=2)(inputs)
122 | shape = list(output.shape)
123 | assert shape[2:] == [7, 7]
124 |
125 | # VGG19
126 | inputs = torch.normal(0, 1, (1, 3, 224, 224))
127 | output = VGG19(cin=3, spatial_dims=2)(inputs)
128 | shape = list(output.shape)
129 | assert shape[2:] == [7, 7]
--------------------------------------------------------------------------------
/fusionlab/encoders/vit/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/encoders/vit/__init__.py
--------------------------------------------------------------------------------
/fusionlab/functional/__init__.py:
--------------------------------------------------------------------------------
1 | from fusionlab import BACKEND
2 | if BACKEND['torch']:
3 | from .dice import dice_score
4 | from .iou import (
5 | iou_score,
6 | jaccard_score,
7 | )
8 | if BACKEND['tf']:
9 | from .tfdice import tf_dice_score
10 | from .tfiou import (
11 | tf_iou_score,
12 | tf_jaccard_score,
13 | )
--------------------------------------------------------------------------------
/fusionlab/functional/dice.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | import torch
3 | from fusionlab.configs import EPS
4 |
5 |
6 | def dice_score(pred: torch.Tensor,
7 | target: torch.Tensor,
8 | dims: Tuple[int, ...]=None) -> torch.Tensor:
9 | """
10 | Computes the dice score
11 |
12 | Args:
13 | pred: (N, C, *)
14 | target: (N, C, *)
15 | dims: dimensions to sum over
16 |
17 | """
18 | assert pred.size() == target.size()
19 | intersection = torch.sum(pred * target, dim=dims)
20 | cardinality = torch.sum(pred + target, dim=dims)
21 | return (2.0 * intersection) / cardinality.clamp(min=EPS)
--------------------------------------------------------------------------------
/fusionlab/functional/iou.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from fusionlab.configs import EPS
3 |
4 |
5 | def iou_score(pred, target, dims=None):
6 | """
7 | Shape:
8 | - pred: :math:`(N, C, *)`
9 | - target: :math:`(N, C, *)`
10 | - Output: scalar.
11 | """
12 | assert pred.size() == target.size()
13 |
14 | intersection = torch.sum(pred * target, dim=dims)
15 | cardinality = torch.sum(pred + target, dim=dims)
16 | union = cardinality - intersection
17 | iou = intersection / union.clamp_min(EPS)
18 | return iou
19 |
20 | jaccard_score = iou_score
--------------------------------------------------------------------------------
/fusionlab/functional/tfdice.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import tensorflow as tf
3 | from fusionlab.configs import EPS
4 |
5 | def tf_dice_score(pred, target, axis=None):
6 | """
7 | Shape:
8 | - pred: :math:`(N, *, C)` where :math:`*` means any number of additional dimensions
9 | - target: :math:`(N, *, C)`, same shape as the input
10 | - Output: scalar.
11 | """
12 | intersection = tf.reduce_sum(pred * target, axis=axis)
13 | cardinality = tf.reduce_sum(pred + target, axis=axis)
14 | cardinality = tf.clip_by_value(cardinality,
15 | clip_value_min=EPS,
16 | clip_value_max=cardinality.dtype.max)
17 | return (2.0 * intersection) / cardinality
18 |
--------------------------------------------------------------------------------
/fusionlab/functional/tfiou.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import tensorflow as tf
3 | from fusionlab.configs import EPS
4 |
5 | def tf_iou_score(pred, target, axis=None):
6 | """
7 | Shape:
8 | - pred: :math:`(N, *, C)` where :math:`*` means any number of additional dimensions
9 | - target: :math:`(N, *, C)`, same shape as the input
10 | - Output: scalar.
11 | """
12 | intersection = tf.reduce_sum(pred * target, axis=axis)
13 | cardinality = tf.reduce_sum(pred + target, axis=axis)
14 | cardinality = tf.clip_by_value(cardinality,
15 | clip_value_min=EPS,
16 | clip_value_max=cardinality.dtype.max)
17 | union = cardinality - intersection
18 | return intersection / union
19 |
20 | tf_jaccard_score = tf_iou_score
21 |
--------------------------------------------------------------------------------
/fusionlab/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from fusionlab import BACKEND
2 | if BACKEND['torch']:
3 | from .factories import (
4 | ConvND,
5 | ConvT,
6 | Upsample,
7 | BatchNorm,
8 | InstanceNorm,
9 | MaxPool,
10 | AvgPool,
11 | AdaptiveMaxPool,
12 | AdaptiveAvgPool,
13 | ReplicationPad,
14 | ConstantPad
15 | )
16 | from .squeeze_excitation.se import SEModule
17 | from .base import (
18 | ConvNormAct,
19 | Rearrange,
20 | DropPath,
21 | )
22 | from .patch_embed.patch_embedding import PatchEmbedding
23 | from .selfattention.selfattention import (
24 | SelfAttention,
25 | SRAttention,
26 | )
27 | if BACKEND['tf']:
28 | from .squeeze_excitation.tfse import TFSEModule
--------------------------------------------------------------------------------
/fusionlab/layers/base.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from typing import Union, Sequence, Optional, Callable
3 | from einops import rearrange
4 | from torchvision.ops import StochasticDepth
5 |
6 | from fusionlab.layers import ConvND, BatchNorm
7 | from fusionlab.utils import make_ntuple
8 |
9 | class ConvNormAct(nn.Module):
10 | '''
11 | ref:
12 | https://pytorch.org/vision/main/generated/torchvision.ops.Conv2dNormActivation.html
13 | https://github.com/pytorch/vision/blob/main/torchvision/ops/misc.py#L68
14 |
15 | Convolution + Normalization + Activation
16 |
17 | Args:
18 | spatial_dims (int): number of spatial dimensions of the input image.
19 | in_channels (int): number of channels of the input image.
20 | out_channels (int): number of channels of the output image.
21 | kernel_size (Union[Sequence[int], int]): size of the convolving kernel.
22 | stride (Union[Sequence[int], int], optional): stride of the convolution. Default: 1
23 | padding (Union[Sequence[int], str], optional): Padding added to all four sides of the input. Default: None,
24 | in which case it will be calculated as padding = (kernel_size - 1) // 2 * dilation
25 | dilation (Union[Sequence[int], int], optional): spacing between kernel elements. Default: 1
26 | groups (int, optional): number of blocked connections from input channels to output channels. Default: 1
27 | bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if norm_layer is None.
28 | norm_layer (Optional[Callable[..., nn.Module]], optional): normalization layer. Default: BatchNorm
29 | act_layer (Optional[Callable[..., nn.Module]], optional): activation layer. Default: nn.ReLU
30 | padding_mode (str, optional): mode of padding. Default: 'zeros'
31 | inplace (Optional[bool], optional): Parameter for the activation layer,
32 | which can optionally do the operation in-place. Default True
33 |
34 | '''
35 | def __init__(self,
36 | spatial_dims: int,
37 | in_channels: int,
38 | out_channels: int,
39 | kernel_size: Union[Sequence[int], int],
40 | stride: Union[Sequence[int], int] = 1,
41 | padding: Union[Sequence[int], str] = None,
42 | dilation: Union[Sequence[int], int] = 1,
43 | groups: int = 1,
44 | bias: Optional[bool] = None,
45 | norm_layer: Optional[Callable[..., nn.Module]] = BatchNorm,
46 | act_layer: Optional[Callable[..., nn.Module]] = nn.ReLU,
47 | padding_mode: str = 'zeros',
48 | inplace: Optional[bool] = bool,
49 | ):
50 | super().__init__()
51 | # padding
52 | if padding is None:
53 | if isinstance(kernel_size, int) and isinstance(dilation, int):
54 | padding = (kernel_size - 1) // 2 * dilation
55 | else:
56 | _conv_dim = spatial_dims
57 | kernel_size = make_ntuple(kernel_size, _conv_dim)
58 | dilation = make_ntuple(dilation, _conv_dim)
59 | padding = tuple((kernel_size[i] - 1) // 2 * dilation[i] for i in range(_conv_dim))
60 | # bias
61 | if bias is None:
62 | bias = norm_layer is None
63 |
64 | self.conv = ConvND(
65 | spatial_dims,
66 | in_channels,
67 | out_channels,
68 | kernel_size,
69 | stride,
70 | padding,
71 | dilation,
72 | groups,
73 | bias,
74 | padding_mode
75 | )
76 | self.norm = norm_layer(spatial_dims, out_channels)
77 | params = {} if inplace is None else {"inplace": inplace}
78 | if act_layer is None:
79 | act_layer = nn.Identity
80 | self.act = act_layer(**params)
81 |
82 | def forward(self, x):
83 | x = self.conv(x)
84 | x = self.norm(x)
85 | x = self.act(x)
86 | return x
87 |
88 | class Rearrange(nn.Module):
89 | '''
90 | nn.Module wrapper for eion's rearrange function
91 | '''
92 |
93 | def __init__(self, pattern: str, **kwargs):
94 | super().__init__()
95 | self.pattern = pattern
96 | self.kwargs = kwargs
97 |
98 | def forward(self, x):
99 | return rearrange(x, self.pattern, **self.kwargs)
100 |
101 | DropPath = StochasticDepth
102 |
103 | if __name__ == '__main__':
104 | import torch
105 | inputs = torch.randn(1, 3, 128, 128)
106 | l = Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=16, p2=16)
107 | outputs = l(inputs)
108 | print(outputs.shape)
109 |
110 | inputs = torch.randn(1, 3, 128, 128)
111 | l = Rearrange('b c h w -> b (c h w)')
112 | outputs = l(inputs)
113 | print(outputs.shape)
--------------------------------------------------------------------------------
/fusionlab/layers/patch_embed/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/layers/patch_embed/__init__.py
--------------------------------------------------------------------------------
/fusionlab/layers/patch_embed/patch_embedding.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence, Union
2 | import torch
3 | import torch.nn as nn
4 | import numpy as np
5 |
6 | from fusionlab.layers import Rearrange, ConvND
7 | from fusionlab.utils import make_ntuple, trunc_normal_
8 |
9 | EMBEDDING_TYPES = ["conv", "fc"]
10 |
11 | class PatchEmbedding(nn.Module):
12 | """
13 | A patch embedding block, based on: "Dosovitskiy et al.,
14 | An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale "
15 |
16 | """
17 |
18 | def __init__(
19 | self,
20 | in_channels: int,
21 | img_size: Union[int, Sequence[int]],
22 | patch_size: Union[int, Sequence[int]],
23 | hidden_size: int,
24 | pos_embed_type: str = 'conv',
25 | dropout_rate: float = 0.0,
26 | spatial_dims: int = 2,
27 | ) -> None:
28 | """
29 | Args:
30 | in_channels: dimension of input channels.
31 | img_size: dimension of input image.
32 | patch_size: dimension of patch size.
33 | hidden_size: dimension of hidden layer.
34 | num_heads: number of attention heads.
35 | pos_embed_type: position embedding layer type.
36 | dropout_rate: faction of the input units to drop.
37 | spatial_dims: number of spatial dimensions.
38 | """
39 |
40 | super().__init__()
41 | assert pos_embed_type in EMBEDDING_TYPES, f"pos_embed_type must be in {EMBEDDING_TYPES}"
42 | self.pos_embed_type = pos_embed_type
43 |
44 | img_sizes = make_ntuple(img_size, spatial_dims)
45 | patch_sizes = make_ntuple(patch_size, spatial_dims)
46 | for m, p in zip(img_sizes, patch_sizes):
47 | if self.pos_embed_type == "fc" and m % p != 0:
48 | raise ValueError("patch_size should be divisible by img_size for fc embedding type.")
49 | self.n_patches = np.prod([im_d // p_d for im_d, p_d in zip(img_sizes, patch_sizes)])
50 | self.patch_dim = int(in_channels * np.prod(patch_sizes))
51 |
52 | self.patch_embeddings: nn.Module
53 | if self.pos_embed_type == "conv":
54 | self.patch_embeddings = nn.Sequential(
55 | ConvND(
56 | spatial_dims,
57 | in_channels,
58 | hidden_size,
59 | kernel_size=patch_size,
60 | stride=patch_size
61 | ),
62 | nn.Flatten(2),
63 | Rearrange('b d n -> b n d'),
64 | )
65 | # self.patch_embeddings = Conv[Conv.CONV, spatial_dims](
66 | # in_channels=in_channels, out_channels=hidden_size, kernel_size=patch_size, stride=patch_size
67 | # )
68 | elif self.pos_embed_type == "fc":
69 | # for 3d: "b c (h p1) (w p2) (d p3) -> b (h w d) (p1 p2 p3 c)"
70 | chars = (("h", "p1"), ("w", "p2"), ("d", "p3"))[:spatial_dims]
71 | from_chars = "b c " + " ".join(f"({k} {v})" for k, v in chars)
72 | to_chars = f"b ({' '.join([c[0] for c in chars])}) ({' '.join([c[1] for c in chars])} c)"
73 | axes_len = {f"p{i+1}": p for i, p in enumerate(patch_sizes)}
74 | self.patch_embeddings = nn.Sequential(
75 | Rearrange(f"{from_chars} -> {to_chars}", **axes_len),
76 | nn.Linear(self.patch_dim, hidden_size),
77 | )
78 | self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size))
79 | self.dropout = nn.Dropout(dropout_rate)
80 | trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0)
81 | self.apply(self._init_weights)
82 |
83 | def forward(self, x):
84 | x = self.patch_embeddings(x)
85 | # if self.pos_embed_type == "conv":
86 | # x = x.flatten(2).transpose(-1, -2) # (b c w h) -> (b c wh) -> (b wh c)
87 | embeddings = x + self.position_embeddings
88 | embeddings = self.dropout(embeddings)
89 | return embeddings
90 |
91 | def _init_weights(self, m):
92 | if isinstance(m, nn.Linear):
93 | trunc_normal_(m.weight, mean=0.0, std=0.02, a=-2.0, b=2.0)
94 | if isinstance(m, nn.Linear) and m.bias is not None:
95 | nn.init.constant_(m.bias, 0)
96 | elif isinstance(m, nn.LayerNorm):
97 | nn.init.constant_(m.bias, 0)
98 | nn.init.constant_(m.weight, 1.0)
99 |
100 |
101 | if __name__ == '__main__':
102 | # 2D
103 | inputs = torch.randn(1, 3, 224, 224)
104 | l = PatchEmbedding(3, 224, 16, 768, pos_embed_type='conv')
105 | outputs = l(inputs)
106 | print(outputs.shape)
107 |
108 | inputs = torch.randn(1, 3, 224, 224)
109 | l = PatchEmbedding(3, 224, 16, 768, pos_embed_type='fc')
110 | print(l)
111 | outputs = l(inputs)
112 | print(outputs.shape)
113 |
114 | # 1D
115 | inputs = torch.randn(1, 3, 224)
116 | l = PatchEmbedding(3, 224, 16, 768, pos_embed_type='conv', spatial_dims=1)
117 | outputs = l(inputs)
118 | print(outputs.shape)
119 |
120 | inputs = torch.randn(1, 3, 224)
121 | l = PatchEmbedding(3, 224, 16, 768, pos_embed_type='fc', spatial_dims=1)
122 | outputs = l(inputs)
123 | print(outputs.shape)
124 |
125 | # 3D
126 | inputs = torch.randn(1, 3, 112, 112, 112)
127 | l = PatchEmbedding(3, 112, 16, 768, pos_embed_type='conv', spatial_dims=3)
128 | outputs = l(inputs)
129 | print(outputs.shape)
130 |
131 | inputs = torch.randn(1, 3, 112, 112, 112)
132 | l = PatchEmbedding(3, 112, 16, 768, pos_embed_type='fc', spatial_dims=3)
133 | outputs = l(inputs)
134 | print(outputs.shape)
--------------------------------------------------------------------------------
/fusionlab/layers/selfattention/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/layers/selfattention/__init__.py
--------------------------------------------------------------------------------
/fusionlab/layers/selfattention/selfattention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from fusionlab.layers import Rearrange
4 |
5 |
6 | class SelfAttention(nn.Module):
7 | """
8 | A self-attention block, based on: "Dosovitskiy et al.,
9 | An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale "
10 |
11 | source code: https://github.com/Project-MONAI/MONAI/blob/main/monai/networks/blocks/selfattention.py#L22
12 | """
13 |
14 | def __init__(
15 | self,
16 | hidden_size: int,
17 | num_heads: int,
18 | dropout_rate: float = 0.0,
19 | qkv_bias: bool = False,
20 | save_attn: bool = False,
21 | ) -> None:
22 | """
23 | Args:
24 | hidden_size (int): dimension of hidden layer.
25 | num_heads (int): number of attention heads.
26 | dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0.
27 | qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False.
28 | save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
29 |
30 | """
31 |
32 | super().__init__()
33 |
34 | if not (0 <= dropout_rate <= 1):
35 | raise ValueError("dropout_rate should be between 0 and 1.")
36 |
37 | if hidden_size % num_heads != 0:
38 | raise ValueError("hidden size should be divisible by num_heads.")
39 |
40 | self.num_heads = num_heads
41 | self.out_proj = nn.Linear(hidden_size, hidden_size)
42 | self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias)
43 | # b: batch size, h: num_patches, l: num_heads, d: head_dim
44 | self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads)
45 | self.out_rearrange = Rearrange("b h l d -> b l (h d)")
46 | self.drop_output = nn.Dropout(dropout_rate)
47 | self.drop_weights = nn.Dropout(dropout_rate)
48 | self.head_dim = hidden_size // num_heads
49 | self.scale = self.head_dim**-0.5
50 | self.save_attn = save_attn
51 | self.att_mat = torch.Tensor()
52 |
53 | def forward(self, x):
54 | qkv = self.input_rearrange(self.qkv(x))
55 | q, k, v = qkv[0], qkv[1], qkv[2] # (b, l, h, d)
56 | att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1) # (b, l, h, h)
57 | if self.save_attn:
58 | self.att_mat = att_mat.detach()
59 |
60 | att_mat = self.drop_weights(att_mat)
61 | x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) # (b, l, h, d)
62 | x = self.out_rearrange(x)
63 | x = self.out_proj(x)
64 | x = self.drop_output(x)
65 | return x
66 |
67 | class SRAttention(nn.Module):
68 | """
69 | Spatial Reduction Attention (SR-Attention) block, based on "Wang et al.,
70 | Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions "
71 |
72 | source code: https://github.com/sithu31296/semantic-segmentation/blob/main/semseg/models/backbones/mit.py
73 | """
74 | def __init__(self, dim, head, sr_ratio):
75 | """
76 | Args:
77 | dim (int): input dimension
78 | head (int): number of attention heads
79 | sr_ratio (int): spatial reduction ratio
80 |
81 | """
82 | super().__init__()
83 | self.head = head
84 | self.sr_ratio = sr_ratio
85 | self.scale = (dim // head) ** -0.5
86 | self.q = nn.Linear(dim, dim)
87 | self.kv = nn.Linear(dim, dim*2)
88 | self.proj = nn.Linear(dim, dim)
89 |
90 | if sr_ratio > 1:
91 | self.sr = nn.Conv2d(dim, dim, sr_ratio, sr_ratio)
92 | self.norm = nn.LayerNorm(dim)
93 |
94 | def forward(self, x: torch.Tensor, H, W) -> torch.Tensor:
95 | B, N, C = x.shape
96 | q = self.q(x).reshape(B, N, self.head, C // self.head).permute(0, 2, 1, 3)
97 |
98 | if self.sr_ratio > 1:
99 | x = x.permute(0, 2, 1).reshape(B, C, H, W)
100 | x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1)
101 | x = self.norm(x)
102 |
103 | k, v = self.kv(x).reshape(B, -1, 2, self.head, C // self.head).permute(2, 0, 3, 1, 4)
104 |
105 | attn = (q @ k.transpose(-2, -1)) * self.scale
106 | attn = attn.softmax(dim=-1)
107 |
108 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
109 | x = self.proj(x)
110 | return x
--------------------------------------------------------------------------------
/fusionlab/layers/squeeze_excitation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/layers/squeeze_excitation/__init__.py
--------------------------------------------------------------------------------
/fusionlab/layers/squeeze_excitation/se.py:
--------------------------------------------------------------------------------
1 | from typing import Callable
2 | import torch
3 | from torch import Tensor
4 | import torch.nn as nn
5 | from fusionlab.layers import ConvND, AdaptiveAvgPool
6 |
7 | class SEModule(nn.Module):
8 | """
9 | source: https://github.com/pytorch/vision/blob/main/torchvision/ops/misc.py#L224
10 | This block implements the Squeeze-and-Excitation block from https://arxiv.org/abs/1709.01507 (see Fig. 1).
11 | Parameters ``activation``, and ``scale_activation`` correspond to ``delta`` and ``sigma`` in eq. 3.
12 |
13 | Args:
14 | input_channels (int): Number of channels in the input image
15 | squeeze_channels (int): Number of squeeze channels
16 | activation (Callable[..., torch.nn.Module], optional): ``delta`` activation. Default: ``torch.nn.ReLU``
17 | scale_activation (Callable[..., torch.nn.Module]): ``sigma`` activation. Default: ``torch.nn.Sigmoid``
18 | """
19 |
20 | def __init__(
21 | self,
22 | input_channels: int,
23 | squeeze_channels: int,
24 | act_layer: Callable[..., torch.nn.Module] = torch.nn.ReLU,
25 | scale_layer: Callable[..., torch.nn.Module] = torch.nn.Sigmoid,
26 | spatial_dims: int = 2,
27 | ) -> None:
28 | super().__init__()
29 | self.avgpool = AdaptiveAvgPool(spatial_dims, 1)
30 | self.fc1 = ConvND(spatial_dims, input_channels, squeeze_channels, kernel_size=1)
31 | self.fc2 = ConvND(spatial_dims, squeeze_channels, input_channels, kernel_size=1)
32 | self.act_layer = act_layer()
33 | self.scale_layer = scale_layer()
34 |
35 | def _scale(self, input: Tensor) -> Tensor:
36 | scale = self.avgpool(input)
37 | scale = self.fc1(scale)
38 | scale = self.act_layer(scale)
39 | scale = self.fc2(scale)
40 | return self.scale_layer(scale)
41 |
42 | def forward(self, input: Tensor) -> Tensor:
43 | scale = self._scale(input)
44 | return scale * input
45 |
46 |
47 | if __name__ == '__main__':
48 | print('SEModule')
49 | inputs = torch.normal(0, 1, (1, 256, 16, 16))
50 | layer = SEModule(256)
51 | outputs = layer(inputs)
52 | assert list(outputs.shape) == [1, 256, 16, 16]
53 |
54 | inputs = torch.normal(0, 1, (1, 256, 16, 16, 16))
55 | layer = SEModule(256, spatial_dims=3)
56 | outputs = layer(inputs)
57 | assert list(outputs.shape) == [1, 256, 16, 16, 16]
58 |
--------------------------------------------------------------------------------
/fusionlab/layers/squeeze_excitation/tfse.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.keras import layers, Sequential
3 |
4 |
5 | class TFSEModule(layers.Layer):
6 | def __init__(self, cin, ratio=16):
7 | super().__init__()
8 | cout = int(cin / ratio)
9 | self.gate = Sequential([
10 | layers.Conv2D(cout, kernel_size=1),
11 | layers.ReLU(),
12 | layers.Conv2D(cin, kernel_size=1),
13 | layers.Activation(tf.nn.sigmoid),
14 | ])
15 |
16 | def call(self, inputs):
17 | x = tf.reduce_mean(inputs, (1, 2), keepdims=True)
18 | x = self.gate(x)
19 | return inputs * x
20 |
21 |
22 | if __name__ == '__main__':
23 | inputs = tf.random.normal((1, 224, 224, 256), 0, 1)
24 | layer = TFSEModule(256)
25 | outputs = layer(inputs)
26 | assert list(outputs.shape) == [1, 224, 224, 256]
27 |
--------------------------------------------------------------------------------
/fusionlab/losses/README.md:
--------------------------------------------------------------------------------
1 | # Loss function
2 |
3 | | Name | PyTorch | Tensorflow |
4 | |:------------------------------|:------------|---------------|
5 | | Dice Loss | DiceLoss | TFDiceLoss |
6 | | Tversky Loss | TverskyLoss | TFTverskyLoss |
7 | | IoU Loss | IoULoss | TFIoULoss |
8 | | Dice Loss + CrossEntropy Loss | DiceCELoss | |
9 |
10 |
--------------------------------------------------------------------------------
/fusionlab/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from fusionlab import BACKEND
2 | if BACKEND['torch']:
3 | from .diceloss.dice import DiceLoss, DiceCELoss
4 | from .iouloss.iou import IoULoss
5 | from .tversky.tversky import TverskyLoss
6 | if BACKEND['tf']:
7 | from .diceloss.tfdice import TFDiceLoss, TFDiceCE
8 | from .iouloss.tfiou import TFIoULoss
9 | from .tversky.tftversky import TFTverskyLoss
--------------------------------------------------------------------------------
/fusionlab/losses/diceloss/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/losses/diceloss/__init__.py
--------------------------------------------------------------------------------
/fusionlab/losses/diceloss/dice.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from einops import rearrange
5 | from fusionlab.functional import dice_score
6 | from fusionlab.configs import EPS
7 |
8 | __all__ = ["DiceLoss", "DiceCELoss"]
9 |
10 | BINARY_MODE = "binary"
11 | MULTICLASS_MODE = "multiclass"
12 |
13 |
14 | class DiceCELoss(nn.Module):
15 | def __init__(self, w_dice=0.5, w_ce=0.5, cls_weight=None):
16 | """
17 | Dice Loss + Cross Entropy Loss
18 | Args:
19 | w_dice: weight of Dice Loss
20 | w_ce: weight of CrossEntropy loss
21 | cls_weight:
22 | """
23 | super().__init__()
24 | self.w_dice = w_dice
25 | self.w_ce = w_ce
26 | self.cls_weight = cls_weight
27 | self.dice = DiceLoss()
28 | self.ce = nn.CrossEntropyLoss(weight=cls_weight)
29 |
30 | def forward(self, y_pred, y_true):
31 | loss_dice = self.dice(y_pred, y_true)
32 | loss_ce = self.ce(y_pred, y_true)
33 | return self.w_dice * loss_dice + self.w_ce * loss_ce
34 |
35 |
36 | class DiceLoss(nn.Module):
37 | def __init__(
38 | self,
39 | mode="multiclass", # binary, multiclass
40 | log_loss=False,
41 | from_logits=True,
42 | ):
43 | """
44 | Implementation of Dice loss for image segmentation task.
45 | It supports "binary", "multiclass"
46 | ref: https://github.com/BloodAxe/pytorch-toolbelt/blob/develop/pytorch_toolbelt/losses/dice.py
47 | Args:
48 | mode: Metric mode {'binary', 'multiclass'}
49 | log_loss: If True, loss computed as `-log(dice)`; otherwise `1 - dice`
50 | from_logits: If True assumes input is raw logits
51 | """
52 | super().__init__()
53 | self.mode = mode
54 | self.from_logits = from_logits
55 | self.log_loss = log_loss
56 |
57 | def forward(self, y_pred, y_true) -> torch.Tensor:
58 | """
59 | :param y_pred: (N, C, *)
60 | :param y_true: (N, *)
61 | :return: scalar
62 | """
63 | assert y_true.size(0) == y_pred.size(0)
64 | num_classes = y_pred.size(1)
65 | dims = (0, 2) # (N, C, HW)
66 |
67 | if self.from_logits:
68 | # get [0..1] class probabilities
69 | if self.mode == MULTICLASS_MODE:
70 | y_pred = F.softmax(y_pred, dim=1)
71 | else:
72 | y_pred = torch.sigmoid(y_pred)
73 |
74 | if self.mode == BINARY_MODE:
75 | y_true = rearrange(y_true, "N ... -> N 1 (...)")
76 | y_pred = rearrange(y_pred, "N 1 ... -> N 1 (...)")
77 | elif self.mode == MULTICLASS_MODE:
78 | y_pred = rearrange(y_pred, "N C ... -> N C (...)")
79 | y_true = F.one_hot(y_true, num_classes) # (N, *) -> (N, *, C)
80 | y_true = rearrange(y_true, "N ... C -> N C (...)")
81 | else:
82 | AssertionError("Not implemented")
83 |
84 | scores = dice_score(y_pred, y_true.type_as(y_pred), dims=dims)
85 | if self.log_loss:
86 | loss = -torch.log(scores.clamp_min(EPS))
87 | else:
88 | loss = 1.0 - scores
89 | return loss.mean()
90 |
91 |
92 | if __name__ == "__main__":
93 |
94 | print("multiclass")
95 | pred = torch.tensor([[
96 | [1., 2., 3., 4.],
97 | [2., 6., 4., 4.],
98 | [9., 6., 3., 4.]
99 | ]]).view(1, 3, 4)
100 | true = torch.tensor([[2, 1, 0, 2]]).view(1, 4)
101 |
102 | dice = DiceLoss("multiclass", from_logits=True)
103 | loss = dice(pred, true)
104 |
105 | print("Binary")
106 | pred = torch.tensor([0.4, 0.2, 0.3, 0.5]).reshape(1, 1, 2, 2)
107 | true = torch.tensor([0, 1, 0, 1]).reshape(1, 2, 2)
108 | dice = DiceLoss("binary", from_logits=True)
109 | loss = dice(pred, true)
110 |
111 | print("Binary Logloss")
112 | pred = torch.tensor([0.4, 0.2, 0.3, 0.5]).reshape(1, 1, 2, 2)
113 | true = torch.tensor([0, 1, 0, 1]).reshape(1, 2, 2)
114 | dice = DiceLoss("binary", from_logits=True, log_loss=True)
115 | loss = dice(pred, true)
116 |
117 |
--------------------------------------------------------------------------------
/fusionlab/losses/diceloss/tfdice.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from einops import rearrange
3 | from fusionlab.functional.tfdice import tf_dice_score
4 |
5 | __all__ = ["TFDiceLoss", "TFDiceCE"]
6 |
7 | BINARY_MODE = "binary"
8 | MULTICLASS_MODE = "multiclass"
9 |
10 | # TODO: Test code
11 | class TFDiceCE(tf.keras.losses.Loss):
12 | def __init__(self, mode="binary", from_logits=False, w_dice=0.5, w_ce=0.5):
13 | """
14 | Dice Loss + Cross Entropy Loss
15 | Args:
16 | w_dice: weight of Dice Loss
17 | w_ce: weight of CrossEntropy loss
18 | mode: Metric mode {'binary', 'multiclass'}
19 | """
20 | super().__init__()
21 | self.w_dice = w_dice
22 | self.w_ce = w_ce
23 | self.dice = TFDiceLoss(mode, from_logits)
24 | if mode == BINARY_MODE:
25 | self.ce = tf.keras.losses.BinaryCrossentropy(from_logits)
26 | elif mode == MULTICLASS_MODE:
27 | self.ce = tf.keras.losses.SparseCategoricalCrossentropy(from_logits)
28 |
29 | def call(self, y_true, y_pred):
30 | loss_dice = self.dice(y_true, y_pred)
31 | loss_ce = self.ce(y_true, y_pred)
32 | return self.w_dice * loss_dice + self.w_ce * loss_ce
33 |
34 |
35 | class TFDiceLoss(tf.keras.losses.Loss):
36 | def __init__(
37 | self,
38 | mode="multiclass", # binary, multiclass
39 | log_loss=False,
40 | from_logits=False,
41 | ):
42 | """
43 | Implementation of Dice loss for image segmentation task.
44 | It supports "binary", "multiclass"
45 | https://github.com/BloodAxe/pytorch-toolbelt/blob/develop/pytorch_toolbelt/losses/dice.py
46 | Args:
47 | mode: Metric mode {'binary', 'multiclass'}
48 | log_loss: If True, loss computed as `-log(dice)`; otherwise `1 - dice`
49 | from_logits: If True assumes input is raw logits
50 | """
51 | super().__init__()
52 | self.mode = mode
53 | self.from_logits = from_logits
54 | self.log_loss = log_loss
55 |
56 | def call(self, y_true, y_pred):
57 | """
58 | :param y_true: (N, *)
59 | :param y_pred: (N, *, C)
60 | :return: scalar
61 | """
62 | y_true_shape = y_true.shape.as_list()
63 | y_pred_shape = y_pred.shape.as_list()
64 | assert y_true_shape[0] == y_pred_shape[0]
65 | num_classes = y_pred_shape[-1]
66 | axis = [0]
67 |
68 | if self.from_logits:
69 | # get [0..1] class probabilities
70 | if self.mode == MULTICLASS_MODE:
71 | y_pred = tf.nn.softmax(y_pred, axis=-1)
72 | else:
73 | y_pred = tf.nn.sigmoid(y_pred)
74 |
75 | if self.mode == BINARY_MODE:
76 | y_true = rearrange(y_true, "... -> (...) 1")
77 | y_pred = rearrange(y_pred, "... -> (...) 1")
78 | elif self.mode == MULTICLASS_MODE:
79 | y_true = tf.cast(y_true, tf.int32)
80 | y_true = tf.one_hot(y_true, num_classes)
81 | y_true = rearrange(y_true, "... C -> (...) C")
82 | y_pred = rearrange(y_pred, "... C -> (...) C")
83 | else:
84 | AssertionError("Not implemented")
85 |
86 | scores = tf_dice_score(y_pred, tf.cast(y_true, y_pred.dtype), axis=axis)
87 | if self.log_loss:
88 | loss = -tf.math.log(tf.clip_by_value(scores, clip_value_min=1e-7, clip_value_max=scores.dtype.max))
89 | else:
90 | loss = 1.0 - scores
91 | return tf.math.reduce_mean(loss)
92 |
93 |
94 | if __name__ == '__main__':
95 | print("Multiclass")
96 | pred = tf.convert_to_tensor([[
97 | [1., 2., 3., 4.],
98 | [2., 6., 4., 4.],
99 | [9., 6., 3., 4.]
100 | ]])
101 | pred = rearrange(pred, "N C H -> N H C")
102 | true = tf.convert_to_tensor([[2, 1, 0, 2]])
103 |
104 | dice = TFDiceLoss("multiclass", from_logits=True)
105 | loss = dice(true, pred)
106 |
107 | print("Binary")
108 | pred = tf.convert_to_tensor([0.4, 0.2, 0.3, 0.5])
109 | pred = tf.reshape(pred, [1, 2, 2, 1])
110 | true = tf.convert_to_tensor([0, 1, 0, 1])
111 | true = tf.reshape(true, [1, 2, 2])
112 | dice = TFDiceLoss("binary", from_logits=True)
113 | loss = dice(true, pred)
114 |
115 | print("Binary Log loss")
116 | pred = tf.convert_to_tensor([0.4, 0.2, 0.3, 0.5])
117 | pred = tf.reshape(pred, [1, 2, 2, 1])
118 | true = tf.convert_to_tensor([0, 1, 0, 1])
119 | true = tf.reshape(true, [1, 2, 2])
120 | dice = TFDiceLoss("binary", from_logits=True, log_loss=True)
121 | loss = dice(true, pred)
--------------------------------------------------------------------------------
/fusionlab/losses/iouloss/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/losses/iouloss/__init__.py
--------------------------------------------------------------------------------
/fusionlab/losses/iouloss/iou.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from einops import rearrange
5 | from fusionlab.functional import iou_score
6 |
7 | __all__ = ["IoULoss"]
8 |
9 | BINARY_MODE = "binary"
10 | MULTICLASS_MODE = "multiclass"
11 |
12 |
13 | class IoULoss(nn.Module):
14 | def __init__(
15 | self,
16 | mode="multiclass", # binary, multiclass
17 | log_loss=False,
18 | from_logits=True,
19 | ):
20 | """
21 | Implementation of Iou loss for image segmentation task.
22 | It supports "binary", "multiclass"
23 | Args:
24 | mode: Metric mode {'binary', 'multiclass'}
25 | log_loss: If True, loss computed as `-log(jaccard)`; otherwise `1 - jaccard`
26 | from_logits: If True assumes input is raw logits
27 | """
28 | super().__init__()
29 | self.mode = mode
30 | self.from_logits = from_logits
31 | self.log_loss = log_loss
32 |
33 | def forward(self, y_pred, y_true):
34 | """
35 | :param y_pred: (N, C, *)
36 | :param y_true: (N, *)
37 | :return: scalar
38 | """
39 | assert y_true.size(0) == y_pred.size(0)
40 | num_classes = y_pred.size(1)
41 | dims = (0, 2) # (N, C, *)
42 |
43 | if self.from_logits:
44 | # get [0..1] class probabilities
45 | if self.mode == MULTICLASS_MODE:
46 | y_pred = F.softmax(y_pred, dim=1)
47 | else:
48 | y_pred = torch.sigmoid(y_pred)
49 |
50 | if self.mode == BINARY_MODE:
51 | y_true = rearrange(y_true, "N ... -> N 1 (...)")
52 | y_pred = rearrange(y_pred, "N 1 ... -> N 1 (...)")
53 | elif self.mode == MULTICLASS_MODE:
54 | y_pred = rearrange(y_pred, "N C ... -> N C (...)")
55 | y_true = F.one_hot(y_true, num_classes) # (N, *) -> (N, *, C)
56 | y_true = rearrange(y_true, "N ... C -> N C (...)")
57 | else:
58 | AssertionError("Not implemented")
59 |
60 | scores = iou_score(y_pred, y_true.type_as(y_pred), dims=dims)
61 | if self.log_loss:
62 | loss = -torch.log(scores.clamp_min(1e-7))
63 | else:
64 | loss = 1.0 - scores
65 | return loss.mean()
66 |
67 |
68 | if __name__ == "__main__":
69 |
70 | print("multiclass")
71 | pred = torch.tensor([[
72 | [1., 2., 3., 4.],
73 | [2., 6., 4., 4.],
74 | [9., 6., 3., 4.]
75 | ]]).view(1, 3, 4)
76 | true = torch.tensor([[2, 1, 0, 2]]).view(1, 4)
77 |
78 | loss = IoULoss("multiclass", from_logits=True)(pred, true)
79 |
80 | print("Binary")
81 | pred = torch.tensor([0.4, 0.2, 0.3, 0.5]).reshape(1, 1, 2, 2)
82 | true = torch.tensor([0, 1, 0, 1]).reshape(1, 2, 2)
83 | loss = IoULoss("binary", from_logits=True)(pred, true)
84 |
85 | print("Binary Logloss")
86 | pred = torch.tensor([0.4, 0.2, 0.3, 0.5]).reshape(1, 1, 2, 2)
87 | true = torch.tensor([0, 1, 0, 1]).reshape(1, 2, 2)
88 | loss = IoULoss("binary", from_logits=True, log_loss=True)(pred, true)
89 |
90 |
--------------------------------------------------------------------------------
/fusionlab/losses/iouloss/tfiou.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from einops import rearrange
3 | from fusionlab.functional.tfiou import tf_iou_score
4 |
5 | __all__ = ["TFIoULoss"]
6 |
7 | BINARY_MODE = "binary"
8 | MULTICLASS_MODE = "multiclass"
9 |
10 |
11 | class TFIoULoss(tf.keras.losses.Loss):
12 | def __init__(
13 | self,
14 | mode="binary", # binary, multiclass
15 | log_loss=False,
16 | from_logits=False,
17 | ):
18 | """
19 | Implementation of IoU loss for image segmentation task.
20 | It supports "binary", "multiclass"
21 | Args:
22 | mode: Metric mode {'binary', 'multiclass'}
23 | log_loss: If True, loss computed as `-log(dice)`; otherwise `1 - dice`
24 | from_logits: If True assumes input is raw logits
25 | """
26 | super().__init__()
27 | self.mode = mode
28 | self.from_logits = from_logits
29 | self.log_loss = log_loss
30 |
31 | def call(self, y_true, y_pred):
32 | """
33 | :param y_true: (N, *)
34 | :param y_pred: (N, *, C)
35 | :return: scalar
36 | """
37 | y_true_shape = y_true.shape.as_list()
38 | y_pred_shape = y_pred.shape.as_list()
39 | assert y_true_shape[0] == y_pred_shape[0]
40 | num_classes = y_pred_shape[-1]
41 | axis = [0]
42 |
43 | if self.from_logits:
44 | # get [0..1] class probabilities
45 | if self.mode == MULTICLASS_MODE:
46 | y_pred = tf.nn.softmax(y_pred, axis=-1)
47 | else:
48 | y_pred = tf.nn.sigmoid(y_pred)
49 |
50 | if self.mode == BINARY_MODE:
51 | y_true = rearrange(y_true, "... -> (...) 1")
52 | y_pred = rearrange(y_pred, "... -> (...) 1")
53 | elif self.mode == MULTICLASS_MODE:
54 | y_true = tf.cast(y_true, tf.int32)
55 | y_true = tf.one_hot(y_true, num_classes)
56 | y_true = rearrange(y_true, "... C -> (...) C")
57 | y_pred = rearrange(y_pred, "... C -> (...) C")
58 | else:
59 | AssertionError("Not implemented")
60 |
61 | scores = tf_iou_score(y_pred, tf.cast(y_true, y_pred.dtype), axis=axis)
62 | scores = tf.clip_by_value(scores, clip_value_min=1e-7, clip_value_max=scores.dtype.max)
63 | if self.log_loss:
64 | loss = -tf.math.log(scores)
65 | else:
66 | loss = 1.0 - scores
67 | return tf.math.reduce_mean(loss)
68 |
69 |
70 | if __name__ == '__main__':
71 | print("Multiclass")
72 | pred = tf.convert_to_tensor([[
73 | [1., 2., 3., 4.],
74 | [2., 6., 4., 4.],
75 | [9., 6., 3., 4.]
76 | ]])
77 | pred = rearrange(pred, "N C H -> N H C")
78 | true = tf.convert_to_tensor([[2, 1, 0, 2]])
79 |
80 | loss = TFIoULoss("multiclass", from_logits=True)(true, pred)
81 |
82 | print("Binary")
83 | pred = tf.convert_to_tensor([0.4, 0.2, 0.3, 0.5])
84 | pred = tf.reshape(pred, [1, 2, 2, 1])
85 | true = tf.convert_to_tensor([0, 1, 0, 1])
86 | true = tf.reshape(true, [1, 2, 2])
87 | loss = TFIoULoss("binary", from_logits=True)(true, pred)
88 |
89 | print("Binary Log loss")
90 | pred = tf.convert_to_tensor([0.4, 0.2, 0.3, 0.5])
91 | pred = tf.reshape(pred, [1, 2, 2, 1])
92 | true = tf.convert_to_tensor([0, 1, 0, 1])
93 | true = tf.reshape(true, [1, 2, 2])
94 | loss = TFIoULoss("binary", from_logits=True, log_loss=True)(true, pred)
--------------------------------------------------------------------------------
/fusionlab/losses/tversky/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/losses/tversky/__init__.py
--------------------------------------------------------------------------------
/fusionlab/losses/tversky/tftversky.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from einops import rearrange
3 | from fusionlab.configs import EPS
4 |
5 | __all__ = ["TFTverskyLoss"]
6 |
7 | BINARY_MODE = "binary"
8 | MULTICLASS_MODE = "multiclass"
9 |
10 |
11 | class TFTverskyLoss(tf.keras.losses.Loss):
12 | def __init__(self,
13 | alpha,
14 | beta,
15 | mode="binary", # binary, multiclass
16 | log_loss=False,
17 | from_logits=False,
18 | ):
19 | """
20 | Implementation of Dice loss for image segmentation task.
21 | It supports "binary", "multiclass"
22 | ref: https://github.com/kornia/kornia/blob/master/kornia/losses/tversky.py
23 | ref: https://github.com/kevinzakka/pytorch-goodies/blob/master/losses.py
24 | Args:
25 | alpha: controls the penalty for false positives(FP).
26 | beta: controls the penalty for false negatives(FN).
27 | mode: Metric mode {'binary', 'multiclass'}
28 | log_loss: If True, loss computed as `-log(dice)`; otherwise `1 - dice`
29 | from_logits: If True assumes input is raw logits
30 | """
31 | super().__init__()
32 | self.alpha = alpha
33 | self.beta = beta
34 | self.mode = mode
35 | self.from_logits = from_logits
36 | self.log_loss = log_loss
37 |
38 | def call(self, y_true, y_pred):
39 | """
40 | :param y_true: (N, *)
41 | :param y_pred: (N, *, C)
42 | :return: scalar
43 | """
44 | y_true_shape = y_true.shape.as_list()
45 | y_pred_shape = y_pred.shape.as_list()
46 | assert y_true_shape[0] == y_pred_shape[0]
47 | num_classes = y_pred_shape[-1]
48 | axis = [0]
49 |
50 | if self.from_logits:
51 | # get [0..1] class probabilities
52 | if self.mode == MULTICLASS_MODE:
53 | y_pred = tf.nn.softmax(y_pred, axis=-1)
54 | else:
55 | y_pred = tf.nn.sigmoid(y_pred)
56 |
57 | if self.mode == BINARY_MODE:
58 | y_true = rearrange(y_true, "... -> (...) 1")
59 | y_pred = rearrange(y_pred, "... -> (...) 1")
60 | elif self.mode == MULTICLASS_MODE:
61 | y_true = tf.cast(y_true, tf.int32)
62 | y_true = tf.one_hot(y_true, num_classes)
63 | y_true = rearrange(y_true, "... C -> (...) C")
64 | y_pred = rearrange(y_pred, "... C -> (...) C")
65 | else:
66 | AssertionError("Not implemented")
67 |
68 | scores = tf_tversky_score(y_pred, tf.cast(y_true, y_pred.dtype),
69 | self.alpha,
70 | self.beta,
71 | axis=axis)
72 | if self.log_loss:
73 | loss = -tf.math.log(tf.clip_by_value(scores, clip_value_min=1e-7, clip_value_max=scores.dtype.max))
74 | else:
75 | loss = 1.0 - scores
76 | return tf.math.reduce_mean(loss)
77 |
78 |
79 | def tf_tversky_score(pred, target, alpha, beta, axis=None):
80 | """
81 | Shape:
82 | - pred: :math:`(N, *, C)` where :math:`*` means any number of additional dimensions
83 | - target: :math:`(N, *, C)`, same shape as the input
84 | - Output: scalar.
85 | """
86 | intersection = tf.reduce_sum(pred * target, axis=axis)
87 | fp = tf.reduce_sum(pred * (1. - target), axis)
88 | fn = tf.reduce_sum((1. - pred) * target, axis)
89 | denominator = intersection + alpha * fp + beta * fn
90 | denominator = tf.clip_by_value(denominator,
91 | clip_value_min=EPS,
92 | clip_value_max=denominator.dtype.max)
93 | return intersection / denominator
94 |
95 |
96 | if __name__ == "__main__":
97 | print("Multiclass")
98 | pred = tf.convert_to_tensor([[
99 | [1., 2., 3., 4.],
100 | [2., 6., 4., 4.],
101 | [9., 6., 3., 4.]
102 | ]])
103 | pred = rearrange(pred, "N C H -> N H C")
104 | true = tf.convert_to_tensor([[2, 1, 0, 2]])
105 |
106 | loss_fn = TFTverskyLoss(0.5, 0.5, "multiclass", from_logits=True)
107 | loss = loss_fn(true, pred)
108 | print(float(loss))
109 |
110 | print("Binary")
111 | pred = tf.convert_to_tensor([0.4, 0.2, 0.3, 0.5])
112 | pred = tf.reshape(pred, [1, 2, 2, 1])
113 | true = tf.convert_to_tensor([0, 1, 0, 1])
114 | true = tf.reshape(true, [1, 2, 2])
115 | loss_fn = TFTverskyLoss(0.5, 0.5, "binary", from_logits=True)
116 | loss = loss_fn(true, pred)
117 | print(float(loss))
118 |
119 | print("Binary Log loss")
120 | pred = tf.convert_to_tensor([0.4, 0.2, 0.3, 0.5])
121 | pred = tf.reshape(pred, [1, 2, 2, 1])
122 | true = tf.convert_to_tensor([0, 1, 0, 1])
123 | true = tf.reshape(true, [1, 2, 2])
124 | loss_fn = TFTverskyLoss(0.5, 0.5, "binary", from_logits=True, log_loss=True)
125 | loss = loss_fn(true, pred)
126 | print(float(loss))
127 |
--------------------------------------------------------------------------------
/fusionlab/losses/tversky/tversky.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from einops import rearrange
5 | from fusionlab.configs import EPS
6 |
7 | __all__ = ["TverskyLoss"]
8 |
9 | BINARY_MODE = "binary"
10 | MULTICLASS_MODE = "multiclass"
11 |
12 |
13 | class TverskyLoss(nn.Module):
14 | def __init__(
15 | self,
16 | alpha,
17 | beta,
18 | mode="multiclass", # binary, multiclass
19 | log_loss=False,
20 | from_logits=True,
21 | ):
22 | """
23 | Implementation of Tversky loss for image segmentation task.
24 | It supports "binary", "multiclass"
25 | ref: https://github.com/kornia/kornia/blob/master/kornia/losses/tversky.py
26 | ref: https://github.com/kevinzakka/pytorch-goodies/blob/master/losses.py
27 | Args:
28 | alpha: controls the penalty for false positives(FP).
29 | beta: controls the penalty for false negatives(FN).
30 | mode: Metric mode {'binary', 'multiclass'}
31 | log_loss: If True, loss computed as `-log(dice)`; otherwise `1 - dice`
32 | from_logits: If True assumes input is raw logits
33 | """
34 | super().__init__()
35 | self.alpha = alpha
36 | self.beta = beta
37 | self.mode = mode
38 | self.from_logits = from_logits
39 | self.log_loss = log_loss
40 |
41 | def forward(self, y_pred, y_true):
42 | """
43 | :param y_pred: (N, C, *)
44 | :param y_true: (N, *)
45 | :return: scalar
46 | """
47 | assert y_true.size(0) == y_pred.size(0)
48 | num_classes = y_pred.size(1)
49 | dims = (0, 2) # (N, C, HW)
50 |
51 | if self.from_logits:
52 | # get [0..1] class probabilities
53 | if self.mode == MULTICLASS_MODE:
54 | y_pred = F.softmax(y_pred, dim=1)
55 | else:
56 | y_pred = torch.sigmoid(y_pred)
57 |
58 | if self.mode == BINARY_MODE:
59 | y_true = rearrange(y_true, "N ... -> N 1 (...)")
60 | y_pred = rearrange(y_pred, "N 1 ... -> N 1 (...)")
61 | elif self.mode == MULTICLASS_MODE:
62 | y_pred = rearrange(y_pred, "N C ... -> N C (...)")
63 | y_true = F.one_hot(y_true, num_classes) # (N, *) -> (N, *, C)
64 | y_true = rearrange(y_true, "N ... C -> N C (...)")
65 | else:
66 | AssertionError("Not implemented")
67 |
68 | scores = tversky_score(y_pred, y_true.type_as(y_pred),
69 | self.alpha, self.beta,
70 | dims=dims)
71 | if self.log_loss:
72 | loss = -torch.log(scores.clamp_min(EPS))
73 | else:
74 | loss = 1.0 - scores
75 | return loss.mean()
76 |
77 |
78 | def tversky_score(pred, target, alpha, beta, dims):
79 | """
80 | Shape:
81 | - pred: :math:`(N, C, *)`
82 | - target: :math:`(N, C, *)`
83 | - Output: scalar.
84 | """
85 | assert pred.size() == target.size()
86 |
87 | intersection = torch.sum(pred * target, dim=dims)
88 | fp = torch.sum(pred * (1. - target), dims)
89 | fn = torch.sum((1. - pred) * target, dims)
90 |
91 | denominator = intersection + alpha * fp + beta * fn
92 | return intersection / denominator.clamp(min=EPS)
93 |
94 | if __name__ == "__main__":
95 | print("multiclass")
96 | pred = torch.tensor([[
97 | [1., 2., 3., 4.],
98 | [2., 6., 4., 4.],
99 | [9., 6., 3., 4.]
100 | ]]).unsqueeze(-1)
101 | true = torch.tensor([[2, 1, 0, 2]]).view(1, 4).unsqueeze(-1)
102 |
103 | loss_fn = TverskyLoss(0.5, 0.5, "multiclass", from_logits=True)
104 | loss = loss_fn(pred, true)
105 | print(loss.item())
106 |
107 | print("Binary")
108 | pred = torch.tensor([0.4, 0.2, 0.3, 0.5]).reshape(1, 1, 2, 2)
109 | true = torch.tensor([0, 1, 0, 1]).reshape(1, 2, 2)
110 | loss_fn = TverskyLoss(0.5, 0.5, "binary", from_logits=True)
111 | loss = loss_fn(pred, true)
112 | print(loss.item())
113 |
114 | print("Binary Logloss")
115 | pred = torch.tensor([0.4, 0.2, 0.3, 0.5]).reshape(1, 1, 2, 2)
116 | true = torch.tensor([0, 1, 0, 1]).reshape(1, 2, 2)
117 | loss_fn = TverskyLoss(0.5, 0.5, "binary", from_logits=True, log_loss=True)
118 | loss = loss_fn(pred, true)
119 | print(loss.item())
120 |
--------------------------------------------------------------------------------
/fusionlab/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from fusionlab import BACKEND
2 | if BACKEND['torch']:
3 | from .dicescore.dice import DiceScore, JaccardScore
4 | from .iouscore.iou import IoUScore
--------------------------------------------------------------------------------
/fusionlab/metrics/dicescore/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/metrics/dicescore/__init__.py
--------------------------------------------------------------------------------
/fusionlab/metrics/dicescore/dice.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from einops import rearrange
5 | from fusionlab.functional import dice_score
6 |
7 | BINARY_MODE = "binary"
8 | MULTICLASS_MODE = "multiclass"
9 |
10 | class DiceScore(nn.Module):
11 | def __init__(
12 | self,
13 | mode="multiclass", # binary, multiclass
14 | from_logits=True,
15 | reduction="none", # mean, none
16 | ):
17 | """
18 | Computer dice score for binary or multiclass input
19 |
20 | Args:
21 | mode: "binary" or "multiclass"
22 | from_logits: if True, assumes input is raw logits
23 | reduction: "mean" or "none", if "none" returns dice score for each channels, else returns mean
24 | """
25 | super().__init__()
26 | self.mode = mode
27 | self.from_logits = from_logits
28 | self.reduction = reduction
29 |
30 | def forward(self, y_pred, y_true) -> torch.Tensor:
31 | """
32 | :param y_pred: (N, C, *)
33 | :param y_true: (N, *)
34 | :return: scalar
35 | """
36 | assert y_true.size(0) == y_pred.size(0)
37 | num_classes = y_pred.size(1)
38 | dims = (0, 2) # dimensions to sum over (N, C, *)
39 |
40 | if self.from_logits:
41 | # get [0..1] class probabilities
42 | if self.mode == MULTICLASS_MODE:
43 | y_pred = F.softmax(y_pred, dim=1)
44 | else:
45 | y_pred = torch.sigmoid(y_pred)
46 |
47 | if self.mode == BINARY_MODE:
48 | y_true = rearrange(y_true, "N ... -> N 1 (...)")
49 | y_pred = rearrange(y_pred, "N 1 ... -> N 1 (...)")
50 | elif self.mode == MULTICLASS_MODE:
51 | y_pred = rearrange(y_pred, "N C ... -> N C (...)")
52 | y_true = F.one_hot(y_true, num_classes) # (N, *) -> (N, *, C)
53 | y_true = rearrange(y_true, "N ... C -> N C (...)")
54 | else:
55 | AssertionError("Not implemented")
56 |
57 | scores = dice_score(y_pred, y_true.type_as(y_pred), dims=dims)
58 | if self.reduction == "none":
59 | return scores
60 | else:
61 | return scores.mean()
62 |
63 | JaccardScore = DiceScore
--------------------------------------------------------------------------------
/fusionlab/metrics/iouscore/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/metrics/iouscore/__init__.py
--------------------------------------------------------------------------------
/fusionlab/metrics/iouscore/iou.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from einops import rearrange
5 | from fusionlab.functional import iou_score
6 |
7 | __all__ = ["IoUScore"]
8 |
9 | BINARY_MODE = "binary"
10 | MULTICLASS_MODE = "multiclass"
11 |
12 | class IoUScore(nn.Module):
13 | def __init__(
14 | self,
15 | mode="multiclass", # binary, multiclass
16 | from_logits=True,
17 | reduction="none", # mean, none
18 | ):
19 | """
20 | Implementation of Iou score for segmentation task.
21 | It supports "binary", "multiclass"
22 | Args:
23 | mode: Metric mode {'binary', 'multiclass'}
24 | from_logits: If True assumes input is raw logits
25 | reduction: "mean" or "none", if "none" returns dice score for each channels, else returns mean
26 | """
27 | super().__init__()
28 | self.mode = mode
29 | self.from_logits = from_logits
30 | self.reduction = reduction
31 |
32 | def forward(self, y_pred, y_true):
33 | """
34 | :param y_pred: (N, C, *)
35 | :param y_true: (N, *)
36 | :return: (C, ) if mode is 'multiclass' else (1, )
37 | """
38 | assert y_true.size(0) == y_pred.size(0)
39 | num_classes = y_pred.size(1)
40 | dims = (0, 2) # (N, C, *)
41 |
42 | if self.from_logits:
43 | # get [0..1] class probabilities
44 | if self.mode == MULTICLASS_MODE:
45 | y_pred = F.softmax(y_pred, dim=1)
46 | else:
47 | y_pred = torch.sigmoid(y_pred)
48 |
49 | if self.mode == BINARY_MODE:
50 | y_true = rearrange(y_true, "N ... -> N 1 (...)")
51 | y_pred = rearrange(y_pred, "N 1 ... -> N 1 (...)")
52 | elif self.mode == MULTICLASS_MODE:
53 | y_pred = rearrange(y_pred, "N C ... -> N C (...)")
54 | y_true = F.one_hot(y_true, num_classes) # (N, *) -> (N, *, C)
55 | y_true = rearrange(y_true, "N ... C -> N C (...)")
56 | else:
57 | AssertionError("Not implemented")
58 |
59 | scores = iou_score(y_pred, y_true.type_as(y_pred), dims=dims)
60 | if self.reduction == "none":
61 | return scores
62 | else:
63 | return scores.mean()
--------------------------------------------------------------------------------
/fusionlab/segmentation/README.md:
--------------------------------------------------------------------------------
1 | # Segmentation
2 |
3 | | Name | PyTorch | Tensorflow | Paper |
4 | | :-------- | :-------- | ----------- | ---------------------------------------------------------------------------------------------------------------------- |
5 | | UNet | UNet | TFUNet | [U-Net: Convolutional Networks for Biomedical Image Segmentation (2015)](https://arxiv.org/abs/1505.04597) |
6 | | ResUNet | ResUNet | TFResUNet | [Road Extraction by Deep Residual U-Net (2017)](https://arxiv.org/abs/1711.10684) |
7 | | UNet++ | Unet2plus | TFUnet2plus | [UNet++: A Nested U-Net Architecture for Medical Image Segmentation (2018)](https://arxiv.org/abs/1807.10165) |
8 | | TransUNet | TransUNet | X | [TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation (2021)](https://arxiv.org/abs/2102.04306) |
9 | | UNETR | UNETR | X | [UNETR: Transformers for 3D Medical Image Segmentation (2021)](https://arxiv.org/abs/2103.10504) |
10 | | SegFormer | SegFormer | X | [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers (2021)](https://arxiv.org/abs/2105.15203) |
11 |
--------------------------------------------------------------------------------
/fusionlab/segmentation/__init__.py:
--------------------------------------------------------------------------------
1 | from fusionlab import BACKEND
2 | if BACKEND['torch']:
3 | from .unet.unet import UNet
4 | from .resunet.resunet import ResUNet
5 | from .unet2plus.unet2plus import UNet2plus
6 | from .transunet.transunet import TransUNet
7 | from .unetr.unetr import UNETR
8 | from .segformer.segformer import SegFormer
9 | from .base import HFSegmentationModel
10 | if BACKEND['tf']:
11 | from .unet.tfunet import TFUNet
12 | from .resunet.tfresunet import TFResUNet
13 | from .unet2plus.tfunet2plus import TFUNet2plus
--------------------------------------------------------------------------------
/fusionlab/segmentation/base.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | class SegmentationModel(nn.Module):
4 | """
5 | Base PyTorch class of the segmentation model with Encoder, Bridger, Decoder, Head
6 | """
7 | def forward(self, x):
8 | features = self.encoder(x)
9 | feature_fusion = self.bridger(features)
10 | decoder_output = self.decoder(feature_fusion)
11 | output = self.head(decoder_output)
12 | return output
13 |
14 | class HFSegmentationModel(nn.Module):
15 | """
16 | Base Hugginface-pytoch model wrapper class of the segmentation model
17 | """
18 | def __init__(self, model, num_cls=None,
19 | loss_fct=nn.CrossEntropyLoss()):
20 | super().__init__()
21 | self.net = model
22 | if 'num_cls' in model.__dict__.keys():
23 | self.num_cls = model.num_cls
24 | else:
25 | self.num_cls = num_cls
26 | assert self.num_cls is not None, "num_cls is not defined"
27 | self.loss_fct = loss_fct
28 | def forward(self, x, labels=None):
29 | logits = self.net(x) # Forward pass the model
30 | if labels is not None:
31 | # logits => [BATCH, NUM_CLS]
32 | # labels => [BATCH]
33 | loss = self.loss_fct(logits, labels) # Calculate loss
34 | else:
35 | loss = None
36 | # return dictionary for hugginface trainer
37 | return {'loss':loss, 'logits':logits, 'hidden_states':None}
38 |
39 |
40 | if __name__ == '__main__':
41 | import torch
42 | from fusionlab.segmentation import ResUNet
43 | H = W = 224
44 | cout = 5
45 | inputs = torch.normal(0, 1, (1, 3, H, W))
46 |
47 | model = ResUNet(3, cout, 64)
48 | hf_model = HFSegmentationModel(model, cout)
49 | output = hf_model(inputs)
50 | assert list(output.keys()) == ['loss', 'logits', 'hidden_states']
51 | print(output['logits'].shape)
52 | assert list(output['logits'].shape) == [1, cout, H, W]
--------------------------------------------------------------------------------
/fusionlab/segmentation/resunet/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/segmentation/resunet/__init__.py
--------------------------------------------------------------------------------
/fusionlab/segmentation/resunet/resunet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from fusionlab.segmentation.base import SegmentationModel
4 | from fusionlab.utils import autopad
5 | from fusionlab.layers.factories import ConvND, ConvT, BatchNorm
6 |
7 |
8 |
9 | class ResUNet(SegmentationModel):
10 | def __init__(
11 | self,
12 | cin,
13 | num_cls,
14 | base_dim=64,
15 | spatial_dims=2
16 | ):
17 | super().__init__()
18 | self.num_cls = num_cls
19 | self.encoder = Encoder(cin, base_dim, spatial_dims)
20 | self.bridger = Bridger()
21 | self.decoder = Decoder(base_dim, spatial_dims)
22 | self.head = Head(base_dim, num_cls, spatial_dims)
23 |
24 |
25 | class Encoder(nn.Module):
26 | def __init__(self, cin, base_dim, spatial_dims=2):
27 | super().__init__()
28 | dims = [base_dim * (2 ** i) for i in range(4)]
29 | self.stem = Stem(cin, dims[0], spatial_dims)
30 | self.stage1 = ResConv(dims[0], dims[1], spatial_dims, stride=2)
31 | self.stage2 = ResConv(dims[1], dims[2], spatial_dims, stride=2)
32 | self.stage3 = ResConv(dims[2], dims[3], spatial_dims, stride=2)
33 |
34 | def forward(self, x):
35 | s0 = self.stem(x)
36 | s1 = self.stage1(s0)
37 | s2 = self.stage2(s1)
38 | s3 = self.stage3(s2)
39 | return [s0, s1, s2, s3]
40 |
41 |
42 | class Decoder(nn.Module):
43 | def __init__(self, base_dim, spatial_dims=2):
44 | """
45 | Base UNet decoder
46 | Args:
47 | base_dim (int): output dim of deepest stage output or input channels
48 | """
49 | super().__init__()
50 | dims = [base_dim*(2**i) for i in range(4)]
51 | self.d3 = DecoderBlock(dims[3], dims[2], spatial_dims)
52 | self.d2 = DecoderBlock(dims[2], dims[1], spatial_dims)
53 | self.d1 = DecoderBlock(dims[1], dims[0], spatial_dims)
54 |
55 | def forward(self, x):
56 | s0, s1, s2, s3 = x
57 |
58 | x = self.d3(s3, s2)
59 | x = self.d2(x, s1)
60 | x = self.d1(x, s0)
61 | return x
62 |
63 |
64 | class DecoderBlock(nn.Module):
65 | def __init__(self, cin, cout, spatial_dims=2):
66 | super().__init__()
67 | self.upsample = ConvT(spatial_dims, cin, cout, 2, stride=2)
68 | self.conv = ResConv(cout*2, cout, spatial_dims, 1)
69 |
70 | def forward(self, x1, x2):
71 | x1 = self.upsample(x1)
72 | x = torch.cat([x1, x2], dim=1)
73 | return self.conv(x)
74 |
75 |
76 | class Bridger(nn.Module):
77 | def __init__(self):
78 | super().__init__()
79 |
80 | def forward(self, x):
81 | outputs = [nn.Identity()(i) for i in x]
82 | return outputs
83 |
84 |
85 | class Stem(nn.Module):
86 | def __init__(self, cin, cout, spatial_dims=2):
87 | super().__init__()
88 | self.conv = nn.Sequential(
89 | ConvND(spatial_dims, cin, cout, 3, padding=autopad(3)),
90 | BatchNorm(spatial_dims, cout),
91 | nn.ReLU(),
92 | ConvND(spatial_dims, cout, cout, 3, padding=autopad(3)),
93 | )
94 | self.skip = nn.Sequential(
95 | ConvND(spatial_dims, cin, cout, 3, padding=autopad(3)),
96 | )
97 |
98 | def forward(self, x):
99 | return self.conv(x) + self.skip(x)
100 |
101 |
102 | class ResConv(nn.Module):
103 | def __init__(self, cin, cout, spatial_dims=2, stride=1):
104 | super().__init__()
105 |
106 | self.conv = nn.Sequential(
107 | BatchNorm(spatial_dims, cin),
108 | nn.ReLU(),
109 | ConvND(spatial_dims, cin, cout, 3, stride, padding=autopad(3)),
110 | BatchNorm(spatial_dims, cout),
111 | nn.ReLU(),
112 | ConvND(spatial_dims, cout, cout, 3, padding=autopad(3)),
113 | )
114 | self.skip = nn.Sequential(
115 | ConvND(spatial_dims, cin, cout, 3, stride=stride, padding=autopad(3)),
116 | BatchNorm(spatial_dims, cout),
117 | )
118 |
119 | def forward(self, x):
120 | return self.conv(x) + self.skip(x)
121 |
122 |
123 | class Head(nn.Sequential):
124 | def __init__(self, cin, cout, spatial_dims):
125 | """
126 | Basic conv head
127 | :param int cin: input channel
128 | :param int cout: output channel
129 | """
130 | conv = ConvND(spatial_dims, cin, cout, 1)
131 | super().__init__(conv)
132 |
133 | if __name__ == '__main__':
134 | H = W = 224
135 | cout = 64
136 | inputs = torch.normal(0, 1, (1, 3, H, W))
137 |
138 | model = ResUNet(3, 100, cout)
139 | output = model(inputs)
140 | print(output.shape)
141 |
142 | dblock = DecoderBlock(64, 128)
143 | inputs2 = torch.normal(0, 1, (1, 128, H, W))
144 | inputs1 = torch.normal(0, 1, (1, 64, H//2, W//2))
145 | outputs = dblock(inputs1, inputs2)
146 | print(outputs.shape)
147 |
148 | encoder = Encoder(3, cout)
149 | outputs = encoder(inputs)
150 | for o in outputs:
151 | print(o.shape)
152 |
153 | decoder = Decoder(cout)
154 | outputs = decoder(outputs)
155 | print("Encoder + Decoder ", outputs.shape)
156 |
157 | stem = Stem(3, cout)
158 | outputs = stem(inputs)
159 | print(outputs.shape)
160 | assert list(outputs.shape) == [1, cout, H, W]
161 |
162 | resconv = ResConv(3, cout, stride=1)
163 | outputs = resconv(inputs)
164 | print(outputs.shape)
165 | assert list(outputs.shape) == [1, cout, H, W]
166 |
167 | resconv = ResConv(3, cout, stride=2)
168 | outputs = resconv(inputs)
169 | print(outputs.shape)
170 | assert list(outputs.shape) == [1, cout, H//2, W//2]
171 |
172 |
173 | print("3D ResUNet")
174 | D = H = W = 64
175 | cout = 32
176 | inputs = torch.rand(1, 3, D, H, W)
177 |
178 | model = ResUNet(3, 100, cout, spatial_dims=3)
179 | output = model(inputs)
180 | print(output.shape)
181 |
182 | print("1D ResUNet")
183 | L = 64
184 | cout = 32
185 | inputs = torch.rand(1, 3, L)
186 |
187 | model = ResUNet(3, 100, cout, spatial_dims=1)
188 | output = model(inputs)
189 | print(output.shape)
190 |
--------------------------------------------------------------------------------
/fusionlab/segmentation/resunet/tfresunet.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.keras import layers, Model, Sequential
3 | from fusionlab.segmentation.tfbase import TFSegmentationModel
4 |
5 |
6 | class TFResUNet(TFSegmentationModel):
7 | def __init__(self, num_cls, base_dim):
8 | super().__init__()
9 | self.encoder = Encoder(base_dim)
10 | self.bridger = Bridger()
11 | self.decoder = Decoder(base_dim)
12 | self.head = Head(num_cls)
13 |
14 |
15 | class Encoder(Model):
16 | def __init__(self, base_dim):
17 | super().__init__()
18 | dims = [base_dim * (2 ** i) for i in range(4)]
19 | self.stem = Stem(dims[0])
20 | self.stage1 = ResConv(dims[1], stride=2)
21 | self.stage2 = ResConv(dims[2], stride=2)
22 | self.stage3 = ResConv(dims[3], stride=2)
23 |
24 | def call(self, x, training):
25 | s0 = self.stem(x, training)
26 | s1 = self.stage1(s0, training)
27 | s2 = self.stage2(s1, training)
28 | s3 = self.stage3(s2, training)
29 | return [s0, s1, s2, s3]
30 |
31 |
32 | class Decoder(Model):
33 | def __init__(self, base_dim):
34 | """
35 | Base UNet decoder
36 | Args:
37 | base_dim (int): output dim of deepest stage output or input channels
38 | """
39 | super().__init__()
40 | dims = [base_dim*(2**i) for i in range(4)]
41 | self.d3 = DecoderBlock(dims[2])
42 | self.d2 = DecoderBlock(dims[1])
43 | self.d1 = DecoderBlock(dims[0])
44 |
45 | def call(self, x, training):
46 | s0, s1, s2, s3 = x
47 |
48 | x = self.d3(s3, s2, training)
49 | x = self.d2(x, s1, training)
50 | x = self.d1(x, s0, training)
51 | return x
52 |
53 |
54 | class DecoderBlock(Model):
55 | def __init__(self, cout):
56 | super().__init__()
57 | self.upsample = layers.Conv2DTranspose(cout, 2, strides=2)
58 | self.conv = ResConv(cout, stride=1)
59 |
60 | def call(self, x1, x2, training):
61 | x1 = self.upsample(x1)
62 | x = tf.concat([x1, x2], axis=-1)
63 | return self.conv(x)
64 |
65 |
66 | class Bridger(Model):
67 | def __init__(self):
68 | super().__init__()
69 |
70 | def call(self, x, training):
71 | outputs = [tf.identity(i) for i in x]
72 | return outputs
73 |
74 |
75 | class Stem(Model):
76 | def __init__(self, cout):
77 | super().__init__()
78 | self.conv = Sequential([
79 | layers.Conv2D(cout, 3, padding='same'),
80 | layers.BatchNormalization(),
81 | layers.ReLU(),
82 | layers.Conv2D(cout, 3, padding='same'),
83 | ])
84 | self.skip = Sequential(
85 | layers.Conv2D(cout, 3, padding='same'),
86 | )
87 |
88 | def call(self, x, training):
89 | return self.conv(x) + self.skip(x)
90 |
91 |
92 | class ResConv(Model):
93 | def __init__(self, cout, stride=1):
94 | super().__init__()
95 |
96 | self.conv = Sequential([
97 | layers.BatchNormalization(),
98 | layers.ReLU(),
99 | layers.Conv2D(cout, 3, stride, padding='same'),
100 | layers.BatchNormalization(),
101 | layers.ReLU(),
102 | layers.Conv2D(cout, 3, padding='same'),
103 | ])
104 | self.skip = Sequential([
105 | layers.Conv2D(cout, 3, strides=stride, padding='same'),
106 | layers.BatchNormalization(),
107 | ])
108 |
109 | def call(self, x, training=None):
110 | return self.conv(x, training) + self.skip(x, training)
111 |
112 |
113 | class Head(Sequential):
114 | def __init__(self, cout):
115 | """
116 | Basic conv head
117 | :param int cout: number of classes
118 | """
119 | conv = layers.Conv2D(cout, 1)
120 | super().__init__(conv)
121 |
122 |
123 | if __name__ == '__main__':
124 | H = W = 224
125 | cout = 64
126 | inputs = tf.random.normal((1, H, W, 3))
127 |
128 | model = TFResUNet(100, cout)
129 | output = model(inputs, training=True)
130 | print(output.shape)
131 |
132 | dblock = DecoderBlock(128)
133 | inputs2 = tf.random.normal((1, H, W, 128))
134 | inputs1 = tf.random.normal((1, H//2, W//2, 64))
135 | outputs = dblock(inputs1, inputs2, training=True)
136 | print(outputs.shape)
137 |
138 | encoder = Encoder(cout)
139 | outputs = encoder(inputs, training=True)
140 | for o in outputs:
141 | print(o.shape)
142 |
143 | decoder = Decoder(cout)
144 | outputs = decoder(outputs, training=True)
145 | print("Encoder + Decoder ", outputs.shape)
146 |
147 | stem = Stem(cout)
148 | outputs = stem(inputs, training=True)
149 | print(outputs.shape)
150 | assert list(outputs.shape) == [1, H, W, cout]
151 |
152 | resconv = ResConv(cout, 1)
153 | outputs = resconv(inputs, training=True)
154 | print(outputs.shape)
155 | assert list(outputs.shape) == [1, H, W, cout]
156 |
157 | resconv = ResConv(cout, stride=2)
158 | outputs = resconv(inputs, training=True)
159 | print(outputs.shape)
160 | assert list(outputs.shape) == [1, H//2, W//2, cout]
--------------------------------------------------------------------------------
/fusionlab/segmentation/segformer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/segmentation/segformer/__init__.py
--------------------------------------------------------------------------------
/fusionlab/segmentation/segformer/segformer.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from fusionlab.encoders import (
6 | MiT,
7 | MiTB0,
8 | MiTB1,
9 | MiTB2,
10 | MiTB3,
11 | MiTB4,
12 | MiTB5,
13 | )
14 |
15 | class MLP(nn.Module):
16 | def __init__(self, dim, embed_dim):
17 | super().__init__()
18 | self.proj = nn.Linear(dim, embed_dim)
19 |
20 | def forward(self, x: torch.Tensor) -> torch.Tensor:
21 | x = x.flatten(2).transpose(1, 2)
22 | x = self.proj(x)
23 | return x
24 |
25 |
26 | class ConvModule(nn.Module):
27 | def __init__(self, c1, c2):
28 | super().__init__()
29 | self.conv = nn.Conv2d(c1, c2, 1, bias=False)
30 | self.bn = nn.BatchNorm2d(c2) # use SyncBN in original
31 | self.activate = nn.ReLU(True)
32 |
33 | def forward(self, x: torch.Tensor) -> torch.Tensor:
34 | return self.activate(self.bn(self.conv(x)))
35 |
36 | class SegFormerHead(nn.Module):
37 | def __init__(self, dims: list, embed_dim: int = 256, num_classes: int = 19):
38 | super().__init__()
39 | for i, dim in enumerate(dims):
40 | self.add_module(f"linear_c{i+1}", MLP(dim, embed_dim))
41 |
42 | self.linear_fuse = ConvModule(embed_dim*4, embed_dim)
43 | self.linear_pred = nn.Conv2d(embed_dim, num_classes, 1)
44 | self.dropout = nn.Dropout2d(0.1)
45 |
46 | def forward(self, features: Sequence[torch.Tensor]) -> torch.Tensor:
47 | B, _, H, W = features[0].shape
48 | outs = [self.linear_c1(features[0]).permute(0, 2, 1).reshape(B, -1, *features[0].shape[-2:])]
49 |
50 | for i, feature in enumerate(features[1:]):
51 | cf = eval(f"self.linear_c{i+2}")(feature).permute(0, 2, 1).reshape(B, -1, *feature.shape[-2:])
52 | outs.append(F.interpolate(cf, size=(H, W), mode='bilinear', align_corners=False))
53 |
54 | seg = self.linear_fuse(torch.cat(outs[::-1], dim=1))
55 | seg = self.linear_pred(self.dropout(seg))
56 | return seg
57 |
58 | class SegFormer(nn.Module):
59 | """
60 | SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers
61 |
62 | source code: https://github.com/sithu31296/semantic-segmentation/blob/main/semseg/models/segformer.py
63 |
64 | Args:
65 |
66 | num_classes (int): number of classes to segment
67 | mit_encoder_type (str): type of MiT encoder, one of ['B0', 'B1', 'B2', 'B3', 'B4', 'B5']
68 | """
69 | def __init__(
70 | self,
71 | num_classes: int = 6,
72 | mit_encoder_type: str = 'B0'
73 | ):
74 | super().__init__()
75 | self.encoder: MiT = eval(f'MiT{mit_encoder_type}')()
76 | embed_dim = self.encoder.channels[-1]
77 | self.decode_head = SegFormerHead(
78 | self.encoder.channels,
79 | embed_dim,
80 | num_classes,
81 | )
82 |
83 | def forward(self, inputs: torch.Tensor) -> torch.Tensor:
84 | _, features = self.encoder(inputs, return_features=True)
85 | x = self.decode_head(features) # 4x reduction in image size
86 | x = F.interpolate(x, size=inputs.shape[2:], mode='bilinear', align_corners=False)
87 | return x
88 |
89 | if __name__ == '__main__':
90 | model = SegFormer(num_classes=6)
91 | x = torch.randn(1, 3, 128, 128)
92 | outputs = model(x)
93 | print(outputs.shape)
94 |
95 |
--------------------------------------------------------------------------------
/fusionlab/segmentation/tfbase.py:
--------------------------------------------------------------------------------
1 | from tensorflow.keras import Model
2 |
3 | class TFSegmentationModel(Model):
4 | """
5 | Base PyTorch class of the segmentation model with Encoder, Bridger, Decoder, Head
6 | """
7 | def call(self, x, training=None):
8 | """
9 |
10 | Args:
11 | x: input tensor
12 | training: flag for BatchNormalization and Dropout, whether the layer should behave in training mode or in
13 | inference mode
14 |
15 | Returns:
16 |
17 | """
18 | features = self.encoder(x, training)
19 | feature_fusion = self.bridger(features, training)
20 | decoder_output = self.decoder(feature_fusion, training)
21 | output = self.head(decoder_output, training)
22 | return output
23 |
--------------------------------------------------------------------------------
/fusionlab/segmentation/transunet/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/segmentation/transunet/__init__.py
--------------------------------------------------------------------------------
/fusionlab/segmentation/unet/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/segmentation/unet/__init__.py
--------------------------------------------------------------------------------
/fusionlab/segmentation/unet/tfunet.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.keras import layers, Model, Sequential
3 | from fusionlab.segmentation.tfbase import TFSegmentationModel
4 |
5 |
6 | class TFUNet(TFSegmentationModel):
7 | def __init__(self, num_cls, base_dim=64):
8 | """
9 | Base Unet
10 | Args:
11 | num_cls (int): number of classes
12 | base_dim (int): 1st stage dim of conv output
13 | """
14 | super().__init__()
15 | stage = 5
16 | self.encoder = Encoder(base_dim)
17 | self.bridger = Bridger()
18 | self.decoder = Decoder(base_dim*(2**(stage-2))) # 512
19 | self.head = Head(num_cls)
20 |
21 |
22 | class Encoder(Model):
23 | def __init__(self, base_dim):
24 | """
25 | UNet Encoder
26 | Args:
27 | base_dim (int): 1st stage dim of conv output
28 | """
29 | super().__init__()
30 | self.pool = layers.MaxPool2D()
31 | self.stage1 = BasicBlock(base_dim)
32 | self.stage2 = BasicBlock(base_dim * 2)
33 | self.stage3 = BasicBlock(base_dim * 4)
34 | self.stage4 = BasicBlock(base_dim * 8)
35 | self.stage5 = BasicBlock(base_dim * 16)
36 |
37 | def call(self, x, training=None):
38 | s1 = self.stage1(x, training)
39 | x = self.pool(s1)
40 | s2 = self.stage2(x, training)
41 | x = self.pool(s2)
42 | s3 = self.stage3(x, training)
43 | x = self.pool(s3)
44 | s4 = self.stage4(x, training)
45 | x = self.pool(s4)
46 | s5 = self.stage5(x, training)
47 |
48 | return [s1, s2, s3, s4, s5]
49 |
50 |
51 | class Decoder(Model):
52 | def __init__(self, base_dim):
53 | """
54 | Base UNet decoder
55 | Args:
56 | base_dim (int): output dim of deepest stage output
57 | """
58 | super().__init__()
59 | self.d4 = DecoderBlock(base_dim)
60 | self.d3 = DecoderBlock(base_dim//2)
61 | self.d2 = DecoderBlock(base_dim//4)
62 | self.d1 = DecoderBlock(base_dim//8)
63 |
64 | def call(self, x, training=None):
65 | f1, f2, f3, f4, f5 = x
66 | x = self.d4(f5, f4, training)
67 | x = self.d3(x, f3, training)
68 | x = self.d2(x, f2, training)
69 | x = self.d1(x, f1, training)
70 | return x
71 |
72 |
73 | class Bridger(Model):
74 | def __init__(self):
75 | super().__init__()
76 |
77 | def call(self, x, training=None):
78 | outputs = [tf.identity(i) for i in x]
79 | return outputs
80 |
81 |
82 | class Head(Sequential):
83 | def __init__(self, cout):
84 | """
85 | Basic Identity
86 | :param int cout: output channel
87 | """
88 | conv = layers.Conv2D(cout, 1)
89 | super().__init__(conv)
90 |
91 |
92 | class BasicBlock(Sequential):
93 | def __init__(self, cout):
94 | conv1 = Sequential([
95 | layers.Conv2D(cout, 3, 1, padding='same'),
96 | layers.ReLU(),
97 | ])
98 | conv2 = Sequential([
99 | layers.Conv2D(cout, 3, 1, padding='same'),
100 | layers.ReLU(),
101 | ])
102 | super().__init__([conv1, conv2])
103 |
104 |
105 | class DecoderBlock(Model):
106 | def __init__(self, cout):
107 | """
108 | Base Unet decoder block for merging the outputs from 2 stages
109 | Args:
110 | cout: output dim of the block
111 | """
112 | super().__init__()
113 | self.up = layers.UpSampling2D()
114 | self.conv = BasicBlock(cout)
115 |
116 | def call(self, x1, x2, training=None):
117 | x1 = self.up(x1)
118 | x = tf.concat([x1, x2], axis=-1)
119 | x = self.conv(x, training)
120 | return x
121 |
122 |
123 | if __name__ == '__main__':
124 | H = W = 224
125 | dim = 64
126 | inputs = tf.random.normal((1, H, W, 3))
127 |
128 | encoder = Encoder(dim)
129 | encoder.build((1, H, W, 3))
130 | outputs = encoder(inputs)
131 | for i, o in enumerate(outputs):
132 | assert list(o.shape) == [1, H // (2 ** i), W // (2 ** i), dim * (2 ** i)]
133 |
134 | bridger = Bridger()
135 | outputs = bridger(outputs)
136 | for i, o in enumerate(outputs):
137 | assert list(o.shape) == [1, H // (2 ** i), W // (2 ** i), dim * (2 ** i)]
138 |
139 | features = [tf.random.normal((1, H // (2 ** i), W // (2 ** i), dim * (2 ** i))) for i in range(5)]
140 | decoder = Decoder(512)
141 | outputs = decoder(features)
142 | assert list(outputs.shape) == [1, H, W, 64]
143 |
144 | head = Head(10)
145 | outputs = head(outputs)
146 | assert list(outputs.shape) == [1, H, W, 10]
147 |
148 | unet = TFUNet(10)
149 | outputs = unet(inputs)
150 | assert list(outputs.shape) == [1, H, W, 10]
151 |
--------------------------------------------------------------------------------
/fusionlab/segmentation/unet/unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from fusionlab.segmentation.base import SegmentationModel
4 | from fusionlab.utils import autopad
5 | from fusionlab.layers import ConvND, MaxPool
6 |
7 |
8 | class UNet(SegmentationModel):
9 | def __init__(self, cin, num_cls, base_dim=64, spatial_dims=2):
10 | """
11 | Base Unet
12 | Args:
13 | cin (int): input channels
14 | num_cls (int): number of classes
15 | base_dim (int): 1st stage dim of conv output
16 | """
17 | super().__init__()
18 | stage = 5
19 | self.num_cls = num_cls
20 | self.encoder = Encoder(cin, base_dim=base_dim, spatial_dims=spatial_dims)
21 | self.bridger = Bridger()
22 | self.decoder = Decoder(cin=base_dim*(2**(stage-1)),
23 | base_dim=base_dim*(2**(stage-2)),
24 | spatial_dims=spatial_dims) # 1024, 512
25 | self.head = Head(base_dim, num_cls, spatial_dims=spatial_dims)
26 |
27 |
28 | class Encoder(nn.Module):
29 | def __init__(self, cin, base_dim, spatial_dims=2):
30 | """
31 | UNet Encoder
32 | Args:
33 | cin (int): input channels
34 | base_dim (int): 1st stage dim of conv output
35 | """
36 | super().__init__()
37 | self.pool = MaxPool(spatial_dims, 2, 2)
38 | self.stage1 = BasicBlock(cin, base_dim, spatial_dims)
39 | self.stage2 = BasicBlock(base_dim, base_dim * 2, spatial_dims)
40 | self.stage3 = BasicBlock(base_dim * 2, base_dim * 4, spatial_dims)
41 | self.stage4 = BasicBlock(base_dim * 4, base_dim * 8, spatial_dims)
42 | self.stage5 = BasicBlock(base_dim * 8, base_dim * 16, spatial_dims)
43 |
44 | def forward(self, x):
45 | s1 = self.stage1(x)
46 | x = self.pool(s1)
47 | s2 = self.stage2(x)
48 | x = self.pool(s2)
49 | s3 = self.stage3(x)
50 | x = self.pool(s3)
51 | s4 = self.stage4(x)
52 | x = self.pool(s4)
53 | s5 = self.stage5(x)
54 |
55 | return [s1, s2, s3, s4, s5]
56 |
57 |
58 | class Decoder(nn.Module):
59 | def __init__(self, cin, base_dim, spatial_dims=2):
60 | """
61 | Base UNet decoder
62 | Args:
63 | cin (int): input channels
64 | base_dim (int): output dim of deepest stage output
65 | """
66 | super().__init__()
67 | self.d4 = DecoderBlock(cin, cin//2, base_dim, spatial_dims)
68 | self.d3 = DecoderBlock(base_dim, cin//4, base_dim//2, spatial_dims)
69 | self.d2 = DecoderBlock(base_dim//2, cin//8, base_dim//4, spatial_dims)
70 | self.d1 = DecoderBlock(base_dim//4, cin//16, base_dim//8, spatial_dims)
71 |
72 | def forward(self, x):
73 | f1, f2, f3, f4, f5 = x
74 | x = self.d4(f5, f4)
75 | x = self.d3(x, f3)
76 | x = self.d2(x, f2)
77 | x = self.d1(x, f1)
78 | return x
79 |
80 |
81 | class Bridger(nn.Module):
82 | def __init__(self):
83 | super().__init__()
84 |
85 | def forward(self, x):
86 | outputs = [nn.Identity()(i) for i in x]
87 | return outputs
88 |
89 |
90 | class Head(nn.Sequential):
91 | def __init__(self, cin, cout, spatial_dims=2):
92 | """
93 | Basic conv head
94 | :param int cin: input channel
95 | :param int cout: output channel
96 | """
97 | conv = ConvND(spatial_dims, cin, cout, 1)
98 | super().__init__(conv)
99 |
100 |
101 | class BasicBlock(nn.Sequential):
102 | def __init__(self, cin, cout, spatial_dims=2):
103 | conv1 = nn.Sequential(
104 | ConvND(spatial_dims, cin, cout, 3, 1, autopad(3)),
105 | nn.ReLU(),
106 | )
107 | conv2 = nn.Sequential(
108 | ConvND(spatial_dims, cout, cout, 3, 1, autopad(3)),
109 | nn.ReLU(),
110 | )
111 | super().__init__(conv1, conv2)
112 |
113 |
114 | class DecoderBlock(nn.Module):
115 | def __init__(self, c1, c2, cout, spatial_dims=2):
116 | """
117 | Base Unet decoder block for merging the outputs from 2 stages
118 | Args:
119 | c1: input dim of the deeper stage
120 | c2: input dim of the shallower stage
121 | cout: output dim of the block
122 | """
123 | super().__init__()
124 | self.up = nn.Upsample(scale_factor=2)
125 | self.conv = BasicBlock(c1 + c2, cout, spatial_dims)
126 |
127 | def forward(self, x1, x2):
128 | x1 = self.up(x1)
129 | x = torch.concat([x1, x2], dim=1)
130 | x = self.conv(x)
131 | return x
132 |
133 |
134 | if __name__ == '__main__':
135 | print("2D UNet")
136 | H = W = 224
137 | dim = 64
138 | inputs = torch.normal(0, 1, (1, 3, H, W))
139 |
140 | encoder = Encoder(3, base_dim=dim)
141 | outputs = encoder(inputs)
142 | for i, o in enumerate(outputs):
143 | assert list(o.shape) == [1, dim*(2**i), H//(2**i), W//(2**i)]
144 |
145 | bridger = Bridger()
146 | outputs = bridger(outputs)
147 | for i, o in enumerate(outputs):
148 | assert list(o.shape) == [1, dim * (2 ** i), H // (2 ** i), W // (2 ** i)]
149 |
150 | features = [torch.normal(0, 1, (1, dim * (2 ** i), H // (2 ** i), W // (2 ** i))) for i in range(5)]
151 | decoder = Decoder(1024, 512)
152 | outputs = decoder(features)
153 | assert list(outputs.shape) == [1, 64, H, W]
154 |
155 | head = Head(64, 10)
156 | outputs = head(outputs)
157 | assert list(outputs.shape) == [1, 10, H, W]
158 |
159 | unet = UNet(3, 10)
160 | outputs = unet(inputs)
161 | assert list(outputs.shape) == [1, 10, H, W]
162 |
163 | print("1D UNet")
164 | L = 224
165 | dim = 64
166 | inputs = torch.rand(1, 3, L)
167 | unet = UNet(3, 10, spatial_dims=1)
168 | outputs = unet(inputs)
169 | assert list(outputs.shape) == [1, 10, L]
170 |
171 | print("3D UNet")
172 | D = H = W = 64
173 | dim = 32
174 | inputs = torch.rand(1, 3, D, H, W)
175 | unet = UNet(3, 10, spatial_dims=3)
176 | outputs = unet(inputs)
177 | assert list(outputs.shape) == [1, 10, D, H, W]
--------------------------------------------------------------------------------
/fusionlab/segmentation/unet2plus/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/segmentation/unet2plus/__init__.py
--------------------------------------------------------------------------------
/fusionlab/segmentation/unet2plus/tfunet2plus.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.keras import layers, Model, Sequential
3 | from fusionlab.segmentation.tfbase import TFSegmentationModel
4 |
5 |
6 | class TFUNet2plus(TFSegmentationModel):
7 | def __init__(self, num_cls, base_dim):
8 | super().__init__()
9 | self.encoder = Encoder(base_dim)
10 | self.bridger = Bridger()
11 | self.decoder = Decoder(base_dim)
12 | self.head = Head(num_cls)
13 |
14 |
15 | class BasicBlock(Sequential):
16 | def __init__(self, cout):
17 | conv1 = Sequential([
18 | layers.Conv2D(cout, 3, 1, padding='same'),
19 | layers.BatchNormalization(),
20 | layers.Activation(tf.nn.relu),
21 | ])
22 | conv2 = Sequential([
23 | layers.Conv2D(cout, 3, 1, padding='same'),
24 | layers.BatchNormalization(),
25 | layers.Activation(tf.nn.relu),
26 | ])
27 | super().__init__([conv1, conv2])
28 |
29 |
30 | class Encoder(Model):
31 | def __init__(self, base_dim):
32 | """
33 | UNet Encoder
34 | Args:
35 | base_dim (int): 1st stage dim of conv output
36 | """
37 | super().__init__()
38 | self.pool = layers.MaxPool2D()
39 | self.conv0_0 = BasicBlock(base_dim)
40 | self.conv1_0 = BasicBlock(base_dim * 2)
41 | self.conv2_0 = BasicBlock(base_dim * 4)
42 | self.conv3_0 = BasicBlock(base_dim * 8)
43 | self.conv4_0 = BasicBlock(base_dim * 16)
44 |
45 | def call(self, x, training):
46 | x0_0 = self.conv0_0(x, training)
47 | x1_0 = self.conv1_0(self.pool(x0_0), training)
48 | x2_0 = self.conv2_0(self.pool(x1_0), training)
49 | x3_0 = self.conv3_0(self.pool(x2_0), training)
50 | x4_0 = self.conv4_0(self.pool(x3_0), training)
51 | return [x0_0, x1_0, x2_0, x3_0, x4_0]
52 |
53 |
54 | class Bridger(Model):
55 | def call(self, x, training=None):
56 | return [tf.identity(i) for i in x]
57 |
58 |
59 | class Decoder(Model):
60 | def __init__(self, base_dim):
61 | super().__init__()
62 | dims = [base_dim*(2**i) for i in range(5)] # [base_dim, base_dim*2, base_dim*4, base_dim*8, base_dim*16]
63 | self.conv0_1 = BasicBlock(dims[0])
64 | self.conv1_1 = BasicBlock(dims[1])
65 | self.conv2_1 = BasicBlock(dims[2])
66 | self.conv3_1 = BasicBlock(dims[3])
67 |
68 | self.conv0_2 = BasicBlock(dims[0])
69 | self.conv1_2 = BasicBlock(dims[1])
70 | self.conv2_2 = BasicBlock(dims[2])
71 |
72 | self.conv0_3 = BasicBlock(dims[0])
73 | self.conv1_3 = BasicBlock(dims[1])
74 |
75 | self.conv0_4 = BasicBlock(dims[0])
76 | self.up = layers.UpSampling2D()
77 |
78 | def call(self, x, training=None):
79 | x0_0, x1_0, x2_0, x3_0, x4_0 = x
80 |
81 | x0_1 = self.conv0_1(layers.concatenate([x0_0, self.up(x1_0)], -1))
82 | x1_1 = self.conv1_1(layers.concatenate([x1_0, self.up(x2_0)], -1))
83 | x0_2 = self.conv0_2(layers.concatenate([x0_0, x0_1, self.up(x1_1)], -1))
84 |
85 | x2_1 = self.conv2_1(layers.concatenate([x2_0, self.up(x3_0)], -1))
86 | x1_2 = self.conv1_2(layers.concatenate([x1_0, x1_1, self.up(x2_1)], -1))
87 | x0_3 = self.conv0_3(layers.concatenate([x0_0, x0_1, x0_2, self.up(x1_2)], -1))
88 |
89 | x3_1 = self.conv3_1(layers.concatenate([x3_0, self.up(x4_0)], -1))
90 | x2_2 = self.conv2_2(layers.concatenate([x2_0, x2_1, self.up(x3_1)], -1))
91 | x1_3 = self.conv1_3(layers.concatenate([x1_0, x1_1, x1_2, self.up(x2_2)], -1))
92 | x0_4 = self.conv0_4(layers.concatenate([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], -1))
93 |
94 | return x0_4
95 |
96 |
97 | class Head(Sequential):
98 | def __init__(self, cout):
99 | """
100 | Basic Identity
101 | :param int cout: output channel
102 | """
103 | conv = layers.Conv2D(cout, 1)
104 | super().__init__(conv)
105 |
106 |
107 | if __name__ == '__main__':
108 | H = W = 224
109 | dim = 32
110 | num_cls = 10
111 | inputs = tf.random.normal((1, H, W, 3))
112 |
113 | encoder = Encoder(base_dim=dim)
114 | outputs = encoder(inputs, training=True)
115 | for i, o in enumerate(outputs):
116 | assert list(o.shape) == [1, H // (2 ** i), W // (2 ** i), dim * (2 ** i)]
117 |
118 | bridger = Bridger()
119 | outputs = bridger(outputs, training=True)
120 | for i, o in enumerate(outputs):
121 | assert list(o.shape) == [1, H // (2 ** i), W // (2 ** i), dim * (2 ** i)]
122 |
123 | features = [tf.random.normal((1, H // (2 ** i), W // (2 ** i), dim * (2 ** i))) for i in range(5)]
124 | decoder = Decoder(dim)
125 | decoder.build([f.shape for f in features])
126 | outputs = decoder(features, training=True)
127 | assert list(outputs.shape) == [1, H, W, dim]
128 |
129 | head = Head(num_cls)
130 | outputs = head(outputs, training=True)
131 | assert list(outputs.shape) == [1, H, W, num_cls]
132 |
133 | unet = TFUNet2plus(num_cls, dim)
134 | outputs = unet(inputs, training=True)
135 | assert list(outputs.shape) == [1, H, W, num_cls]
136 |
137 |
--------------------------------------------------------------------------------
/fusionlab/segmentation/unet2plus/unet2plus.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from fusionlab.segmentation.base import SegmentationModel
4 | from fusionlab.utils import autopad
5 | from fusionlab.layers import ConvND, BatchNorm, MaxPool, Upsample
6 |
7 | # ref: https://github.com/4uiiurz1/pytorch-nested-unet
8 |
9 |
10 | class UNet2plus(SegmentationModel):
11 | def __init__(self, cin, num_cls, base_dim, spatial_dims=2):
12 | super().__init__()
13 | self.num_cls = num_cls
14 | self.encoder = Encoder(cin, base_dim, spatial_dims)
15 | self.bridger = Bridger()
16 | self.decoder = Decoder(base_dim, spatial_dims)
17 | self.head = Head(base_dim, num_cls, spatial_dims)
18 |
19 |
20 | class BasicBlock(nn.Sequential):
21 | def __init__(self, cin, cout, spatial_dims=2):
22 | conv1 = nn.Sequential(
23 | ConvND(spatial_dims, cin, cout, 3, 1, autopad(3)),
24 | BatchNorm(spatial_dims, cout),
25 | nn.ReLU(),
26 | )
27 | conv2 = nn.Sequential(
28 | ConvND(spatial_dims, cout, cout, 3, 1, autopad(3)),
29 | BatchNorm(spatial_dims, cout),
30 | nn.ReLU(),
31 | )
32 | super().__init__(conv1, conv2)
33 |
34 |
35 | class Encoder(nn.Module):
36 | def __init__(self, cin, base_dim, spatial_dims=2):
37 | """
38 | UNet Encoder
39 | Args:
40 | cin (int): input channels
41 | base_dim (int): 1st stage dim of conv output
42 | """
43 | super().__init__()
44 | self.pool = MaxPool(spatial_dims, 2, 2)
45 | self.conv0_0 = BasicBlock(cin, base_dim, spatial_dims)
46 | self.conv1_0 = BasicBlock(base_dim, base_dim * 2, spatial_dims)
47 | self.conv2_0 = BasicBlock(base_dim * 2, base_dim * 4, spatial_dims)
48 | self.conv3_0 = BasicBlock(base_dim * 4, base_dim * 8, spatial_dims)
49 | self.conv4_0 = BasicBlock(base_dim * 8, base_dim * 16, spatial_dims)
50 |
51 | def forward(self, x):
52 | x0_0 = self.conv0_0(x)
53 | x1_0 = self.conv1_0(self.pool(x0_0))
54 | x2_0 = self.conv2_0(self.pool(x1_0))
55 | x3_0 = self.conv3_0(self.pool(x2_0))
56 | x4_0 = self.conv4_0(self.pool(x3_0))
57 | return [x0_0, x1_0, x2_0, x3_0, x4_0]
58 |
59 |
60 | class Bridger(nn.Module):
61 | def __init__(self):
62 | super().__init__()
63 |
64 | def forward(self, x):
65 | return [nn.Identity()(i) for i in x]
66 |
67 |
68 | class Decoder(nn.Module):
69 | def __init__(self, base_dim, spatial_dims=2):
70 | super().__init__()
71 | dims = [base_dim*(2**i) for i in range(5)] # [base_dim, base_dim*2, base_dim*4, base_dim*8, base_dim*16]
72 | self.conv0_1 = BasicBlock(dims[0] + dims[1], dims[0], spatial_dims)
73 | self.conv1_1 = BasicBlock(dims[1] + dims[2], dims[1], spatial_dims)
74 | self.conv2_1 = BasicBlock(dims[2] + dims[3], dims[2], spatial_dims)
75 | self.conv3_1 = BasicBlock(dims[3] + dims[4], dims[3], spatial_dims)
76 |
77 | self.conv0_2 = BasicBlock(dims[0] * 2 + dims[1], dims[0], spatial_dims)
78 | self.conv1_2 = BasicBlock(dims[1] * 2 + dims[2], dims[1], spatial_dims)
79 | self.conv2_2 = BasicBlock(dims[2] * 2 + dims[3], dims[2], spatial_dims)
80 |
81 | self.conv0_3 = BasicBlock(dims[0] * 3 + dims[1], dims[0], spatial_dims)
82 | self.conv1_3 = BasicBlock(dims[1] * 3 + dims[2], dims[1], spatial_dims)
83 |
84 | self.conv0_4 = BasicBlock(dims[0] * 4 + dims[1], dims[0], spatial_dims)
85 | self.up = Upsample(spatial_dims, scale_factor=2)
86 |
87 | def forward(self, x):
88 | x0_0, x1_0, x2_0, x3_0, x4_0 = x
89 |
90 | x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))
91 | x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
92 | x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))
93 |
94 | x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
95 | x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
96 | x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))
97 |
98 | x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
99 | x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
100 | x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
101 | x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))
102 |
103 | return x0_4
104 |
105 |
106 | class Head(nn.Sequential):
107 | def __init__(self, cin, cout, spatial_dims=2):
108 | """
109 | Basic Identity
110 | :param int cin: input channel
111 | :param int cout: output channel
112 | """
113 | conv = ConvND(spatial_dims, cin, cout, 1)
114 | super().__init__(conv)
115 |
116 |
117 | if __name__ == '__main__':
118 | H = W = 224
119 | dim = 32
120 | inputs = torch.normal(0, 1, (1, 3, H, W))
121 |
122 | encoder = Encoder(3, base_dim=dim)
123 | outputs = encoder(inputs)
124 | for i, o in enumerate(outputs):
125 | assert list(o.shape) == [1, dim * (2 ** i), H // (2 ** i), W // (2 ** i)]
126 |
127 | bridger = Bridger()
128 | outputs = bridger(outputs)
129 | for i, o in enumerate(outputs):
130 | assert list(o.shape) == [1, dim * (2 ** i), H // (2 ** i), W // (2 ** i)]
131 |
132 | features = [torch.normal(0, 1, (1, dim * (2 ** i), H // (2 ** i), W // (2 ** i))) for i in range(5)]
133 | decoder = Decoder(dim)
134 | outputs = decoder(features)
135 | assert list(outputs.shape) == [1, dim, H, W]
136 |
137 | head = Head(dim, 10)
138 | outputs = head(outputs)
139 | assert list(outputs.shape) == [1, 10, H, W]
140 |
141 | unet = UNet2plus(3, 10, dim)
142 | outputs = unet(inputs)
143 | assert list(outputs.shape) == [1, 10, H, W]
144 |
145 | print("1D UNet++")
146 | L = 128
147 | dim = 32
148 | inputs = torch.rand(1, 3, L)
149 | unet = UNet2plus(3, 10, dim, spatial_dims=1)
150 |
151 | outputs = unet(inputs)
152 | assert list(outputs.shape) == [1, 10, L]
153 |
154 | print("3D UNet++")
155 | D = H = W = 32
156 | dim = 32
157 | inputs = torch.rand(1, 3, D, H, W)
158 | unet = UNet2plus(3, 10, dim, spatial_dims=3)
159 |
160 | outputs = unet(inputs)
161 | assert list(outputs.shape) == [1, 10, D, H, W]
162 |
--------------------------------------------------------------------------------
/fusionlab/segmentation/unetr/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/segmentation/unetr/__init__.py
--------------------------------------------------------------------------------
/fusionlab/trainers/__init__.py:
--------------------------------------------------------------------------------
1 | from fusionlab import BACKEND
2 | if BACKEND['torch']:
3 | from .dcgan import *
4 | from .trainer import *
--------------------------------------------------------------------------------
/fusionlab/trainers/test.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | m = tf.keras.metrics.Accuracy()
4 | m.reset_state()
--------------------------------------------------------------------------------
/fusionlab/trainers/trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from tqdm.auto import tqdm
3 | import numpy as np
4 |
5 | # ref: https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/Callback
6 |
7 | class Trainer:
8 | def __init__(self, device):
9 | self.device = device
10 |
11 | def train_step(self, data):
12 | data = self._data_to_device(data)
13 | inputs, target = data
14 | pred = self.model(inputs)
15 | loss = self.loss_fn(pred, target)
16 | loss.backward()
17 | self.optimizer.step()
18 | self.optimizer.zero_grad()
19 | return loss.item()
20 |
21 | def val_step(self, data):
22 | data = self._data_to_device(data)
23 | inputs, target = data
24 | with torch.no_grad():
25 | pred = self.model(inputs)
26 | loss = self.loss_fn(pred, target)
27 | return loss.item()
28 |
29 | def train_epoch(self):
30 | self.model.train()
31 | epoch_loss = []
32 | for _, data in enumerate(tqdm(self.train_dataloader, leave=False)):
33 | batch_loss = self.train_step(data)
34 | epoch_loss.append(batch_loss)
35 | return np.mean(epoch_loss)
36 |
37 | def val_epoch(self):
38 | self.model.eval()
39 | epoch_loss = []
40 | for _, data in enumerate(tqdm(self.val_dataloader, leave=False)):
41 | batch_loss = self.val_step(data)
42 | epoch_loss.append(batch_loss)
43 | return np.mean(epoch_loss)
44 |
45 | def on_fit_begin(self):
46 | pass
47 | def on_fit_end(self):
48 | pass
49 |
50 | def on_epoch_begin(self):
51 | pass
52 | def on_epoch_end(self):
53 | pass
54 |
55 | def _data_to_device(self, data):
56 | if isinstance(data, torch.Tensor):
57 | return data.to(self.device)
58 | elif isinstance(data, dict):
59 | return {k: v.to(self.device) for k, v in data.items()}
60 | elif isinstance(data, list):
61 | return [v.to(self.device) for v in data]
62 | else:
63 | raise NotImplementedError
64 |
65 | def fit(self, model, train_dataloader, val_dataloader, epochs, optimizer, loss_fn):
66 | self.model = model.to(self.device)
67 | self.train_dataloader = train_dataloader
68 | self.val_dataloader = val_dataloader
69 | self.epochs = epochs
70 | self.optimizer = optimizer
71 | self.loss_fn = loss_fn
72 |
73 | self.train_log = {'loss': []}
74 | self.val_log = {'loss': []}
75 |
76 | self.on_fit_begin()
77 | for epoch in tqdm(range(epochs)):
78 | self.on_epoch_begin()
79 | train_epoch_loss = self.train_epoch()
80 | self.train_log['loss'].append(train_epoch_loss)
81 |
82 | if self.val_dataloader:
83 | val_epoch_loss = self.val_epoch()
84 | self.val_log['loss'].append(val_epoch_loss)
85 |
86 | print(f'''[{epoch}/{epochs}] train_loss: {self.train_log['loss'][-1]:.4f} \
87 | val_loss: {self.val_log['loss'][-1]:.4f}''')
88 | self.on_epoch_end()
89 | self.on_fit_end()
90 | return
91 |
92 | if __name__ == "__main__":
93 | class FakeModel(torch.nn.Module):
94 | def __init__(self):
95 | super().__init__()
96 | self.conv = torch.nn.Conv2d(1, 3, 3)
97 | self.pool = torch.nn.Sequential(
98 | torch.nn.AdaptiveAvgPool2d(1),
99 | torch.nn.Flatten(),
100 | )
101 | self.cls = torch.nn.Linear(3, 10)
102 | def forward(self, x):
103 | x = self.conv(x)
104 | x = self.pool(x)
105 | x = self.cls(x)
106 | return x
107 |
108 | from abc import ABC, abstractmethod
109 | class Metric(ABC):
110 | def __init__(self):
111 | pass
112 |
113 | @abstractmethod
114 | def reset():
115 | raise NotImplementedError("reset method is not implemented!")
116 |
117 | @abstractmethod
118 | def update():
119 | raise NotImplementedError("update method is not implemented!")
120 |
121 | @abstractmethod
122 | def compute():
123 | raise NotImplementedError("compute method is not implemented!")
124 |
125 | # class Accuracy(Metric):
126 |
127 |
128 |
129 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
130 | print(device)
131 |
132 | # mnist
133 | from torchvision.datasets import MNIST
134 | from torchvision.transforms import ToTensor
135 | from torch.utils.data import DataLoader
136 | train_dataset = MNIST(root='data/', train=True, transform=ToTensor(), download=True)
137 | val_dataset = MNIST(root='data/', train=False, transform=ToTensor(), download=True)
138 | train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
139 | val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False)
140 | model = FakeModel()
141 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
142 | loss_fn = torch.nn.CrossEntropyLoss()
143 |
144 | trainer = Trainer(device)
145 | trainer.fit(model,
146 | train_dataloader,
147 | val_dataloader,
148 | 10,
149 | optimizer,
150 | loss_fn)
151 |
152 |
153 |
--------------------------------------------------------------------------------
/fusionlab/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .basic import autopad, make_ntuple
2 | from .trace import show_classtree
3 | from .plots import plot_channels
4 | from .labelme import convert_labelme_json2mask
5 |
6 | from fusionlab import BACKEND
7 | if BACKEND['torch']:
8 | from .trunc_normal.trunc_normal import trunc_normal_
--------------------------------------------------------------------------------
/fusionlab/utils/basic.py:
--------------------------------------------------------------------------------
1 | import collections
2 | from itertools import repeat
3 | from typing import Any, Tuple
4 |
5 | def autopad(kernel_size, padding=None, dilation=1, spatial_dims=2):
6 | '''
7 | Auto padding for convolutional layers
8 | '''
9 | if padding is None:
10 | if isinstance(kernel_size, int) and isinstance(dilation, int):
11 | padding = (kernel_size - 1) // 2 * dilation
12 | else:
13 | kernel_size = make_ntuple(kernel_size, spatial_dims)
14 | dilation = make_ntuple(dilation, spatial_dims)
15 | padding = tuple((kernel_size[i] - 1) // 2 * dilation[i] for i in range(spatial_dims))
16 | return padding
17 |
18 | def make_ntuple(x: Any, n: int) -> Tuple[Any, ...]:
19 | """
20 | Make n-tuple from input x. If x is an iterable, then we just convert it to tuple.
21 | Otherwise, we will make a tuple of length n, all with value of x.
22 | reference: https://github.com/pytorch/vision/blob/main/torchvision/utils.py#L585C1-L597C31
23 |
24 | Args:
25 | x (Any): input value
26 | n (int): length of the resulting tuple
27 | """
28 | if isinstance(x, collections.abc.Iterable):
29 | return tuple(x)
30 | return tuple(repeat(x, n))
--------------------------------------------------------------------------------
/fusionlab/utils/labelme.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence
2 | import json
3 | import numpy as np
4 | import os
5 | from glob import glob
6 | from tqdm.auto import tqdm
7 |
8 | def convert_labelme_json2mask(
9 | class_names: Sequence[str],
10 | json_dir: str,
11 | output_dir: str,
12 | single_mask: bool = True
13 | ):
14 | """
15 | Convert labelme json files to mask files(.png)
16 |
17 | Args:
18 | class_names (list): list of class names, background class must be included at first
19 | json_dir (str): path to json files directory
20 | output_dir (str): path to output directory
21 | single_mask (bool): if True, save single mask file with class index(uint8), otherwise save multiple mask files with class index(uint8)
22 | """
23 | try:
24 | import cv2
25 | except ImportError:
26 | raise ImportError("opencv-python package is not installed")
27 |
28 | num_classes = len(class_names)
29 | if num_classes > 255:
30 | raise ValueError("Maximum number of classes is 255")
31 |
32 | cls_map = {name: i for i, name in enumerate(class_names)}
33 | print(f"Number of classes: {num_classes}")
34 | print("class name to index: ", cls_map)
35 | json_paths = glob(os.path.join(json_dir, '*.json'))
36 |
37 | for path in tqdm(json_paths):
38 | json_data = json.load(open(path))
39 | h = json_data['imageHeight']
40 | w = json_data['imageWidth']
41 |
42 | # Draw Object mask
43 | mask = np.zeros((len(class_names), h, w), dtype=np.uint8)
44 | for shape in json_data['shapes']:
45 | if shape["shape_type"] != "polygon":
46 | continue
47 | cls_name = shape['label']
48 | cls_idx = cls_map[cls_name]
49 | points = shape['points']
50 | cv2.fillPoly(
51 | mask[cls_idx],
52 | np.array([points], dtype=np.int32),
53 | 255
54 | )
55 | # update backgroud mask
56 | mask[0] = 255-np.max(mask[1:], axis=0)
57 | # Save Mask File
58 | filename = ".".join(os.path.split(path)[-1].split('.')[:-1])
59 | if single_mask:
60 | mask_single = np.argmax(mask, axis=0).astype(np.uint8)
61 | cv2.imwrite(os.path.join(output_dir, f"{filename}.png"), mask_single)
62 | else:
63 | for i, m in enumerate(mask):
64 | cv2.imwrite(os.path.join(output_dir, f'{filename}_{i:03d}.png'), m.astype(np.uint8))
65 |
66 | if __name__ == '__main__':
67 | convert_labelme_json2mask(
68 | ['bg', 'dog', 'cat'],
69 | "json",
70 | "mask",
71 | single_mask=True,
72 | )
--------------------------------------------------------------------------------
/fusionlab/utils/plots.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 |
3 | def plot_channels(signals, show=True):
4 | '''
5 | plot signals by channels
6 |
7 | Args:
8 | signals: numpy array, shape (num_samples, num_channels)
9 | '''
10 | num_channels = signals.shape[1]
11 | fig, axes = plt.subplots(num_channels, 1, figsize=(10, 10))
12 | for i in range(num_channels):
13 | axes[i].plot(signals[:, i])
14 |
15 | if show: plt.show()
16 | return fig
17 |
--------------------------------------------------------------------------------
/fusionlab/utils/trace.py:
--------------------------------------------------------------------------------
1 | import inspect
2 |
3 | # Define the show_classtree function
4 | def show_classtree(clss, indent=0):
5 | # Get the full argument spec for the class
6 | argspec = inspect.getfullargspec(clss)
7 | # Get the arguments for the class
8 | args = argspec.args
9 | # If the class has a varargs argument, append it to args
10 | if argspec.varargs:
11 | args.append('*' + argspec.varargs)
12 | # If the class has a varkw argument, append it to args
13 | if argspec.varkw:
14 | args.append('**' + argspec.varkw)
15 | # Print the class name and arguments, indented by indent spaces
16 | print(' ' * indent + f'{clss} | input: {args}')
17 | # For each base class of the class
18 | for supercls in clss.__bases__:
19 | # Recursively call show_classtree on the base class, with an indent of 3 more spaces
20 | show_classtree(supercls, indent + 3)
--------------------------------------------------------------------------------
/fusionlab/utils/trunc_normal/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/fusionlab/utils/trunc_normal/__init__.py
--------------------------------------------------------------------------------
/fusionlab/utils/trunc_normal/trunc_normal.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 |
4 | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
5 | """Tensor initialization with truncated normal distribution.
6 | Based on:
7 | https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
8 | https://github.com/rwightman/pytorch-image-models
9 |
10 | Args:
11 | tensor: an n-dimensional `torch.Tensor`.
12 | mean: the mean of the normal distribution.
13 | std: the standard deviation of the normal distribution.
14 | a: the minimum cutoff value.
15 | b: the maximum cutoff value.
16 | """
17 |
18 | def norm_cdf(x):
19 | return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
20 |
21 | with torch.no_grad():
22 | l = norm_cdf((a - mean) / std)
23 | u = norm_cdf((b - mean) / std)
24 | tensor.uniform_(2 * l - 1, 2 * u - 1)
25 | tensor.erfinv_()
26 | tensor.mul_(std * math.sqrt(2.0))
27 | tensor.add_(mean)
28 | tensor.clamp_(min=a, max=b)
29 | return tensor
30 |
31 |
32 | def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
33 | """Tensor initialization with truncated normal distribution.
34 | Based on:
35 | https://github.com/rwightman/pytorch-image-models
36 |
37 | Args:
38 | tensor: an n-dimensional `torch.Tensor`
39 | mean: the mean of the normal distribution
40 | std: the standard deviation of the normal distribution
41 | a: the minimum cutoff value
42 | b: the maximum cutoff value
43 | """
44 |
45 | if std <= 0:
46 | raise ValueError("the standard deviation should be greater than zero.")
47 |
48 | if a >= b:
49 | raise ValueError("minimum cutoff value (a) should be smaller than maximum cutoff value (b).")
50 |
51 | return _no_grad_trunc_normal_(tensor, mean, std, a, b)
--------------------------------------------------------------------------------
/make_init.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Set the package directory
4 | PACKAGE_DIR="fusionlab"
5 |
6 | # Find all directories in the package
7 | DIRS=$(find $PACKAGE_DIR -type d)
8 |
9 | # Create __init__.py files in all directories
10 | for DIR in $DIRS; do
11 | touch $DIR/__init__.py
12 | done
13 |
14 | # Add import statements to __init__.py files
15 | for DIR in $DIRS; do
16 | # Get the relative path of the directory
17 | RELATIVE_PATH=${DIR#$PACKAGE_DIR/}
18 | # Get the list of subdirectories and files in the directory
19 | SUBDIRS=$(find $DIR -maxdepth 1 -type d ! -path $DIR)
20 | FILES=$(find $DIR -maxdepth 1 -type f -name "*.py" ! -name "__init__.py")
21 | # Add import statements for subdirectories and files
22 | for SUBDIR in $SUBDIRS; do
23 | MODULE_NAME=${SUBDIR#$DIR/}
24 | echo "from .$MODULE_NAME import *" >> $DIR/__init__.py
25 | done
26 | for FILE in $FILES; do
27 | MODULE_NAME=$(basename $FILE .py)
28 | echo "from .$MODULE_NAME import *" >> $DIR/__init__.py
29 | done
30 | done
--------------------------------------------------------------------------------
/release_logs.md:
--------------------------------------------------------------------------------
1 | 0.1.12
2 |
3 | * Fix: remove print in SegFormer
4 |
5 | 0.1.11
6 |
7 | * Add: convert labelme json to mask files
8 |
9 | 0.1.10
10 |
11 | * Add: TransUNet, ViT, UNETER, SegFormer
12 | * Add layer: Rearrange, PatchEmbedding, InstanceNorm ND, DropPath
13 | * Add utils: trunc_normal_
14 |
15 | 0.1.9
16 |
17 | * Add LSTimeClassificationDataset for time series classification
18 |
19 | 0.1.7
20 |
21 | * Add Dice, IoU score
22 |
23 | 0.1.6
24 |
25 | * Add API Documentation
26 | * Add utils.count_parameters
27 | * Update backend support
28 |
29 | 0.1.5
30 |
31 | * Remove numpy from requirements.txt
32 | * extract make_ntuple from torchvision to fusionlab.utils
33 |
34 | 0.1.4
35 |
36 | * Add EfficientNetB0 ~ B7
37 | * Add ConvNeXt Tiny ~ XLarge
38 |
39 | 0.1.3
40 |
41 | * Add
42 | * ECGCSVClassificationDataset(cinc2017)
43 | * LUDBDataset
44 | * LSTimeSegDataset
45 | * HFDataset
46 |
47 | 0.1.2
48 |
49 | * Add LUDB dataset
50 |
51 |
52 | 0.0.52
53 |
54 | * Tversky Loss for Torch and TF
55 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | einops
2 | tqdm
3 | traceback2
4 |
5 | pandas>=1.5.3
6 | xmltodict
7 | xlwt
8 | scipy
9 | wfdb
--------------------------------------------------------------------------------
/scripts/build_pip.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # ref: https://datainpoint.substack.com/p/c0a
3 |
4 | python ../setup.py sdist bdist_wheel
5 | python -m twine upload dist/*
6 |
7 | # pip install -e .
8 |
9 | # build
10 | python setup.py sdist bdist_wheel
11 | # pypi
12 | twine upload dist/*
13 | # testpypi
14 | twine upload --repository testpypi dist/*
15 |
16 | # testpypi install
17 | !pip install --upgrade fusionlab==?
18 | !pip install -i https://test.pypi.org/simple/ --upgrade fusionlab==?
19 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # Ref: https://github.com/qubvel/segmentation_models.pytorch
5 |
6 | # Note: To use the 'upload' functionality of this file, you must:
7 | # $ pip install twine
8 |
9 | import io
10 | import os
11 | import sys
12 | from shutil import rmtree
13 |
14 | from setuptools import find_packages, setup, Command
15 |
16 | # Package meta-data.
17 | NAME = "fusionlab"
18 | DESCRIPTION = "Useful packages for DL"
19 | URL = "https://github.com/taipingeric/fusionlab"
20 | EMAIL = "taipingeric@gmail.com"
21 | AUTHOR = "Chih-Yang Li"
22 | REQUIRES_PYTHON = ">=3.8.0"
23 | VERSION = None
24 |
25 | # The rest you shouldn't have to touch too much :)
26 | # ------------------------------------------------
27 | # Except, perhaps the License and Trove Classifiers!
28 | # If you do change the License, remember to change the Trove Classifier for that!
29 |
30 | here = os.path.abspath(os.path.dirname(__file__))
31 |
32 | # What packages are required for this module to be executed?
33 | try:
34 | with open(os.path.join(here, "requirements.txt"), encoding="utf-8") as f:
35 | REQUIRED = f.read().split("\n")
36 | except:
37 | REQUIRED = []
38 |
39 | # What packages are optional?
40 | EXTRAS = {
41 | "test": [
42 | "pytest",
43 | "mock",
44 | "pre-commit",
45 | "black==22.3.0",
46 | "flake8==4.0.1",
47 | "flake8-docstrings==1.6.0",
48 | "torchvision",
49 | "tensorflow",
50 | ],
51 | }
52 |
53 | # Import the README and use it as the long-description.
54 | # Note: this will only work if 'README.md' is present in your MANIFEST.in file!
55 | try:
56 | with io.open(os.path.join(here, "README.md"), encoding="utf-8") as f:
57 | long_description = "\n" + f.read()
58 | except FileNotFoundError:
59 | long_description = DESCRIPTION
60 |
61 | # Load the package's __version__.py module as a dictionary.
62 | about = {}
63 | if not VERSION:
64 | with open(os.path.join(here, NAME, "__version__.py")) as f:
65 | exec(f.read(), about)
66 | else:
67 | about["__version__"] = VERSION
68 |
69 |
70 | class UploadCommand(Command):
71 | """Support setup.py upload."""
72 |
73 | description = "Build and publish the package."
74 | user_options = []
75 |
76 | @staticmethod
77 | def status(s):
78 | """Prints things in bold."""
79 | print(s)
80 |
81 | def initialize_options(self):
82 | pass
83 |
84 | def finalize_options(self):
85 | pass
86 |
87 | def run(self):
88 | try:
89 | self.status("Removing previous builds...")
90 | rmtree(os.path.join(here, "dist"))
91 | except OSError:
92 | pass
93 |
94 | self.status("Building Source and Wheel (universal) distribution...")
95 | os.system("{0} setup.py sdist bdist_wheel --universal".format(sys.executable))
96 |
97 | self.status("Uploading the package to PyPI via Twine...")
98 | os.system("twine upload dist/*")
99 |
100 | self.status("Pushing git tags...")
101 | os.system("git tag v{0}".format(about["__version__"]))
102 | os.system("git push --tags")
103 |
104 | sys.exit()
105 |
106 |
107 | # Where the magic happens:
108 | setup(
109 | name=NAME,
110 | version=about["__version__"],
111 | description=DESCRIPTION,
112 | long_description=long_description,
113 | long_description_content_type="text/markdown",
114 | author=AUTHOR,
115 | author_email=EMAIL,
116 | python_requires=REQUIRES_PYTHON,
117 | url=URL,
118 | packages=find_packages(exclude=("tests", "docs", "images")),
119 | # If your package is a single module, use this instead of 'packages':
120 | # py_modules=['mypackage'],
121 | # entry_points={
122 | # 'console_scripts': ['mycli=mymodule:cli'],
123 | # },
124 | install_requires=REQUIRED,
125 | extras_require=EXTRAS,
126 | include_package_data=True,
127 | license="MIT",
128 | classifiers=[
129 | # Trove classifiers
130 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers
131 | "License :: OSI Approved :: MIT License",
132 | # "Programming Language :: Python",
133 | "Programming Language :: Python :: 3",
134 | "Operating System :: OS Independent",
135 | # "Programming Language :: Python :: Implementation :: CPython",
136 | # "Programming Language :: Python :: Implementation :: PyPy",
137 | ],
138 | # $ setup.py publish support.
139 | # cmdclass={
140 | # "upload": UploadCommand,
141 | # },
142 | )
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taipingeric/fusionlab/de28f6a98279b729464d39f01fcb33e21e9a00af/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_classification.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from fusionlab.classification import HFClassificationModel
3 | from fusionlab.classification import VGG16Classifier
4 | from fusionlab.classification import LSTMClassifier
5 |
6 | H = W = 224
7 | cout = 5
8 | inputs = torch.normal(0, 1, (1, 3, W))
9 | # Test CNNClassification
10 | model = VGG16Classifier(3, cout, spatial_dims=1)
11 | hf_model = HFClassificationModel(model, cout)
12 | output = hf_model(inputs)
13 | print(output['logits'].shape)
14 | assert list(output.keys()) == ['loss', 'logits', 'hidden_states']
15 |
16 | inputs = torch.normal(0, 1, (1, 3, H, W))
17 | # Test CNNClassification
18 | model = VGG16Classifier(3, cout, spatial_dims=2)
19 | hf_model = HFClassificationModel(model, cout)
20 | output = hf_model(inputs)
21 | print(output['logits'].shape)
22 | assert list(output.keys()) == ['loss', 'logits', 'hidden_states']
23 |
24 | inputs = torch.normal(0, 1, (1, 3, H))
25 | model = LSTMClassifier(3, cout)
26 | hf_model = HFClassificationModel(model, cout)
27 | output = hf_model(inputs)
28 | print(output['logits'].shape)
29 | assert list(output.keys()) == ['loss', 'logits', 'hidden_states']
--------------------------------------------------------------------------------
/tests/test_datasets.py:
--------------------------------------------------------------------------------
1 | from fusionlab import datasets
2 |
3 | class TestHFDataset:
4 | def test_HFDataset(self):
5 | import torch
6 | print("Test HFDataset")
7 |
8 | NUM_DATA = 20
9 | NUM_FEATURES = 16
10 | ds = torch.utils.data.TensorDataset(
11 | torch.zeros(NUM_DATA, NUM_FEATURES),
12 | torch.zeros(NUM_DATA))
13 | for x, y in ds:
14 | assert list(x.shape) == [NUM_FEATURES]
15 | pass
16 | hf_ds = datasets.HFDataset(ds)
17 | assert hf_ds is not None
18 | for i in range(len(hf_ds)):
19 | data_dict = hf_ds[i]
20 | assert data_dict.keys() == set(['x', 'labels'])
21 | assert list(data_dict['x'].shape) == [NUM_FEATURES]
22 | pass
23 |
24 | class TestLSTimeSegDataset:
25 | def test_LSTimeSegDataset(self, tmpdir):
26 | import numpy as np
27 | import os
28 | import pandas as pd
29 | import json
30 | import torch
31 | filename = "29be6360-12lead.csv"
32 | annotaion_path = os.path.join(tmpdir, "12.json")
33 | annotation = [
34 | {
35 | "csv": f"/data/upload/12/{filename}",
36 | "label": [
37 | {
38 | "start": 0.004,
39 | "end": 0.764,
40 | "instant": False,
41 | "timeserieslabels": ["N"]
42 | },
43 | {
44 | "start": 0.762,
45 | "end": 1.468,
46 | "instant": False,
47 | "timeserieslabels": ["p"]
48 | },
49 | {
50 | "start": 1.466,
51 | "end": 2.5,
52 | "instant": False,
53 | "timeserieslabels": ["t"]
54 | }
55 | ],
56 | "number": [{"number": 500}],
57 | }
58 | ]
59 | with open(annotaion_path, "w") as f:
60 | json.dump(annotation, f)
61 | col_names = ['i', 'ii', 'iii', 'avr', 'avl', 'avf', 'v1', 'v2', 'v3', 'v4', 'v5', 'v6']
62 | df = pd.DataFrame()
63 | sample_rate = 500
64 | num_samples = sample_rate * 10
65 | df['time'] = np.arange(num_samples) / sample_rate
66 | for col_name in col_names:
67 | df[col_name] = np.random.randn(num_samples)
68 | df.to_csv(os.path.join(tmpdir, filename), index=False)
69 | ds = datasets.LSTimeSegDataset(data_dir=tmpdir,
70 | annotation_path=annotaion_path,
71 | class_map={"N": 1, "p": 2, "t": 3},
72 | column_names=col_names)
73 | signals, mask = ds[0]
74 | assert signals.shape == (len(col_names), num_samples)
75 | assert mask.shape == (num_samples, )
76 | assert type(signals) == torch.Tensor
77 | assert type(mask) == torch.Tensor
78 | assert len(ds) == 1
--------------------------------------------------------------------------------
/tests/test_factories.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from fusionlab.layers.factories import (
3 | ConvND,
4 | ConvT,
5 | Upsample,
6 | BatchNorm,
7 | MaxPool,
8 | AvgPool,
9 | AdaptiveMaxPool,
10 | AdaptiveAvgPool,
11 | ReplicationPad,
12 | ConstantPad
13 | )
14 |
15 |
16 |
17 | # Test Code for ConvND
18 | inputs = torch.randn(1, 3, 16) # create random input tensor
19 | layer = ConvND(spatial_dims=1, in_channels=3, out_channels=2, kernel_size=5) # create model instance
20 | outputs = layer(inputs) # pass input through model
21 | print(outputs.shape)
22 | assert list(outputs.shape) == [1, 2, 16] # check output shape is correct
23 |
24 | # Test code for ConvT
25 | inputs = torch.randn(1, 3, 16) # create random input tensor
26 | layer = ConvT(spatial_dims=1, in_channels=3, out_channels=2, kernel_size=5) # create model instance
27 | outputs = layer(inputs) # pass input through model
28 | print(outputs.shape)
29 | assert list(outputs.shape) == [1, 2, 16] # check output shape is correct
30 |
31 | # Test code for Upsample
32 | inputs = torch.randn(1, 3, 16) # create random input tensor
33 | layer = Upsample(spatial_dims=1, scale_factor=2) # create model instance
34 | outputs = layer(inputs) # pass input through model
35 | print(outputs.shape)
36 | assert list(outputs.shape) == [1, 3, 32] # check output shape is correct
37 |
38 | # Test code for BatchNormND
39 | inputs = torch.randn(1, 3, 16) # create random input tensor
40 | layer = BatchNorm(spatial_dims=1, num_features=3) # create model instance
41 | outputs = layer(inputs) # pass input through model
42 |
43 | # Test code for MaxPool
44 | for Module in [MaxPool, AvgPool]:
45 | inputs = torch.randn(1, 3, 16) # create random input tensor
46 | layer = Module(spatial_dims=1, kernel_size=2) # create model instance
47 | outputs = layer(inputs) # pass input through model
48 | print(outputs.shape)
49 | assert list(outputs.shape) == [1, 3, 8] # check output shape is correct
50 |
51 | # Test code for Pool
52 | for Module in [AdaptiveMaxPool, AdaptiveAvgPool]:
53 | inputs = torch.randn(1, 3, 16) # create random input tensor
54 | layer = Module(spatial_dims=1, output_size=8) # create model instance
55 | outputs = layer(inputs) # pass input through model
56 | print(outputs.shape)
57 | assert list(outputs.shape) == [1, 3, 8] # check output shape is correct
58 |
59 | # Test code for Padding
60 | inputs = torch.randn(1, 3, 16) # create random input tensor
61 | layer = ReplicationPad(spatial_dims=1, padding=2) # create model instance
62 | outputs = layer(inputs) # pass input through model
63 | print(outputs.shape)
64 | assert list(outputs.shape) == [1, 3, 20] # check output shape is correct
65 |
66 | layer = ConstantPad(spatial_dims=1, padding=2, value=0) # create model instance
67 | outputs = layer(inputs) # pass input through model
68 | print(outputs.shape)
69 | assert list(outputs.shape) == [1, 3, 20] # check output shape is correct
70 |
--------------------------------------------------------------------------------
/tests/test_layers.py:
--------------------------------------------------------------------------------
1 | from fusionlab.layers import ConvNormAct
2 |
3 | class TestConvNormAct:
4 | def test_ConvNormAct(self):
5 | import torch
6 | inputs = torch.randn(1, 3, 16)
7 | l = ConvNormAct(1, 3, 4, 3, 1, 1)
8 | outputs = l(inputs)
9 | assert outputs.shape == torch.Size([1, 4, 16])
10 |
11 | inputs = torch.randn(1, 3, 16, 16)
12 | l = ConvNormAct(2, 3, 4, 3, 1, 1)
13 | outputs = l(inputs)
14 | assert outputs.shape == torch.Size([1, 4, 16, 16])
15 |
16 | inputs = torch.randn(1, 3, 16, 16, 16)
17 | l = ConvNormAct(3, 3, 4, 3, 1, 1)
18 | outputs = l(inputs)
19 | assert outputs.shape == torch.Size([1, 4, 16, 16, 16])
20 |
21 | class TestSqueezeExcitation:
22 | def test_se(self):
23 | import torch
24 | from fusionlab.layers import SEModule
25 | cin = 64
26 | squeeze_channels = 16
27 |
28 | for i in range(1, 4):
29 | size = tuple([1, cin] + [16] * i)
30 | inputs = torch.randn(size)
31 | layer = SEModule(cin, squeeze_channels, spatial_dims=i)
32 | outputs = layer(inputs)
33 | assert outputs.shape == size
--------------------------------------------------------------------------------
/tests/test_losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import tensorflow as tf
3 | from einops import rearrange
4 | from fusionlab.losses import *
5 | from pytest import approx
6 |
7 | EPS = 1e-6
8 |
9 |
10 | class Data:
11 | def __init__(self):
12 | self.pred = [[
13 | [1., 2., 3., 4.],
14 | [2., 6., 4., 4.],
15 | [9., 6., 3., 4.]
16 | ]]
17 | self.target = [[2, 1, 0, 2]]
18 |
19 |
20 | class BinaryData:
21 | def __init__(self):
22 | self.pred = [0.4, 0.2, 0.3, 0.5]
23 | self.target = [0, 1, 0, 1]
24 |
25 |
26 | class TestSegLoss:
27 | def test_dice_loss(self):
28 | from fusionlab.losses.diceloss.tfdice import TFDiceLoss
29 | # Multi class
30 | data = Data()
31 | true_loss = 0.5519775748252869
32 | # PT
33 | pred = torch.tensor(data.pred).view(1, 3, 4)
34 | true = torch.tensor(data.target).view(1, 4)
35 | loss = DiceLoss("multiclass", from_logits=True)(pred, true)
36 | assert loss == approx(true_loss, EPS)
37 | # TF
38 | pred = tf.convert_to_tensor(data.pred)
39 | pred = rearrange(pred, "N C H -> N H C")
40 | true = tf.convert_to_tensor(data.target)
41 | loss = TFDiceLoss("multiclass", from_logits=True)(true, pred)
42 | assert float(loss) == approx(true_loss, EPS)
43 |
44 | # Binary Loss
45 | data = BinaryData()
46 | true_loss = 0.46044695377349854
47 |
48 | pred = torch.tensor(data.pred).reshape(1, 1, 2, 2)
49 | true = torch.tensor(data.target).reshape(1, 2, 2)
50 | # PT
51 | loss = DiceLoss("binary", from_logits=True)(pred, true)
52 | assert loss == approx(true_loss, EPS)
53 | # TF
54 | pred = tf.convert_to_tensor(data.pred)
55 | pred = tf.reshape(pred, [1, 2, 2, 1])
56 | true = tf.convert_to_tensor(data.target)
57 | true = tf.reshape(true, [1, 2, 2])
58 | loss = TFDiceLoss("binary", from_logits=True)(true, pred)
59 | assert float(loss) == approx(true_loss, EPS)
60 |
61 | # Binary Log loss
62 | true_loss = 0.6170141696929932
63 |
64 | # PT
65 | pred = torch.tensor(data.pred).reshape(1, 1, 2, 2)
66 | true = torch.tensor(data.target).reshape(1, 2, 2)
67 | loss = DiceLoss("binary", from_logits=True, log_loss=True)(pred, true)
68 | assert loss == approx(true_loss, EPS)
69 | # TF
70 | pred = tf.convert_to_tensor(data.pred)
71 | pred = tf.reshape(pred, [1, 2, 2, 1])
72 | true = tf.convert_to_tensor(data.target)
73 | true = tf.reshape(true, [1, 2, 2])
74 | loss = TFDiceLoss("binary", from_logits=True, log_loss=True)(true, pred)
75 | assert float(loss) == approx(true_loss, EPS)
76 |
77 | def test_iou_loss(self):
78 | from fusionlab.losses.iouloss.tfiou import TFIoULoss
79 | # multi class
80 | true_loss = 0.6969285607337952
81 | data = Data()
82 |
83 | # PT
84 | pred = torch.tensor(data.pred).view(1, 3, 4)
85 | true = torch.tensor(data.target).view(1, 4)
86 | loss = IoULoss("multiclass", from_logits=True)(pred, true)
87 | assert loss == approx(true_loss, EPS)
88 | # TF
89 | pred = tf.convert_to_tensor(data.pred)
90 | pred = rearrange(pred, "N C H -> N H C")
91 | true = tf.convert_to_tensor(data.target)
92 | loss = TFIoULoss("multiclass", from_logits=True)(true, pred)
93 | assert float(loss) == approx(true_loss, EPS)
94 |
95 | # Binary
96 | data = BinaryData()
97 | true_loss = 0.6305561661720276
98 | # PT
99 | pred = torch.tensor(data.pred).reshape(1, 1, 2, 2)
100 | true = torch.tensor(data.target).reshape(1, 2, 2)
101 | loss = IoULoss("binary", from_logits=True)(pred, true)
102 | assert loss == approx(true_loss, EPS)
103 | # TF
104 | pred = tf.convert_to_tensor(data.pred)
105 | pred = tf.reshape(pred, [1, 2, 2, 1])
106 | true = tf.convert_to_tensor(data.target)
107 | true = tf.reshape(true, [1, 2, 2])
108 | loss = TFIoULoss("binary", from_logits=True)(true, pred)
109 | assert float(loss) == approx(true_loss, EPS)
110 |
111 | # Binary Log loss
112 | data = BinaryData()
113 | true_loss = 0.9957565665245056
114 | # PT
115 | pred = torch.tensor(data.pred).reshape(1, 1, 2, 2)
116 | true = torch.tensor(data.target).reshape(1, 2, 2)
117 | loss = IoULoss("binary", from_logits=True, log_loss=True)(pred, true)
118 | assert loss == approx(true_loss, EPS)
119 | # TF
120 | pred = tf.convert_to_tensor(data.pred)
121 | pred = tf.reshape(pred, [1, 2, 2, 1])
122 | true = tf.convert_to_tensor(data.target)
123 | true = tf.reshape(true, [1, 2, 2])
124 | loss = TFIoULoss("binary", from_logits=True, log_loss=True)(true, pred)
125 | assert float(loss) == approx(true_loss, EPS)
--------------------------------------------------------------------------------
/tests/test_metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import tensorflow as tf
3 | from einops import rearrange
4 | from fusionlab.metrics import (
5 | DiceScore,
6 | IoUScore,
7 | )
8 | from pytest import approx
9 |
10 | EPS = 1e-6
11 |
12 | class Data:
13 | def __init__(self):
14 | self.pred = [[
15 | [1., 2., 3., 4.],
16 | [2., 6., 4., 4.],
17 | [9., 6., 3., 4.]
18 | ]]
19 | self.target = [[2, 1, 0, 2]]
20 |
21 |
22 | class BinaryData:
23 | def __init__(self):
24 | self.pred = [0.4, 0.2, 0.3, 0.5]
25 | self.target = [0, 1, 0, 1]
26 |
27 |
28 | class TestMetrics:
29 | def test_dice_score(self):
30 | # Multi class
31 | data = Data()
32 | true_scores = [0.27264893, 0.41188607, 0.65953231]
33 | true_score_mean = 0.44802245
34 | # PT
35 | pred = torch.tensor(data.pred).view(1, 3, 4)
36 | true = torch.tensor(data.target).view(1, 4)
37 | scores = DiceScore("multiclass")(pred, true)
38 | assert scores.mean() == approx(true_score_mean, EPS)
39 | assert scores == approx(true_scores, EPS)
40 |
41 | # TODO: Add binary class test
42 | data = BinaryData()
43 | true_loss = 0.46044695377349854
44 |
45 | def test_iou_score(self):
46 | # multi class
47 | true_loss = 0.6969285607337952
48 | true_scores = [0.15784223, 0.25935552, 0.49201655]
49 | true_score_mean = 0.30307144
50 | data = Data()
51 |
52 | # PT
53 | pred = torch.tensor(data.pred).view(1, 3, 4)
54 | true = torch.tensor(data.target).view(1, 4)
55 | scores = IoUScore("multiclass", from_logits=True)(pred, true)
56 | assert scores.mean() == approx(true_score_mean, EPS)
57 | assert scores == approx(true_scores, EPS)
58 |
59 | # TODO: Add binary class test
60 |
61 |
--------------------------------------------------------------------------------
/tests/test_seg.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def generate_inputs(img_size, spatial_dims):
4 | size = tuple([1, 3] + [img_size] * spatial_dims)
5 | return torch.randn(size)
6 |
7 | class TestSeg:
8 | def test_unet(self):
9 | from fusionlab.segmentation import UNet
10 | for i in range(1, 4):
11 | inputs = generate_inputs(64, i)
12 | model = UNet(3, 2, spatial_dims=i)
13 | outputs = model(inputs)
14 | assert outputs.shape == tuple([1, 2] + [64] * i)
15 |
16 | def test_resunet(self):
17 | from fusionlab.segmentation import ResUNet
18 | for i in range(1, 4):
19 | inputs = generate_inputs(64, i)
20 | model = ResUNet(3, 2, spatial_dims=i)
21 | outputs = model(inputs)
22 | assert outputs.shape == tuple([1, 2] + [64] * i)
23 |
24 | def test_unet2plus(self):
25 | from fusionlab.segmentation import UNet2plus
26 | for i in range(1, 4):
27 | inputs = generate_inputs(64, i)
28 | model = UNet2plus(3, 2, 16, spatial_dims=i)
29 | outputs = model(inputs)
30 | assert outputs.shape == tuple([1, 2] + [64] * i)
31 |
32 | def test_transunet(self):
33 | from fusionlab.segmentation import TransUNet
34 | inputs = generate_inputs(64, 2)
35 | model = TransUNet(
36 | in_channels=3,
37 | img_size=64,
38 | num_classes=2,
39 | )
40 | outputs = model(inputs)
41 | assert outputs.shape == tuple([1, 2] + [64] * 2)
42 |
43 | def test_unetr(self):
44 | from fusionlab.segmentation import UNETR
45 | for i in range(1, 4):
46 | inputs = generate_inputs(64, i)
47 | model = UNETR(3, 2, 64, spatial_dims=i)
48 | outputs = model(inputs)
49 | assert outputs.shape == tuple([1, 2] + [64] * i)
50 |
51 | def test_segformer(self):
52 | from fusionlab.segmentation import SegFormer
53 | inputs = torch.rand(1, 3, 64, 64)
54 | for i in range(6):
55 | mit_type = f'B{i}'
56 | model = SegFormer(num_classes=2, mit_encoder_type=mit_type)
57 | outputs = model(inputs)
58 | assert outputs.shape == (1, 2, 64, 64)
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | class TestPlotChannels:
2 | def test_plot_channels(self):
3 | from fusionlab.utils import plot_channels
4 | import numpy as np
5 | import matplotlib
6 | """Test that plot_channels() returns a figure."""
7 | signals = np.random.randn(500, 12)
8 | fig = plot_channels(signals, show=False)
9 | assert isinstance(fig, matplotlib.figure.Figure)
--------------------------------------------------------------------------------