├── .dockerignore
├── .github
└── workflows
│ ├── pre-commit.yml
│ └── unit-test.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .pylintrc
├── LICENSE
├── MELBA_Symposium_2024.pdf
├── Makefile
├── README.md
├── docker
├── Dockerfile
├── Dockerfile.tpu
├── environment.yml
├── environment_mac_m1.yml
├── environment_tpu.yml
└── requirements.txt
├── examples
└── segmentation
│ ├── BB_anon_348_1.png
│ ├── BB_anon_348_1_mask.png
│ ├── config.yaml
│ ├── files
│ └── ckpt
│ │ └── checkpoint_1900
│ │ └── checkpoint
│ └── inference.ipynb
├── images
├── diffusion_class_diagram.png
├── diffusion_training_strategy_diagram.png
└── melba_graphic_abstract.png
├── imgx
├── __init__.py
├── conf
│ ├── __init__.py
│ ├── config.yaml
│ ├── data
│ │ ├── amos_ct.yaml
│ │ ├── brats2021_mr.yaml
│ │ ├── male_pelvic_mr.yaml
│ │ └── muscle_us.yaml
│ └── task
│ │ ├── gaussian_diff_seg.yaml
│ │ └── seg.yaml
├── config.py
├── data
│ ├── __init__.py
│ ├── augmentation
│ │ ├── __init__.py
│ │ ├── affine.py
│ │ ├── affine_test.py
│ │ ├── flip.py
│ │ ├── flip_test.py
│ │ ├── intensity.py
│ │ ├── intensity_test.py
│ │ ├── patch.py
│ │ └── patch_test.py
│ ├── iterator.py
│ ├── iterator_test.py
│ ├── util.py
│ ├── util_test.py
│ ├── warp.py
│ └── warp_test.py
├── datasets
│ ├── README.md
│ ├── __init__.py
│ ├── amos_ct
│ │ ├── __init__.py
│ │ └── amos_ct_dataset_builder.py
│ ├── brats2021_mr
│ │ ├── __init__.py
│ │ └── brats2021_mr_dataset_builder.py
│ ├── constant.py
│ ├── dataset_info.py
│ ├── male_pelvic_mr
│ │ ├── __init__.py
│ │ └── male_pelvic_mr_dataset_builder.py
│ ├── muscle_us
│ │ ├── __init__.py
│ │ └── muscle_us_dataset_builder.py
│ ├── preprocess.py
│ ├── preprocess_test.py
│ ├── save.py
│ ├── tests
│ │ ├── conftest.py
│ │ ├── fixtures
│ │ │ ├── BB_anon_1789_3_mask_pred.png
│ │ │ ├── BB_anon_1789_3_mask_pred_postprocessed_0.5.png
│ │ │ ├── BB_anon_1789_3_mask_pred_postprocessed_0.75.png
│ │ │ ├── BB_anon_425_2_mask_pred.png
│ │ │ ├── BB_anon_425_2_mask_pred_postprocessed.png
│ │ │ ├── GM_anon_780_3_mask_pred.png
│ │ │ └── GM_anon_780_3_mask_pred_postprocessed.png
│ │ └── test_muscle_us.py
│ ├── util.py
│ └── util_test.py
├── device.py
├── diffusion
│ ├── README.md
│ ├── __init__.py
│ ├── diffusion.py
│ ├── gaussian
│ │ ├── __init__.py
│ │ ├── gaussian_diffusion.py
│ │ ├── gaussian_diffusion_test.py
│ │ ├── sampler.py
│ │ ├── variance_schedule.py
│ │ └── variance_schedule_test.py
│ ├── time_sampler.py
│ ├── time_sampler_test.py
│ ├── util.py
│ └── util_test.py
├── experiment.py
├── integration_test.py
├── loss
│ ├── __init__.py
│ ├── cross_entropy.py
│ ├── cross_entropy_test.py
│ ├── deformation.py
│ ├── deformation_test.py
│ ├── dice.py
│ ├── dice_test.py
│ ├── segmentation.py
│ ├── segmentation_test.py
│ └── similarity.py
├── metric
│ ├── __init__.py
│ ├── area.py
│ ├── area_test.py
│ ├── centroid.py
│ ├── centroid_test.py
│ ├── deformation.py
│ ├── deformation_test.py
│ ├── dice.py
│ ├── dice_test.py
│ ├── distribution.py
│ ├── segmentation.py
│ ├── segmentation_test.py
│ ├── similarity.py
│ ├── similarity_test.py
│ ├── smoothing.py
│ ├── smoothing_test.py
│ ├── surface_distance.py
│ ├── surface_distance_test.py
│ └── util.py
├── model
│ ├── __init__.py
│ ├── attention
│ │ ├── __init__.py
│ │ ├── efficient_attention.py
│ │ ├── efficient_attention_test.py
│ │ ├── transformer.py
│ │ └── transformer_test.py
│ ├── basic.py
│ ├── basic_test.py
│ ├── conv.py
│ ├── conv_test.py
│ ├── slice.py
│ ├── slice_test.py
│ └── unet
│ │ ├── __init__.py
│ │ ├── bottom_encoder.py
│ │ ├── bottom_encoder_test.py
│ │ ├── downsample_encoder.py
│ │ ├── unet.py
│ │ ├── unet_test.py
│ │ └── upsample_decoder.py
├── run_test.py
├── run_train.py
├── run_valid.py
├── task
│ ├── __init__.py
│ ├── diffusion_segmentation
│ │ ├── __init__.py
│ │ ├── diffusion.py
│ │ ├── diffusion_step.py
│ │ ├── experiment.py
│ │ ├── gaussian_diffusion.py
│ │ ├── gaussian_diffusion_test.py
│ │ ├── recycling_step.py
│ │ ├── save.py
│ │ ├── self_conditioning_step.py
│ │ └── train_state.py
│ ├── segmentation
│ │ ├── __init__.py
│ │ ├── experiment.py
│ │ └── save.py
│ └── util.py
└── train_state.py
└── pyproject.toml
/.dockerignore:
--------------------------------------------------------------------------------
1 | wandb/
2 | outputs/
3 | tensorflow_datasets/
4 | *.zip
5 |
--------------------------------------------------------------------------------
/.github/workflows/pre-commit.yml:
--------------------------------------------------------------------------------
1 | name: pre-commit
2 |
3 | on:
4 | pull_request:
5 | push:
6 | branches: [main]
7 |
8 | jobs:
9 | pre-commit:
10 | runs-on: ubuntu-latest
11 | strategy:
12 | matrix:
13 | python-version: ["3.9"]
14 | steps:
15 | - name: Checkout repository
16 | uses: actions/checkout@v4
17 | - name: Set up Python ${{ matrix.python-version }}
18 | uses: actions/setup-python@v5
19 | with:
20 | python-version: ${{ matrix.python-version }}
21 | check-latest: true
22 | - name: Install Wily
23 | run: pip install wily
24 | - name: Build cache and diff
25 | run: wily build .
26 | - name: Run pre-commit
27 | uses: pre-commit/action@v3.0.0
28 |
--------------------------------------------------------------------------------
/.github/workflows/unit-test.yml:
--------------------------------------------------------------------------------
1 | name: unit-test
2 |
3 | on:
4 | pull_request:
5 | push:
6 | branches: [main]
7 |
8 | jobs:
9 | build:
10 | runs-on: ubuntu-20.04
11 | timeout-minutes: 30
12 | strategy:
13 | matrix:
14 | group: [1, 2, 3, 4]
15 | python-version: ["3.9"]
16 |
17 | steps:
18 | - name: Checkout repository
19 | uses: actions/checkout@v4
20 | - name: Set up Python ${{ matrix.python-version }}
21 | uses: actions/setup-python@v5
22 | with:
23 | check-latest: true
24 | python-version: ${{ matrix.python-version }}
25 | cache: "pip"
26 | cache-dependency-path: |
27 | docker/environment.yml
28 | docker/environment_mac_m1.yml
29 | docker/Dockerfile
30 | docker/requirements.txt
31 | pyproject.toml
32 | - name: Install dependencies
33 | run: |
34 | pip install tensorflow-cpu==2.12.0
35 | pip install jax==0.4.20
36 | pip install jaxlib==0.4.20
37 | pip install -r docker/requirements.txt
38 | pip install -e .
39 | - name: Test with pytest
40 | run: |
41 | pytest --splits 4 --group ${{ matrix.group }} --randomly-seed=0 -k "not slow and not integration"
42 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage*
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # pycharm
132 | .idea/
133 |
134 | # Mac
135 | .DS_Store
136 |
137 | # notebook
138 | *.ipynb
139 | notebooks/
140 | !examples/segmentation/inference.ipynb
141 |
142 | # hydra outputs
143 | outputs/
144 |
145 | # wandb outputs
146 | wandb/
147 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | default_language_version:
2 | python: python3
3 | repos:
4 | - repo: https://github.com/pre-commit/pre-commit-hooks
5 | rev: v4.5.0
6 | hooks:
7 | - id: check-added-large-files
8 | args: ["--maxkb=15000"]
9 | - id: check-ast
10 | - id: check-byte-order-marker
11 | - id: check-builtin-literals
12 | - id: check-case-conflict
13 | - id: check-docstring-first
14 | - id: check-merge-conflict
15 | - id: check-symlinks
16 | - id: check-toml
17 | - id: check-yaml
18 | - id: debug-statements
19 | - id: destroyed-symlinks
20 | - id: end-of-file-fixer
21 | - id: fix-byte-order-marker
22 | - id: mixed-line-ending
23 | - id: file-contents-sorter
24 | files: "docker/requirements.txt"
25 | - id: trailing-whitespace
26 | - repo: https://github.com/PyCQA/isort
27 | rev: 5.13.2
28 | hooks:
29 | - id: isort
30 | - repo: https://github.com/psf/black
31 | rev: 23.12.1
32 | hooks:
33 | - id: black
34 | args:
35 | - --line-length=100
36 | - repo: https://github.com/pre-commit/mirrors-mypy
37 | rev: v1.8.0
38 | hooks: # https://github.com/python/mypy/issues/4008#issuecomment-582458665
39 | - id: mypy
40 | name: mypy
41 | pass_filenames: false
42 | args:
43 | [
44 | --strict-equality,
45 | --disallow-untyped-calls,
46 | --disallow-untyped-defs,
47 | --disallow-incomplete-defs,
48 | --disallow-any-generics,
49 | --check-untyped-defs,
50 | --disallow-untyped-decorators,
51 | --warn-redundant-casts,
52 | --warn-unused-ignores,
53 | --no-warn-no-return,
54 | --warn-unreachable,
55 | ]
56 | - repo: https://github.com/pre-commit/mirrors-prettier
57 | rev: v4.0.0-alpha.8
58 | hooks:
59 | - id: prettier
60 | args:
61 | - --print-width=100
62 | - --prose-wrap=always
63 | - --tab-width=2
64 | - repo: https://github.com/charliermarsh/ruff-pre-commit
65 | rev: "v0.1.9"
66 | hooks:
67 | - id: ruff
68 | - repo: https://github.com/pre-commit/mirrors-pylint
69 | rev: v3.0.0a5
70 | hooks:
71 | - id: pylint
72 | - repo: https://github.com/asottile/pyupgrade
73 | rev: v3.15.0
74 | hooks:
75 | - id: pyupgrade
76 | args:
77 | - --py39-plus
78 | - repo: https://github.com/pycqa/pydocstyle
79 | rev: 6.3.0
80 | hooks:
81 | - id: pydocstyle
82 | args:
83 | - --convention=google
84 | - repo: local
85 | hooks:
86 | - id: wily
87 | name: wily
88 | entry: wily diff
89 | verbose: false
90 | language: python
91 | additional_dependencies: [wily]
92 |
--------------------------------------------------------------------------------
/MELBA_Symposium_2024.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mathpluscode/ImgX-DiffSeg/ce57729e2dcd1d21960ec797bdff2ef5df7e5101/MELBA_Symposium_2024.pdf
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | pip:
2 | pip install -e .
3 |
4 | test:
5 | pytest --cov=imgx -n 4 imgx
6 |
7 | build_dataset:
8 | tfds build imgx/datasets/male_pelvic_mr
9 | tfds build imgx/datasets/amos_ct
10 | tfds build imgx/datasets/muscle_us
11 | tfds build imgx/datasets/brats2021_mr
12 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/deepmind/alphafold/blob/main/docker/Dockerfile
2 | # If changing CUDA/CUDNN versions, also update the corresponding versions in
3 | # the second last command
4 | ARG CUDA=11.8.0
5 | ARG CUDNN=8.6.0
6 | FROM nvidia/cuda:${CUDA}-cudnn8-runtime-ubuntu20.04
7 |
8 | # FROM directive resets ARGS, so we specify again (the value is retained if
9 | # previously set).
10 | ARG CUDA
11 | ARG CUDNN
12 |
13 | ARG HOST_UID
14 | ARG HOST_GID
15 |
16 | ENV USER=app
17 |
18 | # Ensure ARGs are sets
19 | RUN test -n "$HOST_UID" && test -n "$HOST_GID"
20 |
21 | # Use bash to support string substitution.
22 | SHELL ["/bin/bash", "-c"]
23 |
24 | # Create group and user, add -f to skip the command without error if it exists already
25 | RUN groupadd --force --gid $HOST_GID $USER && \
26 | useradd -r -m --uid $HOST_UID --gid $HOST_GID $USER
27 |
28 | RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
29 | vim \
30 | unzip \
31 | build-essential \
32 | cmake \
33 | python3-opencv \
34 | cuda-command-line-tools-$(cut -f1,2 -d- <<< ${CUDA//./-}) \
35 | git \
36 | tzdata \
37 | wget \
38 | make \
39 | && rm -rf /var/lib/apt/lists/*
40 |
41 | # Add SETUID bit to the ldconfig binary so that non-root users can run it.
42 | RUN chmod u+s /sbin/ldconfig.real
43 |
44 | # We need to run `ldconfig` first to ensure GPUs are visible, due to some quirk
45 | # with Debian. See https://github.com/NVIDIA/nvidia-docker/issues/1399 for
46 | # details.
47 | RUN echo $'#!/bin/bash\nldconfig'
48 |
49 | RUN mkdir -p /${USER}/tmp
50 | RUN mkdir -p /${USER}/ImgX
51 | RUN mkdir -p /${USER}/tensorflow_datasets
52 | RUN chgrp -R ${USER} /${USER} && \
53 | chmod -R g+rwx /${USER} && \
54 | chown -R ${USER} /${USER}
55 |
56 | USER ${USER}
57 |
58 | # Install Miniconda package manager.
59 | RUN wget -q -P /${USER}/tmp https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
60 | RUN bash /${USER}/tmp/Miniconda3-latest-Linux-x86_64.sh -b -p /${USER}/conda
61 | RUN rm /${USER}/tmp/Miniconda3-latest-Linux-x86_64.sh
62 |
63 | # https://anaconda.org/nvidia/cuda-toolkit
64 | ENV PATH="/${USER}/conda/bin:$PATH"
65 | RUN conda update -qy conda \
66 | && conda install -y -n base conda-libmamba-solver \
67 | && conda config --set solver libmamba \
68 | && conda install -y --channel "nvidia/label/cuda-${CUDA}" cuda-toolkit \
69 | && conda install -y -c conda-forge \
70 | pip \
71 | python=3.9
72 |
73 | # Install pip packages.
74 | COPY docker/requirements.txt /${USER}/requirements.txt
75 |
76 | RUN /${USER}/conda/bin/pip3 install --upgrade pip \
77 | && /${USER}/conda/bin/pip3 install \
78 | jax==0.4.20 \
79 | jaxlib==0.4.20+cuda11.cudnn86 \
80 | -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \
81 | && /${USER}/conda/bin/pip3 install tensorflow-cpu==2.14.0 \
82 | && /${USER}/conda/bin/pip3 install -r /${USER}/requirements.txt
83 |
84 | RUN git config --global --add safe.directory /${USER}/ImgX
85 |
86 | WORKDIR /${USER}/ImgX
87 |
--------------------------------------------------------------------------------
/docker/Dockerfile.tpu:
--------------------------------------------------------------------------------
1 | FROM mambaorg/micromamba:1.5.1 as conda
2 |
3 | # Speed up the build, and avoid unnecessary writes to disk
4 | ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 PYTHONDONTWRITEBYTECODE=1 PYTHONUNBUFFERED=1
5 | ENV PIPENV_VENV_IN_PROJECT=true PIP_NO_CACHE_DIR=false PIP_DISABLE_PIP_VERSION_CHECK=1
6 |
7 | COPY docker/environment_tpu.yml /tmp/environment_tpu.yml
8 | COPY docker/requirements.txt /tmp/requirements.txt
9 |
10 | RUN micromamba create -y --file /tmp/environment_tpu.yml \
11 | && micromamba clean --all --yes \
12 | && find /opt/conda/ -follow -type f -name '*.pyc' -delete
13 |
14 | FROM debian:bullseye-slim as test-image
15 |
16 | COPY --from=conda /opt/conda/envs/. /opt/conda/envs/
17 | ENV PATH=/opt/conda/envs/imgx/bin/:$PATH APP_FOLDER=/app
18 | ENV PYTHONPATH=$APP_FOLDER:$PYTHONPATH
19 | ENV TZ=Europe/London
20 | RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
21 |
22 | WORKDIR $APP_FOLDER
23 |
24 | ARG USER_ID=1000
25 | ARG GROUP_ID=1000
26 | ENV USER=app
27 | ENV GROUP=app
28 |
29 | USER root
30 | RUN apt-get update && apt-get install -y git vim unzip apt-transport-https ca-certificates gnupg curl make python3-opencv
31 | RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list
32 | RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add -
33 | RUN apt-get update && apt-get install google-cloud-cli
34 | RUN git config --global --add safe.directory /${USER}/ImgX
35 |
36 | ENV TF_CUDNN_DETERMINISTIC=1
37 |
38 | FROM test-image as run-image
39 | # The run-image (default) is the same as the dev-image with the some files directly
40 | # copied inside
41 |
42 | WORKDIR /${USER}/ImgX
43 |
--------------------------------------------------------------------------------
/docker/environment.yml:
--------------------------------------------------------------------------------
1 | name: imgx
2 | channels:
3 | - defaults
4 | dependencies:
5 | - python=3.9
6 | - pip=23.3.1
7 | - pip:
8 | - tensorflow-cpu==2.14.0
9 | - jax==0.4.20
10 | - jaxlib==0.4.20
11 | - -r requirements.txt
12 |
--------------------------------------------------------------------------------
/docker/environment_mac_m1.yml:
--------------------------------------------------------------------------------
1 | name: imgx
2 | channels:
3 | - defaults
4 | dependencies:
5 | - python=3.9
6 | - pip=23.3.1
7 | - pip:
8 | - tensorflow-macos==2.14.0
9 | - tensorflow-metal==1.1.0
10 | - jax==0.4.20
11 | - jaxlib==0.4.20
12 | - -r requirements.txt
13 |
--------------------------------------------------------------------------------
/docker/environment_tpu.yml:
--------------------------------------------------------------------------------
1 | name: imgx
2 | channels:
3 | - defaults
4 | - conda-forge
5 | dependencies:
6 | - python=3.9
7 | - pip=23.3.1
8 | - pip:
9 | - --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html
10 | - tensorflow-cpu==2.14.0
11 | - jax[tpu]==0.4.20
12 | - jaxlib==0.4.20
13 | - -r requirements.txt
14 |
--------------------------------------------------------------------------------
/docker/requirements.txt:
--------------------------------------------------------------------------------
1 | SimpleITK==2.3.1
2 | chex==0.1.8
3 | coverage==7.3.3
4 | flax==0.7.5
5 | hydra-core==1.3.2
6 | kaggle==1.5.16
7 | matplotlib==3.8.2
8 | nbmake==1.4.6
9 | numpy==1.26.2
10 | opencv-python==4.8.1.78
11 | optax==0.1.7
12 | pandas==2.1.4
13 | pre-commit==3.6.0
14 | protobuf==3.20.3 # https://github.com/tensorflow/datasets/issues/4858
15 | pytest-cov==4.1.0
16 | pytest-mock==3.12.0
17 | pytest-randomly==3.15.0
18 | pytest-split==0.8.1
19 | pytest-xdist==3.5.0
20 | pytest==7.4.3
21 | rdkit-pypi==2022.9.5
22 | rich==13.7.0
23 | ruff==0.1.8
24 | tensorflow-datasets==4.9.3
25 | tomli==2.0.1
26 | torch==2.1.2 # for testing only
27 | wandb==0.16.1
28 | wily==1.25.0
29 |
--------------------------------------------------------------------------------
/examples/segmentation/BB_anon_348_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mathpluscode/ImgX-DiffSeg/ce57729e2dcd1d21960ec797bdff2ef5df7e5101/examples/segmentation/BB_anon_348_1.png
--------------------------------------------------------------------------------
/examples/segmentation/BB_anon_348_1_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mathpluscode/ImgX-DiffSeg/ce57729e2dcd1d21960ec797bdff2ef5df7e5101/examples/segmentation/BB_anon_348_1_mask.png
--------------------------------------------------------------------------------
/examples/segmentation/config.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | name: muscle_us
3 | loader:
4 | max_num_samples_per_split: -1
5 | patch_shape:
6 | - 480
7 | - 512
8 | patch_overlap:
9 | - 0
10 | - 0
11 | data_augmentation:
12 | max_rotation: 30
13 | max_zoom: 0.2
14 | max_shear: 30
15 | max_shift: 0.3
16 | max_log_gamma: 0.3
17 | v_min: 0.0
18 | v_max: 1.0
19 | p: 0.5
20 | trainer:
21 | max_num_samples: 512000
22 | batch_size: 64
23 | batch_size_per_replica: 8
24 | num_devices_per_replica: 1
25 | patch_size:
26 | - 2
27 | - 2
28 | scale_factor:
29 | - 2
30 | - 2
31 | task:
32 | name: segmentation
33 | model:
34 | _target_: imgx.model.Unet
35 | remat: true
36 | num_spatial_dims: 2
37 | patch_size:
38 | - 2
39 | - 2
40 | scale_factor:
41 | - 2
42 | - 2
43 | num_res_blocks: 2
44 | num_channels:
45 | - 8
46 | - 16
47 | - 32
48 | - 64
49 | out_channels: 2
50 | num_heads: 8
51 | widening_factor: 4
52 | num_transform_layers: 1
53 | dropout: 0.1
54 | loss:
55 | dice: 1.0
56 | cross_entropy: 0.0
57 | focal: 1.0
58 | early_stopping:
59 | metric: mean_binary_dice_score_without_background
60 | mode: max
61 | min_delta: 0.0001
62 | patience: 10
63 | debug: false
64 | seed: 0
65 | half_precision: true
66 | optimizer:
67 | name: adamw
68 | kwargs:
69 | b1: 0.9
70 | b2: 0.999
71 | weight_decay: 1.0e-08
72 | grad_norm: 1.0
73 | lr_schedule:
74 | warmup_steps: 100
75 | decay_steps: 10000
76 | init_value: 1.0e-05
77 | peak_value: 0.0008
78 | end_value: 5.0e-05
79 | logging:
80 | root_dir: null
81 | log_freq: 10
82 | save_freq: 100
83 | wandb:
84 | project: imgx
85 | entity: entity
86 |
--------------------------------------------------------------------------------
/examples/segmentation/files/ckpt/checkpoint_1900/checkpoint:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mathpluscode/ImgX-DiffSeg/ce57729e2dcd1d21960ec797bdff2ef5df7e5101/examples/segmentation/files/ckpt/checkpoint_1900/checkpoint
--------------------------------------------------------------------------------
/images/diffusion_class_diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mathpluscode/ImgX-DiffSeg/ce57729e2dcd1d21960ec797bdff2ef5df7e5101/images/diffusion_class_diagram.png
--------------------------------------------------------------------------------
/images/diffusion_training_strategy_diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mathpluscode/ImgX-DiffSeg/ce57729e2dcd1d21960ec797bdff2ef5df7e5101/images/diffusion_training_strategy_diagram.png
--------------------------------------------------------------------------------
/images/melba_graphic_abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mathpluscode/ImgX-DiffSeg/ce57729e2dcd1d21960ec797bdff2ef5df7e5101/images/melba_graphic_abstract.png
--------------------------------------------------------------------------------
/imgx/__init__.py:
--------------------------------------------------------------------------------
1 | """A Jax-based DL toolkit for biomedical and bioinformatics applications."""
2 | EPS = 1.0e-5 # machine error
3 |
4 | # jax device
5 | # one model can be stored across multiple shards/slices
6 | # given 8 devices, it can be grouped into 4x2
7 | # if num_devices_per_replica = 2, then one model is stored across 2 devices
8 | # so the replica_axis would be of size 4
9 | SHARD_AXIS = "shard_axis"
10 | REPLICA_AXIS = "replica_axis"
11 |
--------------------------------------------------------------------------------
/imgx/conf/__init__.py:
--------------------------------------------------------------------------------
1 | """Package for config files."""
2 |
--------------------------------------------------------------------------------
/imgx/conf/config.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - data: muscle_us
3 | - task: gaussian_diff_seg
4 | # config below overwrites the values above
5 | # https://hydra.cc/docs/1.2/upgrades/1.0_to_1.1/default_composition_order/
6 | - _self_
7 |
8 | debug: False
9 | seed: 0
10 | half_precision: True
11 |
12 | optimizer:
13 | name: "adamw"
14 | kwargs:
15 | b1: 0.9
16 | b2: 0.999
17 | weight_decay: 1e-08
18 | grad_norm: 1.0
19 | lr_schedule:
20 | warmup_steps: 100
21 | decay_steps: 10_000
22 | init_value: 1e-05
23 | peak_value: 8e-04
24 | end_value: 5e-05
25 |
26 | logging:
27 | root_dir:
28 | log_freq: 10
29 | save_freq: 500
30 | wandb:
31 | project: imgx
32 | entity: wandb_entity
33 |
--------------------------------------------------------------------------------
/imgx/conf/data/amos_ct.yaml:
--------------------------------------------------------------------------------
1 | name: amos_ct
2 |
3 | loader:
4 | max_num_samples_per_split: -1
5 | patch_shape: [128, 128, 128]
6 | patch_overlap: [64, 0, 0] # image shape is [192, 128, 128]
7 | data_augmentation:
8 | max_rotation: 30 # degrees
9 | max_zoom: 0.2 # as a fraction of the image size
10 | max_shear: 30 # degrees
11 | max_shift: 0.3 # as a fraction of the image size
12 | max_log_gamma: 0.3
13 | v_min: 0.0 # minimum value for intensity
14 | v_max: 1.0 # maximum value for intensity
15 | p: 0.5 # probability of applying each augmentation
16 |
17 | trainer:
18 | max_num_samples: 100_000
19 | batch_size: 8 # all model replicas are updated every `batch_size` samples
20 | batch_size_per_replica: 1 # each model replicate takes `batch_size_per_replica` samples per step
21 | num_devices_per_replica: 1 # model is split into num_devices_per_replica shards/slices
22 |
23 | patch_size: [2, 2, 2]
24 | scale_factor: [2, 2, 2]
25 |
--------------------------------------------------------------------------------
/imgx/conf/data/brats2021_mr.yaml:
--------------------------------------------------------------------------------
1 | name: brats2021_mr
2 |
3 | loader:
4 | max_num_samples_per_split: -1
5 | patch_shape: [128, 128, 128]
6 | patch_overlap: [0, 0, 32] # image shape is [179, 219, 155]
7 | data_augmentation:
8 | max_rotation: 30 # degrees
9 | max_zoom: 0.2 # as a fraction of the image size
10 | max_shear: 30 # degrees
11 | max_shift: 0.3 # as a fraction of the image size
12 | max_log_gamma: 0.3
13 | v_min: 0.0 # minimum value for intensity
14 | v_max: 1.0 # maximum value for intensity
15 | p: 0.5 # probability of applying each augmentation
16 |
17 | trainer:
18 | max_num_samples: 100_000
19 | batch_size: 8 # all model replicas are updated every `batch_size` samples
20 | batch_size_per_replica: 1 # each model replicate takes `batch_size_per_replica` samples per step
21 | num_devices_per_replica: 1 # model is split into num_devices_per_replica shards/slices
22 |
23 | patch_size: [2, 2, 2]
24 | scale_factor: [2, 2, 2]
25 |
--------------------------------------------------------------------------------
/imgx/conf/data/male_pelvic_mr.yaml:
--------------------------------------------------------------------------------
1 | name: male_pelvic_mr
2 |
3 | loader:
4 | max_num_samples_per_split: -1
5 | patch_shape: [256, 256, 32]
6 | patch_overlap: [0, 0, 16] # image shape is [256, 256, 48]
7 | data_augmentation:
8 | max_rotation: 30 # degrees
9 | max_zoom: 0.2 # as a fraction of the image size
10 | max_shear: 30 # degrees
11 | max_shift: 0.3 # as a fraction of the image size
12 | max_log_gamma: 0.3
13 | v_min: 0.0 # minimum value for intensity
14 | v_max: 1.0 # maximum value for intensity
15 | p: 0.5 # probability of applying each augmentation
16 |
17 | trainer:
18 | max_num_samples: 100_000
19 | batch_size: 8 # all model replicas are updated every `batch_size` samples
20 | batch_size_per_replica: 1 # each model replicate takes `batch_size_per_replica` samples per step
21 | num_devices_per_replica: 1 # model is split into num_devices_per_replica shards/slices
22 |
23 | patch_size: [2, 2, 1] # do not downsample z axis
24 | scale_factor: [2, 2, 2]
25 |
--------------------------------------------------------------------------------
/imgx/conf/data/muscle_us.yaml:
--------------------------------------------------------------------------------
1 | name: muscle_us
2 |
3 | loader:
4 | max_num_samples_per_split: -1
5 | patch_shape: [480, 512]
6 | patch_overlap: [0, 0]
7 | data_augmentation:
8 | max_rotation: 30 # degrees
9 | max_zoom: 0.2 # as a fraction of the image size
10 | max_shear: 30 # degrees
11 | max_shift: 0.3 # as a fraction of the image size
12 | max_log_gamma: 0.3
13 | v_min: 0.0 # minimum value for intensity
14 | v_max: 1.0 # maximum value for intensity
15 | p: 0.5 # probability of applying each augmentation
16 |
17 | trainer:
18 | max_num_samples: 512_000
19 | batch_size: 64 # all model replicas are updated every `batch_size` samples
20 | batch_size_per_replica: 8 # each model replicate takes `batch_size_per_replica` samples per step
21 | num_devices_per_replica: 1 # model is split into num_devices_per_replica shards/slices
22 |
23 | patch_size: [2, 2]
24 | scale_factor: [2, 2]
25 |
--------------------------------------------------------------------------------
/imgx/conf/task/gaussian_diff_seg.yaml:
--------------------------------------------------------------------------------
1 | name: "diffusion_segmentation"
2 |
3 | recycling:
4 | use: True
5 | # max means the previous step is num_timesteps - 1
6 | # next means the previous step is min(t+1, num_timesteps - 1)
7 | prev_step: "max" # max or next
8 | reverse_step: False
9 |
10 | self_conditioning:
11 | use: False
12 | probability: 0.5
13 | prev_step: "next" # same or next
14 |
15 | uniform_time_sampling: False # False for importance sampling
16 |
17 | diffusion:
18 | num_timesteps: 1001
19 | num_timesteps_beta: 1001
20 | beta_schedule: "linear" # linear, quadratic, cosine, warmup10, warmup50
21 | beta_start: 0.0001
22 | beta_end: 0.02
23 | model_out_type: "x_start" # x_start, noise
24 | model_var_type: "fixed_small" # fixed_small, fixed_large, learned, learned_range
25 |
26 | sampler:
27 | name: "DDPM" # DDPM, DDIM
28 | num_inference_timesteps: 5
29 |
30 | model:
31 | _target_: imgx.model.Unet
32 | remat: True
33 | num_spatial_dims: 3
34 | patch_size: MISSING # data dependent, will be set after loading config
35 | scale_factor: MISSING # data dependent, will be set after loading config
36 | num_res_blocks: 2
37 | num_channels: [32, 64, 128, 256]
38 | out_channels: MISSING # data dependent, will be set after loading config
39 | num_heads: 8
40 | widening_factor: 4
41 | dropout: 0.1
42 |
43 | loss:
44 | dice: 1.0
45 | cross_entropy: 0.0
46 | focal: 1.0
47 | mse: 0.0
48 | vlb: 0.0
49 |
50 | early_stopping: # used on validation set
51 | metric: "mean_binary_dice_score_without_background"
52 | mode: "max"
53 | min_delta: 0.0001
54 | patience: 10
55 |
--------------------------------------------------------------------------------
/imgx/conf/task/seg.yaml:
--------------------------------------------------------------------------------
1 | name: "segmentation"
2 |
3 | model:
4 | _target_: imgx.model.Unet
5 | remat: True
6 | num_spatial_dims: 3
7 | patch_size: MISSING # data dependent, will be set after loading config
8 | scale_factor: MISSING # data dependent, will be set after loading config
9 | num_res_blocks: 2
10 | num_channels: [32, 64, 128, 256]
11 | out_channels: MISSING # data dependent, will be set after loading config
12 | num_heads: 8
13 | widening_factor: 4
14 | num_transform_layers: 1
15 | dropout: 0.1
16 |
17 | loss:
18 | dice: 1.0
19 | cross_entropy: 0.0
20 | focal: 1.0
21 |
22 | early_stopping: # used on validation set
23 | metric: "mean_binary_dice_score_without_background"
24 | mode: "max"
25 | min_delta: 0.0001
26 | patience: 10
27 |
--------------------------------------------------------------------------------
/imgx/config.py:
--------------------------------------------------------------------------------
1 | """Module for configuration related functions."""
2 |
3 |
4 | def flatten_dict(d: dict, parent_key: str = "", sep: str = "_") -> dict: # type:ignore[type-arg]
5 | """Flat a nested dict.
6 |
7 | Args:
8 | d: dict to flat.
9 | parent_key: key of the parent.
10 | sep: separation string.
11 |
12 | Returns:
13 | flatten dict.
14 | """
15 | items = {}
16 | for k, v in d.items():
17 | new_key = parent_key + sep + k if parent_key else k
18 | if isinstance(v, dict):
19 | items.update(flatten_dict(d=v, parent_key=new_key, sep=sep))
20 | else:
21 | items[new_key] = v
22 | return dict(items)
23 |
--------------------------------------------------------------------------------
/imgx/data/__init__.py:
--------------------------------------------------------------------------------
1 | """Module to handle data."""
2 |
--------------------------------------------------------------------------------
/imgx/data/augmentation/__init__.py:
--------------------------------------------------------------------------------
1 | """Data augmentation module."""
2 | from __future__ import annotations
3 |
4 | from collections.abc import Sequence
5 | from typing import Callable
6 |
7 | import jax
8 | from jax import numpy as jnp
9 |
10 | AugmentationFn = Callable[[jax.Array, dict[str, jnp.ndarray]], dict[str, jnp.ndarray]]
11 |
12 |
13 | def chain_aug_fns(
14 | fns: Sequence[AugmentationFn],
15 | ) -> AugmentationFn:
16 | """Combine a list of data augmentation functions.
17 |
18 | Args:
19 | fns: entire config.
20 |
21 | Returns:
22 | A data augmentation function.
23 | """
24 |
25 | def aug_fn(key: jax.Array, batch: dict[str, jnp.ndarray]) -> dict[str, jnp.ndarray]:
26 | keys = jax.random.split(key, num=len(fns))
27 | for k, fn in zip(keys, fns):
28 | batch = fn(k, batch)
29 | return batch
30 |
31 | return aug_fn
32 |
--------------------------------------------------------------------------------
/imgx/data/augmentation/flip.py:
--------------------------------------------------------------------------------
1 | """Flip augmentation for image and label."""
2 | from __future__ import annotations
3 |
4 | from functools import partial
5 |
6 | import jax
7 | import jax.numpy as jnp
8 | from jax import lax
9 | from omegaconf import DictConfig
10 |
11 | from imgx.data.augmentation import AugmentationFn
12 | from imgx.data.util import get_batch_size
13 | from imgx.datasets import INFO_MAP
14 | from imgx.datasets.constant import FOREGROUND_RANGE, IMAGE, LABEL
15 |
16 |
17 | def random_flip(x: jnp.ndarray, to_flip: jnp.ndarray) -> jnp.ndarray:
18 | """Flip an array along an axis.
19 |
20 | Args:
21 | x: of shape (d1, ..., dn) or (d1, ..., dn, c)
22 | to_flip: (n, ), with boolean values, True means to flip along that axis.
23 |
24 | Returns:
25 | Flipped array.
26 | """
27 | for i in range(to_flip.size):
28 | x = lax.select(
29 | to_flip[i],
30 | jnp.flip(x, axis=i),
31 | x,
32 | )
33 | return x
34 |
35 |
36 | def batch_random_flip(
37 | key: jax.Array, batch: dict[str, jnp.ndarray], num_spatial_dims: int, p: float
38 | ) -> dict[str, jnp.ndarray]:
39 | """Flip an array along an axis.
40 |
41 | Args:
42 | key: jax random key.
43 | batch: dict having images or labels, or foreground_range.
44 | images have shape (batch, d1, ..., dn) or (batch, d1, ..., dn, c)
45 | labels have shape (batch, d1, ..., dn)
46 | batch should not have other keys such as UID.
47 | if foreground_range exists, it's pre-calculated based on label, it's
48 | pre-calculated because nonzero function is not jittable.
49 | num_spatial_dims: number of spatial dimensions.
50 | p: probability to flip for each axis.
51 |
52 | Returns:
53 | Flipped batch.
54 | """
55 | batch_size = get_batch_size(batch)
56 | to_flip = jax.random.uniform(key=key, shape=(batch_size, num_spatial_dims)) < p
57 |
58 | random_flip_vmap = jax.vmap(random_flip)
59 | flipped_batch = {}
60 | for k, v in batch.items():
61 | if (LABEL in k) or (IMAGE in k):
62 | flipped_batch[k] = random_flip_vmap(x=v, to_flip=to_flip)
63 | elif k == FOREGROUND_RANGE:
64 | flipped_batch[k] = v
65 | else:
66 | raise ValueError(f"Unknown key {k} in batch.")
67 | return flipped_batch
68 |
69 |
70 | def get_random_flip_augmentation_fn(config: DictConfig) -> AugmentationFn:
71 | """Return a data augmentation function for random flip.
72 |
73 | Args:
74 | config: entire config.
75 |
76 | Returns:
77 | A data augmentation function.
78 | """
79 | dataset_info = INFO_MAP[config.data.name]
80 | da_config = config.data.loader.data_augmentation
81 | return partial(
82 | batch_random_flip,
83 | num_spatial_dims=len(dataset_info.image_spatial_shape),
84 | p=da_config.p,
85 | )
86 |
--------------------------------------------------------------------------------
/imgx/data/augmentation/flip_test.py:
--------------------------------------------------------------------------------
1 | """Test the flip functions."""
2 |
3 |
4 | from functools import partial
5 |
6 | import chex
7 | import jax
8 | import numpy as np
9 | from absl.testing import parameterized
10 | from chex._src import fake
11 |
12 | from imgx.data.augmentation.flip import batch_random_flip, random_flip
13 | from imgx.datasets.constant import IMAGE
14 |
15 |
16 | # Set `FLAGS.chex_n_cpu_devices` CPU devices for all tests.
17 | def setUpModule() -> None: # pylint: disable=invalid-name
18 | """Fake multi-devices."""
19 | fake.set_n_cpu_devices(2)
20 |
21 |
22 | class TestRandomFlip(chex.TestCase):
23 | """Test random_flip."""
24 |
25 | @chex.all_variants()
26 | @parameterized.named_parameters(
27 | ("1d - flip", np.array([1, 2, 3, 4]), np.array([True]), np.array([4, 3, 2, 1])),
28 | ("1d - no flip", np.array([1, 2, 3, 4]), np.array([False]), np.array([1, 2, 3, 4])),
29 | (
30 | "2d - no flip",
31 | np.array([[1, 2, 3, 4], [5, 6, 7, 8]]),
32 | np.array([False, False]),
33 | np.array([[1, 2, 3, 4], [5, 6, 7, 8]]),
34 | ),
35 | (
36 | "2d - flip the first axis",
37 | np.array([[1, 2, 3, 4], [5, 6, 7, 8]]),
38 | np.array([True, False]),
39 | np.array([[5, 6, 7, 8], [1, 2, 3, 4]]),
40 | ),
41 | (
42 | "2d - flip the second axis",
43 | np.array([[1, 2, 3, 4], [5, 6, 7, 8]]),
44 | np.array([False, True]),
45 | np.array([[4, 3, 2, 1], [8, 7, 6, 5]]),
46 | ),
47 | (
48 | "2d - flip both axes",
49 | np.array([[1, 2, 3, 4], [5, 6, 7, 8]]),
50 | np.array([True, True]),
51 | np.array([[8, 7, 6, 5], [4, 3, 2, 1]]),
52 | ),
53 | )
54 | def test_values(
55 | self,
56 | x: np.ndarray,
57 | to_flip: np.ndarray,
58 | expected: np.ndarray,
59 | ) -> None:
60 | """Test output shapes.
61 |
62 | Args:
63 | x: input array.
64 | to_flip: (n, ), with boolean values, True means to flip along that axis.
65 | expected: expected output array.
66 | """
67 | got = self.variant(random_flip)(
68 | x=x,
69 | to_flip=to_flip,
70 | )
71 |
72 | chex.assert_trees_all_equal(got, expected)
73 |
74 | @chex.all_variants()
75 | @parameterized.product(
76 | image_shape=[(3, 4, 5, 6), (3, 4, 5), (3, 4), (3,)],
77 | )
78 | def test_shapes(
79 | self,
80 | image_shape: tuple[int, ...],
81 | ) -> None:
82 | """Test output shapes.
83 |
84 | Args:
85 | image_shape: image spatial shape.
86 | """
87 | key = jax.random.PRNGKey(0)
88 | key_flip, key_image = jax.random.split(key, 2)
89 | image = jax.random.uniform(key=key_image, shape=image_shape, minval=0, maxval=1)
90 | to_flip = jax.random.uniform(key=key_flip, shape=(len(image_shape),)) < 1 / 2
91 | got = self.variant(random_flip)(
92 | x=image,
93 | to_flip=to_flip,
94 | )
95 |
96 | chex.assert_shape(got, image_shape)
97 |
98 |
99 | class TestBatchRandomFlip(chex.TestCase):
100 | """Test batch_random_flip."""
101 |
102 | batch_size = 2
103 |
104 | @chex.all_variants()
105 | @parameterized.product(
106 | image_shape=[(3, 4, 5, 6), (3, 4, 5), (3, 4), (3,)],
107 | p=[0.0, 0.5, 1.0],
108 | )
109 | def test_shapes(
110 | self,
111 | image_shape: tuple[int, ...],
112 | p: float,
113 | ) -> None:
114 | """Test output shapes.
115 |
116 | Args:
117 | image_shape: image spatial shape.
118 | p: probability to flip for each axis.
119 | """
120 | key = jax.random.PRNGKey(0)
121 | key_image, key = jax.random.split(key)
122 | num_spatial_dims = len(image_shape)
123 | image = jax.random.uniform(
124 | key=key_image, shape=(self.batch_size, *image_shape), minval=0, maxval=1
125 | )
126 | batch = {IMAGE: image}
127 | got = self.variant(partial(batch_random_flip, num_spatial_dims=num_spatial_dims, p=p))(
128 | key, batch
129 | )
130 |
131 | chex.assert_shape(got[IMAGE], (self.batch_size, *image_shape))
132 | if p == 0:
133 | chex.assert_trees_all_equal(got, batch)
134 |
--------------------------------------------------------------------------------
/imgx/data/augmentation/intensity.py:
--------------------------------------------------------------------------------
1 | """Intensity related data augmentation functions."""
2 | from __future__ import annotations
3 |
4 | from functools import partial
5 |
6 | import jax
7 | import jax.numpy as jnp
8 | from omegaconf import DictConfig
9 |
10 | from imgx.data.augmentation import AugmentationFn
11 | from imgx.data.util import get_batch_size
12 | from imgx.datasets.constant import IMAGE
13 |
14 |
15 | def adjust_gamma(
16 | x: jnp.ndarray,
17 | gamma: jnp.ndarray,
18 | gain: float = 1.0,
19 | ) -> jnp.ndarray:
20 | """Adjust gamma of input images.
21 |
22 | https://github.com/tensorflow/tensorflow/blob/v2.14.0/tensorflow/python/ops/image_ops_impl.py#L2303-L2366
23 |
24 | Args:
25 | x: input image.
26 | gamma: non-negative real number.
27 | gain: the constant multiplier.
28 |
29 | Returns:
30 | Adjusted image.
31 | """
32 | return gain * x**gamma
33 |
34 |
35 | def batch_random_adjust_gamma(
36 | key: jax.Array,
37 | batch: dict[str, jnp.ndarray],
38 | max_log_gamma: float,
39 | p: float = 0.5,
40 | ) -> dict[str, jnp.ndarray]:
41 | """Perform random gamma adjustment on images in batch.
42 |
43 | https://torchio.readthedocs.io/_modules/torchio/transforms/augmentation/intensity/random_gamma.html
44 |
45 | Args:
46 | key: jax random key.
47 | batch: dict having images or labels, or foreground_range.
48 | max_log_gamma: maximum log gamma.
49 | p: probability of performing the augmentation.
50 |
51 | Returns:
52 | Augmented dict having image and label, shapes are not changed.
53 | """
54 | batch_size = get_batch_size(batch)
55 |
56 | adjusted_batch = {}
57 | for k, v in batch.items():
58 | if IMAGE in k:
59 | key_gamma, key_act, key = jax.random.split(key, 3)
60 | log_gamma = jax.random.uniform(
61 | key=key_gamma,
62 | shape=(batch_size,),
63 | minval=-max_log_gamma,
64 | maxval=max_log_gamma,
65 | )
66 | gamma = jnp.exp(log_gamma)
67 | gamma = jnp.where(
68 | jax.random.uniform(key=key_act, shape=gamma.shape) < p,
69 | gamma,
70 | jnp.ones_like(gamma),
71 | )
72 | adjusted_batch[k] = jax.vmap(adjust_gamma)(v, gamma)
73 | else:
74 | adjusted_batch[k] = v
75 | return adjusted_batch
76 |
77 |
78 | def get_random_gamma_augmentation_fn(config: DictConfig) -> AugmentationFn:
79 | """Return a data augmentation function for random gamma transformation.
80 |
81 | Args:
82 | config: entire config.
83 |
84 | Returns:
85 | A data augmentation function.
86 | """
87 | da_config = config.data.loader.data_augmentation
88 | return partial(
89 | batch_random_adjust_gamma,
90 | max_log_gamma=da_config.max_log_gamma,
91 | p=da_config.p,
92 | )
93 |
94 |
95 | def rescale_intensity(
96 | x: jnp.ndarray,
97 | v_min: float = 0.0,
98 | v_max: float = 1.0,
99 | ) -> jnp.ndarray:
100 | """Adjust intensity linearly to the desired range.
101 |
102 | Args:
103 | x: input image, (batch, *spatial_dims, channel).
104 | v_min: minimum intensity.
105 | v_max: maximum intensity.
106 |
107 | Returns:
108 | Adjusted image.
109 | """
110 | reduction_axes = tuple(range(x.ndim)[slice(1, -1)])
111 | x_min = jnp.min(x, axis=reduction_axes, keepdims=True)
112 | x_max = jnp.max(x, axis=reduction_axes, keepdims=True)
113 | x = (x - x_min) / (x_max - x_min)
114 | x = x * (v_max - v_min) + v_min
115 | return x
116 |
117 |
118 | def batch_rescale_intensity(
119 | key: jax.Array, # noqa: ARG001, pylint: disable=unused-argument
120 | batch: dict[str, jnp.ndarray],
121 | v_min: float = 0.0,
122 | v_max: float = 1.0,
123 | ) -> dict[str, jnp.ndarray]:
124 | """Perform intensity scaling on images in batch.
125 |
126 | Args:
127 | key: jax random key.
128 | batch: dict having images or labels, or foreground_range.
129 | v_min: minimum intensity.
130 | v_max: maximum intensity.
131 |
132 | Returns:
133 | Augmented dict having image and label, shapes are not changed.
134 | """
135 | adjusted_batch = {}
136 | for k, v in batch.items():
137 | if IMAGE in k:
138 | adjusted_batch[k] = rescale_intensity(v, v_min=v_min, v_max=v_max)
139 | else:
140 | adjusted_batch[k] = v
141 | return adjusted_batch
142 |
143 |
144 | def get_rescale_intensity_fn(config: DictConfig) -> AugmentationFn:
145 | """Return a data augmentation function for intensity scaling.
146 |
147 | Args:
148 | config: entire config.
149 |
150 | Returns:
151 | A data augmentation function.
152 | """
153 | da_config = config.data.loader.data_augmentation
154 | return partial(
155 | batch_rescale_intensity,
156 | v_min=da_config.v_min,
157 | v_max=da_config.v_max,
158 | )
159 |
--------------------------------------------------------------------------------
/imgx/data/augmentation/intensity_test.py:
--------------------------------------------------------------------------------
1 | """Test function for intensity data augmentation."""
2 |
3 | import chex
4 | import jax
5 | from absl.testing import parameterized
6 | from chex._src import fake
7 |
8 | from imgx.data.augmentation.intensity import batch_random_adjust_gamma, batch_rescale_intensity
9 | from imgx.datasets.constant import IMAGE
10 |
11 |
12 | # Set `FLAGS.chex_n_cpu_devices` CPU devices for all tests.
13 | def setUpModule() -> None: # pylint: disable=invalid-name
14 | """Fake multi-devices."""
15 | fake.set_n_cpu_devices(2)
16 |
17 |
18 | class TestRandomAdjustGamma(chex.TestCase):
19 | """Test batch_random_adjust_gamma."""
20 |
21 | @chex.all_variants()
22 | @parameterized.product(
23 | image_shape=[(8, 12, 6), (8, 12)],
24 | max_log_gamma=[0.0, 0.3],
25 | batch_size=[4, 1],
26 | )
27 | def test_shapes(
28 | self,
29 | batch_size: int,
30 | max_log_gamma: float,
31 | image_shape: tuple[int, ...],
32 | ) -> None:
33 | """Test output shapes.
34 |
35 | Args:
36 | batch_size: number of samples in batch.
37 | max_log_gamma: maximum log gamma.
38 | image_shape: image spatial shape.
39 | """
40 | key = jax.random.PRNGKey(0)
41 | image = jax.random.uniform(key=key, shape=(batch_size, *image_shape), minval=0, maxval=1)
42 | batch = {IMAGE: image}
43 | got = self.variant(batch_random_adjust_gamma)(
44 | key=key,
45 | batch=batch,
46 | max_log_gamma=max_log_gamma,
47 | )
48 |
49 | assert len(got) == 1
50 | chex.assert_shape(got[IMAGE], (batch_size, *image_shape))
51 |
52 |
53 | class TestRescaleIntensity(chex.TestCase):
54 | """Test batch_rescale_intensity."""
55 |
56 | @chex.all_variants()
57 | @parameterized.product(
58 | image_shape=[(8, 12, 6), (8, 12)],
59 | v_min=[0.0, 0.3],
60 | v_max=[1.0, 0.5],
61 | batch_size=[4, 1],
62 | )
63 | def test_shapes(
64 | self,
65 | batch_size: int,
66 | v_min: float,
67 | v_max: float,
68 | image_shape: tuple[int, ...],
69 | ) -> None:
70 | """Test output shapes.
71 |
72 | Args:
73 | batch_size: number of samples in batch.
74 | v_min: minimum intensity.
75 | v_max: maximum intensity.
76 | image_shape: image spatial shape.
77 | """
78 | key = jax.random.PRNGKey(0)
79 | image = jax.random.uniform(key=key, shape=(batch_size, *image_shape), minval=0, maxval=1)
80 | batch = {IMAGE: image}
81 | got = self.variant(batch_rescale_intensity)(
82 | key=key,
83 | batch=batch,
84 | v_min=v_min,
85 | v_max=v_max,
86 | )
87 |
88 | assert len(got) == 1
89 | chex.assert_shape(got[IMAGE], (batch_size, *image_shape))
90 |
--------------------------------------------------------------------------------
/imgx/data/util_test.py:
--------------------------------------------------------------------------------
1 | """Tests for image utils of datasets."""
2 |
3 |
4 | import chex
5 | import numpy as np
6 | import pytest
7 | from chex._src import fake
8 |
9 | from imgx.data.util import get_foreground_range
10 |
11 |
12 | # Set `FLAGS.chex_n_cpu_devices` CPU devices for all tests.
13 | def setUpModule() -> None: # pylint: disable=invalid-name
14 | """Fake multi-devices."""
15 | fake.set_n_cpu_devices(2)
16 |
17 |
18 | @pytest.mark.parametrize(
19 | ("label", "expected"),
20 | [
21 | (
22 | np.array([0, 1, 2, 3]),
23 | np.array([[1, 3]]),
24 | ),
25 | (
26 | np.array([1, 2, 3, 0]),
27 | np.array([[0, 2]]),
28 | ),
29 | (
30 | np.array([1, 2, 3, 4]),
31 | np.array([[0, 3]]),
32 | ),
33 | (
34 | np.array([0, 1, 2, 3, 4, 0, 0]),
35 | np.array([[1, 4]]),
36 | ),
37 | (
38 | np.array([[0, 1, 2, 3], [0, 1, 2, 3], [0, 0, 0, 0]]),
39 | np.array([[0, 1], [1, 3]]),
40 | ),
41 | ],
42 | ids=[
43 | "1d-left",
44 | "1d-right",
45 | "1d-none",
46 | "1d-both",
47 | "2d",
48 | ],
49 | )
50 | def test_get_foreground_range(
51 | label: np.ndarray,
52 | expected: np.ndarray,
53 | ) -> None:
54 | """Test get_translation_range return values.
55 |
56 | Args:
57 | label: label with int values, not one-hot.
58 | expected: expected range.
59 | """
60 | got = get_foreground_range(
61 | label=label,
62 | )
63 | chex.assert_trees_all_equal(got, expected)
64 |
--------------------------------------------------------------------------------
/imgx/data/warp.py:
--------------------------------------------------------------------------------
1 | """Module for image/lavel warping."""
2 | from __future__ import annotations
3 |
4 | from functools import partial
5 |
6 | import jax
7 | import jax.numpy as jnp
8 | from jax._src.scipy.ndimage import map_coordinates
9 |
10 |
11 | def get_coordinate_grid(shape: tuple[int, ...]) -> jnp.ndarray:
12 | """Generate a grid with given shape.
13 |
14 | This function is not jittable as the output depends on the value of shapes.
15 |
16 | Args:
17 | shape: shape of the grid, (d1, ..., dn).
18 |
19 | Returns:
20 | grid: grid coordinates, of shape (n, d1, ..., dn).
21 | grid[:, i1, ..., in] = [i1, ..., in]
22 | """
23 | return jnp.stack(
24 | jnp.meshgrid(
25 | *(jnp.arange(d) for d in shape),
26 | indexing="ij",
27 | ),
28 | axis=0,
29 | dtype=jnp.float32,
30 | )
31 |
32 |
33 | def batch_grid_sample(
34 | x: jnp.ndarray,
35 | grid: jnp.ndarray,
36 | order: int,
37 | constant_values: float = 0.0,
38 | ) -> jnp.ndarray:
39 | """Apply sampling to input.
40 |
41 | https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
42 |
43 | Args:
44 | x: shape (batch, d1, ..., dn) or (batch, d1, ..., dn, c).
45 | grid: grid coordinates, of shape (batch, n, d1, ..., dn).
46 | order: interpolation order, 0 for nearest, 1 for linear.
47 | constant_values: constant value for out of bound coordinates.
48 |
49 | Returns:
50 | Same shape as x.
51 | """
52 | if x.ndim not in [grid.ndim - 1, grid.ndim]:
53 | raise ValueError(f"Input x has shape {x.shape}, grid has shape {grid.shape}.")
54 |
55 | # vmap on batch axis
56 | sample_vmap = jax.vmap(
57 | partial(
58 | map_coordinates,
59 | order=order,
60 | mode="constant",
61 | cval=constant_values,
62 | ),
63 | in_axes=(0, 0),
64 | )
65 | if x.ndim == grid.ndim:
66 | # vmap on channel axis
67 | ch_axis = x.ndim - 1
68 | sample_vmap = jax.vmap(
69 | sample_vmap,
70 | in_axes=(ch_axis, None),
71 | out_axes=ch_axis,
72 | )
73 | return sample_vmap(x, grid)
74 |
75 |
76 | def warp_image(
77 | x: jnp.ndarray,
78 | ddf: jnp.ndarray,
79 | order: int,
80 | ) -> jnp.ndarray:
81 | """Warp the image with the deformation field.
82 |
83 | TODO: grid is a constant, can be precomputed.
84 |
85 | Args:
86 | x: shape (batch, d1, ..., dn) or (batch, d1, ..., dn, c).
87 | ddf: deformation field, of shape (batch, d1, ..., dn, n).
88 | order: interpolation order, 0 for nearest, 1 for linear.
89 |
90 | Returns:
91 | warped image, of shape (batch, d1, ..., dn) or (batch, d1, ..., dn, c).
92 | """
93 | # (batch, n, d1, ..., dn)
94 | grid = get_coordinate_grid(shape=ddf.shape[1:-1])
95 | grid += jnp.moveaxis(ddf, -1, 1)
96 | return batch_grid_sample(x, grid, order=order)
97 |
--------------------------------------------------------------------------------
/imgx/datasets/README.md:
--------------------------------------------------------------------------------
1 | # ImgX Datasets
2 |
3 | A [TFDS](https://www.tensorflow.org/datasets/add_dataset)-based python package for data set
4 | building.
5 |
6 | Current supported data sets are listed below. Use the following commands to (re)build all data sets.
7 |
8 | ```bash
9 | make build_dataset
10 | ```
11 |
12 | ## Male pelvic MR
13 |
14 | ### Description
15 |
16 | This data set from [Li et al. 2022](https://zenodo.org/record/7013610#.Y1U95-zMKrM) contains 589
17 | T2-weighted labeled images which are split for training, validation and testing respectively.
18 |
19 | ### Download and Build
20 |
21 | Use the following commands at the root of this repository (i.e. under `ImgX/`) to automatically
22 | download and build the data set, which will be built under `~/tensorflow_datasets` folder.
23 | Optionally, add flag `--overwrite` to rebuild/overwrite the data set.
24 |
25 | ```bash
26 | tfds build imgx/datasets/male_pelvic_mr
27 | ```
28 |
29 | ## AMOS CT
30 |
31 | ### Description
32 |
33 | This data set from [Ji et al. 2022](https://zenodo.org/record/7155725#.ZAN4BuzP2rO) contains 500 CT
34 | labeled images which has been split into 200, 100, and 200 images for training, validation, and test
35 | sets. But test set labels were not released, therefore validation is further split into 10 and 90
36 | images for validation and test sets.
37 |
38 | ### Download and Build
39 |
40 | Use the following commands at the root of this repository (i.e. under `ImgX/`) to automatically
41 | download and build the data set, which will be built under `~/tensorflow_datasets` folder.
42 | Optionally, add flag `--overwrite` to rebuild/overwrite the data set.
43 |
44 | ```bash
45 | tfds build imgx/datasets/amos_ct
46 | ```
47 |
48 | ## Muscle Ultrasound
49 |
50 | ### Description
51 |
52 | This data set from [Marzola et al. 2021](https://data.mendeley.com/datasets/3jykz7wz8d/1) contains
53 | 3910 labeled images, which has been split into 2531, 666, and 713 images for training, validation,
54 | and test sets.
55 |
56 | ### Download and Build
57 |
58 | Use the following commands at the root of this repository (i.e. under `ImgX/`) to automatically
59 | download and build the data set, which will be built under `~/tensorflow_datasets` folder.
60 | Optionally, add flag `--overwrite` to rebuild/overwrite the data set.
61 |
62 | ```bash
63 | tfds build imgx/datasets/muscle_us
64 | ```
65 |
66 | ## Brain MR
67 |
68 | ### Description
69 |
70 | This data set from [Baid et al. 2021](https://arxiv.org/abs/2107.02314) contains 1251 labeled images
71 | which are split for training, validation and testing respectively.
72 |
73 | ### Download and Build
74 |
75 | #### Manual Download
76 |
77 | This data set requires manual data downloading from
78 | [Kaggle](https://www.kaggle.com/datasets/dschettler8845/brats-2021-task1). using
79 | [kaggle API](https://www.kaggle.com/docs/api). The
80 | [authentication token](https://www.kaggle.com/docs/api#getting-started-installation-&-authentication)
81 | shall be obtained and stored under `~/.kaggle/kaggle.json`.
82 |
83 | Then, execute the following commands to download and unzip files. Afterward, return to `ImgX/`
84 | folder (`/app/ImgX` for docker).
85 |
86 | ```bash
87 | mkdir -p ~/tensorflow_datasets/downloads/manual/BraTS2021_Kaggle/BraTS2021_Training_Data/
88 | cd ~/tensorflow_datasets/downloads/manual/BraTS2021_Kaggle/BraTS2021_Training_Data/
89 | kaggle datasets download -d dschettler8845/brats-2021-task1
90 | unzip brats-2021-task1.zip
91 | tar xf BraTS2021_Training_Data.tar
92 | rm BraTS2021_00495.tar
93 | rm BraTS2021_00621.tar
94 | rm BraTS2021_Training_Data.tar
95 | rm brats-2021-task1.zip
96 | ```
97 |
98 | This way under `BraTS2021_Kaggle/` exist folders per sample. For example, files corresponding to uid
99 | `BraTS2021_01666` should be located at
100 | `~/tensorflow_datasets/downloads/manual/BraTS2021_Kaggle/BraTS2021_Training_Data/BraTS2021_01666/`
101 | under which there are five files:
102 |
103 | - `BraTS2021_01666_flair.nii.gz`,
104 | - `BraTS2021_01666_t1.nii.gz`,
105 | - `BraTS2021_01666_t1ce.nii.gz`,
106 | - `BraTS2021_01666_t2.nii.gz`,
107 | - `BraTS2021_01666_seg.nii.gz`.
108 |
109 | #### Automatic Build
110 |
111 | Use the following commands at the root of this repository (i.e. under `ImgX/`) to automatically
112 | build the data set, which will be built under `~/tensorflow_datasets` folder. Optionally, add flag
113 | `--overwrite` to rebuild/overwrite the data set.
114 |
115 | ```bash
116 | tfds build imgx/datasets/brats2021_mr
117 | ```
118 |
119 | ## Automated Cardiac Diagnosis Challenge (ACDC)
120 |
121 | ### Description
122 |
123 | This data set from [Bernard et al. 2018](https://ieeexplore.ieee.org/document/8360453) contains 150
124 | samples. Samples are split into 100 and 50 for training and test sets. Each sample contains
125 |
126 | - a 4D image (a sequence of 3D MR images)
127 | - a 3D image and corresponding segmentation label for end-diastolic (ED) frame
128 | - a 3D image and corresponding segmentation label for end-systolic (ES) frame
129 |
130 | ### Download and Build
131 |
132 | Use the following commands at the root of this repository (i.e. under `ImgX/`) to automatically
133 | download and build the data set, which will be built under `~/tensorflow_datasets` folder.
134 | Optionally, add flag `--overwrite` to rebuild/overwrite the data set.
135 |
136 | ```bash
137 | tfds build imgx/datasets/acdc_mr
138 | ```
139 |
--------------------------------------------------------------------------------
/imgx/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | """Dataset module to build tensorflow datasets."""
2 |
3 | from imgx.datasets.amos_ct.amos_ct_dataset_builder import AMOS_CT_INFO
4 | from imgx.datasets.brats2021_mr.brats2021_mr_dataset_builder import BRATS2021_MR_INFO
5 | from imgx.datasets.dataset_info import DatasetInfo
6 | from imgx.datasets.male_pelvic_mr.male_pelvic_mr_dataset_builder import MALE_PELVIR_MR_INFO
7 | from imgx.datasets.muscle_us.muscle_us_dataset_builder import MUSCLE_US_INFO
8 |
9 | # supported datasets
10 | MALE_PELVIC_MR = "male_pelvic_mr"
11 | AMOS_CT = "amos_ct"
12 | MUSCLE_US = "muscle_us"
13 | BRATS2021_MR = "brats2021_mr"
14 |
15 | # dataset info
16 | INFO_MAP: dict[str, DatasetInfo] = {
17 | MALE_PELVIC_MR: MALE_PELVIR_MR_INFO,
18 | AMOS_CT: AMOS_CT_INFO,
19 | MUSCLE_US: MUSCLE_US_INFO,
20 | BRATS2021_MR: BRATS2021_MR_INFO,
21 | }
22 |
--------------------------------------------------------------------------------
/imgx/datasets/amos_ct/__init__.py:
--------------------------------------------------------------------------------
1 | """AMOS dataset.
2 |
3 | https://arxiv.org/abs/2206.08023
4 |
5 | https://zenodo.org/record/7262581
6 | """
7 |
--------------------------------------------------------------------------------
/imgx/datasets/brats2021_mr/__init__.py:
--------------------------------------------------------------------------------
1 | """BRaTS 2021 MR dataset.
2 |
3 | https://www.kaggle.com/datasets/dschettler8845/brats-2021-task1
4 | """
5 |
--------------------------------------------------------------------------------
/imgx/datasets/constant.py:
--------------------------------------------------------------------------------
1 | """Constants for datasets.
2 |
3 | Cannot be defined in __init__.py because of circular import.
4 | __init__.py imports from each data set, which imports constants from this file.
5 | """
6 | from pathlib import Path
7 |
8 | # splits
9 | TRAIN_SPLIT = "train"
10 | VALID_SPLIT = "valid"
11 | TEST_SPLIT = "test"
12 |
13 | # data dict keys
14 | UID = "uid"
15 | IMAGE = "image" # in a batch, keys having image are also considered as images
16 | LABEL = "label" # in a batch, keys having label are also considered as labels
17 |
18 | # prediction
19 | LABEL_PRED = "label_pred"
20 |
21 | TFDS_DIR: Path = Path.home() / "tensorflow_datasets"
22 | TFDS_EXTRACTED_DIR: Path = TFDS_DIR / "downloads" / "extracted"
23 | TFDS_MANUAL_DIR: Path = TFDS_DIR / "downloads" / "manual"
24 |
25 | # segmentation task
26 | FOREGROUND_RANGE = "foreground_range" # added during pre-processing in tf for augmentation
27 |
--------------------------------------------------------------------------------
/imgx/datasets/dataset_info.py:
--------------------------------------------------------------------------------
1 | """Data set class for datasets."""
2 | import dataclasses
3 | from pathlib import Path
4 |
5 | import jax
6 | import jax.numpy as jnp
7 |
8 |
9 | @dataclasses.dataclass
10 | class DatasetInfo:
11 | """Data set class for datasets."""
12 |
13 | name: str
14 | tfds_preprocessed_dir: Path
15 | image_spacing: tuple[float, ...]
16 | image_spatial_shape: tuple[int, ...]
17 | image_channels: int
18 | class_names: tuple[str, ...] # for segmentation label only
19 | classes_are_exclusive: bool # for segmentation label only
20 |
21 | @property
22 | def input_image_shape(self) -> tuple[int, ...]:
23 | """Input shape of image."""
24 | return (*self.image_spatial_shape, self.image_channels)
25 |
26 | @property
27 | def label_shape(self) -> tuple[int, ...]:
28 | """Shape of label."""
29 | return self.image_spatial_shape
30 |
31 | @property
32 | def ndim(self) -> int:
33 | """Number of dimensions."""
34 | return len(self.image_spatial_shape)
35 |
36 | @property
37 | def num_classes(self) -> int:
38 | """Number of classes for segmentation."""
39 | raise NotImplementedError
40 |
41 | def logits_to_label(self, x: jnp.ndarray, axis: int) -> jnp.ndarray:
42 | """Transform logits to label with integers."""
43 | raise NotImplementedError
44 |
45 | def label_to_mask(
46 | self, x: jnp.ndarray, axis: int, dtype: jnp.dtype = jnp.float32
47 | ) -> jnp.ndarray:
48 | """Transform label to boolean mask."""
49 | raise NotImplementedError
50 |
51 | def logits_to_label_with_post_process(self, x: jnp.ndarray, axis: int) -> jnp.ndarray:
52 | """Transform logits to label with post-processing."""
53 | return self.post_process_label(self.logits_to_label(x, axis=axis))
54 |
55 | def post_process_label(self, label: jnp.ndarray) -> jnp.ndarray:
56 | """Label post-processing."""
57 | return label
58 |
59 |
60 | class OneHotLabeledDatasetInfo(DatasetInfo):
61 | """Data set with mutual exclusive labels."""
62 |
63 | @property
64 | def num_classes(self) -> int:
65 | """Number of classes including background."""
66 | return len(self.class_names) + 1
67 |
68 | def logits_to_label(self, x: jnp.ndarray, axis: int) -> jnp.ndarray:
69 | """Transform logits to label with integers.
70 |
71 | Args:
72 | x: logits.
73 | axis: axis of num_classes.
74 |
75 | Returns:
76 | Label with integers.
77 | """
78 | return jnp.argmax(x, axis=axis)
79 |
80 | def label_to_mask(
81 | self, x: jnp.ndarray, axis: int, dtype: jnp.dtype = jnp.float32
82 | ) -> jnp.ndarray:
83 | """Transform label to boolean mask.
84 |
85 | Args:
86 | x: label.
87 | axis: axis of num_classes.
88 | dtype: dtype of output.
89 |
90 | Returns:
91 | One hot mask.
92 | """
93 | return jax.nn.one_hot(
94 | x=x,
95 | num_classes=self.num_classes,
96 | axis=axis,
97 | dtype=dtype,
98 | )
99 |
--------------------------------------------------------------------------------
/imgx/datasets/male_pelvic_mr/__init__.py:
--------------------------------------------------------------------------------
1 | """Male pelvic MR dataset.
2 |
3 | https://arxiv.org/abs/2209.05160
4 | """
5 |
--------------------------------------------------------------------------------
/imgx/datasets/muscle_us/__init__.py:
--------------------------------------------------------------------------------
1 | """Muscle ultrasound dataset.
2 |
3 | https://www.sciencedirect.com/science/article/pii/S0010482521004170
4 | https://data.mendeley.com/datasets/3jykz7wz8d/1
5 | """
6 |
--------------------------------------------------------------------------------
/imgx/datasets/save.py:
--------------------------------------------------------------------------------
1 | """IO related functions (file cannot be named as io).
2 |
3 | https://stackoverflow.com/questions/26569828/pycharm-py-initialize-cant-initialize-sys-standard-streams
4 | """
5 | from __future__ import annotations
6 |
7 | from pathlib import Path
8 |
9 | import numpy as np
10 | import pandas as pd
11 | import SimpleITK as sitk # noqa: N813
12 | from absl import logging
13 | from PIL import Image
14 |
15 |
16 | def save_uids(
17 | train_uids: list[str],
18 | valid_uids: list[str],
19 | test_uids: list[str],
20 | out_dir: Path,
21 | ) -> None:
22 | """Save uids to csv files.
23 |
24 | Args:
25 | train_uids: list of training uids.
26 | valid_uids: list of validation uids.
27 | test_uids: list of test uids.
28 | out_dir: directory to save the csv files.
29 | """
30 | pd.DataFrame({"uid": train_uids}).to_csv(out_dir / "train_uids.csv", index=False)
31 | pd.DataFrame({"uid": valid_uids}).to_csv(out_dir / "valid_uids.csv", index=False)
32 | pd.DataFrame({"uid": test_uids}).to_csv(out_dir / "test_uids.csv", index=False)
33 | logging.info(f"There are {len(train_uids)} training samples.")
34 | logging.info(f"There are {len(valid_uids)} validation samples.")
35 | logging.info(f"There are {len(test_uids)} test samples.")
36 |
37 |
38 | def save_2d_grayscale_image(
39 | image: np.ndarray,
40 | out_path: Path,
41 | ) -> None:
42 | """Save grayscale 2d images.
43 |
44 | Args:
45 | image: (height, width), the values between [0, 1].
46 | out_path: output path.
47 | """
48 | out_path.parent.mkdir(parents=True, exist_ok=True)
49 | image = np.asarray(image * 255, dtype="uint8")
50 | Image.fromarray(image, "L").save(str(out_path))
51 |
52 |
53 | def load_2d_grayscale_image(
54 | image_path: Path,
55 | dtype: np.dtype = np.uint8,
56 | ) -> np.ndarray:
57 | """Load 2d images.
58 |
59 | Args:
60 | image_path: path to the mask.
61 | dtype: data type of the output.
62 |
63 | Returns:
64 | mask: (height, width), the values are between [0, 1].
65 | """
66 | mask = Image.open(str(image_path)).convert("L") # value [0, 255]
67 | mask = np.asarray(mask) / 255 # value [0, 1]
68 | mask = np.asarray(mask, dtype=dtype)
69 | return mask
70 |
71 |
72 | def save_3d_image(
73 | image: np.ndarray,
74 | reference_image: sitk.Image,
75 | out_path: Path,
76 | ) -> None:
77 | """Save 3d image.
78 |
79 | Args:
80 | image: (depth, height, width), the values are integers.
81 | reference_image: reference image for copy meta data.
82 | out_path: output path.
83 | """
84 | out_path.parent.mkdir(parents=True, exist_ok=True)
85 | image = sitk.GetImageFromArray(image)
86 | image.CopyInformation(reference_image)
87 | # output
88 | sitk.WriteImage(
89 | image=image,
90 | fileName=out_path,
91 | useCompression=True,
92 | )
93 |
94 |
95 | def save_image(
96 | image: np.ndarray,
97 | reference_image: sitk.Image,
98 | out_path: Path,
99 | dtype: np.dtype,
100 | ) -> None:
101 | """Save 2d or 3d image.
102 |
103 | Args:
104 | image: (width, height, depth) or (height, width), 3D is not reversed but 2D is reversed.
105 | reference_image: reference image for copy metadata.
106 | out_path: output path.
107 | dtype: data type of the output.
108 | """
109 | out_path.parent.mkdir(parents=True, exist_ok=True)
110 | if image.ndim not in [2, 3]:
111 | raise ValueError(
112 | f"Image should be 2D or 3D, but {image.ndim}D is given with shape {image.shape}."
113 | )
114 | if image.ndim == 2:
115 | save_2d_grayscale_image(
116 | image=image.astype(dtype=dtype),
117 | out_path=out_path,
118 | )
119 | else:
120 | # (width, height, depth) -> (depth, height, width)
121 | image = np.transpose(image, axes=[2, 1, 0]).astype(dtype=dtype)
122 | save_3d_image(
123 | image=image,
124 | reference_image=reference_image,
125 | out_path=out_path,
126 | )
127 |
128 |
129 | def save_ddf(
130 | ddf: np.ndarray,
131 | reference_image: sitk.Image,
132 | out_path: Path,
133 | dtype: np.dtype = np.float64,
134 | ) -> None:
135 | """Save ddf for 3d volumes.
136 |
137 | Args:
138 | ddf: (width, height, depth, 3), unit is 1 without spacing.
139 | reference_image: reference image for copy metadata.
140 | out_path: output path.
141 | dtype: data type of the output.
142 | """
143 | if ddf.ndim != 4:
144 | raise ValueError(f"Mask should be 4D, but {ddf.ndim}D is given.")
145 | out_path.parent.mkdir(parents=True, exist_ok=True)
146 |
147 | # ddf is scaled by spacing
148 | ddf = np.transpose(ddf, axes=[2, 1, 0, 3]).astype(dtype=dtype)
149 | ddf *= np.expand_dims(reference_image.GetSpacing(), axis=list(range(ddf.ndim - 1)))
150 |
151 | ddf_volume = sitk.GetImageFromArray(ddf, isVector=True)
152 | ddf_volume.SetSpacing(reference_image.GetSpacing())
153 | ddf_volume.CopyInformation(reference_image)
154 | tx = sitk.DisplacementFieldTransform(ddf_volume)
155 | sitk.WriteTransform(tx, out_path)
156 |
--------------------------------------------------------------------------------
/imgx/datasets/tests/conftest.py:
--------------------------------------------------------------------------------
1 | """Fixtures for tests."""
2 | from pathlib import Path
3 |
4 | import pytest
5 |
6 |
7 | @pytest.fixture(scope="session")
8 | def fixture_path() -> Path:
9 | """Directory path containing fixture data.
10 |
11 | Returns:
12 | Folder path containing the data.
13 | """
14 | return Path(__file__).resolve().parent / "fixtures"
15 |
--------------------------------------------------------------------------------
/imgx/datasets/tests/fixtures/BB_anon_1789_3_mask_pred.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mathpluscode/ImgX-DiffSeg/ce57729e2dcd1d21960ec797bdff2ef5df7e5101/imgx/datasets/tests/fixtures/BB_anon_1789_3_mask_pred.png
--------------------------------------------------------------------------------
/imgx/datasets/tests/fixtures/BB_anon_1789_3_mask_pred_postprocessed_0.5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mathpluscode/ImgX-DiffSeg/ce57729e2dcd1d21960ec797bdff2ef5df7e5101/imgx/datasets/tests/fixtures/BB_anon_1789_3_mask_pred_postprocessed_0.5.png
--------------------------------------------------------------------------------
/imgx/datasets/tests/fixtures/BB_anon_1789_3_mask_pred_postprocessed_0.75.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mathpluscode/ImgX-DiffSeg/ce57729e2dcd1d21960ec797bdff2ef5df7e5101/imgx/datasets/tests/fixtures/BB_anon_1789_3_mask_pred_postprocessed_0.75.png
--------------------------------------------------------------------------------
/imgx/datasets/tests/fixtures/BB_anon_425_2_mask_pred.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mathpluscode/ImgX-DiffSeg/ce57729e2dcd1d21960ec797bdff2ef5df7e5101/imgx/datasets/tests/fixtures/BB_anon_425_2_mask_pred.png
--------------------------------------------------------------------------------
/imgx/datasets/tests/fixtures/BB_anon_425_2_mask_pred_postprocessed.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mathpluscode/ImgX-DiffSeg/ce57729e2dcd1d21960ec797bdff2ef5df7e5101/imgx/datasets/tests/fixtures/BB_anon_425_2_mask_pred_postprocessed.png
--------------------------------------------------------------------------------
/imgx/datasets/tests/fixtures/GM_anon_780_3_mask_pred.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mathpluscode/ImgX-DiffSeg/ce57729e2dcd1d21960ec797bdff2ef5df7e5101/imgx/datasets/tests/fixtures/GM_anon_780_3_mask_pred.png
--------------------------------------------------------------------------------
/imgx/datasets/tests/fixtures/GM_anon_780_3_mask_pred_postprocessed.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mathpluscode/ImgX-DiffSeg/ce57729e2dcd1d21960ec797bdff2ef5df7e5101/imgx/datasets/tests/fixtures/GM_anon_780_3_mask_pred_postprocessed.png
--------------------------------------------------------------------------------
/imgx/datasets/util.py:
--------------------------------------------------------------------------------
1 | """Util functions for image.
2 |
3 | Some are adapted from
4 | https://github.com/google-research/scenic/blob/03735eb81f64fd1241c4efdb946ea6de3d326fe1/scenic/dataset_lib/dataset_utils.py
5 | """
6 | from __future__ import annotations
7 |
8 | import numpy as np
9 |
10 |
11 | def get_center_pad_shape(
12 | current_shape: tuple[int, ...], target_shape: tuple[int, ...]
13 | ) -> tuple[tuple[int, ...], tuple[int, ...]]:
14 | """Get pad sizes for sitk.ConstantPad.
15 |
16 | The padding is added symmetrically.
17 |
18 | Args:
19 | current_shape: current shape of the image.
20 | target_shape: target shape of the image.
21 |
22 | Returns:
23 | pad_lower: shape to pad on the lower side.
24 | pad_upper: shape to pad on the upper side.
25 | """
26 | pad_lower = []
27 | pad_upper = []
28 | for i, size_i in enumerate(current_shape):
29 | pad_i = max(target_shape[i] - size_i, 0)
30 | pad_lower_i = pad_i // 2
31 | pad_upper_i = pad_i - pad_lower_i
32 | pad_lower.append(pad_lower_i)
33 | pad_upper.append(pad_upper_i)
34 | return tuple(pad_lower), tuple(pad_upper)
35 |
36 |
37 | def get_center_crop_shape(
38 | current_shape: tuple[int, ...], target_shape: tuple[int, ...]
39 | ) -> tuple[tuple[int, ...], tuple[int, ...]]:
40 | """Get crop sizes for sitk.Crop.
41 |
42 | The crop is performed symmetrically.
43 |
44 | Args:
45 | current_shape: current shape of the image.
46 | target_shape: target shape of the image.
47 |
48 | Returns:
49 | crop_lower: shape to pad on the lower side.
50 | crop_upper: shape to pad on the upper side.
51 | """
52 | crop_lower = []
53 | crop_upper = []
54 | for i, size_i in enumerate(current_shape):
55 | crop_i = max(size_i - target_shape[i], 0)
56 | crop_lower_i = crop_i // 2
57 | crop_upper_i = crop_i - crop_lower_i
58 | crop_lower.append(crop_lower_i)
59 | crop_upper.append(crop_upper_i)
60 | return tuple(crop_lower), tuple(crop_upper)
61 |
62 |
63 | def try_to_get_center_crop_shape(
64 | label_min: int, label_max: int, current_length: int, target_length: int
65 | ) -> tuple[int, int]:
66 | """Try to crop at the center of label, 1D.
67 |
68 | Args:
69 | label_min: label index minimum, inclusive.
70 | label_max: label index maximum, exclusive.
71 | current_length: current image length.
72 | target_length: target image length.
73 |
74 | Returns:
75 | crop_lower: shape to pad on the lower side.
76 | crop_upper: shape to pad on the upper side.
77 |
78 | Raises:
79 | ValueError: if label min max is out of range.
80 | """
81 | if label_min < 0 or label_max > current_length:
82 | raise ValueError("Label index out of range.")
83 |
84 | if current_length <= target_length:
85 | # no need of crop
86 | return 0, 0
87 | # attend to perform crop centered at label center
88 | label_center = (label_max - 1 + label_min) / 2.0
89 | bbox_lower = int(np.ceil(label_center - target_length / 2.0))
90 | bbox_upper = bbox_lower + target_length
91 | # if lower is negative, then have to shift the window to right
92 | bbox_lower = max(bbox_lower, 0)
93 | # if upper is too large, then have to shift the window to left
94 | if bbox_upper > current_length:
95 | bbox_lower -= bbox_upper - current_length
96 | # calculate crop
97 | crop_lower = bbox_lower # bbox index starts at 0
98 | crop_upper = current_length - target_length - crop_lower
99 | return crop_lower, crop_upper
100 |
101 |
102 | def get_center_crop_shape_from_bbox(
103 | bbox_min: tuple[int, ...] | np.ndarray,
104 | bbox_max: tuple[int, ...] | np.ndarray,
105 | current_shape: tuple[int, ...],
106 | target_shape: tuple[int, ...],
107 | ) -> tuple[tuple[int, ...], tuple[int, ...]]:
108 | """Get crop sizes for sitk.Crop from label bounding box.
109 |
110 | The crop is not necessarily performed symmetrically.
111 |
112 | Args:
113 | bbox_min: [start_in_1st_spatial_dim, ...], inclusive, starts at zero.
114 | bbox_max: [end_in_1st_spatial_dim, ...], exclusive, starts at zero.
115 | current_shape: current shape of the image.
116 | target_shape: target shape of the image.
117 |
118 | Returns:
119 | crop_lower: shape to crop on the lower side.
120 | crop_upper: shape to crop on the upper side.
121 | """
122 | crop_lower = []
123 | crop_upper = []
124 | for i, current_length in enumerate(current_shape):
125 | crop_lower_i, crop_upper_i = try_to_get_center_crop_shape(
126 | label_min=bbox_min[i],
127 | label_max=bbox_max[i],
128 | current_length=current_length,
129 | target_length=target_shape[i],
130 | )
131 | crop_lower.append(crop_lower_i)
132 | crop_upper.append(crop_upper_i)
133 | return tuple(crop_lower), tuple(crop_upper)
134 |
--------------------------------------------------------------------------------
/imgx/device.py:
--------------------------------------------------------------------------------
1 | """Module to handle multi-devices."""
2 | from __future__ import annotations
3 |
4 | import chex
5 | import jax
6 | import jax.numpy as jnp
7 | from jax import lax
8 |
9 |
10 | def broadcast_to_local_devices(value: chex.ArrayTree) -> chex.ArrayTree:
11 | """Broadcasts an object to all local devices.
12 |
13 | Args:
14 | value: value to be broadcast.
15 |
16 | Returns:
17 | broadcast value.
18 | """
19 | devices = jax.local_devices()
20 | return jax.tree_map(lambda v: jax.device_put_sharded(len(devices) * [v], devices), value)
21 |
22 |
23 | def get_first_replica_values(value: chex.ArrayTree) -> chex.ArrayTree:
24 | """Gets values from the first replica.
25 |
26 | Args:
27 | value: broadcast value.
28 |
29 | Returns:
30 | value of the first replica.
31 | """
32 | return jax.tree_map(lambda x: x[0], value)
33 |
34 |
35 | def bind_rng_to_host_or_device(
36 | rng: jnp.ndarray,
37 | bind_to: str | None = None,
38 | axis_name: str | tuple[str, ...] | None = None,
39 | ) -> jnp.ndarray:
40 | """Binds a rng to the host or device.
41 |
42 | https://github.com/google-research/scenic/blob/main/scenic/train_lib/train_utils.py#L577
43 |
44 | Must be called from within a pmapped function. Note that when binding to
45 | "device", we also bind the rng to hosts, as we fold_in the rng with
46 | axis_index, which is unique for devices across all hosts.
47 |
48 | Args:
49 | rng: A jax.random.PRNGKey.
50 | bind_to: Must be one of the 'host' or 'device'. None means no binding.
51 | axis_name: The axis of the devices we are binding rng across, necessary
52 | if bind_to is device.
53 |
54 | Returns:
55 | jax.random.PRNGKey specialized to host/device.
56 | """
57 | if bind_to is None:
58 | return rng
59 | if bind_to == "host":
60 | return jax.random.fold_in(rng, jax.process_index())
61 | if bind_to == "device":
62 | return jax.random.fold_in(rng, lax.axis_index(axis_name))
63 | raise ValueError("`bind_to` should be one of the `[None, 'host', 'device']`")
64 |
65 |
66 | def shard(
67 | pytree: chex.ArrayTree,
68 | num_replicas: int,
69 | ) -> chex.ArrayTree:
70 | """Reshapes all arrays in the pytree to add a leading shard dimension.
71 |
72 | We assume that all arrays in the pytree have leading dimension
73 | divisible by num_devices_per_replica.
74 |
75 | Args:
76 | pytree: A pytree of arrays to be sharded.
77 | num_replicas: number of model replicas.
78 |
79 | Returns:
80 | Sharded data.
81 | """
82 |
83 | def _shard_array(array: jnp.ndarray) -> jnp.ndarray:
84 | return array.reshape((num_replicas, -1) + array.shape[1:])
85 |
86 | return jax.tree_map(_shard_array, pytree)
87 |
88 |
89 | def unshard(pytree: chex.ArrayTree, device: jax.Device) -> chex.ArrayTree:
90 | """Reshapes arrays from [ndev, bs, ...] to [host_bs, ...].
91 |
92 | Args:
93 | pytree: A pytree of arrays to be sharded.
94 | device: device to put.
95 |
96 | Returns:
97 | Sharded data.
98 | """
99 |
100 | def _unshard_array(array: jnp.ndarray) -> jnp.ndarray:
101 | ndev, bs = array.shape[:2]
102 | return array.reshape((ndev * bs,) + array.shape[2:])
103 |
104 | pytree = jax.device_put(pytree, device)
105 | return jax.tree_map(_unshard_array, pytree)
106 |
--------------------------------------------------------------------------------
/imgx/diffusion/README.md:
--------------------------------------------------------------------------------
1 | # Diffusion
2 |
3 | The class diagram is illustrated below, where
4 |
5 | - red blocks are classes having abstract methods.
6 | - yellow blocks are classes used during training, with `sample()` not being implemented.
7 | - green blocks are classes used during inference, with `sample()` implemented.
8 |
9 |
10 |

11 |
12 |
--------------------------------------------------------------------------------
/imgx/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | """Diffusion related functions."""
2 | from imgx.diffusion.diffusion import Diffusion
3 |
4 | __all__ = [
5 | "Diffusion",
6 | ]
7 |
--------------------------------------------------------------------------------
/imgx/diffusion/diffusion.py:
--------------------------------------------------------------------------------
1 | """Base diffusion class."""
2 | from __future__ import annotations
3 |
4 | from collections.abc import Sequence
5 | from dataclasses import dataclass
6 | from typing import Callable
7 |
8 | import jax.numpy as jnp
9 | import jax.random
10 |
11 |
12 | @dataclass
13 | class Diffusion:
14 | """Base class for diffusion."""
15 |
16 | num_timesteps: int
17 | noise_fn: Callable[..., jnp.ndarray]
18 |
19 | def sample_noise(self, key: jax.Array, shape: Sequence[int], dtype: jnp.dtype) -> jnp.ndarray:
20 | """Return a noise of the same shape as input.
21 |
22 | Define this function to avoid defining randon key.
23 |
24 | Args:
25 | key: random key.
26 | shape: array shape.
27 | dtype: data type.
28 |
29 | Returns:
30 | Noise of the same shape and dtype as x.
31 | """
32 | return self.noise_fn(key=key, shape=shape, dtype=dtype)
33 |
34 | def t_index_to_t(self, t_index: jnp.ndarray) -> jnp.ndarray:
35 | """Convert t_index to t.
36 |
37 | t_index = 0 corresponds to t = 1 / num_timesteps.
38 | t_index = num_timesteps - 1 corresponds to t = 1.
39 |
40 | Args:
41 | t_index: t_index, shape (batch, ).
42 |
43 | Returns:
44 | t: t, shape (batch, ).
45 | """
46 | return jnp.asarray(t_index + 1, jnp.float32) / self.num_timesteps
47 |
48 | def q_sample(
49 | self,
50 | x_start: jnp.ndarray,
51 | noise: jnp.ndarray,
52 | t_index: jnp.ndarray,
53 | ) -> jnp.ndarray:
54 | """Sample from q(x_t | x_0).
55 |
56 | Args:
57 | x_start: noiseless input.
58 | noise: same shape as x_start.
59 | t_index: storing index values < self.num_timesteps.
60 |
61 | Returns:
62 | Noisy array with same shape as x_start.
63 | """
64 | raise NotImplementedError
65 |
66 | def predict_xprev_from_xstart_xt(
67 | self, x_start: jnp.ndarray, x_t: jnp.ndarray, t_index: jnp.ndarray
68 | ) -> jnp.ndarray:
69 | """Get x_{t-1} from x_0 and x_t.
70 |
71 | Args:
72 | x_start: noisy input at t, shape (batch, ...).
73 | x_t: noisy input, same shape as x_start.
74 | t_index: storing index values < self.num_timesteps, shape (batch, ).
75 |
76 | Returns:
77 | predicted x_0, same shape as x_prev.
78 | """
79 | raise NotImplementedError
80 |
81 | def predict_xstart_from_model_out_xt(
82 | self,
83 | model_out: jnp.ndarray,
84 | x_t: jnp.ndarray,
85 | t_index: jnp.ndarray,
86 | ) -> jnp.ndarray:
87 | """Predict x_0 from model output and x_t.
88 |
89 | Args:
90 | model_out: model output.
91 | x_t: noisy input.
92 | t_index: storing index values < self.num_timesteps.
93 |
94 | Returns:
95 | x_start, same shape as x_t.
96 | """
97 | raise NotImplementedError
98 |
99 | def variational_lower_bound(
100 | self,
101 | model_out: jnp.ndarray,
102 | x_start: jnp.ndarray,
103 | x_t: jnp.ndarray,
104 | t_index: jnp.ndarray,
105 | ) -> tuple[jnp.ndarray, jnp.ndarray]:
106 | """Variational lower-bound, ELBO, smaller is better.
107 |
108 | Args:
109 | model_out: raw model output, may contain additional parameters.
110 | x_start: noiseless input.
111 | x_t: noisy input, same shape as x_start.
112 | t_index: storing index values < self.num_timesteps.
113 |
114 | Returns:
115 | - lower bounds, shape (batch, ).
116 | - model_out with the same shape as x_start.
117 | """
118 | raise NotImplementedError
119 |
120 | def sample(
121 | self,
122 | key: jax.Array,
123 | model_out: jnp.ndarray,
124 | x_t: jnp.ndarray,
125 | t_index: jnp.ndarray,
126 | ) -> tuple[jnp.ndarray, jnp.ndarray]:
127 | """Sample x_{t-1} ~ p(x_{t-1} | x_t).
128 |
129 | Args:
130 | key: random key.
131 | model_out: model predicted output.
132 | If model estimates variance, the last axis will be split.
133 | x_t: noisy x at time t.
134 | t_index: storing index values < self.num_timesteps.
135 |
136 | Returns:
137 | sample: x_{t-1}, same shape as x_t.
138 | x_start_pred: same shape as x_t.
139 | """
140 | raise NotImplementedError
141 |
142 | def diffusion_loss(
143 | self,
144 | x_start: jnp.ndarray,
145 | x_t: jnp.ndarray,
146 | t_index: jnp.ndarray,
147 | noise: jnp.ndarray,
148 | model_out: jnp.ndarray,
149 | ) -> tuple[dict[str, jnp.ndarray], jnp.ndarray]:
150 | """Diffusion-specific loss function.
151 |
152 | Args:
153 | x_start: noiseless input.
154 | x_t: noisy input.
155 | t_index: storing index values < self.num_timesteps.
156 | noise: sampled noise, same shape as x_t.
157 | model_out: model output, may contain additional parameters.
158 |
159 | Returns:
160 | scalars: dict of losses, each with shape (batch, ).
161 | model_out: same shape as x_start.
162 | """
163 | raise NotImplementedError
164 |
--------------------------------------------------------------------------------
/imgx/diffusion/gaussian/__init__.py:
--------------------------------------------------------------------------------
1 | """Gaussian based continuous diffusion models."""
2 |
--------------------------------------------------------------------------------
/imgx/diffusion/gaussian/sampler.py:
--------------------------------------------------------------------------------
1 | """Module for sampling."""
2 | from __future__ import annotations
3 |
4 | import jax
5 | import jax.numpy as jnp
6 |
7 | from imgx.diffusion.gaussian.gaussian_diffusion import GaussianDiffusion
8 | from imgx.diffusion.util import expand, extract_and_expand
9 |
10 |
11 | class DDPMSampler(GaussianDiffusion):
12 | """DDPM https://arxiv.org/abs/2006.11239."""
13 |
14 | def sample(
15 | self,
16 | key: jax.Array,
17 | model_out: jnp.ndarray,
18 | x_t: jnp.ndarray,
19 | t_index: jnp.ndarray,
20 | ) -> tuple[jnp.ndarray, jnp.ndarray]:
21 | """Sample x_{t-1} ~ p(x_{t-1} | x_t) using DDPM.
22 |
23 | https://arxiv.org/abs/2006.11239
24 |
25 | Args:
26 | key: random key.
27 | model_out: model predicted output.
28 | If model estimates variance, the last axis will be split.
29 | x_t: noisy input, shape (batch, ...).
30 | t_index: storing index values < self.num_timesteps,
31 | shape (batch, ) or broadcast-compatible to x_start shape.
32 |
33 | Returns:
34 | sample: x_{t-1}, same shape as x_t.
35 | x_start_pred: same shape as x_t.
36 | """
37 | x_start_pred, mean, log_variance = self.p_mean_variance(
38 | model_out=model_out,
39 | x_t=x_t,
40 | t_index=t_index,
41 | )
42 | noise = self.sample_noise(key=key, shape=x_t.shape, dtype=x_t.dtype)
43 |
44 | # no noise when t=0
45 | # mean + exp(log(sigma**2)/2) * noise = mean + sigma * noise
46 | nonzero_mask = jnp.array(t_index != 0, dtype=noise.dtype)
47 | nonzero_mask = expand(nonzero_mask, noise.ndim)
48 | sample = mean + nonzero_mask * jnp.exp(0.5 * log_variance) * noise
49 |
50 | return sample, x_start_pred
51 |
52 |
53 | class DDIMSampler(GaussianDiffusion):
54 | """DDIM https://arxiv.org/abs/2010.02502.
55 |
56 | https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ddim/pipeline_ddim.py
57 | """
58 |
59 | def sample(
60 | self,
61 | key: jax.Array,
62 | model_out: jnp.ndarray,
63 | x_t: jnp.ndarray,
64 | t_index: jnp.ndarray,
65 | eta: float = 0.0,
66 | ) -> tuple[jnp.ndarray, jnp.ndarray]:
67 | """Sample x_{t-1} ~ p(x_{t-1} | x_t) using DDIM.
68 |
69 | https://arxiv.org/abs/2010.02502
70 |
71 | Args:
72 | key: random key.
73 | model_out: model predicted output.
74 | If model estimates variance, the last axis will be split.
75 | x_t: noisy input, shape (batch, ...).
76 | t_index: storing index values < self.num_timesteps,
77 | shape (batch, ) or broadcast-compatible to x_start shape.
78 | eta: control the noise level in sampling.
79 |
80 | Returns:
81 | sample: x_{t-1}, same shape as x_t.
82 | x_start_pred: same shape as x_t.
83 | """
84 | # prepare constants
85 | x_start_pred, _ = self.p_mean(
86 | model_out=model_out,
87 | x_t=x_t,
88 | t_index=t_index,
89 | )
90 | noise = self.predict_noise_from_xstart_xt(x_t=x_t, x_start=x_start_pred, t_index=t_index)
91 | alphas_cumprod_prev = extract_and_expand(
92 | self.alphas_cumprod_prev, t_index=t_index, ndim=x_t.ndim
93 | )
94 | coeff_start = jnp.sqrt(alphas_cumprod_prev)
95 | log_variance = (
96 | extract_and_expand(
97 | self.posterior_log_variance_clipped,
98 | t_index=t_index,
99 | ndim=x_t.ndim,
100 | )
101 | * eta
102 | )
103 | coeff_noise = jnp.sqrt(1.0 - alphas_cumprod_prev - log_variance**2)
104 | mean = coeff_start * x_start_pred + coeff_noise * noise
105 |
106 | # deterministic for t_index > 0
107 | nonzero_mask = jnp.array(t_index != 0, dtype=x_t.dtype)
108 | nonzero_mask = expand(nonzero_mask, x_t.ndim)
109 |
110 | # sample
111 | noise = self.sample_noise(key=key, shape=x_t.shape, dtype=x_t.dtype)
112 | sample = mean + nonzero_mask * log_variance * noise
113 | return sample, x_start_pred
114 |
--------------------------------------------------------------------------------
/imgx/diffusion/gaussian/variance_schedule.py:
--------------------------------------------------------------------------------
1 | """Variance schedule for diffusion models."""
2 | from __future__ import annotations
3 |
4 | import jax.numpy as jnp
5 | import numpy as np
6 |
7 |
8 | def get_beta_schedule(
9 | num_timesteps: int,
10 | beta_schedule: str,
11 | beta_start: float,
12 | beta_end: float,
13 | ) -> jnp.ndarray:
14 | """Get variance (beta) schedule for q(x_t | x_{t-1}).
15 |
16 | Args:
17 | num_timesteps: number of time steps in total, T.
18 | beta_schedule: schedule for beta.
19 | beta_start: beta for t=0.
20 | beta_end: beta for t=T-1.
21 |
22 | Returns:
23 | Shape (num_timesteps,) array of beta values, for t=0, ..., T-1.
24 | Values are in ascending order.
25 |
26 | Raises:
27 | ValueError: for unknown schedule.
28 | """
29 | if beta_schedule == "linear":
30 | return jnp.linspace(
31 | beta_start,
32 | beta_end,
33 | num_timesteps,
34 | )
35 | if beta_schedule == "quadradic":
36 | return (
37 | jnp.linspace(
38 | beta_start**0.5,
39 | beta_end**0.5,
40 | num_timesteps,
41 | )
42 | ** 2
43 | )
44 | if beta_schedule == "cosine":
45 |
46 | def f(t: float) -> float:
47 | """Eq 17 in https://arxiv.org/abs/2102.09672.
48 |
49 | Args:
50 | t: time step with values in [0, 1].
51 |
52 | Returns:
53 | Cumulative product of alpha.
54 | """
55 | return np.cos((t + 0.008) / 1.008 * np.pi / 2) ** 2
56 |
57 | betas = [0.0]
58 | alphas_cumprod_prev = 1.0
59 | for i in range(1, num_timesteps):
60 | t = i / (num_timesteps - 1)
61 | alphas_cumprod = f(t)
62 | beta = 1 - alphas_cumprod / alphas_cumprod_prev
63 | betas.append(beta)
64 | return jnp.array(betas) * (beta_end - beta_start) + beta_start
65 |
66 | if beta_schedule == "warmup10":
67 | num_timesteps_warmup = max(num_timesteps // 10, 1)
68 | betas_warmup = (
69 | jnp.linspace(
70 | beta_start**0.5,
71 | beta_end**0.5,
72 | num_timesteps_warmup,
73 | )
74 | ** 2
75 | )
76 | return jnp.concatenate(
77 | [
78 | betas_warmup,
79 | jnp.ones((num_timesteps - num_timesteps_warmup,)) * beta_end,
80 | ]
81 | )
82 | if beta_schedule == "warmup50":
83 | num_timesteps_warmup = max(num_timesteps // 2, 1)
84 | betas_warmup = (
85 | jnp.linspace(
86 | beta_start**0.5,
87 | beta_end**0.5,
88 | num_timesteps_warmup,
89 | )
90 | ** 2
91 | )
92 | return jnp.concatenate(
93 | [
94 | betas_warmup,
95 | jnp.ones((num_timesteps - num_timesteps_warmup,)) * beta_end,
96 | ]
97 | )
98 | raise ValueError(f"Unknown beta_schedule {beta_schedule}.")
99 |
100 |
101 | def downsample_beta_schedule(
102 | betas: jnp.ndarray,
103 | num_timesteps: int,
104 | num_timesteps_to_keep: int,
105 | ) -> jnp.ndarray:
106 | """Down-sample beta schedule.
107 |
108 | After down-sampling, the first and last values of alphas_cumprod are kept.
109 |
110 | Args:
111 | betas: beta schedule, shape (num_timesteps,).
112 | Values are in ascending order.
113 | num_timesteps: number of time steps in total, T.
114 | num_timesteps_to_keep: number of time steps to keep.
115 |
116 | Returns:
117 | Down-sampled beta schedule, shape (num_timesteps_to_keep,).
118 | """
119 | if betas.shape != (num_timesteps,):
120 | raise ValueError(
121 | f"betas.shape ({betas.shape}) must be equal to (num_timesteps,)=({num_timesteps},)"
122 | )
123 | if num_timesteps_to_keep > num_timesteps:
124 | raise ValueError(
125 | f"num_timesteps_to_keep ({num_timesteps_to_keep}) "
126 | f"must be <= num_timesteps ({num_timesteps})"
127 | )
128 | if (num_timesteps - 1) % (num_timesteps_to_keep - 1) != 0:
129 | raise ValueError(
130 | f"num_timesteps-1={num_timesteps-1} can't be evenly divided by "
131 | f"num_timesteps_to_keep-1={num_timesteps_to_keep-1}."
132 | )
133 | if num_timesteps_to_keep < 2:
134 | raise ValueError(f"num_timesteps_to_keep ({num_timesteps_to_keep}) must be >= 2.")
135 | if num_timesteps_to_keep == num_timesteps:
136 | return betas
137 | step_scale = (num_timesteps - 1) // (num_timesteps_to_keep - 1)
138 | alphas_cumprod = jnp.cumprod(1.0 - betas)
139 | # (num_timesteps_to_keep,)
140 | alphas_cumprod = alphas_cumprod[::step_scale]
141 |
142 | # recompute betas
143 | # (num_timesteps_to_keep-1,)
144 | alphas = alphas_cumprod[1:] / alphas_cumprod[:-1]
145 | # (num_timesteps_to_keep,)
146 | alphas = jnp.concatenate([alphas_cumprod[:1], alphas])
147 | betas = 1 - alphas
148 | return betas
149 |
--------------------------------------------------------------------------------
/imgx/diffusion/gaussian/variance_schedule_test.py:
--------------------------------------------------------------------------------
1 | """Test Gaussian diffusion related classes and functions."""
2 |
3 |
4 | import chex
5 | import jax.numpy as jnp
6 | from absl.testing import parameterized
7 | from chex._src import fake
8 |
9 | from imgx.diffusion.gaussian.variance_schedule import downsample_beta_schedule, get_beta_schedule
10 |
11 |
12 | # Set `FLAGS.chex_n_cpu_devices` CPU devices for all tests.
13 | def setUpModule() -> None: # pylint: disable=invalid-name
14 | """Fake multi-devices."""
15 | fake.set_n_cpu_devices(2)
16 |
17 |
18 | class TestGetBetaSchedule(chex.TestCase):
19 | """Test get_beta_schedule."""
20 |
21 | @parameterized.product(
22 | num_timesteps=[1, 4],
23 | beta_schedule=[
24 | "linear",
25 | "quadradic",
26 | "cosine",
27 | "warmup10",
28 | "warmup50",
29 | ],
30 | )
31 | def test_shapes(
32 | self,
33 | num_timesteps: int,
34 | beta_schedule: str,
35 | ) -> None:
36 | """Test output shape."""
37 | beta_start = 0.0
38 | beta_end = 0.2
39 | got = get_beta_schedule(
40 | num_timesteps=num_timesteps,
41 | beta_schedule=beta_schedule,
42 | beta_start=beta_start,
43 | beta_end=beta_end,
44 | )
45 | chex.assert_shape(got, (num_timesteps,))
46 |
47 | assert got[0] == beta_start
48 | if num_timesteps > 1:
49 | chex.assert_trees_all_close(got[-1], beta_end)
50 |
51 |
52 | class TestDownsampleBetaSchedule(chex.TestCase):
53 | """Test downsample_beta_schedule."""
54 |
55 | @parameterized.named_parameters(
56 | ("same 1001 steps", 1001, 1001),
57 | ("same 101 steps", 101, 101),
58 | ("downsample 21 to 5", 21, 5),
59 | ("downsample 101 to 5", 101, 5),
60 | ("downsample 11 to 3", 11, 3),
61 | )
62 | def test_values(
63 | self,
64 | num_timesteps: int,
65 | num_timesteps_to_keep: int,
66 | ) -> None:
67 | """Test output values and shapes."""
68 | betas = jnp.linspace(1e-4, 0.02, num_timesteps)
69 | alphas_cumprod = jnp.cumprod(1.0 - betas)
70 | got = downsample_beta_schedule(betas, num_timesteps, num_timesteps_to_keep)
71 | alphas_cumprod_got = jnp.cumprod(1.0 - got)
72 | chex.assert_shape(got, (num_timesteps_to_keep,))
73 | chex.assert_trees_all_close(alphas_cumprod_got[0], alphas_cumprod[0])
74 | chex.assert_trees_all_close(alphas_cumprod_got[-1], alphas_cumprod[-1])
75 |
--------------------------------------------------------------------------------
/imgx/diffusion/util.py:
--------------------------------------------------------------------------------
1 | """Utility functions for diffusion models."""
2 | from __future__ import annotations
3 |
4 | import jax.numpy as jnp
5 |
6 |
7 | def extract_and_expand(arr: jnp.ndarray, t_index: jnp.ndarray, ndim: int) -> jnp.ndarray:
8 | """Extract values from a 1D array and expand.
9 |
10 | This function is not jittable.
11 |
12 | Args:
13 | arr: 1D of shape (num_timesteps, ).
14 | t_index: storing index values < self.num_timesteps,
15 | shape (batch, ) or has ndim dimension.
16 | ndim: number of dimensions for an array of shape (batch, ...).
17 |
18 | Returns:
19 | Expanded array of shape (batch, ...), expanded axes have dim 1.
20 | """
21 | if arr.ndim != 1:
22 | raise ValueError(f"arr must be 1D, got {arr.ndim}D.")
23 | x = arr[t_index]
24 | return expand(x, ndim)
25 |
26 |
27 | def expand(x: jnp.ndarray, ndim: int) -> jnp.ndarray:
28 | """Expand.
29 |
30 | This function is not jittable.
31 |
32 | Args:
33 | x: a 1D or nD array.
34 | ndim: number of dimensions as output.
35 |
36 | Returns:
37 | Expanded array, expanded axes have dim 1.
38 | """
39 | if x.ndim == 1:
40 | return jnp.expand_dims(x, axis=tuple(range(1, ndim)))
41 | if x.ndim == ndim:
42 | return x
43 | raise ValueError(f"t_index must be 1D or {ndim}D, got {x.ndim}D.")
44 |
--------------------------------------------------------------------------------
/imgx/diffusion/util_test.py:
--------------------------------------------------------------------------------
1 | """Test Gaussian diffusion related classes and functions."""
2 |
3 |
4 | import chex
5 | import jax
6 | import jax.numpy as jnp
7 | from absl.testing import parameterized
8 | from chex._src import fake
9 |
10 | from imgx.diffusion.util import extract_and_expand
11 |
12 |
13 | # Set `FLAGS.chex_n_cpu_devices` CPU devices for all tests.
14 | def setUpModule() -> None: # pylint: disable=invalid-name
15 | """Fake multi-devices."""
16 | fake.set_n_cpu_devices(2)
17 |
18 |
19 | class TestExtractAndExpand(chex.TestCase):
20 | """Test extract_and_expand."""
21 |
22 | @chex.variants(without_jit=True, with_device=True, without_device=True)
23 | @parameterized.named_parameters(
24 | (
25 | "1d",
26 | 1,
27 | ),
28 | (
29 | "2d",
30 | 2,
31 | ),
32 | (
33 | "3d",
34 | 3,
35 | ),
36 | )
37 | def test_shapes(
38 | self,
39 | ndim: int,
40 | ) -> None:
41 | """Test output shape.
42 |
43 | Args:
44 | ndim: number of dimensions.
45 | """
46 | batch_size = 2
47 | betas = jnp.array([0, 0.2, 0.5, 1.0])
48 | num_timesteps = len(betas)
49 | rng = jax.random.PRNGKey(0)
50 | t_index = jax.random.randint(rng, shape=(batch_size,), minval=0, maxval=num_timesteps)
51 | got = self.variant(extract_and_expand)(arr=betas, t_index=t_index, ndim=ndim)
52 | expected_shape = (batch_size,) + (1,) * (ndim - 1)
53 | chex.assert_shape(got, expected_shape)
54 |
--------------------------------------------------------------------------------
/imgx/experiment.py:
--------------------------------------------------------------------------------
1 | """Experiment interface."""
2 | from __future__ import annotations
3 |
4 | from collections.abc import Iterator
5 | from pathlib import Path
6 |
7 | import chex
8 | import jax
9 | import jax.numpy as jnp
10 | import numpy as np
11 | import tensorflow as tf
12 | from absl import logging
13 | from omegaconf import DictConfig
14 |
15 | from imgx.datasets import INFO_MAP
16 | from imgx.metric.util import merge_aggregated_metrics
17 | from imgx.train_state import TrainState
18 |
19 |
20 | class Experiment:
21 | """Experiment for supervised training."""
22 |
23 | def __init__(self, config: DictConfig) -> None:
24 | """Initializes experiment.
25 |
26 | Args:
27 | config: experiment config.
28 | """
29 | # Do not use accelerators in data pipeline.
30 | try:
31 | tf.config.set_visible_devices([], device_type="GPU")
32 | tf.config.set_visible_devices([], device_type="TPU")
33 | except RuntimeError:
34 | logging.error(
35 | f"Failed to set visible devices, data set may be using GPU/TPUs. "
36 | f"Visible GPU devices: {tf.config.get_visible_devices('GPU')}. "
37 | f"Visible TPU devices: {tf.config.get_visible_devices('TPU')}."
38 | )
39 |
40 | self.config = config
41 | self.dataset_info = INFO_MAP[self.config.data.name]
42 | self.p_train_step = None # To be defined in train_init
43 | self.p_eval_step = None # To be defined in train_init
44 |
45 | def train_init(
46 | self, batch: dict[str, jnp.ndarray], ckpt_dir: Path | None = None, step: int | None = None
47 | ) -> tuple[TrainState, int]:
48 | """Initialize data loader, loss, networks for training.
49 |
50 | Args:
51 | batch: training data.
52 | ckpt_dir: checkpoint directory to restore from.
53 | step: checkpoint step to restore from, if None use the latest one.
54 |
55 | Returns:
56 | initialized training state.
57 | """
58 | raise NotImplementedError
59 |
60 | def train_step(
61 | self, train_state: TrainState, batch: dict[str, jnp.ndarray], key: jax.Array
62 | ) -> tuple[TrainState, chex.ArrayTree]:
63 | """Perform a training step.
64 |
65 | Args:
66 | train_state: training state.
67 | batch: training data.
68 | key: random key.
69 |
70 | Returns:
71 | - new training state.
72 | - new random key.
73 | - metric dict.
74 | """
75 | # key is updated/fold inside pmap function
76 | # to ensure a different key is used per step
77 | train_state, metrics = self.p_train_step( # pylint: disable=not-callable
78 | train_state,
79 | batch,
80 | key,
81 | )
82 | metrics = merge_aggregated_metrics(metrics)
83 | metrics = jax.tree_map(lambda x: x.item(), metrics) # tensor to values
84 | return train_state, metrics
85 |
86 | def eval_step(
87 | self,
88 | train_state: TrainState,
89 | iterator: Iterator[dict[str, jnp.ndarray]],
90 | num_steps: int,
91 | key: jax.Array,
92 | out_dir: Path | None = None,
93 | ) -> dict[str, jnp.ndarray]:
94 | """Evaluation on entire validation/test data set.
95 |
96 | Args:
97 | train_state: training state.
98 | iterator: data iterator.
99 | num_steps: number of steps for evaluation.
100 | key: random key.
101 | out_dir: output directory, if not None, predictions will be saved.
102 |
103 | Returns:
104 | metric dict.
105 | """
106 | raise NotImplementedError
107 |
108 | def eval_batch(
109 | self,
110 | train_state: TrainState,
111 | key: jax.Array,
112 | batch: dict[str, jnp.ndarray],
113 | uids: list[str],
114 | device_cpu: jax.Device,
115 | ) -> tuple[list[str], dict[str, np.ndarray], dict[str, np.ndarray]]:
116 | """Evaluate a batch.
117 |
118 | Args:
119 | train_state: training state.
120 | key: random key.
121 | batch: batch data without uid.
122 | uids: uids in the batch, potentially including padded samples.
123 | device_cpu: cpu device.
124 |
125 | Returns:
126 | uids: uids in the batch, excluding padded samples.
127 | metrics: each item has shape (num_samples,).
128 | prediction dict: each item has shape (num_samples, ...).
129 | """
130 | raise NotImplementedError
131 |
--------------------------------------------------------------------------------
/imgx/integration_test.py:
--------------------------------------------------------------------------------
1 | """Test experiments train, valid, and test.
2 |
3 | mocker.patch, https://pytest-mock.readthedocs.io/en/latest/
4 | """
5 | import shutil
6 | from tempfile import TemporaryDirectory
7 |
8 | import pytest
9 | from pytest_mock import MockFixture
10 |
11 | from imgx.run_test import main as run_test
12 | from imgx.run_train import main as run_train
13 | from imgx.run_valid import main as run_valid
14 |
15 |
16 | @pytest.mark.integration()
17 | @pytest.mark.parametrize(
18 | "dataset",
19 | ["muscle_us", "amos_ct"],
20 | )
21 | def test_segmentation_train_valid_test(mocker: MockFixture, dataset: str) -> None:
22 | """Test train, valid, and test.
23 |
24 | Args:
25 | mocker: mocker, a wrapper of unittest.mock.
26 | dataset: dataset name.
27 | """
28 | with TemporaryDirectory() as temp_dir:
29 | mocker.resetall()
30 | mocker.patch.dict("os.environ", {"WANDB_MODE": "offline"})
31 | mocker.patch(
32 | "sys.argv",
33 | ["pytest", "debug=true", "task=seg", f"data={dataset}", f"logging.root_dir={temp_dir}"],
34 | )
35 | run_train() # pylint: disable=no-value-for-parameter
36 | mocker.patch(
37 | "sys.argv",
38 | ["pytest", f"--log_dir={temp_dir}/wandb/latest-run"],
39 | )
40 | run_valid()
41 | run_test()
42 | shutil.rmtree(temp_dir)
43 |
44 |
45 | @pytest.mark.integration()
46 | @pytest.mark.parametrize(
47 | "dataset",
48 | ["muscle_us", "amos_ct"],
49 | )
50 | def test_diffusion_segmentation_train_valid_test(mocker: MockFixture, dataset: str) -> None:
51 | """Test train, valid, and test.
52 |
53 | Args:
54 | mocker: mocker, a wrapper of unittest.mock.
55 | dataset: dataset name.
56 | """
57 | with TemporaryDirectory() as temp_dir:
58 | mocker.resetall()
59 | mocker.patch.dict("os.environ", {"WANDB_MODE": "offline"})
60 | mocker.patch(
61 | "sys.argv",
62 | [
63 | "pytest",
64 | "debug=true",
65 | "task=gaussian_diff_seg",
66 | f"data={dataset}",
67 | f"logging.root_dir={temp_dir}",
68 | ],
69 | )
70 | run_train() # pylint: disable=no-value-for-parameter
71 | mocker.patch(
72 | "sys.argv",
73 | [
74 | "pytest",
75 | "--num_timesteps=2",
76 | "--sampler=DDPM",
77 | f"--log_dir={temp_dir}/wandb/latest-run",
78 | ],
79 | )
80 | run_valid()
81 | run_test()
82 | shutil.rmtree(temp_dir)
83 |
--------------------------------------------------------------------------------
/imgx/loss/__init__.py:
--------------------------------------------------------------------------------
1 | """Package for loss functions."""
2 | from imgx.loss.cross_entropy import cross_entropy, focal_loss
3 | from imgx.loss.deformation import bending_energy_loss, gradient_norm_loss, jacobian_loss
4 | from imgx.loss.dice import dice_loss, dice_loss_from_masks
5 | from imgx.loss.similarity import nrmsd_loss, psnr_loss
6 |
7 | __all__ = [
8 | "cross_entropy",
9 | "focal_loss",
10 | "dice_loss",
11 | "dice_loss_from_masks",
12 | "psnr_loss",
13 | "nrmsd_loss",
14 | "bending_energy_loss",
15 | "gradient_norm_loss",
16 | "jacobian_loss",
17 | ]
18 |
--------------------------------------------------------------------------------
/imgx/loss/cross_entropy.py:
--------------------------------------------------------------------------------
1 | """Loss functions for classification."""
2 | import jax
3 | import jax.numpy as jnp
4 | import optax
5 | from jax import lax
6 |
7 |
8 | def cross_entropy(
9 | logits: jnp.ndarray,
10 | mask_true: jnp.ndarray,
11 | classes_are_exclusive: bool,
12 | ) -> jnp.ndarray:
13 | """Cross entropy, supporting soft label.
14 |
15 | optax.softmax_cross_entropy returns (batch, ...).
16 | optax.sigmoid_binary_cross_entropy returns (batch, ..., num_classes).
17 |
18 | Args:
19 | logits: unscaled prediction, (batch, ..., num_classes).
20 | mask_true: probabilities per class, (batch, ..., num_classes).
21 | classes_are_exclusive: if False, each element can be assigned to multiple classes.
22 |
23 | Returns:
24 | Cross entropy loss value of shape (batch, ).
25 | """
26 | mask_true = mask_true.astype(logits.dtype)
27 | loss = lax.cond(
28 | classes_are_exclusive,
29 | optax.softmax_cross_entropy,
30 | lambda *args: jnp.sum(optax.sigmoid_binary_cross_entropy(*args), axis=-1),
31 | logits,
32 | mask_true,
33 | )
34 | return jnp.mean(loss, axis=range(1, loss.ndim))
35 |
36 |
37 | def softmax_focal_loss(
38 | logits: jnp.ndarray,
39 | mask_true: jnp.ndarray,
40 | gamma: float = 2.0,
41 | ) -> jnp.ndarray:
42 | """Focal loss with one-hot / mutual exclusive classes.
43 |
44 | https://arxiv.org/abs/1708.02002
45 | Implementation is similar to optax.softmax_cross_entropy.
46 |
47 | Args:
48 | logits: unscaled prediction, (batch, ..., num_classes).
49 | mask_true: probabilities per class, (batch, ..., num_classes).
50 | gamma: adjust class imbalance, 0 is equivalent to cross entropy.
51 |
52 | Returns:
53 | Loss of shape (batch, ...).
54 | """
55 | log_p = jax.nn.log_softmax(logits)
56 | p = jnp.exp(log_p)
57 | return -jnp.sum(((1 - p) ** gamma) * log_p * mask_true, axis=-1)
58 |
59 |
60 | def sigmoid_focal_loss(
61 | logits: jnp.ndarray,
62 | mask_true: jnp.ndarray,
63 | gamma: float = 2.0,
64 | ) -> jnp.ndarray:
65 | """Focal loss with multi-hot / non mutual exclusive classes.
66 |
67 | https://arxiv.org/abs/1708.02002
68 | Implementation is similar to optax.sigmoid_binary_cross_entropy.
69 |
70 | Args:
71 | logits: unscaled prediction, (batch, ..., num_classes).
72 | mask_true: probabilities per class, (batch, ..., num_classes).
73 | gamma: adjust class imbalance, 0 is equivalent to cross entropy.
74 |
75 | Returns:
76 | Focal loss value of shape (batch, ..., num_classes).
77 | """
78 | log_p = jax.nn.log_sigmoid(logits)
79 | log_not_p = jax.nn.log_sigmoid(-logits)
80 | p = jnp.exp(log_p)
81 | return -((1 - p) ** gamma) * log_p * mask_true - (p**gamma) * log_not_p * (1 - mask_true)
82 |
83 |
84 | def focal_loss(
85 | logits: jnp.ndarray,
86 | mask_true: jnp.ndarray,
87 | classes_are_exclusive: bool,
88 | gamma: float = 2.0,
89 | ) -> jnp.ndarray:
90 | """Focal loss.
91 |
92 | https://arxiv.org/abs/1708.02002
93 | softmax_focal_loss returns (batch, ...).
94 | sigmoid_focal_loss returns (batch, ..., num_classes).
95 |
96 | Args:
97 | logits: unscaled prediction, (batch, ..., num_classes).
98 | mask_true: probabilities per class, (batch, ..., num_classes).
99 | classes_are_exclusive: if False, each element can be assigned to multiple classes.
100 | gamma: adjust class imbalance, 0 is equivalent to cross entropy.
101 |
102 | Returns:
103 | Focal loss value of shape (batch, ).
104 | """
105 | mask_true = mask_true.astype(logits.dtype)
106 | loss = lax.cond(
107 | classes_are_exclusive,
108 | softmax_focal_loss,
109 | lambda *args: jnp.sum(sigmoid_focal_loss(*args), axis=-1),
110 | logits,
111 | mask_true,
112 | gamma,
113 | )
114 | return jnp.mean(loss, axis=range(1, loss.ndim))
115 |
--------------------------------------------------------------------------------
/imgx/loss/deformation.py:
--------------------------------------------------------------------------------
1 | """Deformation losses for ddf."""
2 | from functools import partial
3 |
4 | import jax
5 | import jax.numpy as jnp
6 |
7 | from imgx.metric.deformation import gradient, jacobian_det
8 |
9 |
10 | def gradient_norm_loss(x: jnp.ndarray, norm_ord: int, spacing: jnp.ndarray) -> jnp.ndarray:
11 | """Calculate noram of gradients of x using central finite difference.
12 |
13 | Args:
14 | x: shape = (batch, d1, d2, ..., dn, channel).
15 | norm_ord: 1 for L1 or 2 for L2.
16 | spacing: spacing between each pixel/voxel, shape = (n,).
17 |
18 | Returns:
19 | shape = (batch,).
20 | """
21 | # (batch, d1, d2, ..., dn, channel, n)
22 | grad = gradient(x, spacing)
23 | if norm_ord == 1:
24 | return jnp.mean(jnp.abs(grad), axis=tuple(range(1, grad.ndim)))
25 | if norm_ord == 2:
26 | return jnp.mean(grad**2, axis=tuple(range(1, grad.ndim)))
27 | raise ValueError(f"norm_ord = {norm_ord} is not supported.")
28 |
29 |
30 | def bending_energy_loss(x: jnp.ndarray, spacing: jnp.ndarray) -> jnp.ndarray:
31 | """Calculate bending energey (L2 norm of second order gradient) using central finite difference.
32 |
33 | Args:
34 | x: shape = (batch, d1, d2, ..., dn, channel).
35 | spacing: spacing between each pixel/voxel, shape = (n,).
36 |
37 | Returns:
38 | shape = (batch,).
39 | """
40 | # (batch, d1, d2, ..., dn, channel, n)
41 | grad_1d = gradient(x, spacing)
42 | # (batch, d1, d2, ..., dn, channel, n, n)
43 | grad_2d = jax.vmap(partial(gradient, spacing=spacing), in_axes=-1, out_axes=-1)(grad_1d)
44 | # (batch,)
45 | return jnp.mean(grad_2d**2, axis=tuple(range(1, grad_2d.ndim)))
46 |
47 |
48 | def jacobian_loss(x: jnp.ndarray, spacing: jnp.ndarray) -> jnp.ndarray:
49 | """Calculate Jacobian loss.
50 |
51 | https://arxiv.org/abs/1907.00068
52 |
53 | If the Jacobian determinant is <0, the transformation is folding, thus not volume preserving.
54 | Do not penalize the Jacobian determinant if it is >0.
55 |
56 | Args:
57 | x: shape = (batch, d1, d2, ..., dn, n).
58 | spacing: spacing between each pixel/voxel, shape = (n,).
59 |
60 | Returns:
61 | shape = (batch,).
62 | """
63 | # (batch, d1, d2, ..., dn)
64 | det = jacobian_det(x, spacing)
65 | det = jnp.clip(det, a_max=0.0) # negative values are folding
66 | # (batch,)
67 | return jnp.mean(-det, axis=tuple(range(1, det.ndim))) # reverse sign
68 |
--------------------------------------------------------------------------------
/imgx/loss/deformation_test.py:
--------------------------------------------------------------------------------
1 | """Test deformation loss functions."""
2 | from functools import partial
3 |
4 | import chex
5 | import jax.numpy as jnp
6 | import numpy as np
7 | from absl.testing import parameterized
8 | from chex._src import fake
9 |
10 | from imgx.loss.deformation import bending_energy_loss, gradient_norm_loss, jacobian_loss
11 |
12 |
13 | # Set `FLAGS.chex_n_cpu_devices` CPU devices for all tests.
14 | def setUpModule() -> None: # pylint: disable=invalid-name
15 | """Fake multi-devices."""
16 | fake.set_n_cpu_devices(2)
17 |
18 |
19 | class TestGradientNormLoss(chex.TestCase):
20 | """Test gradient_norm_loss."""
21 |
22 | batch = 2
23 |
24 | @chex.all_variants()
25 | @parameterized.product(
26 | shape=[(2, 3), (2, 3, 4), (2, 3, 4, 5)],
27 | norm_ord=[1, 2],
28 | )
29 | def test_shapes(
30 | self,
31 | shape: tuple[int, ...],
32 | norm_ord: int,
33 | ) -> None:
34 | """Test return shapes.
35 |
36 | Args:
37 | shape: input shape.
38 | norm_ord: 1 for L1 or 2 for L2.
39 | """
40 | x = jnp.ones((self.batch, *shape))
41 | spacing = jnp.ones((len(shape),))
42 | got = self.variant(partial(gradient_norm_loss, norm_ord=norm_ord, spacing=spacing))(x)
43 | chex.assert_shape(got, (self.batch,))
44 |
45 | @chex.all_variants()
46 | @parameterized.named_parameters(
47 | (
48 | "1d - L1 norm",
49 | np.array([[[-2.3], [1.1], [-0.5]]]),
50 | 1,
51 | # norm [1.7, 0.9, -0.8]
52 | np.array([3.4 / 3]),
53 | ),
54 | (
55 | "1d - L2 norm",
56 | np.array([[[-2.3], [1.1], [-0.5]]]),
57 | 2,
58 | # norm [1.7, 0.9, -0.8]
59 | np.array([(1.7 * 1.7 + 0.9 * 0.9 + 0.8 * 0.8) / 3]),
60 | ),
61 | (
62 | "batch 1d - L1 norm",
63 | np.array([[[2.0], [1.0], [0.1]], [[0.0], [-1.0], [-2.0]], [[0.2], [-1.0], [-2.0]]]),
64 | 1,
65 | # norm (singleton axis is removed)
66 | # [[-0.5, -0.95, -0.45],
67 | # [-0.5, -1.0, -0.5],
68 | # [-0.6, -1.1, -0.5]],
69 | np.array([1.9 / 3, 2.0 / 3, 2.2 / 3]),
70 | ),
71 | (
72 | "2d - L1 norm",
73 | # (1,3,1,3)
74 | np.array([[[[2.0, 1.0, 0.1]], [[0.0, -1.0, -2.0]], [[0.2, -1.0, -2.0]]]]),
75 | 1,
76 | # dx norm (singleton axis is removed)
77 | # [[-1.0, -1.0, -1.05],
78 | # [-0.9, -1.0, -1.05],
79 | # [0.1, 0.0, 0.0]]
80 | # dy norm are all zeros
81 | np.array([6.1 / 18]),
82 | ),
83 | )
84 | def test_values(
85 | self,
86 | x: jnp.ndarray,
87 | norm_ord: int,
88 | expected: jnp.ndarray,
89 | ) -> None:
90 | """Test return values.
91 |
92 | Args:
93 | x: input.
94 | norm_ord: norm order.
95 | expected: expected output.
96 | """
97 | spacing = jnp.ones((x.ndim - 2,))
98 | got = self.variant(partial(gradient_norm_loss, norm_ord=norm_ord, spacing=spacing))(x)
99 | chex.assert_trees_all_close(got, expected)
100 |
101 |
102 | class TestBendingEnergyLoss(chex.TestCase):
103 | """Test bending_energy_loss."""
104 |
105 | batch = 2
106 |
107 | @chex.all_variants()
108 | @parameterized.product(
109 | shape=[(2, 3), (2, 3, 4), (2, 3, 4, 5)],
110 | )
111 | def test_shapes(
112 | self,
113 | shape: tuple[int, ...],
114 | ) -> None:
115 | """Test return shapes.
116 |
117 | Args:
118 | shape: input shape.
119 | """
120 | x = jnp.ones((self.batch, *shape))
121 | spacing = jnp.ones((x.ndim - 2,))
122 | got = self.variant(bending_energy_loss)(x, spacing)
123 | chex.assert_shape(got, (self.batch,))
124 |
125 | @chex.all_variants()
126 | @parameterized.named_parameters(
127 | (
128 | "1d",
129 | # (1, 3, 1)
130 | np.array([[[-2.3], [1.1], [-0.5]]]),
131 | # dx [1.7, 0.9, -0.8]
132 | # dxx [-0.4, -1.25, -0.85]
133 | np.array([(0.4 * 0.4 + 1.25 * 1.25 + 0.85 * 0.85) / 3]),
134 | ),
135 | )
136 | def test_values(
137 | self,
138 | x: jnp.ndarray,
139 | expected: jnp.ndarray,
140 | ) -> None:
141 | """Test return values.
142 |
143 | Args:
144 | x: input.
145 | expected: expected output.
146 | """
147 | spacing = jnp.ones((x.ndim - 2,))
148 | got = self.variant(bending_energy_loss)(x, spacing)
149 | chex.assert_trees_all_close(got, expected)
150 |
151 |
152 | class TestJacobianLoss(chex.TestCase):
153 | """Test jacobian_loss."""
154 |
155 | batch = 2
156 |
157 | @chex.all_variants()
158 | @parameterized.product(
159 | shape=[(2, 3), (2, 3, 4), (2, 3, 4, 5)],
160 | )
161 | def test_shapes(
162 | self,
163 | shape: tuple[int, ...],
164 | ) -> None:
165 | """Test return shapes.
166 |
167 | Args:
168 | shape: input shape.
169 | """
170 | x = jnp.ones((self.batch, *shape, len(shape)))
171 | spacing = jnp.ones((x.ndim - 2,))
172 | got = self.variant(jacobian_loss)(x, spacing)
173 | chex.assert_shape(got, (self.batch,))
174 |
--------------------------------------------------------------------------------
/imgx/loss/dice.py:
--------------------------------------------------------------------------------
1 | """Loss functions for image segmentation."""
2 | import jax
3 | import jax.numpy as jnp
4 | from jax import lax
5 |
6 |
7 | def dice_loss_from_masks(
8 | mask_pred: jnp.ndarray,
9 | mask_true: jnp.ndarray,
10 | ) -> jnp.ndarray:
11 | """Mean dice loss, smaller is better.
12 |
13 | Losses are not calculated on instance-classes, where there is no label.
14 | This is to avoid the need of smoothing and potentially nan gradients.
15 |
16 | Args:
17 | mask_pred: binary masks, (batch, ..., num_classes).
18 | mask_true: binary masks, (batch, ..., num_classes).
19 |
20 | Returns:
21 | Dice loss value of shape (batch, num_classes).
22 | """
23 | reduce_axis = tuple(range(mask_pred.ndim))[1:-1]
24 | # (batch, num_classes)
25 | numerator = 2.0 * jnp.sum(mask_pred * mask_true, axis=reduce_axis)
26 | denominator = jnp.sum(mask_pred + mask_true, axis=reduce_axis)
27 | not_nan_mask = jnp.sum(mask_true, axis=reduce_axis) > 0
28 | # nan loss are replaced by 0.0
29 | return jnp.where(
30 | condition=not_nan_mask,
31 | x=1.0 - numerator / denominator,
32 | y=jnp.nan,
33 | )
34 |
35 |
36 | def dice_loss(
37 | logits: jnp.ndarray,
38 | mask_true: jnp.ndarray,
39 | classes_are_exclusive: bool,
40 | ) -> jnp.ndarray:
41 | """Mean dice loss, smaller is better.
42 |
43 | Losses are not calculated on instance-classes, where there is no label.
44 | This is to avoid the need of smoothing and potentially nan gradients.
45 |
46 | Args:
47 | logits: unscaled prediction, (batch, ..., num_classes).
48 | mask_true: binary masks, (batch, ..., num_classes).
49 | classes_are_exclusive: classes are exclusive, i.e. no overlap.
50 |
51 | Returns:
52 | Dice loss value of shape (batch, num_classes).
53 | """
54 | mask_pred = lax.cond(
55 | classes_are_exclusive,
56 | jax.nn.softmax,
57 | jax.nn.sigmoid,
58 | logits,
59 | )
60 | return dice_loss_from_masks(mask_pred, mask_true)
61 |
--------------------------------------------------------------------------------
/imgx/loss/segmentation.py:
--------------------------------------------------------------------------------
1 | """Vanilla segmentation loss."""
2 | from __future__ import annotations
3 |
4 | import jax.numpy as jnp
5 | from omegaconf import DictConfig
6 |
7 | from imgx.datasets.dataset_info import DatasetInfo
8 | from imgx.loss import cross_entropy, dice_loss, focal_loss
9 | from imgx.metric import class_proportion, class_volume
10 |
11 |
12 | def segmentation_loss(
13 | logits: jnp.ndarray,
14 | label: jnp.ndarray,
15 | dataset_info: DatasetInfo,
16 | loss_config: DictConfig,
17 | ) -> tuple[jnp.ndarray, dict[str, jnp.ndarray]]:
18 | """Calculate segmentation loss with auxiliary losses and return metrics.
19 |
20 | Args:
21 | logits: unnormalised logits of shape (batch, ..., num_classes).
22 | label: label of shape (batch, ...).
23 | dataset_info: dataset info with helper functions.
24 | loss_config: have weights of diff losses.
25 |
26 | Returns:
27 | - calculated loss, of shape (batch,).
28 | - metrics, values of shape (batch,).
29 | """
30 | spacing = jnp.array(dataset_info.image_spacing)
31 | mask_true = dataset_info.label_to_mask(label, axis=-1)
32 | metrics = {}
33 |
34 | # (batch, num_classes)
35 | class_prop_batch_cls = class_proportion(mask_true)
36 | for i in range(dataset_info.num_classes):
37 | metrics[f"class_{i}_proportion_true"] = class_prop_batch_cls[:, i]
38 | class_volume_batch_cls = class_volume(mask_true, spacing)
39 | for i in range(dataset_info.num_classes):
40 | metrics[f"class_{i}_volume_true"] = class_volume_batch_cls[:, i]
41 |
42 | # total loss
43 | loss_batch = jnp.zeros((logits.shape[0],), dtype=logits.dtype)
44 | if loss_config.get("dice", 0.0) > 0:
45 | # (batch, num_classes)
46 | dice_loss_batch_cls = dice_loss(
47 | logits=logits,
48 | mask_true=mask_true,
49 | classes_are_exclusive=dataset_info.classes_are_exclusive,
50 | )
51 | # (batch, )
52 | # without background
53 | # mask out non-existing classes
54 | dice_loss_batch = jnp.mean(
55 | dice_loss_batch_cls[:, 1:], axis=-1, where=class_prop_batch_cls[:, 1:] > 0
56 | )
57 | metrics["dice_loss"] = dice_loss_batch
58 | for i in range(dice_loss_batch_cls.shape[-1]):
59 | metrics[f"dice_loss_class_{i}"] = dice_loss_batch_cls[:, i]
60 | loss_batch += dice_loss_batch * loss_config["dice"]
61 |
62 | if loss_config.get("cross_entropy", 0.0) > 0:
63 | # (batch, )
64 | ce_loss_batch = cross_entropy(
65 | logits=logits,
66 | mask_true=mask_true,
67 | classes_are_exclusive=dataset_info.classes_are_exclusive,
68 | )
69 | metrics["cross_entropy_loss"] = ce_loss_batch
70 | loss_batch += ce_loss_batch * loss_config["cross_entropy"]
71 |
72 | if loss_config.get("focal", 0.0) > 0:
73 | # (batch, )
74 | focal_loss_batch = focal_loss(
75 | logits=logits,
76 | mask_true=mask_true,
77 | classes_are_exclusive=dataset_info.classes_are_exclusive,
78 | )
79 | metrics["focal_loss"] = focal_loss_batch
80 | loss_batch += focal_loss_batch * loss_config["focal"]
81 | metrics["total_loss"] = loss_batch
82 | return loss_batch, metrics
83 |
--------------------------------------------------------------------------------
/imgx/loss/segmentation_test.py:
--------------------------------------------------------------------------------
1 | """Test segmentation loss."""
2 | from functools import partial
3 |
4 | import chex
5 | import jax.numpy as jnp
6 | from absl.testing import parameterized
7 | from chex._src import fake
8 | from omegaconf import DictConfig
9 |
10 | from imgx.datasets import INFO_MAP
11 | from imgx.loss.segmentation import segmentation_loss
12 |
13 |
14 | # Set `FLAGS.chex_n_cpu_devices` CPU devices for all tests.
15 | def setUpModule() -> None: # pylint: disable=invalid-name
16 | """Fake multi-devices."""
17 | fake.set_n_cpu_devices(2)
18 |
19 |
20 | class TestSegmentationLoss(chex.TestCase):
21 | """Test segmentation_loss."""
22 |
23 | batch_size = 2
24 |
25 | @chex.all_variants()
26 | @parameterized.product(
27 | dataset_name=sorted(INFO_MAP.keys()),
28 | loss_config=[
29 | {
30 | "cross_entropy": 1.0,
31 | "dice": 1.0,
32 | "focal": 1.0,
33 | },
34 | {
35 | "dice": 1.0,
36 | },
37 | ],
38 | )
39 | def test_shapes(
40 | self,
41 | dataset_name: str,
42 | loss_config: dict[str, float],
43 | ) -> None:
44 | """Test return shapes.
45 |
46 | Args:
47 | dataset_name: dataset name.
48 | loss_config: loss config.
49 | """
50 | dataset_info = INFO_MAP[dataset_name]
51 | shape = dataset_info.image_spatial_shape
52 | shape = tuple(max(x // 16, 2) for x in shape) # reduce shape to speed up test
53 | logits = jnp.ones(
54 | (self.batch_size, *shape, dataset_info.num_classes),
55 | dtype=jnp.float32,
56 | )
57 | label = jnp.ones((self.batch_size, *shape), dtype=jnp.int32)
58 |
59 | got_loss_batch, got_metrics = self.variant(
60 | partial(
61 | segmentation_loss,
62 | dataset_info=dataset_info,
63 | loss_config=DictConfig(loss_config),
64 | )
65 | )(logits, label)
66 | chex.assert_shape(got_loss_batch, (self.batch_size,))
67 | for v in got_metrics.values():
68 | chex.assert_shape(v, (self.batch_size,))
69 |
--------------------------------------------------------------------------------
/imgx/loss/similarity.py:
--------------------------------------------------------------------------------
1 | """Image similarity loss functions."""
2 | import jax.numpy as jnp
3 |
4 | from imgx import EPS
5 | from imgx.metric import nrmsd, psnr
6 |
7 |
8 | def psnr_loss(
9 | image1: jnp.ndarray,
10 | image2: jnp.ndarray,
11 | value_range: float = 1.0,
12 | eps: float = EPS,
13 | ) -> jnp.ndarray:
14 | """Peak signal-to-noise ratio (PSNR) loss.
15 |
16 | Args:
17 | image1: image of shape (batch, ..., channels).
18 | image2: image of shape (batch, ..., channels).
19 | value_range: value range of input images.
20 | eps: epsilon, if two images are identical, MSE=0.
21 |
22 | Returns:
23 | PSNR loss of shape (batch,).
24 | """
25 | return -psnr(
26 | image1=image1,
27 | image2=image2,
28 | value_range=value_range,
29 | eps=eps,
30 | )
31 |
32 |
33 | def nrmsd_loss(
34 | image_pred: jnp.ndarray,
35 | image_true: jnp.ndarray,
36 | eps: float = EPS,
37 | ) -> jnp.ndarray:
38 | """Normalized root-mean-square-deviation (NRMSD) loss.
39 |
40 | Args:
41 | image_pred: predicted image of shape (batch, ..., channels).
42 | image_true: ground truth image of shape (batch, ..., channels).
43 | eps: epsilon, if two images are identical, MSE=0.
44 |
45 | Returns:
46 | NRMSD loss of shape (batch,).
47 | """
48 | return nrmsd(
49 | image_pred=image_pred,
50 | image_true=image_true,
51 | eps=eps,
52 | )
53 |
--------------------------------------------------------------------------------
/imgx/metric/__init__.py:
--------------------------------------------------------------------------------
1 | """Module for metrics."""
2 | from imgx.metric.area import class_proportion, class_volume
3 | from imgx.metric.centroid import centroid_distance
4 | from imgx.metric.deformation import jacobian_det
5 | from imgx.metric.dice import dice_score, iou, stability
6 | from imgx.metric.similarity import nrmsd, psnr, ssim
7 | from imgx.metric.smoothing import gaussian_smooth_label, smooth_label
8 | from imgx.metric.surface_distance import (
9 | aggregated_surface_distance,
10 | average_surface_distance,
11 | hausdorff_distance,
12 | normalized_surface_dice,
13 | normalized_surface_dice_from_distances,
14 | )
15 |
16 | __all__ = [
17 | "dice_score",
18 | "iou",
19 | "average_surface_distance",
20 | "aggregated_surface_distance",
21 | "normalized_surface_dice",
22 | "normalized_surface_dice_from_distances",
23 | "hausdorff_distance",
24 | "centroid_distance",
25 | "class_proportion",
26 | "class_volume",
27 | "stability",
28 | "ssim",
29 | "psnr",
30 | "nrmsd",
31 | "jacobian_det",
32 | "gaussian_smooth_label",
33 | "smooth_label",
34 | ]
35 |
--------------------------------------------------------------------------------
/imgx/metric/area.py:
--------------------------------------------------------------------------------
1 | """Metrics to measure foreground area."""
2 | import jax
3 | import numpy as np
4 | from jax import numpy as jnp
5 |
6 | MM3_TO_ML = 0.001
7 |
8 |
9 | def class_proportion(mask: jnp.ndarray) -> jnp.ndarray:
10 | """Calculate proportion per class.
11 |
12 | This metric does not consider spacing.
13 |
14 | Args:
15 | mask: shape = (batch, d1, ..., dn, num_classes).
16 |
17 | Returns:
18 | Proportion, shape = (batch, num_classes).
19 | """
20 | reduce_axes = tuple(range(1, mask.ndim - 1))
21 | volume = jnp.float32(np.prod(mask.shape[1:-1]))
22 | sqrt_volume = jnp.sqrt(volume)
23 | mask = jnp.float32(mask)
24 | # attempt to avoid over/underflow
25 | return jnp.sum(mask / sqrt_volume, axis=reduce_axes) / sqrt_volume
26 |
27 |
28 | def get_volume(label: jnp.ndarray, spacing: jnp.ndarray) -> jnp.ndarray:
29 | """Calculate volume from binary label/mask.
30 |
31 | This metric considers spacing.
32 |
33 | Args:
34 | label: binary label, of shape (batch, d1, ..., dn).
35 | spacing: (n,), spacing of each dimension, in mm.
36 |
37 | Returns:
38 | volume: volume in ml, (batch,).
39 | """
40 | volume_per_voxel = jnp.prod(spacing) * MM3_TO_ML # ml
41 | return jnp.sum(label, axis=list(range(1, label.ndim))) * volume_per_voxel
42 |
43 |
44 | def class_volume(mask: jnp.ndarray, spacing: jnp.ndarray) -> jnp.ndarray:
45 | """Calculate volume per class.
46 |
47 | This metric does consider spacing.
48 |
49 | Args:
50 | mask: shape = (batch, d1, ..., dn, num_classes).
51 | spacing: (n,), spacing of each dimension, in mm.
52 |
53 | Returns:
54 | volume: volume in ml, (batch, num_classes).
55 | """
56 | return jax.vmap(get_volume, in_axes=(-1, None), out_axes=-1)(mask, spacing)
57 |
--------------------------------------------------------------------------------
/imgx/metric/area_test.py:
--------------------------------------------------------------------------------
1 | """Test area functions."""
2 |
3 | import chex
4 | import numpy as np
5 | from absl.testing import parameterized
6 | from chex._src import fake
7 |
8 | from imgx.metric.area import class_proportion, class_volume, get_volume
9 |
10 |
11 | # Set `FLAGS.chex_n_cpu_devices` CPU devices for all tests.
12 | def setUpModule() -> None: # pylint: disable=invalid-name
13 | """Fake multi-devices."""
14 | fake.set_n_cpu_devices(2)
15 |
16 |
17 | class TestClassProportion(chex.TestCase):
18 | """Test class_proportion."""
19 |
20 | @chex.all_variants()
21 | @parameterized.named_parameters(
22 | (
23 | "1d-1class",
24 | np.asarray([False, True, True, False])[..., None],
25 | np.asarray([0.5])[..., None],
26 | ),
27 | (
28 | "1d-1class-empty",
29 | np.asarray([False, False, False, False])[..., None],
30 | np.asarray([0.0])[..., None],
31 | ),
32 | (
33 | "1d-2classes",
34 | np.asarray([[False, True], [True, True], [True, False], [False, False]]),
35 | np.asarray([[0.5, 0.5]]),
36 | ),
37 | (
38 | "2d-1class",
39 | np.array(
40 | [
41 | [False, False, True, False, False],
42 | [False, True, False, True, False],
43 | [False, False, False, False, False],
44 | [False, False, False, False, False],
45 | ]
46 | )[..., None],
47 | np.asarray([3.0 / 20.0])[..., None],
48 | ),
49 | )
50 | def test_values(self, mask: np.ndarray, expected: np.ndarray) -> None:
51 | """Test exact values.
52 |
53 | Args:
54 | mask: shape = (batch, d1, ..., dn, num_classes).
55 | expected: expected coordinates.
56 | """
57 | got = self.variant(class_proportion)(
58 | mask=mask[None, ...],
59 | )
60 | chex.assert_trees_all_close(got, expected)
61 |
62 |
63 | class TestGetVolume(chex.TestCase):
64 | """Test get_volume."""
65 |
66 | @chex.all_variants()
67 | @parameterized.product(
68 | image_shape=[(8, 12, 6), (8, 12)],
69 | batch_size=[4, 1],
70 | )
71 | def test_shapes(self, image_shape: tuple[int, ...], batch_size: int) -> None:
72 | """Test output shapes.
73 |
74 | Args:
75 | image_shape: spatial shape of images.
76 | batch_size: number of samples in batch.
77 | """
78 | mask = np.zeros((batch_size, *image_shape), dtype=np.bool_)
79 | spacing = np.ones(len(image_shape))
80 | got = self.variant(get_volume)(mask, spacing)
81 | chex.assert_shape(got, (batch_size,))
82 |
83 |
84 | class TestClassVolume(chex.TestCase):
85 | """Test class_volume."""
86 |
87 | @chex.all_variants()
88 | @parameterized.product(
89 | image_shape=[(8, 12, 6), (8, 12)],
90 | num_classes=[1, 2, 3],
91 | batch_size=[4, 1],
92 | )
93 | def test_shapes(self, image_shape: tuple[int, ...], num_classes: int, batch_size: int) -> None:
94 | """Test output shapes.
95 |
96 | Args:
97 | image_shape: spatial shape of images.
98 | num_classes: number of classes.
99 | batch_size: number of samples in batch.
100 | """
101 | mask = np.zeros((batch_size, *image_shape, num_classes), dtype=np.bool_)
102 | spacing = np.ones(len(image_shape))
103 | got = self.variant(class_volume)(mask, spacing)
104 | chex.assert_shape(got, (batch_size, num_classes))
105 |
--------------------------------------------------------------------------------
/imgx/metric/centroid.py:
--------------------------------------------------------------------------------
1 | """Metric centroid distance."""
2 | from __future__ import annotations
3 |
4 | import jax.numpy as jnp
5 |
6 |
7 | def get_centroid(
8 | mask: jnp.ndarray,
9 | grid: jnp.ndarray,
10 | ) -> tuple[jnp.ndarray, jnp.ndarray]:
11 | """Calculate the centroid of the mask.
12 |
13 | Args:
14 | mask: boolean mask of shape = (batch, d1, ..., dn, num_classes)
15 | grid: shape = (n, d1, ..., dn)
16 |
17 | Returns:
18 | centroid of shape = (batch, n, num_classes).
19 | nan mask of shape = (batch, num_classes).
20 | """
21 | mask_reduce_axes = tuple(range(1, mask.ndim - 1))
22 | grid_reduce_axes = tuple(range(2, mask.ndim))
23 | # (batch, n, d1, ..., dn, num_classes)
24 | masked_grid = jnp.expand_dims(mask, axis=1) * jnp.expand_dims(grid, axis=(0, -1))
25 | # (batch, n, num_classes)
26 | numerator = jnp.sum(masked_grid, axis=grid_reduce_axes)
27 | # (batch, num_classes)
28 | summed_mask = jnp.sum(mask, axis=mask_reduce_axes)
29 | # (batch, 1, num_classes)
30 | denominator = summed_mask[:, None, :]
31 | # if mask is not empty return real centroid, else nan
32 | centroid = jnp.where(condition=denominator > 0, x=numerator / denominator, y=jnp.nan)
33 | return centroid, jnp.array(summed_mask == 0, dtype=jnp.bool_)
34 |
35 |
36 | def centroid_distance(
37 | mask_true: jnp.ndarray,
38 | mask_pred: jnp.ndarray,
39 | grid: jnp.ndarray,
40 | spacing: jnp.ndarray | None = None,
41 | ) -> jnp.ndarray:
42 | """Calculate the L2-distance between two centroids.
43 |
44 | Args:
45 | mask_true: shape = (batch, d1, ..., dn, num_classes).
46 | mask_pred: shape = (batch, d1, ..., dn, num_classes).
47 | grid: shape = (n, d1, ..., dn).
48 | spacing: spacing of pixel/voxels along each dimension, (n,).
49 |
50 | Returns:
51 | distance, shape = (batch, num_classes).
52 | """
53 | # centroid (batch, n, num_classes) nan_mask (batch, num_classes)
54 | centroid_true, nan_mask_true = get_centroid(
55 | mask=mask_true,
56 | grid=grid,
57 | )
58 | centroid_pred, nan_mask_pred = get_centroid(
59 | mask=mask_pred,
60 | grid=grid,
61 | )
62 | nan_mask = nan_mask_true | nan_mask_pred
63 | if spacing is not None:
64 | centroid_true = jnp.where(
65 | condition=nan_mask[:, None, :],
66 | x=jnp.nan,
67 | y=centroid_true * spacing[None, :, None],
68 | )
69 | centroid_pred = jnp.where(
70 | condition=nan_mask[:, None, :],
71 | x=jnp.nan,
72 | y=centroid_pred * spacing[None, :, None],
73 | )
74 |
75 | # return nan if the centroid cannot be defined for one sample with one class
76 | return jnp.where(
77 | condition=nan_mask,
78 | x=jnp.nan,
79 | y=jnp.linalg.norm(centroid_true - centroid_pred, axis=1),
80 | )
81 |
--------------------------------------------------------------------------------
/imgx/metric/deformation.py:
--------------------------------------------------------------------------------
1 | """Deformation metrics for ddf."""
2 | from __future__ import annotations
3 |
4 | from jax import numpy as jnp
5 |
6 |
7 | def gradient_along_axis(x: jnp.ndarray, axis: int, spacing: float | jnp.ndarray) -> jnp.ndarray:
8 | """Calculate gradients on one axis of using central finite difference.
9 |
10 | https://en.wikipedia.org/wiki/Finite_difference
11 | dx[i] = (x[i+1] - x[i-1]) / 2
12 | The edge values are padded.
13 |
14 | Args:
15 | x: shape = (d1, d2, ..., dn).
16 | axis: axis to calculate gradient.
17 | spacing: spacing between each pixel/voxel.
18 |
19 | Returns:
20 | shape = (d1, d2, ..., dn).
21 | """
22 | # repeat edge values
23 | # (d1, ..., di+2, ..., dn)
24 | x = jnp.pad(
25 | x,
26 | pad_width=[(0, 0)] * axis + [(1, 1)] + [(0, 0)] * (x.ndim - axis - 1),
27 | mode="edge",
28 | )
29 | # x[i-1]
30 | # (d1, ..., di, ..., dn)
31 | indices = jnp.arange(0, x.shape[axis] - 2)
32 | x_prev = jnp.take(x, indices=indices, axis=axis, indices_are_sorted=True)
33 | # x[i+1]
34 | # (d1, ..., di, ..., dn)
35 | indices = jnp.arange(2, x.shape[axis])
36 | x_next = jnp.take(x, indices=indices, axis=axis, indices_are_sorted=True)
37 | return (x_next - x_prev) / 2 / spacing
38 |
39 |
40 | def gradient(x: jnp.ndarray, spacing: jnp.ndarray) -> jnp.ndarray:
41 | """Calculate gradients per axis of using central finite difference.
42 |
43 | Args:
44 | x: shape = (batch, d1, d2, ..., dn, channel).
45 | spacing: spacing between each pixel/voxel, shape = (n,).
46 |
47 | Returns:
48 | shape = (batch, d1, d2, ..., dn, channel, n).
49 | """
50 | return jnp.stack(
51 | [gradient_along_axis(x, axis, spacing[axis - 1]) for axis in range(1, x.ndim - 1)], axis=-1
52 | )
53 |
54 |
55 | def jacobian_det(x: jnp.ndarray, spacing: jnp.ndarray) -> jnp.ndarray:
56 | """Calculate Jacobian matrix of ddf.
57 |
58 | https://arxiv.org/abs/1907.00068
59 |
60 | Args:
61 | x: shape = (batch, d1, d2, ..., dn, n).
62 | spacing: spacing between each pixel/voxel, shape = (n,).
63 |
64 | Returns:
65 | shape = (batch, d1, d2, ..., dn).
66 | """
67 | n = x.shape[-1]
68 | # (batch, d1, d2, ..., dn, n, n)
69 | grad = gradient(x, spacing)
70 | # (1, 1, ..., 1, n, n)
71 | eye = jnp.expand_dims(jnp.eye(n), axis=range(n + 1))
72 | # (batch, d1, d2, ..., dn)
73 | return jnp.linalg.det(grad + eye)
74 |
--------------------------------------------------------------------------------
/imgx/metric/dice.py:
--------------------------------------------------------------------------------
1 | """Metric functions for image segmentation."""
2 | import jax.numpy as jnp
3 |
4 |
5 | def dice_score(
6 | mask_pred: jnp.ndarray,
7 | mask_true: jnp.ndarray,
8 | ) -> jnp.ndarray:
9 | """Soft Dice score, larger is better.
10 |
11 | Args:
12 | mask_pred: soft mask with probabilities, (batch, ..., num_classes).
13 | mask_true: one hot targets, (batch, ..., num_classes).
14 |
15 | Returns:
16 | Dice score of shape (batch, num_classes).
17 | """
18 | # additions between bools results in errors
19 | if mask_pred.dtype == jnp.bool_:
20 | mask_pred = mask_pred.astype(jnp.float32)
21 | if mask_true.dtype == jnp.bool_:
22 | mask_true = mask_true.astype(jnp.float32)
23 |
24 | reduce_axis = tuple(range(mask_pred.ndim))[1:-1]
25 | numerator = 2.0 * jnp.sum(mask_pred * mask_true, axis=reduce_axis)
26 | denominator = jnp.sum(mask_pred + mask_true, axis=reduce_axis)
27 | return jnp.where(
28 | condition=denominator > 0,
29 | x=numerator / denominator,
30 | y=jnp.nan,
31 | )
32 |
33 |
34 | def iou(
35 | mask_pred: jnp.ndarray,
36 | mask_true: jnp.ndarray,
37 | ) -> jnp.ndarray:
38 | """IOU (Intersection Over Union), or Jaccard index.
39 |
40 | Args:
41 | mask_pred: binary mask of predictions, (batch, ..., num_classes).
42 | mask_true: one hot targets, (batch, ..., num_classes).
43 |
44 | Returns:
45 | IoU of shape (batch, num_classes).
46 | """
47 | # additions between bools results in errors
48 | if mask_pred.dtype == jnp.bool_:
49 | mask_pred = mask_pred.astype(jnp.float32)
50 | if mask_true.dtype == jnp.bool_:
51 | mask_true = mask_true.astype(jnp.float32)
52 |
53 | reduce_axis = tuple(range(mask_pred.ndim))[1:-1]
54 | numerator = jnp.sum(mask_pred * mask_true, axis=reduce_axis)
55 | sum_mask = jnp.sum(mask_pred + mask_true, axis=reduce_axis)
56 | denominator = sum_mask - numerator
57 | return jnp.where(condition=sum_mask > 0, x=numerator / denominator, y=jnp.nan)
58 |
59 |
60 | def stability(
61 | logits: jnp.ndarray,
62 | threshold: float = 0.0,
63 | threshold_offset: float = 1.0,
64 | ) -> jnp.ndarray:
65 | """Calculate stability of predictions.
66 |
67 | https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/utils/amg.py
68 |
69 | Args:
70 | logits: shape = (batch, ..., num_classes).
71 | threshold: threshold for prediction.
72 | threshold_offset: offset for threshold.
73 |
74 | Returns:
75 | Stability of shape (batch, num_classes).
76 |
77 | Raises:
78 | ValueError: if threshold_offset is negative.
79 | """
80 | if threshold_offset < 0:
81 | raise ValueError(f"threshold_offset must be non-negative, got {threshold_offset}.")
82 | # logits max values is 0
83 | logits -= jnp.mean(logits, axis=-1, keepdims=True)
84 | mask_high_threshold = logits >= (threshold + threshold_offset)
85 | mask_low_threshold = logits >= (threshold - threshold_offset)
86 | return iou(mask_high_threshold, mask_low_threshold)
87 |
--------------------------------------------------------------------------------
/imgx/metric/distribution.py:
--------------------------------------------------------------------------------
1 | """Metric functions for probability distributions."""
2 | import jax.numpy as jnp
3 |
4 |
5 | def normal_kl(
6 | p_mean: jnp.ndarray,
7 | p_log_variance: jnp.ndarray,
8 | q_mean: jnp.ndarray,
9 | q_log_variance: jnp.ndarray,
10 | ) -> jnp.ndarray:
11 | r"""Compute the KL divergence between two 1D normal distributions.
12 |
13 | KL[p||q] = \int p \log (p / q) dx
14 |
15 | Although the inputs are arrays, each value is considered independently.
16 | This function is not symmetric.
17 |
18 | Input array shapes should be broadcast-compatible.
19 |
20 | Args:
21 | p_mean: mean of distribution p.
22 | p_log_variance: log variance of distribution p.
23 | q_mean: mean of distribution q.
24 | q_log_variance: log variance of distribution q.
25 |
26 | Returns:
27 | KL divergence.
28 | """
29 | return 0.5 * (
30 | -1.0
31 | + q_log_variance
32 | - p_log_variance
33 | + jnp.exp(p_log_variance - q_log_variance)
34 | + ((p_mean - q_mean) ** 2) * jnp.exp(-q_log_variance)
35 | )
36 |
37 |
38 | def approx_standard_normal_cdf(x: jnp.ndarray) -> jnp.ndarray:
39 | """Approximate cumulative distribution function of standard normal.
40 |
41 | if x ~ Normal(mean, var), then cdf(z) = p(x <= z)
42 |
43 | https://www.aimspress.com/article/doi/10.3934/math.2022648#b13
44 | https://www.jstor.org/stable/2346872
45 |
46 | Args:
47 | x: array of any shape with any float values.
48 |
49 | Returns:
50 | CDF estimation.
51 | """
52 | return 0.5 * (1.0 + jnp.tanh(jnp.sqrt(2.0 / jnp.pi) * (x + 0.044715 * x**3)))
53 |
54 |
55 | def discretized_gaussian_log_likelihood(
56 | x: jnp.ndarray,
57 | mean: jnp.ndarray,
58 | log_variance: jnp.ndarray,
59 | x_delta: float = 1.0 / 255.0,
60 | x_bound: float = 0.999,
61 | ) -> jnp.ndarray:
62 | """Log-likelihood of a normal distribution discretizing to an image.
63 |
64 | p(y=x) is approximated by p(y <= x+delta) - p(y <= x-delta).
65 |
66 | Args:
67 | x: target image, with value in [-1, 1].
68 | mean: normal distribution mean.
69 | log_variance: log of distribution variance.
70 | x_delta: discretization step, used to estimate probability.
71 | x_bound: values with abs > x_bound are calculated differently.
72 |
73 | Returns:
74 | Discretized log likelihood over 2*delta.
75 | """
76 | log_scales = 0.5 * log_variance
77 | centered_x = x - mean
78 | inv_stdv = jnp.exp(-log_scales)
79 |
80 | # let y be a variable
81 | # cdf(z+delta) = p(y <= z+delta)
82 | plus_in = inv_stdv * (centered_x + x_delta) # z
83 | cdf_plus = approx_standard_normal_cdf(plus_in)
84 | # log( p(y <= z+delta) )
85 | log_cdf_plus = jnp.log(cdf_plus.clip(min=1e-12))
86 |
87 | # cdf(z-delta) = p(y <= z-delta)
88 | minus_in = inv_stdv * (centered_x - x_delta)
89 | cdf_minus = approx_standard_normal_cdf(minus_in)
90 | # log( 1-p(y <= z-delta) ) = log( p(y > z-delta) )
91 | log_one_minus_cdf_minus = jnp.log((1.0 - cdf_minus).clip(min=1e-12))
92 |
93 | # p(z-delta < y <= z+delta)
94 | cdf_delta = cdf_plus - cdf_minus
95 | log_cdf_delta = jnp.log(cdf_delta.clip(min=1e-12))
96 |
97 | # if x < -0.999, log( p(y <= z+delta) )
98 | # if x > 0.999, log( p(y > z-delta) )
99 | # if -0.999 <= x <= 0.999, log( p(z-delta < y <= z+delta) )
100 | return jnp.where(
101 | x < -x_bound,
102 | log_cdf_plus,
103 | jnp.where(x > x_bound, log_one_minus_cdf_minus, log_cdf_delta),
104 | )
105 |
--------------------------------------------------------------------------------
/imgx/metric/segmentation_test.py:
--------------------------------------------------------------------------------
1 | """Test functions in imgx.metric.util."""
2 |
3 |
4 | import chex
5 | import jax
6 | import numpy as np
7 | from chex._src import fake
8 |
9 | from imgx.metric.segmentation import (
10 | get_jit_segmentation_metrics,
11 | get_non_jit_segmentation_metrics,
12 | get_non_jit_segmentation_metrics_per_step,
13 | )
14 |
15 |
16 | # Set `FLAGS.chex_n_cpu_devices` CPU devices for all tests.
17 | def setUpModule() -> None: # pylint: disable=invalid-name
18 | """Fake multi-devices."""
19 | fake.set_n_cpu_devices(2)
20 |
21 |
22 | class TestGetSegmentationMetrics(chex.TestCase):
23 | """Test get_segmentation_metrics."""
24 |
25 | batch = 2
26 | num_classes = 3
27 | spatial_shape = (4, 5, 6)
28 | spacing = np.array((0.2, 0.5, 1.0))
29 | mask_shape = (batch, *spatial_shape, num_classes)
30 |
31 | @chex.all_variants()
32 | def test_jit_shapes(self) -> None:
33 | """Test shapes."""
34 | key = jax.random.PRNGKey(0)
35 | key_pred, key_true = jax.random.split(key)
36 | mask_pred = jax.random.uniform(key_pred, shape=self.mask_shape)
37 | mask_true = jax.random.uniform(key_true, shape=self.mask_shape)
38 |
39 | got = self.variant(get_jit_segmentation_metrics)(mask_pred, mask_true, self.spacing)
40 | for _, v in got.items():
41 | chex.assert_shape(v, (self.batch,))
42 |
43 | @chex.variants(without_jit=True, with_device=True, without_device=True)
44 | def test_nonjit_shapes(self) -> None:
45 | """Test shapes."""
46 | key = jax.random.PRNGKey(0)
47 | key_pred, key_true = jax.random.split(key)
48 | mask_pred = jax.random.uniform(key_pred, shape=self.mask_shape)
49 | mask_true = jax.random.uniform(key_true, shape=self.mask_shape)
50 |
51 | got = self.variant(get_non_jit_segmentation_metrics)(mask_pred, mask_true, self.spacing)
52 | for _, v in got.items():
53 | chex.assert_shape(v, (self.batch,))
54 |
55 | @chex.variants(without_jit=True, with_device=True, without_device=True)
56 | def test_nonjit_per_step_shapes(self) -> None:
57 | """Test shapes."""
58 | num_steps = 2
59 | key = jax.random.PRNGKey(0)
60 | key_pred, key_true = jax.random.split(key)
61 | mask_pred = jax.random.uniform(key_pred, shape=(*self.mask_shape, num_steps))
62 | mask_true = jax.random.uniform(key_true, shape=self.mask_shape)
63 |
64 | got = self.variant(get_non_jit_segmentation_metrics_per_step)(
65 | mask_pred, mask_true, self.spacing
66 | )
67 | for _, v in got.items():
68 | chex.assert_shape(v, (self.batch, num_steps))
69 |
--------------------------------------------------------------------------------
/imgx/metric/similarity.py:
--------------------------------------------------------------------------------
1 | """Image similarity metrics."""
2 |
3 | import jax.numpy as jnp
4 |
5 | from imgx import EPS
6 | from imgx.metric.smoothing import get_conv
7 |
8 |
9 | def ssim(
10 | image1: jnp.ndarray,
11 | image2: jnp.ndarray,
12 | max_val: float = 1.0,
13 | kernel_sigma: float = 1.5,
14 | kernel_size: int = 11,
15 | kernel_type: str = "gaussian",
16 | k1: float = 0.01,
17 | k2: float = 0.03,
18 | ) -> jnp.ndarray:
19 | """Calculate Structural similarity index metric (SSIM).
20 |
21 | https://en.wikipedia.org/wiki/Structural_similarity
22 | https://github.com/Project-MONAI/MONAI/blob/ccd32ca5e9e84562d2f388b45b6724b5c77c1f57/monai/metrics/regression.py#L240
23 |
24 | SSIM is calculated per window. The window size is specified by `window_size`.
25 | The window is convolved with a Gaussian kernel specified by `kernel_sigma`.
26 |
27 | SSIM as loss is not stable and may cause NaNs gradients.
28 | https://github.com/tensorflow/tensorflow/issues/50400
29 | https://github.com/tensorflow/tensorflow/issues/57353
30 |
31 | Args:
32 | image1: image of shape (batch, ..., channels).
33 | image2: image of shape (batch, ..., channels).
34 | max_val: maximum value of input images, minimum is assumed to be zero.
35 | kernel_sigma: sigma for Gaussian kernel.
36 | kernel_size: size for kernel.
37 | kernel_type: type of kernel, "gaussian" or "uniform".
38 | k1: stability constant for luminance.
39 | k2: stability constant for contrast.
40 |
41 | Returns:
42 | SSIM of shape (batch,).
43 | """
44 | num_spatial_dims = image1.ndim - 2
45 | conv = get_conv(
46 | num_spatial_dims=num_spatial_dims,
47 | kernel_sigma=kernel_sigma,
48 | kernel_size=kernel_size,
49 | kernel_type=kernel_type,
50 | padding="VALID",
51 | )
52 |
53 | # (batch, ..., channels)
54 | # the spatial dims are reduced by the conv
55 | mean1 = conv(image1)
56 | mean2 = conv(image2)
57 | mean11 = conv(image1 * image1)
58 | mean12 = conv(image1 * image2)
59 | mean22 = conv(image2 * image2)
60 | var1 = mean11 - mean1 * mean1
61 | var2 = mean22 - mean2 * mean2
62 | covar12 = mean12 - mean1 * mean2
63 |
64 | c1 = (k1 * max_val) ** 2
65 | c2 = (k2 * max_val) ** 2
66 | numerator = (2 * mean1 * mean2 + c1) * (2 * covar12 + c2)
67 | denominator = (mean1**2 + mean2**2 + c1) * (var1 + var2 + c2)
68 |
69 | # (batch,)
70 | return jnp.mean(numerator / denominator, axis=tuple(range(1, image1.ndim)))
71 |
72 |
73 | def psnr(
74 | image1: jnp.ndarray,
75 | image2: jnp.ndarray,
76 | value_range: float = 1.0,
77 | eps: float = EPS,
78 | ) -> jnp.ndarray:
79 | """Calculate Peak signal-to-noise ratio (PSNR) metric.
80 |
81 | https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
82 | https://github.com/Project-MONAI/MONAI/blob/ccd32ca5e9e84562d2f388b45b6724b5c77c1f57/monai/metrics/regression.py#L186
83 |
84 | Args:
85 | image1: image of shape (batch, ..., channels).
86 | image2: image of shape (batch, ..., channels).
87 | value_range: value range of input images.
88 | eps: epsilon, if two images are identical, MSE=0.
89 |
90 | Returns:
91 | PSNR of shape (batch,).
92 | """
93 | mse = jnp.mean((image1 - image2) ** 2, axis=range(1, image1.ndim))
94 | mse = jnp.maximum(mse, eps)
95 | return 20 * jnp.log10(value_range) - 10 * jnp.log10(mse)
96 |
97 |
98 | def nrmsd(
99 | image_pred: jnp.ndarray,
100 | image_true: jnp.ndarray,
101 | eps: float = EPS,
102 | ) -> jnp.ndarray:
103 | """Calculate normalized root-mean-square-deviation (NRMSD) metric.
104 |
105 | The normalization is performed by the mean of ground truth image.
106 | https://en.wikipedia.org/wiki/Root-mean-square_deviation
107 |
108 | Args:
109 | image_pred: predicted image of shape (batch, ..., channels).
110 | image_true: ground truth image of shape (batch, ..., channels).
111 | eps: epsilon, if two images are identical, MSE=0.
112 |
113 | Returns:
114 | NRMSD of shape (batch,).
115 | """
116 | mse = jnp.mean((image_pred - image_true) ** 2, axis=range(1, image_true.ndim))
117 | rmsd = jnp.sqrt(mse)
118 | denominator = jnp.maximum(jnp.mean(image_true, axis=range(1, image_true.ndim)), eps)
119 | return rmsd / denominator
120 |
--------------------------------------------------------------------------------
/imgx/metric/similarity_test.py:
--------------------------------------------------------------------------------
1 | """Test similarity functions."""
2 |
3 | from __future__ import annotations
4 |
5 | import chex
6 | import jax.numpy as jnp
7 | import numpy as np
8 | from absl.testing import parameterized
9 | from chex._src import fake
10 |
11 | from imgx.metric.similarity import nrmsd, psnr, ssim
12 |
13 |
14 | def setUpModule() -> None: # pylint: disable=invalid-name
15 | """Fake multi-devices."""
16 | fake.set_n_cpu_devices(2)
17 |
18 |
19 | class TestSSIM(chex.TestCase):
20 | """Test SSIM."""
21 |
22 | @parameterized.product(
23 | (
24 | {"image_shape": (13,), "in_channels": 1},
25 | {"image_shape": (11, 13), "in_channels": 1},
26 | {"image_shape": (11, 13), "in_channels": 3},
27 | {"image_shape": (11, 13, 15), "in_channels": 2},
28 | ),
29 | kernel_type=("gaussian", "uniform"),
30 | )
31 | def test_shapes(
32 | self,
33 | image_shape: tuple[int, ...],
34 | in_channels: int,
35 | kernel_type: str,
36 | ) -> None:
37 | """Test return shapes.
38 |
39 | Args:
40 | image_shape: image shapes.
41 | in_channels: number of input channels.
42 | kernel_type: type of kernel, "gaussian" or "uniform".
43 | """
44 | batch_size = 2
45 |
46 | image = jnp.ones((batch_size, *image_shape, in_channels))
47 | got = ssim(image, image, kernel_type=kernel_type)
48 |
49 | chex.assert_shape(got, (batch_size,))
50 | chex.assert_trees_all_close(got, jnp.ones((batch_size,)))
51 |
52 | @parameterized.named_parameters(
53 | (
54 | "1d",
55 | # x
56 | # mean(x)=0.6
57 | # mean(x**2)=0.6
58 | # var(x)=0.24
59 | np.array([0.0, 1.0, 0.0, 1.0, 1.0]),
60 | # mean(y)=0.4
61 | # mean(y**2)=0.252
62 | # var(y)=0.092
63 | # mean(xy)=0.36
64 | # covar(x,y)=0.12
65 | np.array([0.2, 0.8, 0, 0.3, 0.7]),
66 | 5,
67 | "uniform",
68 | (2 * 0.6 * 0.4 + 0.0001)
69 | * (2 * 0.12 + 0.0009)
70 | / ((0.6**2 + 0.4**2 + 0.0001) * (0.24 + 0.092 + 0.0009)),
71 | ),
72 | )
73 | def test_values(
74 | self,
75 | image1: np.ndarray,
76 | image2: np.ndarray,
77 | kernel_size: int,
78 | kernel_type: str,
79 | expected: float,
80 | ) -> None:
81 | """Test return values.
82 |
83 | Args:
84 | image1: image 1.
85 | image2: image 2.
86 | kernel_size: kernel size.
87 | kernel_type: type of kernel, "gaussian" or "uniform".
88 | expected: expected SSIM value.
89 | """
90 | image1 = image1[None, ..., None]
91 | image2 = image2[None, ..., None]
92 | got = ssim(
93 | jnp.array(image1), jnp.array(image2), kernel_size=kernel_size, kernel_type=kernel_type
94 | )
95 | chex.assert_trees_all_close(got, jnp.array([expected]))
96 |
97 |
98 | class TestPSNR(chex.TestCase):
99 | """Test PSNR."""
100 |
101 | @parameterized.product(
102 | (
103 | {"image_shape": (13,), "in_channels": 1},
104 | {"image_shape": (11, 13), "in_channels": 1},
105 | {"image_shape": (11, 13), "in_channels": 3},
106 | {"image_shape": (11, 13, 15), "in_channels": 2},
107 | ),
108 | )
109 | def test_shapes(
110 | self,
111 | image_shape: tuple[int, ...],
112 | in_channels: int,
113 | ) -> None:
114 | """Test return shapes.
115 |
116 | Args:
117 | image_shape: image shapes.
118 | in_channels: number of input channels.
119 | kernel_type: type of kernel, "gaussian" or "uniform".
120 | """
121 | batch_size = 2
122 |
123 | image = jnp.ones((batch_size, *image_shape, in_channels))
124 | got = psnr(image, image)
125 |
126 | chex.assert_shape(got, (batch_size,))
127 | chex.assert_tree_all_finite(got)
128 |
129 |
130 | class TestNRMSD(chex.TestCase):
131 | """Test NRMSD."""
132 |
133 | @parameterized.product(
134 | (
135 | {"image_shape": (13,), "in_channels": 1},
136 | {"image_shape": (11, 13), "in_channels": 1},
137 | {"image_shape": (11, 13), "in_channels": 3},
138 | {"image_shape": (11, 13, 15), "in_channels": 2},
139 | ),
140 | )
141 | def test_shapes(
142 | self,
143 | image_shape: tuple[int, ...],
144 | in_channels: int,
145 | ) -> None:
146 | """Test return shapes.
147 |
148 | Args:
149 | image_shape: image shapes.
150 | in_channels: number of input channels.
151 | kernel_type: type of kernel, "gaussian" or "uniform".
152 | """
153 | batch_size = 2
154 |
155 | image = jnp.zeros((batch_size, *image_shape, in_channels))
156 | got = nrmsd(image, image)
157 |
158 | chex.assert_shape(got, (batch_size,))
159 | chex.assert_tree_all_finite(got)
160 |
--------------------------------------------------------------------------------
/imgx/metric/smoothing.py:
--------------------------------------------------------------------------------
1 | """Label/image smoothing functions."""
2 | from typing import Callable
3 |
4 | import jax
5 | from jax import lax
6 | from jax import numpy as jnp
7 |
8 |
9 | def gaussian_kernel(
10 | num_spatial_dims: int,
11 | kernel_sigma: float,
12 | kernel_size: int,
13 | ) -> jnp.ndarray:
14 | """Gaussian kernel for convolution.
15 |
16 | Args:
17 | num_spatial_dims: number of spatial dimensions.
18 | kernel_sigma: sigma for Gaussian kernel.
19 | kernel_size: size for Gaussian kernel.
20 |
21 | Returns:
22 | Gaussian kernel of shape (window_size, ..., window_size), ndim=num_spatial_dims.
23 | """
24 | if kernel_size // 2 == 0:
25 | raise ValueError(f"kernel_size = {kernel_size} must be odd.")
26 | # (kernel_size,)
27 | dist = jnp.arange((1 - kernel_size) / 2, (1 + kernel_size) / 2)
28 | kernel_1d = jnp.exp(-jnp.power(dist / kernel_sigma, 2) / 2)
29 | kernel_1d /= kernel_1d.sum()
30 | if num_spatial_dims == 1:
31 | return kernel_1d
32 |
33 | kernel_nd = jnp.ones((kernel_size,) * num_spatial_dims)
34 | for i in range(num_spatial_dims):
35 | kernel_nd *= jnp.expand_dims(kernel_1d, axis=[j for j in range(num_spatial_dims) if j != i])
36 | return kernel_nd
37 |
38 |
39 | def get_conv(
40 | num_spatial_dims: int,
41 | kernel_sigma: float,
42 | kernel_size: int,
43 | kernel_type: str,
44 | padding: str,
45 | ) -> Callable[[jnp.ndarray], jnp.ndarray]:
46 | """Get Gaussian convolution function.
47 |
48 | The function performs convolution on the spatial dimensions for each feature channel.
49 |
50 | Args:
51 | num_spatial_dims: number of spatial dimensions.
52 | kernel_sigma: sigma for Gaussian kernel.
53 | kernel_size: size for Gaussian kernel.
54 | kernel_type: type of kernel, "gaussian" or "uniform".
55 | padding: padding type, "SAME" or "VALID".
56 |
57 | Returns:
58 | Gaussian convolution function that takes input of shape (batch, ..., channels).
59 | """
60 | if num_spatial_dims > 4:
61 | raise ValueError(f"num_spatial_dims = {num_spatial_dims} must be <= 4.")
62 | if kernel_type == "gaussian":
63 | # (window_size, ..., window_size), ndim=num_spatial_dims
64 | kernel = gaussian_kernel(
65 | num_spatial_dims,
66 | kernel_sigma=kernel_sigma,
67 | kernel_size=kernel_size,
68 | )
69 | elif kernel_type == "uniform":
70 | kernel = jnp.ones((kernel_size,) * num_spatial_dims) / kernel_size**num_spatial_dims
71 | else:
72 | raise ValueError(f"kernel_type = {kernel_type} must be 'gaussian' or 'uniform'.")
73 | # (1,1,window_size, ..., window_size), ndim=num_spatial_dims+2
74 | kernel = kernel[None, None, ...]
75 |
76 | spatial_dilation = "WHDT"[:num_spatial_dims]
77 | lhs_spec = f"N{spatial_dilation}C"
78 | rhs_spec = f"OI{spatial_dilation}"
79 | out_spec = f"N{spatial_dilation}C"
80 |
81 | def conv(x: jnp.ndarray) -> jnp.ndarray:
82 | """Convolution function.
83 |
84 | Args:
85 | x: (batch, ...)
86 |
87 | Returns:
88 | (batch, ...)
89 | """
90 | return lax.conv_general_dilated(
91 | lhs=x[..., None],
92 | rhs=kernel.astype(x.dtype),
93 | dimension_numbers=(lhs_spec, rhs_spec, out_spec),
94 | window_strides=(1,) * num_spatial_dims,
95 | padding=padding,
96 | )[..., 0]
97 |
98 | return jax.vmap(conv, in_axes=-1, out_axes=-1)
99 |
100 |
101 | def smooth_label(
102 | mask: jnp.ndarray, classes_are_exclusive: bool, label_smoothing: float = 0.0
103 | ) -> jnp.ndarray:
104 | """Label smoothing.
105 |
106 | If classes are exclusive, the even distribution has p = 1 / num_classes per class.
107 | If classes are not exclusive, the even distribution has p = 0.5 per class.
108 |
109 | Args:
110 | mask: probabilities per class, (batch, ..., num_classes).
111 | classes_are_exclusive: if False, each element can be assigned to multiple classes.
112 | label_smoothing: label smoothing factor between 0 and 1, 0.0 means no smoothing.
113 |
114 | Returns:
115 | Label smoothed mask.
116 | """
117 | p_even = lax.select(classes_are_exclusive, 1.0 / mask.shape[-1], 0.5)
118 | return mask * (1 - label_smoothing) + p_even * label_smoothing
119 |
120 |
121 | def gaussian_smooth_label(
122 | mask: jnp.ndarray,
123 | classes_are_exclusive: bool,
124 | kernel_sigma: float = 1.5,
125 | kernel_size: int = 3,
126 | ) -> jnp.ndarray:
127 | """Label smoothing using Gaussian kernels.
128 |
129 | The smoothing is performed on the spatial dimensions per class.
130 | https://arxiv.org/abs/2104.05788
131 |
132 | Args:
133 | mask: probabilities per class, (batch, ..., num_classes).
134 | classes_are_exclusive: if False, each element can be assigned to multiple classes.
135 | kernel_sigma: sigma for Gaussian kernel.
136 | kernel_size: size for kernel.
137 |
138 | Returns:
139 | Label smoothed mask.
140 | """
141 | mask = get_conv(
142 | num_spatial_dims=mask.ndim - 2,
143 | kernel_sigma=kernel_sigma,
144 | kernel_size=kernel_size,
145 | kernel_type="gaussian",
146 | padding="SAME",
147 | )(mask)
148 | # if classes are exclusive, the sum should be 1
149 | mask = lax.cond(
150 | classes_are_exclusive,
151 | lambda x: x / jnp.sum(x, axis=-1, keepdims=True),
152 | lambda x: x,
153 | mask,
154 | )
155 | return mask
156 |
--------------------------------------------------------------------------------
/imgx/metric/util.py:
--------------------------------------------------------------------------------
1 | """Util functions."""
2 | from __future__ import annotations
3 |
4 | import chex
5 | import jax.numpy as jnp
6 |
7 |
8 | def aggregate_metrics(metrics: dict[str, chex.ArrayTree]) -> dict[str, chex.ArrayTree]:
9 | """Calculate min/mean/max statistics for each key.
10 |
11 | Args:
12 | metrics: metric dict.
13 |
14 | Returns:
15 | aggregated metric dict.
16 | """
17 | agg_metrics = {}
18 | for key, value in metrics.items():
19 | agg_metrics[f"mean_{key}"] = jnp.nanmean(value)
20 | agg_metrics[f"min_{key}"] = jnp.nanmin(value)
21 | agg_metrics[f"max_{key}"] = jnp.nanmax(value)
22 | return agg_metrics
23 |
24 |
25 | def aggregate_metrics_for_diffusion(
26 | metrics: dict[str, jnp.ndarray],
27 | t_index: jnp.ndarray,
28 | ) -> dict[str, jnp.ndarray]:
29 | """Aggregate metrics for diffusion.
30 |
31 | Args:
32 | metrics: dict of metrics, each of shape (batch,).
33 | t_index: time of shape (batch, ...), values in [0, num_timesteps).
34 |
35 | Returns:
36 | aggregated metrics.
37 | """
38 | mask_t_min = t_index == jnp.min(t_index)
39 | mask_t_max = t_index == jnp.max(t_index)
40 | metrics_diff = {}
41 | for key, value in metrics.items():
42 | metrics_diff[f"{key}_t_min"] = jnp.nanmean(value, where=mask_t_min)
43 | metrics_diff[f"{key}_t_max"] = jnp.nanmean(value, where=mask_t_max)
44 | return metrics_diff
45 |
46 |
47 | def merge_aggregated_metrics(metrics: dict[str, chex.ArrayTree]) -> dict[str, chex.ArrayTree]:
48 | """Merge/aggregate aggregated metrics.
49 |
50 | Args:
51 | metrics: metric dict containing aggregated metrics.
52 |
53 | Returns:
54 | Merged aggregated metric dict.
55 | """
56 | min_metrics = {}
57 | max_metrics = {}
58 | mean_metrics = {}
59 | for k in metrics:
60 | if k.startswith("min_"):
61 | min_metrics[k] = jnp.nanmin(metrics[k])
62 | elif k.startswith("max_"):
63 | max_metrics[k] = jnp.nanmax(metrics[k])
64 | else:
65 | mean_metrics[k] = jnp.nanmean(metrics[k])
66 | return {
67 | **min_metrics,
68 | **max_metrics,
69 | **mean_metrics,
70 | }
71 |
72 |
73 | def flatten_diffusion_metrics(metrics: dict[str, jnp.ndarray]) -> dict[str, jnp.ndarray]:
74 | """Flatten metrics dict for diffusion models.
75 |
76 | Args:
77 | metrics: dict of metrics, each value of shape (batch, num_steps).
78 |
79 | Returns:
80 | metrics: dict of metrics, each value of shape (batch, ).
81 | """
82 | metrics_flatten = {}
83 | for k, v in metrics.items():
84 | for i in range(v.shape[-1]):
85 | metrics_flatten[f"{k}_step_{i}"] = v[..., i]
86 | metrics_flatten[k] = v[..., -1]
87 | return metrics_flatten
88 |
--------------------------------------------------------------------------------
/imgx/model/__init__.py:
--------------------------------------------------------------------------------
1 | """Package for models."""
2 | from imgx.model.unet.unet import Unet # noqa: F401
3 |
4 | SUPPORTED_VISION_MODELS = [
5 | "Unet",
6 | ]
7 |
8 | __all__ = SUPPORTED_VISION_MODELS
9 |
--------------------------------------------------------------------------------
/imgx/model/attention/__init__.py:
--------------------------------------------------------------------------------
1 | """Attention and transformer related modules."""
2 | from imgx.model.attention.efficient_attention import dot_product_attention_with_qkv_chunks
3 | from imgx.model.attention.transformer import TransformerEncoder
4 |
5 | __all__ = [
6 | "dot_product_attention_with_qkv_chunks",
7 | "TransformerEncoder",
8 | ]
9 |
--------------------------------------------------------------------------------
/imgx/model/attention/transformer.py:
--------------------------------------------------------------------------------
1 | """Generic transformer related functions.
2 |
3 | https://github.com/deepmind/dm-haiku/blob/main/examples/transformer/model.py
4 | https://github.com/google-research/vision_transformer/blob/main/vit_jax/models_vit.py
5 | """
6 | from __future__ import annotations
7 |
8 | import flax.linen as nn
9 | import jax.numpy as jnp
10 | import numpy as np
11 |
12 | from imgx.model.attention.efficient_attention import dot_product_attention_with_qkv_chunks
13 | from imgx.model.basic import MLP
14 |
15 |
16 | class TransformerEncoder(nn.Module):
17 | """A transformer encoder/decoder.
18 |
19 | The architecture is from
20 | https://github.com/google-deepmind/dm-haiku/blob/main/examples/transformer/model.py
21 | """
22 |
23 | num_heads: int
24 | num_layers: int = 1
25 | autoregressive: bool = False
26 | widening_factor: int = 4
27 | add_position_embedding: bool = False
28 | dropout: float = 0.0
29 | remat: bool = True # if True also uses efficient attention
30 | dtype: jnp.dtype = jnp.float32
31 |
32 | @nn.compact
33 | def __call__(
34 | self,
35 | is_train: bool,
36 | x: jnp.ndarray,
37 | mask: jnp.ndarray | None = None,
38 | ) -> jnp.ndarray:
39 | """Transformer encoder forward pass.
40 |
41 | Args:
42 | is_train: whether in training mode.
43 | x: shape (batch, ..., model_size).
44 | mask: shape (batch, ...) or None.
45 | Tokens with False values are ignored.
46 |
47 | Returns:
48 | (batch, ..., model_size).
49 | """
50 | batch, *spatial_shape, channels = x.shape
51 | if len(spatial_shape) > 1:
52 | # x: shape(batch, seq_len, model_size).
53 | # mask: shape(batch, seq_len) or None.
54 | x = x.reshape((batch, -1, channels))
55 | if mask is not None:
56 | mask = mask.reshape((batch, -1))
57 |
58 | kernel_init = nn.initializers.variance_scaling(
59 | scale=2 / self.num_layers,
60 | mode="fan_in",
61 | distribution="truncated_normal",
62 | )
63 |
64 | # define classes
65 | attn_cls = nn.MultiHeadDotProductAttention
66 | attention_fn = nn.dot_product_attention
67 | if self.remat:
68 | attn_cls = nn.remat(attn_cls)
69 | attention_fn = dot_product_attention_with_qkv_chunks
70 |
71 | _, seq_len, model_size = x.shape
72 |
73 | if model_size % self.widening_factor != 0:
74 | raise ValueError(
75 | f"Model size {model_size} is not divisible by widening factor "
76 | f"{self.widening_factor}"
77 | )
78 |
79 | # compute mask if provided
80 | if mask is not None:
81 | if self.autoregressive:
82 | # compute causal mask for autoregressive sequence modelling.
83 | # (1, 1, seq_len, seq_len)
84 | causal_mask = np.tril(np.ones((1, 1, seq_len, seq_len)))
85 | # (batch, 1, seq_len, seq_len)
86 | mask = mask[:, None, None, :] * causal_mask
87 | else:
88 | # (batch, 1, 1, seq_len) * (batch, 1, seq_len, 1)
89 | # -> (batch, 1, seq_len, seq_len)
90 | mask = mask[:, None, None, :] * mask[:, None, :, None]
91 |
92 | # embed the input tokens and positions.
93 | if self.add_position_embedding:
94 | positional_embeddings = self.param(
95 | "transformer_positional_embeddings",
96 | nn.initializers.truncated_normal(stddev=0.02),
97 | (1, seq_len, model_size),
98 | )
99 | x += positional_embeddings
100 | x = nn.Dropout(rate=self.dropout, deterministic=not is_train)(x)
101 |
102 | h = x
103 | for _ in range(self.num_layers):
104 | h_attn = nn.LayerNorm(dtype=self.dtype)(h)
105 | h_attn = attn_cls(
106 | num_heads=self.num_heads,
107 | # head_dim = qkv_features // num_heads
108 | qkv_features=model_size // self.widening_factor * self.num_heads,
109 | out_features=model_size,
110 | attention_fn=attention_fn,
111 | kernel_init=kernel_init,
112 | dtype=self.dtype,
113 | )(inputs_q=h_attn, inputs_k=h_attn, inputs_v=h_attn, mask=mask)
114 | h_attn = nn.Dropout(rate=self.dropout, deterministic=not is_train)(h_attn)
115 | h += h_attn
116 |
117 | h_dense = nn.LayerNorm(dtype=self.dtype)(h)
118 | h_dense = MLP(
119 | emb_size=model_size * self.widening_factor,
120 | output_size=model_size,
121 | dtype=self.dtype,
122 | kernel_init=kernel_init,
123 | remat=self.remat,
124 | )(h_dense)
125 | h_dense = nn.Dropout(rate=self.dropout, deterministic=not is_train)(h_dense)
126 | h += h_dense
127 |
128 | h = nn.LayerNorm(dtype=self.dtype)(h)
129 |
130 | if len(spatial_shape) > 1:
131 | # (batch, spatial_shape, model_size).
132 | h = h.reshape((batch, *spatial_shape, model_size))
133 |
134 | return h
135 |
--------------------------------------------------------------------------------
/imgx/model/basic.py:
--------------------------------------------------------------------------------
1 | """Basic functions and modules."""
2 | from __future__ import annotations
3 |
4 | from typing import Callable
5 |
6 | import flax.linen as nn
7 | import jax
8 | import jax.numpy as jnp
9 |
10 |
11 | class Identity(nn.Module):
12 | """Identity module."""
13 |
14 | dtype: jnp.dtype = jnp.float32
15 |
16 | @nn.compact
17 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
18 | """Forward pass.
19 |
20 | Args:
21 | x: input.
22 |
23 | Returns:
24 | input.
25 | """
26 | return x
27 |
28 |
29 | class InstanceNorm(nn.Module):
30 | """Instance norm.
31 |
32 | The norm is calculated on axes excluding batch and features.
33 | """
34 |
35 | dtype: jnp.dtype = jnp.float32
36 |
37 | @nn.compact
38 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
39 | """Forward pass.
40 |
41 | Args:
42 | x: input with batch axis, (batch, ..., channel).
43 |
44 | Returns:
45 | Normalised input.
46 | """
47 | reduction_axes = tuple(range(x.ndim)[slice(1, -1)])
48 | return nn.LayerNorm(
49 | reduction_axes=reduction_axes,
50 | )(x)
51 |
52 |
53 | def sinusoidal_positional_embedding(
54 | x: jnp.ndarray,
55 | dim: int,
56 | max_period: int = 10000,
57 | dtype: jnp.dtype = jnp.float32,
58 | ) -> jnp.ndarray:
59 | """Create sinusoidal timestep embeddings.
60 |
61 | Half defined by sin, half by cos.
62 | For position x, the embeddings are (for i = 0,...,half_dim-1)
63 | sin(x / (max_period ** (i/half_dim)))
64 | cos(x / (max_period ** (i/half_dim)))
65 |
66 | Args:
67 | x: (..., ), with values in [0, 1].
68 | dim: embedding dimension, assume to be evenly divided by two.
69 | max_period: controls the minimum frequency of the embeddings.
70 | dtype: dtype of the embeddings.
71 |
72 | Returns:
73 | Embedding of size (..., dim).
74 | """
75 | ndim_x = len(x.shape)
76 | if dim % 2 != 0:
77 | raise ValueError(f"dim must be evenly divided by two, got {dim}.")
78 | half_dim = dim // 2
79 | # (half_dim,)
80 | freq = jnp.arange(0, half_dim, dtype=dtype)
81 | freq = jnp.exp(-jnp.log(max_period) * freq / half_dim)
82 | # (..., half_dim)
83 | freq = jnp.expand_dims(freq, axis=tuple(range(ndim_x)))
84 | args = x[..., None] * max_period * freq
85 | # (..., dim)
86 | return jnp.concatenate([jnp.cos(args), jnp.sin(args)], axis=-1)
87 |
88 |
89 | class MLP(nn.Module):
90 | """Two-layer MLP."""
91 |
92 | emb_size: int
93 | output_size: int
94 | activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.gelu
95 | kernel_init: Callable[
96 | [jax.Array, jnp.shape, jnp.dtype], jnp.ndarray
97 | ] = nn.initializers.lecun_normal()
98 | remat: bool = True
99 | dtype: jnp.dtype = jnp.float32
100 |
101 | @nn.compact
102 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
103 | """Forward pass.
104 |
105 | Args:
106 | x: shape (..., in_size)
107 |
108 | Returns:
109 | shape (..., out_size)
110 | """
111 | dense_cls = nn.remat(nn.Dense) if self.remat else nn.Dense
112 | x = dense_cls(
113 | self.emb_size,
114 | kernel_init=self.kernel_init,
115 | dtype=self.dtype,
116 | )(x)
117 | x = self.activation(x)
118 | x = dense_cls(
119 | self.output_size,
120 | kernel_init=self.kernel_init,
121 | dtype=self.dtype,
122 | )(x)
123 | return x
124 |
--------------------------------------------------------------------------------
/imgx/model/basic_test.py:
--------------------------------------------------------------------------------
1 | """Test basic functions for model."""
2 |
3 |
4 | from functools import partial
5 |
6 | import chex
7 | import jax
8 | import jax.numpy as jnp
9 | from absl.testing import parameterized
10 | from chex._src import fake
11 |
12 | from imgx.model.basic import MLP, InstanceNorm, sinusoidal_positional_embedding
13 |
14 |
15 | # Set `FLAGS.chex_n_cpu_devices` CPU devices for all tests.
16 | def setUpModule() -> None: # pylint: disable=invalid-name
17 | """Fake multi-devices."""
18 | fake.set_n_cpu_devices(2)
19 |
20 |
21 | class TestInstanceNorm(chex.TestCase):
22 | """Test the function sinusoidal_positional_embedding."""
23 |
24 | @chex.all_variants()
25 | @parameterized.named_parameters(
26 | (
27 | "1d",
28 | (2,),
29 | ),
30 | (
31 | "2d",
32 | (2, 3),
33 | ),
34 | )
35 | def test_shapes(
36 | self,
37 | in_shape: tuple[int, ...],
38 | ) -> None:
39 | """Test output shapes under different device condition.
40 |
41 | Args:
42 | in_shape: input shape.
43 | """
44 | rng = {"params": jax.random.PRNGKey(0)}
45 | norm = InstanceNorm()
46 | x = jax.random.uniform(
47 | jax.random.PRNGKey(0),
48 | shape=in_shape,
49 | )
50 | out, _ = self.variant(norm.init_with_output)(rng, x)
51 | chex.assert_shape(out, in_shape)
52 |
53 |
54 | class TestSinusoidalPositionalEmbedding(chex.TestCase):
55 | """Test the function sinusoidal_positional_embedding."""
56 |
57 | @chex.all_variants()
58 | @parameterized.named_parameters(
59 | ("1d case 1", (2,), 4, 5),
60 | (
61 | "1d case 2",
62 | (2,),
63 | 8,
64 | 10000,
65 | ),
66 | (
67 | "2d",
68 | (2, 3),
69 | 8,
70 | 10000,
71 | ),
72 | )
73 | def test_shapes(self, in_shape: tuple[int, ...], dim: int, max_period: int) -> None:
74 | """Test output shapes under different device condition.
75 |
76 | Args:
77 | in_shape: input shape.
78 | dim: embedding dimension, assume to be evenly divided by two.
79 | max_period: controls the minimum frequency of the embeddings.
80 | """
81 | rng = jax.random.PRNGKey(0)
82 | x = jax.random.uniform(
83 | rng,
84 | shape=in_shape,
85 | )
86 | out = self.variant(
87 | partial(sinusoidal_positional_embedding, dim=dim, max_period=max_period)
88 | )(x)
89 | chex.assert_shape(out, (*in_shape, dim))
90 |
91 |
92 | class TestMLP(chex.TestCase):
93 | """Test MLP."""
94 |
95 | emb_size: int = 4
96 | output_size: int = 8
97 |
98 | @chex.all_variants()
99 | @parameterized.product(
100 | in_shape=[(3, 4, 5), (3, 4), (3,)],
101 | remat=[True, False],
102 | )
103 | def test_shapes(
104 | self,
105 | in_shape: tuple[int, ...],
106 | remat: bool,
107 | ) -> None:
108 | """Test output shapes."""
109 | rng = {"params": jax.random.PRNGKey(0)}
110 | mlp = MLP(
111 | emb_size=self.emb_size,
112 | output_size=self.output_size,
113 | remat=remat,
114 | )
115 | out, _ = self.variant(mlp.init_with_output)(rng, jnp.ones(in_shape))
116 | chex.assert_shape(out, (*in_shape[:-1], self.output_size))
117 |
--------------------------------------------------------------------------------
/imgx/model/conv_test.py:
--------------------------------------------------------------------------------
1 | """Test conv layers."""
2 |
3 | import chex
4 | import jax
5 | from absl.testing import parameterized
6 | from chex._src import fake
7 |
8 | from imgx.model.conv import ConvDownSample, ConvResBlock, ConvUpSample
9 |
10 |
11 | # Set `FLAGS.chex_n_cpu_devices` CPU devices for all tests.
12 | def setUpModule() -> None: # pylint: disable=invalid-name
13 | """Fake multi-devices."""
14 | fake.set_n_cpu_devices(2)
15 |
16 |
17 | class TestConvDownSample(chex.TestCase):
18 | """Test ConvDownSample."""
19 |
20 | batch = 2
21 |
22 | @chex.all_variants()
23 | @parameterized.named_parameters(
24 | (
25 | "1d",
26 | (12,),
27 | (2,),
28 | (6,),
29 | ),
30 | (
31 | "2d",
32 | (12, 13),
33 | (2, 2),
34 | (6, 7),
35 | ),
36 | (
37 | "2d - different scale factors",
38 | (12, 13),
39 | (4, 2),
40 | (3, 7),
41 | ),
42 | (
43 | "3d - large scale factor",
44 | (2, 4, 8),
45 | (4, 4, 4),
46 | (1, 1, 2),
47 | ),
48 | (
49 | "3d",
50 | (12, 13, 14),
51 | (2, 2, 2),
52 | (6, 7, 7),
53 | ),
54 | )
55 | def test_shapes(
56 | self,
57 | in_shape: tuple[int, ...],
58 | scale_factor: tuple[int, ...],
59 | out_shape: tuple[int, ...],
60 | ) -> None:
61 | """Test output shapes under different device condition.
62 |
63 | Args:
64 | in_shape: input shape, without batch, channel.
65 | scale_factor: downsample factor.
66 | out_shape: output shape, without batch, channel.
67 | """
68 | in_channels = 1
69 | out_channels = 1
70 | rng = {"params": jax.random.PRNGKey(0)}
71 | conv = ConvDownSample(
72 | out_channels=out_channels,
73 | scale_factor=scale_factor,
74 | )
75 | x = jax.random.uniform(
76 | jax.random.PRNGKey(0),
77 | shape=(self.batch, *in_shape, in_channels),
78 | )
79 | out, _ = self.variant(conv.init_with_output)(rng, x)
80 | chex.assert_shape(out, (self.batch, *out_shape, out_channels))
81 |
82 |
83 | class TestConvUpSample(chex.TestCase):
84 | """Test ConvUpSample."""
85 |
86 | batch = 2
87 |
88 | @chex.all_variants()
89 | @parameterized.named_parameters(
90 | (
91 | "1d",
92 | (3,),
93 | (2,),
94 | (6,),
95 | ),
96 | (
97 | "2d",
98 | (3, 4),
99 | (2, 2),
100 | (6, 8),
101 | ),
102 | (
103 | "2d - different scale factors",
104 | (3, 4),
105 | (4, 2),
106 | (12, 8),
107 | ),
108 | (
109 | "3d",
110 | (2, 3, 4),
111 | (2, 2, 2),
112 | (4, 6, 8),
113 | ),
114 | )
115 | def test_shapes(
116 | self,
117 | in_shape: tuple[int, ...],
118 | scale_factor: tuple[int, ...],
119 | out_shape: tuple[int, ...],
120 | ) -> None:
121 | """Test output shapes under different device condition.
122 |
123 | Args:
124 | in_shape: input shape, without batch, channel.
125 | scale_factor: up-sampler factor.
126 | out_shape: output shape, without batch, channel.
127 | """
128 | in_channels = 1
129 | out_channels = 1
130 | rng = {"params": jax.random.PRNGKey(0)}
131 | conv = ConvUpSample(
132 | out_channels=out_channels,
133 | scale_factor=scale_factor,
134 | )
135 | x = jax.random.uniform(
136 | jax.random.PRNGKey(0),
137 | shape=(self.batch, *in_shape, in_channels),
138 | )
139 | out, _ = self.variant(conv.init_with_output)(rng, x)
140 | chex.assert_shape(out, (self.batch, *out_shape, out_channels))
141 |
142 |
143 | class TestConvResBlock(chex.TestCase):
144 | """Test ConvResBlock."""
145 |
146 | batch = 2
147 |
148 | @chex.all_variants()
149 | @parameterized.product(
150 | in_shape=[(12,), (12, 13), (12, 13, 14)],
151 | has_t=[True, False],
152 | dropout=[0.0, 0.5, 1.0],
153 | is_train=[True, False],
154 | remat=[True, False],
155 | )
156 | def test_shapes(
157 | self,
158 | in_shape: tuple[int, ...],
159 | has_t: bool,
160 | dropout: float,
161 | is_train: bool,
162 | remat: bool,
163 | ) -> None:
164 | """Test output shapes.
165 |
166 | Args:
167 | in_shape: input shape, without batch, channel.
168 | has_t: whether has time embedding.
169 | dropout: dropout rate.
170 | is_train: whether in training mode.
171 | remat: remat or not.
172 | """
173 | kernel_size = (3,) * len(in_shape)
174 | in_channels = 1
175 | t_channels = 2
176 | out_channels = 1
177 | rng = {"params": jax.random.PRNGKey(0), "dropout": jax.random.PRNGKey(1)}
178 | conv = ConvResBlock(
179 | out_channels=out_channels,
180 | kernel_size=kernel_size,
181 | dropout=dropout,
182 | remat=remat,
183 | )
184 | x = jax.random.uniform(
185 | jax.random.PRNGKey(0),
186 | shape=(self.batch, *in_shape, in_channels),
187 | )
188 | t_emb = None
189 | if has_t:
190 | t_emb = jax.random.uniform(
191 | jax.random.PRNGKey(0),
192 | shape=(self.batch, t_channels),
193 | )
194 | out, _ = self.variant(conv.init_with_output, static_argnums=(1,))(rng, is_train, x, t_emb)
195 | chex.assert_shape(out, (self.batch, *in_shape, out_channels))
196 |
--------------------------------------------------------------------------------
/imgx/model/slice.py:
--------------------------------------------------------------------------------
1 | """Functions for slicing images."""
2 | from __future__ import annotations
3 |
4 | import jax.numpy as jnp
5 |
6 |
7 | def merge_spatial_dim_into_batch(x: jnp.ndarray, num_spatial_dims: int) -> jnp.ndarray:
8 | """Merge spatial dimensions into batch dimension.
9 |
10 | Args:
11 | x: array with original shape (batch, ..., in_channels).
12 | num_spatial_dims: target number of spatial dimensions.
13 |
14 | Returns:
15 | array with ndim=num_spatial_dims+2,
16 | shape = (extended_batch, ..., in_channels).
17 | """
18 | # e.g. if x.shape = (batch, h, w, d, in_channels)
19 | # then x.ndim == 5, num_spatial_dims = 2
20 | # axes = (0, 3, 1, 2, 4)
21 | axes = (
22 | 0,
23 | *range(num_spatial_dims + 1, x.ndim - 1),
24 | *range(1, num_spatial_dims + 1),
25 | x.ndim - 1,
26 | )
27 | # move extra dims to front
28 | # e.g. (batch, h, w, d, in_channels) -> (batch, d, h, w, in_channels)
29 | x = jnp.transpose(x, axes)
30 | # e.g. (batch, d, h, w, in_channels) -> (batch*d, h, w, in_channels)
31 | return jnp.reshape(x, (-1, *x.shape[x.ndim - num_spatial_dims - 1 :]))
32 |
33 |
34 | def split_spatial_dim_from_batch(
35 | x: jnp.ndarray,
36 | num_spatial_dims: int,
37 | batch_size: int,
38 | spatial_shape: tuple[int, ...],
39 | ) -> jnp.ndarray:
40 | """Remove spatial dimensions from batch axis.
41 |
42 | Args:
43 | x: array with merged shape (batch, ..., in_channels),
44 | x.ndim=num_spatial_dims+2.
45 | num_spatial_dims: current number of spatial dimensions.
46 | batch_size: batch size.
47 | spatial_shape: original spatial shape.
48 |
49 | Returns:
50 | array with original shape (batch, ..., in_channels).
51 | """
52 | # e.g. (batch*d, h, w, out_channels) -> (batch, d, h, w, out_channels)
53 | x = jnp.reshape(x, (batch_size, *spatial_shape[num_spatial_dims:], *x.shape[1:]))
54 |
55 | # e.g. if x.shape = (batch, d, h, w, out_channels)
56 | # then x.ndim == 5, num_spatial_dims = 2
57 | # axes = (0, 3, 1, 2, 4)
58 | axes = (
59 | 0,
60 | *range(x.ndim - 1 - num_spatial_dims, x.ndim - 1),
61 | *range(1, x.ndim - 1 - num_spatial_dims),
62 | x.ndim - 1,
63 | )
64 | # e.g. (batch, d, h, w, out_channels) -> (batch, h, w, d, out_channels)
65 | return jnp.transpose(x, axes)
66 |
--------------------------------------------------------------------------------
/imgx/model/slice_test.py:
--------------------------------------------------------------------------------
1 | """Test slicing functions."""
2 |
3 | from functools import partial
4 |
5 | import chex
6 | import jax.numpy as jnp
7 | from absl.testing import parameterized
8 | from chex._src import fake
9 |
10 | from imgx.model.slice import merge_spatial_dim_into_batch, split_spatial_dim_from_batch
11 |
12 |
13 | # Set `FLAGS.chex_n_cpu_devices` CPU devices for all tests.
14 | def setUpModule() -> None: # pylint: disable=invalid-name
15 | """Fake multi-devices."""
16 | fake.set_n_cpu_devices(2)
17 |
18 |
19 | class TestMergeSplitSpatialDims(chex.TestCase):
20 | """Test merge_spatial_dim_into_batch and split_spatial_dim_from_batch."""
21 |
22 | @chex.all_variants()
23 | @parameterized.named_parameters(
24 | ("2d-1", (2, 3, 4, 5), 1, (8, 3, 5)),
25 | ("2d-2", (2, 3, 4, 5), 2, (2, 3, 4, 5)),
26 | ("3d-1", (2, 3, 4, 5, 6), 1, (40, 3, 6)),
27 | ("3d-2", (2, 3, 4, 5, 6), 2, (10, 3, 4, 6)),
28 | ("3d-3", (2, 3, 4, 5, 6), 3, (2, 3, 4, 5, 6)),
29 | )
30 | def test_merge_spatial_dim_into_batch(
31 | self,
32 | in_shape: tuple[int, ...],
33 | num_spatial_dims: int,
34 | expected_shape: tuple[int, ...],
35 | ) -> None:
36 | """Test merge_spatial_dim_into_batch.
37 |
38 | Args:
39 | in_shape: input shape.
40 | num_spatial_dims: number of spatial dimensions.
41 | expected_shape: expected output shape.
42 | """
43 | x = jnp.ones(in_shape)
44 | x = self.variant(partial(merge_spatial_dim_into_batch, num_spatial_dims=num_spatial_dims))(
45 | x
46 | )
47 | chex.assert_shape(x, expected_shape)
48 |
49 | @chex.all_variants()
50 | @parameterized.named_parameters(
51 | ("2d-1", (8, 3, 5), 1, 2, (3, 4), (2, 3, 4, 5)),
52 | ("2d-2", (2, 3, 4, 5), 2, 2, (3, 4), (2, 3, 4, 5)),
53 | ("3d-1", (40, 3, 6), 1, 2, (3, 4, 5), (2, 3, 4, 5, 6)),
54 | ("3d-2", (10, 3, 4, 6), 2, 2, (3, 4, 5), (2, 3, 4, 5, 6)),
55 | ("3d-3", (2, 3, 4, 5, 6), 3, 2, (3, 4, 5), (2, 3, 4, 5, 6)),
56 | )
57 | def test_split_spatial_dim_from_batch(
58 | self,
59 | in_shape: tuple[int, ...],
60 | num_spatial_dims: int,
61 | batch_size: int,
62 | spatial_shape: tuple[int, ...],
63 | expected_shape: tuple[int, ...],
64 | ) -> None:
65 | """Test split_spatial_dim_from_batch.
66 |
67 | Args:
68 | in_shape: input shape, with certain spatial axes merged.
69 | num_spatial_dims: number of spatial dimensions.
70 | batch_size: batch size.
71 | spatial_shape: spatial shape.
72 | expected_shape: expected output shape.
73 | """
74 | x = jnp.ones(in_shape)
75 | x = self.variant(
76 | partial(
77 | split_spatial_dim_from_batch,
78 | num_spatial_dims=num_spatial_dims,
79 | spatial_shape=spatial_shape,
80 | batch_size=batch_size,
81 | )
82 | )(x)
83 | chex.assert_shape(x, expected_shape)
84 |
--------------------------------------------------------------------------------
/imgx/model/unet/__init__.py:
--------------------------------------------------------------------------------
1 | """Module to build a UNet model."""
2 |
--------------------------------------------------------------------------------
/imgx/model/unet/bottom_encoder.py:
--------------------------------------------------------------------------------
1 | """Image encoder for unet."""
2 | from __future__ import annotations
3 |
4 | import flax.linen as nn
5 | import jax.numpy as jnp
6 |
7 | from imgx.model.attention import TransformerEncoder
8 | from imgx.model.conv import ConvResBlock
9 |
10 |
11 | class BottomImageEncoderUnet(nn.Module):
12 | """Image encoder module with convolutions for unet."""
13 |
14 | kernel_size: tuple[int, ...] # convolution layer kernel size
15 | dropout: float = 0.0 # for resnet block
16 | num_heads: int = 8 # for multi head attention
17 | num_layers: int = 1 # for transformer encoder
18 | widening_factor: int = 4 # for key size in MHA
19 | remat: bool = True # reduces memory cost at cost of compute speed
20 | dtype: jnp.dtype = jnp.float32
21 |
22 | @nn.compact
23 | def __call__(
24 | self,
25 | is_train: bool,
26 | image_emb: jnp.ndarray,
27 | t_emb: jnp.ndarray | None,
28 | ) -> jnp.ndarray:
29 | """Encoder the image.
30 |
31 | Args:
32 | is_train: whether in training mode.
33 | image_emb: shape (batch, *spatial_shape, model_size).
34 | t_emb: time embedding, (batch, t_channels).
35 |
36 | Returns:
37 | image_emb: (batch, *spatial_shape, model_size).
38 | """
39 | model_size = image_emb.shape[-1]
40 |
41 | # conv before attention
42 | # image_emb (batch, *spatial_shape, image_emb_size)
43 | image_emb = ConvResBlock(
44 | out_channels=model_size,
45 | kernel_size=self.kernel_size,
46 | dropout=self.dropout,
47 | remat=self.remat,
48 | )(is_train, image_emb, t_emb)
49 |
50 | # attention
51 | image_emb = TransformerEncoder(
52 | num_heads=self.num_heads,
53 | widening_factor=self.widening_factor,
54 | dropout=self.dropout,
55 | remat=self.remat,
56 | dtype=self.dtype,
57 | )(is_train, image_emb)
58 |
59 | # conv after attention
60 | # image_emb (batch, *spatial_shape, image_emb_size)
61 | image_emb = ConvResBlock(
62 | out_channels=model_size,
63 | kernel_size=self.kernel_size,
64 | dropout=self.dropout,
65 | remat=self.remat,
66 | )(is_train, image_emb, t_emb)
67 |
68 | return image_emb
69 |
--------------------------------------------------------------------------------
/imgx/model/unet/bottom_encoder_test.py:
--------------------------------------------------------------------------------
1 | """Test Unet bottom encoder related classes and functions."""
2 |
3 | import chex
4 | import jax
5 | import jax.numpy as jnp
6 | from absl.testing import parameterized
7 | from chex._src import fake
8 |
9 | from imgx.model.unet.bottom_encoder import BottomImageEncoderUnet
10 |
11 |
12 | # Set `FLAGS.chex_n_cpu_devices` CPU devices for all tests.
13 | def setUpModule() -> None: # pylint: disable=invalid-name
14 | """Fake multi-devices."""
15 | fake.set_n_cpu_devices(2)
16 |
17 |
18 | class TestBottomEncoderUnet(chex.TestCase):
19 | """Test the class BottomEncoderUnet."""
20 |
21 | batch_size = 2
22 | model_size = 16
23 | num_heads = 2
24 | t_size = 3
25 |
26 | @chex.all_variants()
27 | @parameterized.product(
28 | spatial_shape=[(5, 6), (5, 6, 7)],
29 | with_time=[True, False],
30 | is_train=[True, False],
31 | )
32 | def test_output_shape(
33 | self,
34 | spatial_shape: tuple[int, ...],
35 | with_time: bool,
36 | is_train: bool,
37 | ) -> None:
38 | """Test output shape."""
39 | rng = {"params": jax.random.PRNGKey(0)}
40 | kernel_size = (3,) * len(spatial_shape)
41 | encoder = BottomImageEncoderUnet(num_heads=self.num_heads, kernel_size=kernel_size)
42 | image_emb = jnp.ones((self.batch_size, *spatial_shape, self.model_size))
43 | t_emb = jnp.ones((self.batch_size, self.t_size)) if with_time else None
44 | out, _ = self.variant(encoder.init_with_output, static_argnums=(1,))(
45 | rng, is_train, image_emb, t_emb
46 | )
47 | chex.assert_shape(out, (self.batch_size, *spatial_shape, self.model_size))
48 |
--------------------------------------------------------------------------------
/imgx/model/unet/downsample_encoder.py:
--------------------------------------------------------------------------------
1 | """Downsample encoder for unet."""
2 | from __future__ import annotations
3 |
4 | import flax.linen as nn
5 | import jax.numpy as jnp
6 |
7 | from imgx.model.conv import ConvDownSample, ConvNormAct, ConvResBlock
8 |
9 |
10 | class DownsampleEncoder(nn.Module):
11 | """Down-sample encoder module with convolutions for unet."""
12 |
13 | num_channels: tuple[int, ...] # channel at each depth, including the bottom
14 | patch_size: tuple[int, ...] # first down sampling layer
15 | scale_factor: tuple[int, ...] # spatial down-sampling/up-sampling
16 | kernel_size: tuple[int, ...] # convolution layer kernel size
17 | num_res_blocks: int = 2 # number of residual blocks
18 | dropout: float = 0.0 # for resnet block
19 | remat: bool = True # reduces memory cost at cost of compute speed
20 | dtype: jnp.dtype = jnp.float32
21 |
22 | @nn.compact
23 | def __call__(
24 | self,
25 | is_train: bool,
26 | x: jnp.ndarray,
27 | t_emb: jnp.ndarray | None,
28 | ) -> list[jnp.ndarray]:
29 | """Encoder the image.
30 |
31 | If batch_size = 2, image_shape = (256, 256, 32), num_channels = (1,2,4)
32 | with num_res_blocks = 2, patch_size = 4
33 | the embeddings' shape are:
34 | (2, 256, 256, 32, 1), from first residual block before for loop
35 | (2, 256, 256, 32, 1), from residual block, i=0
36 | (2, 256, 256, 32, 1), from residual block, i=0
37 | (2, 64, 64, 8, 2), from down-sampling block, i=0
38 | (2, 64, 64, 8, 2), from residual block, i=1
39 | (2, 64, 64, 8, 2), from residual block, i=1
40 | (2, 32, 32, 4, 4), from down-sampling block, i=1
41 | (2, 32, 32, 4, 4), from residual block, i=2
42 | (2, 32, 32, 4, 4), from residual block, i=2
43 |
44 | Args:
45 | is_train: whether in training mode.
46 | x: array of shape (batch, *spatial_shape, in_channels).
47 | t_emb: array of shape (batch, t_channels).
48 |
49 | Returns:
50 | List of embeddings from each layer.
51 | """
52 | conv_down_sample_cls = nn.remat(ConvDownSample) if self.remat else ConvDownSample
53 |
54 | # encoder raw input
55 | x = ConvNormAct(
56 | out_channels=self.num_channels[0],
57 | kernel_size=self.kernel_size,
58 | remat=self.remat,
59 | )(x)
60 |
61 | # encoding
62 | embeddings = [x]
63 | for i, ch in enumerate(self.num_channels):
64 | # residual blocks
65 | # spatial shape get halved by 2**i
66 | for _ in range(self.num_res_blocks):
67 | x = ConvResBlock(
68 | out_channels=ch,
69 | kernel_size=self.kernel_size,
70 | dropout=self.dropout,
71 | remat=self.remat,
72 | )(is_train, x, t_emb)
73 | embeddings.append(x)
74 |
75 | # down-sampling for non-bottom layers
76 | # spatial shape get halved by 2**(i+1)
77 | if i < len(self.num_channels) - 1:
78 | x = conv_down_sample_cls(
79 | out_channels=self.num_channels[i + 1],
80 | scale_factor=self.patch_size if i == 0 else self.scale_factor,
81 | )(x)
82 | embeddings.append(x)
83 |
84 | return embeddings
85 |
--------------------------------------------------------------------------------
/imgx/model/unet/upsample_decoder.py:
--------------------------------------------------------------------------------
1 | """Upsample encoder for unet."""
2 | from __future__ import annotations
3 |
4 | import flax.linen as nn
5 | import jax
6 | import jax.numpy as jnp
7 | from jax import lax
8 |
9 | from imgx.model.conv import ConvResBlock, ConvUpSample
10 |
11 |
12 | class UpsampleDecoder(nn.Module):
13 | """Upsample decoder module with convolutions for unet."""
14 |
15 | out_channels: int
16 | num_channels: tuple[int, ...] # channel at each depth, including the bottom
17 | patch_size: tuple[int, ...] # first down sampling layer
18 | scale_factor: tuple[int, ...] # spatial down-sampling/up-sampling
19 | kernel_size: tuple[int, ...] # convolution layer kernel size
20 | num_res_blocks: int = 2 # number of residual blocks
21 | dropout: float = 0.0 # for resnet block
22 | out_kernel_init: jax.nn.initializers.Initializer = nn.linear.default_kernel_init
23 | remat: bool = True # remat reduces memory cost at cost of compute speed
24 | dtype: jnp.dtype = jnp.float32
25 |
26 | @nn.compact
27 | def __call__(
28 | self,
29 | is_train: bool,
30 | embeddings: list[jnp.ndarray],
31 | t_emb: jnp.ndarray | None,
32 | ) -> jnp.ndarray | list[jnp.ndarray]:
33 | """Decode the embedding and perform prediction.
34 |
35 | Args:
36 | is_train: whether in training mode.
37 | embeddings: list of embeddings from each layer.
38 | Starting with the first layer.
39 | t_emb: array of shape (batch, t_channels).
40 |
41 | Returns:
42 | Logits (batch, ..., out_channels).
43 | """
44 | if len(embeddings) != len(self.num_channels) * (self.num_res_blocks + 1) + 1:
45 | raise ValueError("MaskDecoderConvUnet input length does not match")
46 | num_spatial_dims = len(self.kernel_size)
47 | conv_up_sample_cls = nn.remat(ConvUpSample) if self.remat else ConvUpSample
48 | conv_cls = nn.remat(nn.Conv) if self.remat else nn.Conv
49 |
50 | # spatial shape get halved by 2**(len(self.num_channels)-1)
51 | # channel = self.num_channels[-1]
52 | x = embeddings.pop()
53 |
54 | for i, ch in enumerate(self.num_channels[::-1]):
55 | # spatial shape get halved by 2**(len(self.num_channels)-1-i)
56 | # channel = ch
57 | for _ in range(self.num_res_blocks + 1):
58 | # add skipped
59 | # use addition instead of concatenation to reduce memory cost
60 | skipped = embeddings.pop()
61 | x += skipped
62 |
63 | # conv
64 | x = ConvResBlock(
65 | out_channels=ch,
66 | kernel_size=self.kernel_size,
67 | dropout=self.dropout,
68 | remat=self.remat,
69 | )(is_train, x, t_emb)
70 |
71 | if i < len(self.num_channels) - 1:
72 | # up-sampling
73 | # skipped.shape <= up-scaled shape
74 | # as padding may be added when down-sampling
75 | skipped_shape = embeddings[-1].shape[1:-1]
76 | # deconv and pad to make emb of same shape as skipped
77 | x = conv_up_sample_cls(
78 | out_channels=self.num_channels[-i - 2],
79 | scale_factor=self.patch_size
80 | if i == len(self.num_channels) - 2
81 | else self.scale_factor,
82 | )(x)
83 | x = lax.dynamic_slice(
84 | x,
85 | start_indices=(0,) * (num_spatial_dims + 2),
86 | slice_sizes=(x.shape[0], *skipped_shape, x.shape[-1]),
87 | )
88 | out = conv_cls(
89 | features=self.out_channels,
90 | kernel_size=(1,) * num_spatial_dims,
91 | kernel_init=self.out_kernel_init,
92 | )(x)
93 | return out
94 |
--------------------------------------------------------------------------------
/imgx/run_valid.py:
--------------------------------------------------------------------------------
1 | """Script to launch evaluation on validation tests."""
2 | import argparse
3 | import json
4 | from pathlib import Path
5 |
6 | import jax
7 | from absl import logging
8 | from flax import jax_utils
9 | from flax.training import common_utils
10 | from omegaconf import DictConfig, OmegaConf
11 |
12 | from imgx.data.iterator import get_image_tfds_dataset
13 | from imgx.run_train import build_experiment
14 |
15 | logging.set_verbosity(logging.INFO)
16 |
17 |
18 | def get_checkpoint_steps(
19 | log_dir: Path,
20 | ) -> list[int]:
21 | """Get the steps of all available checkpoints.
22 |
23 | Args:
24 | log_dir: Directory of entire log.
25 |
26 | Returns:
27 | A list of available steps.
28 |
29 | Raises:
30 | ValueError: if any file not found.
31 | """
32 | ckpt_dir = log_dir / "files" / "ckpt"
33 | steps = []
34 | for step_dir in ckpt_dir.glob("checkpoint_*/"):
35 | if not step_dir.is_dir():
36 | continue
37 | ckpt_path = step_dir / "checkpoint"
38 | if not ckpt_path.exists():
39 | continue
40 | steps.append(int(step_dir.stem.split("_")[-1]))
41 | return sorted(steps, reverse=True)
42 |
43 |
44 | def parse_args() -> argparse.Namespace:
45 | """Parse arguments."""
46 | parser = argparse.ArgumentParser(description=__doc__)
47 | parser.add_argument(
48 | "--log_dir",
49 | type=Path,
50 | help="Folder of wandb.",
51 | default=None,
52 | )
53 | parser.add_argument(
54 | "--num_timesteps",
55 | type=int,
56 | help="Number of sampling steps for diffusion_segmentation.",
57 | default=-1,
58 | )
59 | parser.add_argument(
60 | "--sampler",
61 | type=str,
62 | help="Sampling algorithm for diffusion_segmentation.",
63 | default="",
64 | choices=["", "DDPM", "DDIM"],
65 | )
66 | args = parser.parse_args()
67 |
68 | return args
69 |
70 |
71 | def load_and_parse_config(
72 | log_dir: Path,
73 | num_timesteps: int,
74 | sampler: str,
75 | ) -> DictConfig:
76 | """Load and parse config.
77 |
78 | Args:
79 | log_dir: Directory of entire log.
80 | num_timesteps: Number of sampling steps for diffusion_segmentation.
81 | sampler: Sampling algorithm for diffusion_segmentation.
82 |
83 | Returns:
84 | Loaded config.
85 | """
86 | config = OmegaConf.load(log_dir / "files" / "config_backup.yaml")
87 | if config.task.name == "diffusion_segmentation":
88 | if num_timesteps <= 0:
89 | raise ValueError("num_timesteps required for diffusion.")
90 | config.task.sampler.num_inference_timesteps = num_timesteps
91 | logging.info(f"Sampling {num_timesteps} steps.")
92 | if not sampler:
93 | raise ValueError("sampler required for diffusion.")
94 | config.task.sampler.name = sampler
95 | logging.info(f"Using sampler {sampler}.")
96 | return config
97 |
98 |
99 | def main() -> None:
100 | """Main function."""
101 | args = parse_args()
102 | logging.info(f"Local devices are: {jax.local_devices()}")
103 |
104 | # load config
105 | config = load_and_parse_config(
106 | log_dir=args.log_dir, num_timesteps=args.num_timesteps, sampler=args.sampler
107 | )
108 |
109 | # find all available checkpoints
110 | steps = get_checkpoint_steps(log_dir=args.log_dir)
111 |
112 | key = jax.random.PRNGKey(config.seed)
113 | key = common_utils.shard_prng_key(key) # each replica has a different key
114 |
115 | # init data
116 | dataset = get_image_tfds_dataset(
117 | dataset_name=config.data.name,
118 | config=config,
119 | )
120 | train_iter = dataset.train_iter
121 | valid_iter = dataset.valid_iter
122 | platform = jax.local_devices()[0].platform
123 | if platform not in ["cpu", "tpu"]:
124 | train_iter = jax_utils.prefetch_to_device(train_iter, 2)
125 | valid_iter = jax_utils.prefetch_to_device(valid_iter, 2)
126 |
127 | # evaluate
128 | ckpt_dir = args.log_dir / "files" / "ckpt"
129 | run = build_experiment(config=config)
130 | for step in steps:
131 | logging.info(f"Starting valid split evaluation for step {step}.")
132 |
133 | # load checkpoint
134 | batch = next(train_iter)
135 | train_state, _ = run.train_init(batch=batch, ckpt_dir=ckpt_dir, step=step)
136 |
137 | # evaluation
138 | val_metrics = run.eval_step(
139 | train_state=train_state, iterator=valid_iter, num_steps=dataset.num_valid_steps, key=key
140 | )
141 |
142 | # save metrics
143 | out_dir = ckpt_dir / f"checkpoint_{step}"
144 | if config.task.name == "diffusion_segmentation":
145 | out_dir = out_dir / config.task.sampler.name
146 | out_dir.mkdir(parents=True, exist_ok=True)
147 | with open(out_dir / "mean_metrics.json", "w", encoding="utf-8") as f:
148 | json.dump(val_metrics, f, sort_keys=True, indent=4)
149 |
150 |
151 | if __name__ == "__main__":
152 | main()
153 |
--------------------------------------------------------------------------------
/imgx/task/__init__.py:
--------------------------------------------------------------------------------
1 | """Module for different learning tasks."""
2 |
--------------------------------------------------------------------------------
/imgx/task/diffusion_segmentation/__init__.py:
--------------------------------------------------------------------------------
1 | """Diffusion-based segmentation task."""
2 |
--------------------------------------------------------------------------------
/imgx/task/diffusion_segmentation/diffusion.py:
--------------------------------------------------------------------------------
1 | """Module for diffusion segmentation."""
2 | from __future__ import annotations
3 |
4 | from dataclasses import dataclass
5 |
6 | import jax.numpy as jnp
7 |
8 | from imgx.diffusion.diffusion import Diffusion
9 |
10 |
11 | @dataclass
12 | class DiffusionSegmentation(Diffusion):
13 | """Base class for segmentation."""
14 |
15 | def mask_to_x(self, mask: jnp.ndarray) -> jnp.ndarray:
16 | """Convert mask to x.
17 |
18 | Args:
19 | mask: boolean segmentation mask.
20 |
21 | Returns:
22 | array in diffusion space.
23 | """
24 | raise NotImplementedError
25 |
26 | def x_to_mask(self, x: jnp.ndarray) -> jnp.ndarray:
27 | """Convert x to mask.
28 |
29 | Args:
30 | x: array in diffusion space.
31 |
32 | Returns:
33 | boolean segmentation mask.
34 | """
35 | raise NotImplementedError
36 |
37 | def x_to_logits(self, x: jnp.ndarray) -> jnp.ndarray:
38 | """Convert x into model output space, which is logits.
39 |
40 | Args:
41 | x: array in diffusion space.
42 |
43 | Returns:
44 | unnormalised logits.
45 | """
46 | raise NotImplementedError
47 |
48 | def model_out_to_logits_start(
49 | self, model_out: jnp.ndarray, x_t: jnp.ndarray, t_index: jnp.ndarray
50 | ) -> jnp.ndarray:
51 | """Convert model outputs to logits at time 0, noiseless.
52 |
53 | Args:
54 | model_out: model outputs.
55 | x_t: noisy x at time t.
56 | t_index: storing index values < self.num_timesteps.
57 |
58 | Returns:
59 | logits.
60 | """
61 | raise NotImplementedError
62 |
--------------------------------------------------------------------------------
/imgx/task/diffusion_segmentation/gaussian_diffusion_test.py:
--------------------------------------------------------------------------------
1 | """Test Gaussian diffusion related classes and functions."""
2 |
3 |
4 | import chex
5 | from chex._src import fake
6 |
7 | from imgx.task.diffusion_segmentation.gaussian_diffusion import GaussianDiffusionSegmentation
8 |
9 |
10 | # Set `FLAGS.chex_n_cpu_devices` CPU devices for all tests.
11 | def setUpModule() -> None: # pylint: disable=invalid-name
12 | """Fake multi-devices."""
13 | fake.set_n_cpu_devices(2)
14 |
15 |
16 | class TestGaussianDiffusionSegmentation(chex.TestCase):
17 | """Test the class GaussianDiffusion."""
18 |
19 | batch_size = 2
20 |
21 | # unet
22 | in_channels = 1
23 | num_classes = 2
24 | num_channels = (1, 2)
25 |
26 | num_timesteps = 5
27 | num_timesteps_beta = 1001
28 | beta_schedule = "linear"
29 | beta_start = 0.0001
30 | beta_end = 0.02
31 |
32 | def test_attributes(
33 | self,
34 | ) -> None:
35 | """Test attribute shape."""
36 | gd = GaussianDiffusionSegmentation.create(
37 | classes_are_exclusive=True,
38 | num_timesteps=self.num_timesteps,
39 | num_timesteps_beta=self.num_timesteps_beta,
40 | beta_schedule=self.beta_schedule,
41 | beta_start=self.beta_start,
42 | beta_end=self.beta_end,
43 | model_out_type="x_start",
44 | model_var_type="fixed_large",
45 | )
46 |
47 | chex.assert_shape(gd.betas, (self.num_timesteps,))
48 | chex.assert_shape(gd.alphas_cumprod, (self.num_timesteps,))
49 | chex.assert_shape(gd.alphas_cumprod_prev, (self.num_timesteps,))
50 | chex.assert_shape(gd.alphas_cumprod_next, (self.num_timesteps,))
51 | chex.assert_shape(gd.sqrt_alphas_cumprod, (self.num_timesteps,))
52 | chex.assert_shape(gd.sqrt_one_minus_alphas_cumprod, (self.num_timesteps,))
53 | chex.assert_shape(gd.log_one_minus_alphas_cumprod, (self.num_timesteps,))
54 | chex.assert_shape(gd.sqrt_recip_alphas_cumprod, (self.num_timesteps,))
55 | chex.assert_shape(gd.sqrt_recip_alphas_cumprod_minus_one, (self.num_timesteps,))
56 | chex.assert_shape(gd.posterior_mean_coeff_start, (self.num_timesteps,))
57 | chex.assert_shape(gd.posterior_mean_coeff_t, (self.num_timesteps,))
58 | chex.assert_shape(gd.posterior_variance, (self.num_timesteps,))
59 | chex.assert_shape(gd.posterior_log_variance_clipped, (self.num_timesteps,))
60 |
--------------------------------------------------------------------------------
/imgx/task/diffusion_segmentation/save.py:
--------------------------------------------------------------------------------
1 | """Segmentation related io (file cannot be named as io).
2 |
3 | https://stackoverflow.com/questions/26569828/pycharm-py-initialize-cant-initialize-sys-standard-streams
4 | """
5 | from __future__ import annotations
6 |
7 | from pathlib import Path
8 |
9 | import numpy as np
10 |
11 | from imgx.task.segmentation.save import save_segmentation_prediction
12 |
13 |
14 | def save_diffusion_segmentation_prediction(
15 | label_pred: np.ndarray,
16 | uids: list[str],
17 | out_dir: Path | None,
18 | tfds_dir: Path,
19 | reference_suffix: str = "mask_preprocessed",
20 | output_suffix: str = "mask_pred",
21 | ) -> None:
22 | """Save segmentation predictions.
23 |
24 | Args:
25 | label_pred: (num_samples, ..., num_timesteps), the values are integers.
26 | uids: (num_samples,).
27 | out_dir: output directory.
28 | tfds_dir: directory saving preprocessed images and labels.
29 | reference_suffix: suffix of reference image.
30 | output_suffix: suffix of output image.
31 | """
32 | if out_dir is None:
33 | return
34 | num_timesteps = label_pred.shape[-1]
35 | for i in range(num_timesteps):
36 | save_segmentation_prediction(
37 | label_pred=label_pred[..., i],
38 | uids=uids,
39 | out_dir=out_dir / f"step_{i}",
40 | tfds_dir=tfds_dir,
41 | reference_suffix=reference_suffix,
42 | output_suffix=output_suffix,
43 | )
44 |
--------------------------------------------------------------------------------
/imgx/task/diffusion_segmentation/train_state.py:
--------------------------------------------------------------------------------
1 | """Training state and checkpoints.
2 |
3 | https://github.com/google/flax/blob/main/examples/imagenet/train.py
4 | """
5 | from __future__ import annotations
6 |
7 | from typing import Callable
8 |
9 | import chex
10 | import flax.linen as nn
11 | import jax
12 | import jax.numpy as jnp
13 | from absl import logging
14 | from flax.training import dynamic_scale as dynamic_scale_lib
15 | from omegaconf import DictConfig
16 |
17 | from imgx.train_state import TrainState as BaseTrainState
18 | from imgx.train_state import init_optimizer
19 |
20 |
21 | class TrainState(BaseTrainState):
22 | """Train state.
23 |
24 | If using nn.BatchNorm, batch_stats needs to be tracked.
25 | https://flax.readthedocs.io/en/latest/guides/batch_norm.html
26 | https://github.com/google/flax/blob/main/examples/imagenet/train.py
27 | """
28 |
29 | loss_count_hist: jnp.ndarray # mutable
30 | loss_sq_hist: jnp.ndarray # mutable
31 |
32 |
33 | def create_train_state(
34 | key: jax.Array,
35 | batch: dict[str, jnp.ndarray],
36 | model: nn.Module,
37 | config: DictConfig,
38 | initialized: Callable[[jax.Array, chex.ArrayTree, nn.Module], chex.ArrayTree],
39 | ) -> TrainState:
40 | """Create initial training state.
41 |
42 | Args:
43 | key: random key.
44 | batch: batch data for determining input shapes.
45 | model: model.
46 | config: entire configuration.
47 | initialized: function to get initialized model parameters.
48 |
49 | Returns:
50 | initial training state.
51 | """
52 | dynamic_scale = None
53 | platform = jax.local_devices()[0].platform
54 | if config.half_precision and platform == "gpu":
55 | dynamic_scale = dynamic_scale_lib.DynamicScale()
56 |
57 | params = initialized(key, batch, model)
58 |
59 | # count params
60 | params_count = sum(x.size for x in jax.tree_util.tree_leaves(params))
61 | logging.info(f"The model has {params_count:,} parameters.")
62 |
63 | # diffusion related
64 | num_timesteps = config.task.diffusion.num_timesteps
65 | loss_count_hist = jnp.zeros((num_timesteps,), dtype=jnp.int32)
66 | loss_sq_hist = jnp.zeros((num_timesteps,), dtype=jnp.float32)
67 |
68 | tx = init_optimizer(config=config)
69 | state = TrainState.create(
70 | apply_fn=model.apply,
71 | params=params,
72 | tx=tx,
73 | dynamic_scale=dynamic_scale,
74 | loss_count_hist=loss_count_hist,
75 | loss_sq_hist=loss_sq_hist,
76 | )
77 | return state
78 |
--------------------------------------------------------------------------------
/imgx/task/segmentation/__init__.py:
--------------------------------------------------------------------------------
1 | """Segmentation task."""
2 |
--------------------------------------------------------------------------------
/imgx/task/segmentation/save.py:
--------------------------------------------------------------------------------
1 | """Segmentation related io (file cannot be named as io).
2 |
3 | https://stackoverflow.com/questions/26569828/pycharm-py-initialize-cant-initialize-sys-standard-streams
4 | """
5 | from __future__ import annotations
6 |
7 | from pathlib import Path
8 |
9 | import numpy as np
10 | import SimpleITK as sitk # noqa: N813
11 |
12 | from imgx.datasets.save import save_image
13 |
14 |
15 | def save_segmentation_prediction(
16 | label_pred: np.ndarray,
17 | uids: list[str],
18 | out_dir: Path | None,
19 | tfds_dir: Path,
20 | reference_suffix: str = "mask_preprocessed",
21 | output_suffix: str = "mask_pred",
22 | ) -> None:
23 | """Save segmentation predictions.
24 |
25 | Args:
26 | label_pred: (num_samples, ...), the values are integers.
27 | uids: (num_samples,).
28 | out_dir: output directory.
29 | tfds_dir: directory saving preprocessed images and labels.
30 | reference_suffix: suffix of reference image.
31 | output_suffix: suffix of output image.
32 | """
33 | if out_dir is None:
34 | return
35 | if label_pred.ndim == 3 and np.max(label_pred) > 1:
36 | raise ValueError(
37 | f"Prediction values should be 0 or 1, but "
38 | f"max value is {np.max(label_pred)}. "
39 | f"Multi-class segmentation for 2D images are not supported."
40 | )
41 | if label_pred.ndim not in [3, 4]:
42 | raise ValueError(
43 | f"Prediction should be 3D or 4D with num_samples axis, but {label_pred.ndim}D is given."
44 | )
45 | file_suffix = "nii.gz" if label_pred.ndim == 4 else "png"
46 | out_dir.mkdir(parents=True, exist_ok=True)
47 | for i, uid in enumerate(uids):
48 | reference_image = sitk.ReadImage(tfds_dir / f"{uid}_{reference_suffix}.{file_suffix}")
49 | save_image(
50 | image=label_pred[i, ...],
51 | reference_image=reference_image,
52 | out_path=out_dir / f"{uid}_{output_suffix}.{file_suffix}",
53 | dtype=np.uint8,
54 | )
55 |
--------------------------------------------------------------------------------
/imgx/task/util.py:
--------------------------------------------------------------------------------
1 | """Shared utility functions."""
2 | from __future__ import annotations
3 |
4 | from jax import numpy as jnp
5 |
6 |
7 | def decode_uids(uids: jnp.ndarray) -> list[str]:
8 | """Decode uids.
9 |
10 | Args:
11 | uids: uids in bytes or int.
12 |
13 | Returns:
14 | decoded uids.
15 | """
16 | decoded = []
17 | for x in uids.tolist():
18 | if isinstance(x, bytes):
19 | decoded.append(x.decode("utf-8"))
20 | elif x == 0:
21 | # the batch was not complete, padded with zero
22 | decoded.append("")
23 | else:
24 | raise ValueError(f"uid {x} is not supported.")
25 | return decoded
26 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | # package
2 | [build-system]
3 | requires = ["setuptools", "setuptools-scm"]
4 | build-backend = "setuptools.build_meta"
5 |
6 | [project]
7 | name = "imgx"
8 | authors = [
9 | {name = "Yunguan Fu", email = "yunguan.fu.18@ucl.ac.uk"},
10 | ]
11 | description = "A Jax-based deep learning toolkit for biomedical applications."
12 | requires-python = ">=3.9"
13 | license = {text = "Apache-2.0"}
14 | version = "0.3.2"
15 |
16 | [project.scripts]
17 | imgx_train="imgx.run_train:main"
18 | imgx_valid="imgx.run_valid:main"
19 | imgx_test="imgx.run_test:main"
20 |
21 | [tool.setuptools]
22 | packages = ["imgx"]
23 | package-dir = {"imgx"="./imgx"}
24 |
25 | # pytest
26 | [tool.pytest.ini_options]
27 | markers = [
28 | "slow", # slow unit tests
29 | "integration", # integration tests
30 | ]
31 |
32 | # pre-commit
33 | [tool.isort]
34 | py_version=39
35 | known_third_party = ["SimpleITK","absl","chex","hydra","jax","numpy","omegaconf","optax",
36 | "pandas","pytest","ray","rdkit","setuptools","tensorflow","tensorflow_datasets","tree","wandb"]
37 | multi_line_output = 3
38 | force_grid_wrap = 0
39 | line_length = 100
40 | include_trailing_comma = true
41 | use_parentheses = true
42 |
43 | [tool.mypy]
44 | python_version = "3.9"
45 | warn_unused_configs = true
46 | ignore_missing_imports = true
47 | disable_error_code = ["misc","attr-defined","call-arg"]
48 | show_error_codes = true
49 | files = "**/*.py"
50 |
51 | [tool.ruff]
52 | line-length = 100
53 | # Enable Pyflakes `E` and `F` codes by default.
54 | select = [
55 | "F", # Pyflakes
56 | "E", "W", # pycodestyle
57 | "UP", # pyupgrade
58 | "N", # pep8-naming
59 | # flake8
60 | "YTT",
61 | "ANN",
62 | "S",
63 | "BLE",
64 | "B",
65 | "A",
66 | "C4",
67 | "T10",
68 | "EM",
69 | "ISC",
70 | "ICN",
71 | "T20",
72 | "PT",
73 | "Q",
74 | "RET",
75 | "SIM",
76 | "TID",
77 | "ARG",
78 | "DTZ",
79 | "PIE",
80 | "PGH", # pygrep-hooks
81 | "RUF", # ruff
82 | "PLC", "PLE", "PLR", "PLW", # Pylint
83 | ]
84 | ignore = [
85 | "ANN002", # MissingTypeArgs
86 | "ANN003", # MissingTypeKwargs
87 | "ANN101", # MissingTypeSelf
88 | "EM101", # Exception must not use a string literal, assign to variable first
89 | "EM102", # Exception must not use an f-string literal, assign to variable first
90 | "RET504", # Unnecessary variable assignment before `return` statement
91 | "S301", # `pickle` and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue
92 | "PLR0913", # Too many arguments to function call
93 | "PLR0915", # Too many statements
94 | "PLE0605", # Invalid format for `__all__`, must be `tuple` or `list`
95 | "PLR0912", # Too many branches
96 | ]
97 | # Exclude a variety of commonly ignored directories.
98 | exclude = [
99 | ".bzr",
100 | ".direnv",
101 | ".eggs",
102 | ".git",
103 | ".hg",
104 | ".mypy_cache",
105 | ".nox",
106 | ".pants.d",
107 | ".ruff_cache",
108 | ".svn",
109 | ".tox",
110 | ".venv",
111 | "__pypackages__",
112 | "_build",
113 | "buck-out",
114 | "build",
115 | "dist",
116 | "node_modules",
117 | "venv",
118 | ]
119 | # Allow unused variables when underscore-prefixed.
120 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
121 | # Assume Python 3.9
122 | target-version = "py39"
123 |
124 | [tool.ruff.mccabe]
125 | # Unlike Flake8, default to a complexity level of 10.
126 | max-complexity = 10
127 |
128 | [tool.ruff.per-file-ignores]
129 | "test_*.py" = ["S101"]
130 | "*_test.py" = ["S101"]
131 |
132 | [tool.ruff.pydocstyle]
133 | # Use Google-style docstrings.
134 | convention = "google"
135 |
136 | [tool.ruff.pylint]
137 | allow-magic-value-types = ["int", "str"]
138 |
--------------------------------------------------------------------------------