├── .dockerignore ├── .github └── workflows │ ├── pre-commit.yml │ └── unit-test.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .pylintrc ├── LICENSE ├── MELBA_Symposium_2024.pdf ├── Makefile ├── README.md ├── docker ├── Dockerfile ├── Dockerfile.tpu ├── environment.yml ├── environment_mac_m1.yml ├── environment_tpu.yml └── requirements.txt ├── examples └── segmentation │ ├── BB_anon_348_1.png │ ├── BB_anon_348_1_mask.png │ ├── config.yaml │ ├── files │ └── ckpt │ │ └── checkpoint_1900 │ │ └── checkpoint │ └── inference.ipynb ├── images ├── diffusion_class_diagram.png ├── diffusion_training_strategy_diagram.png └── melba_graphic_abstract.png ├── imgx ├── __init__.py ├── conf │ ├── __init__.py │ ├── config.yaml │ ├── data │ │ ├── amos_ct.yaml │ │ ├── brats2021_mr.yaml │ │ ├── male_pelvic_mr.yaml │ │ └── muscle_us.yaml │ └── task │ │ ├── gaussian_diff_seg.yaml │ │ └── seg.yaml ├── config.py ├── data │ ├── __init__.py │ ├── augmentation │ │ ├── __init__.py │ │ ├── affine.py │ │ ├── affine_test.py │ │ ├── flip.py │ │ ├── flip_test.py │ │ ├── intensity.py │ │ ├── intensity_test.py │ │ ├── patch.py │ │ └── patch_test.py │ ├── iterator.py │ ├── iterator_test.py │ ├── util.py │ ├── util_test.py │ ├── warp.py │ └── warp_test.py ├── datasets │ ├── README.md │ ├── __init__.py │ ├── amos_ct │ │ ├── __init__.py │ │ └── amos_ct_dataset_builder.py │ ├── brats2021_mr │ │ ├── __init__.py │ │ └── brats2021_mr_dataset_builder.py │ ├── constant.py │ ├── dataset_info.py │ ├── male_pelvic_mr │ │ ├── __init__.py │ │ └── male_pelvic_mr_dataset_builder.py │ ├── muscle_us │ │ ├── __init__.py │ │ └── muscle_us_dataset_builder.py │ ├── preprocess.py │ ├── preprocess_test.py │ ├── save.py │ ├── tests │ │ ├── conftest.py │ │ ├── fixtures │ │ │ ├── BB_anon_1789_3_mask_pred.png │ │ │ ├── BB_anon_1789_3_mask_pred_postprocessed_0.5.png │ │ │ ├── BB_anon_1789_3_mask_pred_postprocessed_0.75.png │ │ │ ├── BB_anon_425_2_mask_pred.png │ │ │ ├── BB_anon_425_2_mask_pred_postprocessed.png │ │ │ ├── GM_anon_780_3_mask_pred.png │ │ │ └── GM_anon_780_3_mask_pred_postprocessed.png │ │ └── test_muscle_us.py │ ├── util.py │ └── util_test.py ├── device.py ├── diffusion │ ├── README.md │ ├── __init__.py │ ├── diffusion.py │ ├── gaussian │ │ ├── __init__.py │ │ ├── gaussian_diffusion.py │ │ ├── gaussian_diffusion_test.py │ │ ├── sampler.py │ │ ├── variance_schedule.py │ │ └── variance_schedule_test.py │ ├── time_sampler.py │ ├── time_sampler_test.py │ ├── util.py │ └── util_test.py ├── experiment.py ├── integration_test.py ├── loss │ ├── __init__.py │ ├── cross_entropy.py │ ├── cross_entropy_test.py │ ├── deformation.py │ ├── deformation_test.py │ ├── dice.py │ ├── dice_test.py │ ├── segmentation.py │ ├── segmentation_test.py │ └── similarity.py ├── metric │ ├── __init__.py │ ├── area.py │ ├── area_test.py │ ├── centroid.py │ ├── centroid_test.py │ ├── deformation.py │ ├── deformation_test.py │ ├── dice.py │ ├── dice_test.py │ ├── distribution.py │ ├── segmentation.py │ ├── segmentation_test.py │ ├── similarity.py │ ├── similarity_test.py │ ├── smoothing.py │ ├── smoothing_test.py │ ├── surface_distance.py │ ├── surface_distance_test.py │ └── util.py ├── model │ ├── __init__.py │ ├── attention │ │ ├── __init__.py │ │ ├── efficient_attention.py │ │ ├── efficient_attention_test.py │ │ ├── transformer.py │ │ └── transformer_test.py │ ├── basic.py │ ├── basic_test.py │ ├── conv.py │ ├── conv_test.py │ ├── slice.py │ ├── slice_test.py │ └── unet │ │ ├── __init__.py │ │ ├── bottom_encoder.py │ │ ├── bottom_encoder_test.py │ │ ├── downsample_encoder.py │ │ ├── unet.py │ │ ├── unet_test.py │ │ └── upsample_decoder.py ├── run_test.py ├── run_train.py ├── run_valid.py ├── task │ ├── __init__.py │ ├── diffusion_segmentation │ │ ├── __init__.py │ │ ├── diffusion.py │ │ ├── diffusion_step.py │ │ ├── experiment.py │ │ ├── gaussian_diffusion.py │ │ ├── gaussian_diffusion_test.py │ │ ├── recycling_step.py │ │ ├── save.py │ │ ├── self_conditioning_step.py │ │ └── train_state.py │ ├── segmentation │ │ ├── __init__.py │ │ ├── experiment.py │ │ └── save.py │ └── util.py └── train_state.py └── pyproject.toml /.dockerignore: -------------------------------------------------------------------------------- 1 | wandb/ 2 | outputs/ 3 | tensorflow_datasets/ 4 | *.zip 5 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yml: -------------------------------------------------------------------------------- 1 | name: pre-commit 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [main] 7 | 8 | jobs: 9 | pre-commit: 10 | runs-on: ubuntu-latest 11 | strategy: 12 | matrix: 13 | python-version: ["3.9"] 14 | steps: 15 | - name: Checkout repository 16 | uses: actions/checkout@v4 17 | - name: Set up Python ${{ matrix.python-version }} 18 | uses: actions/setup-python@v5 19 | with: 20 | python-version: ${{ matrix.python-version }} 21 | check-latest: true 22 | - name: Install Wily 23 | run: pip install wily 24 | - name: Build cache and diff 25 | run: wily build . 26 | - name: Run pre-commit 27 | uses: pre-commit/action@v3.0.0 28 | -------------------------------------------------------------------------------- /.github/workflows/unit-test.yml: -------------------------------------------------------------------------------- 1 | name: unit-test 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [main] 7 | 8 | jobs: 9 | build: 10 | runs-on: ubuntu-20.04 11 | timeout-minutes: 30 12 | strategy: 13 | matrix: 14 | group: [1, 2, 3, 4] 15 | python-version: ["3.9"] 16 | 17 | steps: 18 | - name: Checkout repository 19 | uses: actions/checkout@v4 20 | - name: Set up Python ${{ matrix.python-version }} 21 | uses: actions/setup-python@v5 22 | with: 23 | check-latest: true 24 | python-version: ${{ matrix.python-version }} 25 | cache: "pip" 26 | cache-dependency-path: | 27 | docker/environment.yml 28 | docker/environment_mac_m1.yml 29 | docker/Dockerfile 30 | docker/requirements.txt 31 | pyproject.toml 32 | - name: Install dependencies 33 | run: | 34 | pip install tensorflow-cpu==2.12.0 35 | pip install jax==0.4.20 36 | pip install jaxlib==0.4.20 37 | pip install -r docker/requirements.txt 38 | pip install -e . 39 | - name: Test with pytest 40 | run: | 41 | pytest --splits 4 --group ${{ matrix.group }} --randomly-seed=0 -k "not slow and not integration" 42 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage* 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # pycharm 132 | .idea/ 133 | 134 | # Mac 135 | .DS_Store 136 | 137 | # notebook 138 | *.ipynb 139 | notebooks/ 140 | !examples/segmentation/inference.ipynb 141 | 142 | # hydra outputs 143 | outputs/ 144 | 145 | # wandb outputs 146 | wandb/ 147 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v4.5.0 6 | hooks: 7 | - id: check-added-large-files 8 | args: ["--maxkb=15000"] 9 | - id: check-ast 10 | - id: check-byte-order-marker 11 | - id: check-builtin-literals 12 | - id: check-case-conflict 13 | - id: check-docstring-first 14 | - id: check-merge-conflict 15 | - id: check-symlinks 16 | - id: check-toml 17 | - id: check-yaml 18 | - id: debug-statements 19 | - id: destroyed-symlinks 20 | - id: end-of-file-fixer 21 | - id: fix-byte-order-marker 22 | - id: mixed-line-ending 23 | - id: file-contents-sorter 24 | files: "docker/requirements.txt" 25 | - id: trailing-whitespace 26 | - repo: https://github.com/PyCQA/isort 27 | rev: 5.13.2 28 | hooks: 29 | - id: isort 30 | - repo: https://github.com/psf/black 31 | rev: 23.12.1 32 | hooks: 33 | - id: black 34 | args: 35 | - --line-length=100 36 | - repo: https://github.com/pre-commit/mirrors-mypy 37 | rev: v1.8.0 38 | hooks: # https://github.com/python/mypy/issues/4008#issuecomment-582458665 39 | - id: mypy 40 | name: mypy 41 | pass_filenames: false 42 | args: 43 | [ 44 | --strict-equality, 45 | --disallow-untyped-calls, 46 | --disallow-untyped-defs, 47 | --disallow-incomplete-defs, 48 | --disallow-any-generics, 49 | --check-untyped-defs, 50 | --disallow-untyped-decorators, 51 | --warn-redundant-casts, 52 | --warn-unused-ignores, 53 | --no-warn-no-return, 54 | --warn-unreachable, 55 | ] 56 | - repo: https://github.com/pre-commit/mirrors-prettier 57 | rev: v4.0.0-alpha.8 58 | hooks: 59 | - id: prettier 60 | args: 61 | - --print-width=100 62 | - --prose-wrap=always 63 | - --tab-width=2 64 | - repo: https://github.com/charliermarsh/ruff-pre-commit 65 | rev: "v0.1.9" 66 | hooks: 67 | - id: ruff 68 | - repo: https://github.com/pre-commit/mirrors-pylint 69 | rev: v3.0.0a5 70 | hooks: 71 | - id: pylint 72 | - repo: https://github.com/asottile/pyupgrade 73 | rev: v3.15.0 74 | hooks: 75 | - id: pyupgrade 76 | args: 77 | - --py39-plus 78 | - repo: https://github.com/pycqa/pydocstyle 79 | rev: 6.3.0 80 | hooks: 81 | - id: pydocstyle 82 | args: 83 | - --convention=google 84 | - repo: local 85 | hooks: 86 | - id: wily 87 | name: wily 88 | entry: wily diff 89 | verbose: false 90 | language: python 91 | additional_dependencies: [wily] 92 | -------------------------------------------------------------------------------- /MELBA_Symposium_2024.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mathpluscode/ImgX-DiffSeg/ce57729e2dcd1d21960ec797bdff2ef5df7e5101/MELBA_Symposium_2024.pdf -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | pip: 2 | pip install -e . 3 | 4 | test: 5 | pytest --cov=imgx -n 4 imgx 6 | 7 | build_dataset: 8 | tfds build imgx/datasets/male_pelvic_mr 9 | tfds build imgx/datasets/amos_ct 10 | tfds build imgx/datasets/muscle_us 11 | tfds build imgx/datasets/brats2021_mr 12 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/deepmind/alphafold/blob/main/docker/Dockerfile 2 | # If changing CUDA/CUDNN versions, also update the corresponding versions in 3 | # the second last command 4 | ARG CUDA=11.8.0 5 | ARG CUDNN=8.6.0 6 | FROM nvidia/cuda:${CUDA}-cudnn8-runtime-ubuntu20.04 7 | 8 | # FROM directive resets ARGS, so we specify again (the value is retained if 9 | # previously set). 10 | ARG CUDA 11 | ARG CUDNN 12 | 13 | ARG HOST_UID 14 | ARG HOST_GID 15 | 16 | ENV USER=app 17 | 18 | # Ensure ARGs are sets 19 | RUN test -n "$HOST_UID" && test -n "$HOST_GID" 20 | 21 | # Use bash to support string substitution. 22 | SHELL ["/bin/bash", "-c"] 23 | 24 | # Create group and user, add -f to skip the command without error if it exists already 25 | RUN groupadd --force --gid $HOST_GID $USER && \ 26 | useradd -r -m --uid $HOST_UID --gid $HOST_GID $USER 27 | 28 | RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \ 29 | vim \ 30 | unzip \ 31 | build-essential \ 32 | cmake \ 33 | python3-opencv \ 34 | cuda-command-line-tools-$(cut -f1,2 -d- <<< ${CUDA//./-}) \ 35 | git \ 36 | tzdata \ 37 | wget \ 38 | make \ 39 | && rm -rf /var/lib/apt/lists/* 40 | 41 | # Add SETUID bit to the ldconfig binary so that non-root users can run it. 42 | RUN chmod u+s /sbin/ldconfig.real 43 | 44 | # We need to run `ldconfig` first to ensure GPUs are visible, due to some quirk 45 | # with Debian. See https://github.com/NVIDIA/nvidia-docker/issues/1399 for 46 | # details. 47 | RUN echo $'#!/bin/bash\nldconfig' 48 | 49 | RUN mkdir -p /${USER}/tmp 50 | RUN mkdir -p /${USER}/ImgX 51 | RUN mkdir -p /${USER}/tensorflow_datasets 52 | RUN chgrp -R ${USER} /${USER} && \ 53 | chmod -R g+rwx /${USER} && \ 54 | chown -R ${USER} /${USER} 55 | 56 | USER ${USER} 57 | 58 | # Install Miniconda package manager. 59 | RUN wget -q -P /${USER}/tmp https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh 60 | RUN bash /${USER}/tmp/Miniconda3-latest-Linux-x86_64.sh -b -p /${USER}/conda 61 | RUN rm /${USER}/tmp/Miniconda3-latest-Linux-x86_64.sh 62 | 63 | # https://anaconda.org/nvidia/cuda-toolkit 64 | ENV PATH="/${USER}/conda/bin:$PATH" 65 | RUN conda update -qy conda \ 66 | && conda install -y -n base conda-libmamba-solver \ 67 | && conda config --set solver libmamba \ 68 | && conda install -y --channel "nvidia/label/cuda-${CUDA}" cuda-toolkit \ 69 | && conda install -y -c conda-forge \ 70 | pip \ 71 | python=3.9 72 | 73 | # Install pip packages. 74 | COPY docker/requirements.txt /${USER}/requirements.txt 75 | 76 | RUN /${USER}/conda/bin/pip3 install --upgrade pip \ 77 | && /${USER}/conda/bin/pip3 install \ 78 | jax==0.4.20 \ 79 | jaxlib==0.4.20+cuda11.cudnn86 \ 80 | -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \ 81 | && /${USER}/conda/bin/pip3 install tensorflow-cpu==2.14.0 \ 82 | && /${USER}/conda/bin/pip3 install -r /${USER}/requirements.txt 83 | 84 | RUN git config --global --add safe.directory /${USER}/ImgX 85 | 86 | WORKDIR /${USER}/ImgX 87 | -------------------------------------------------------------------------------- /docker/Dockerfile.tpu: -------------------------------------------------------------------------------- 1 | FROM mambaorg/micromamba:1.5.1 as conda 2 | 3 | # Speed up the build, and avoid unnecessary writes to disk 4 | ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 PYTHONDONTWRITEBYTECODE=1 PYTHONUNBUFFERED=1 5 | ENV PIPENV_VENV_IN_PROJECT=true PIP_NO_CACHE_DIR=false PIP_DISABLE_PIP_VERSION_CHECK=1 6 | 7 | COPY docker/environment_tpu.yml /tmp/environment_tpu.yml 8 | COPY docker/requirements.txt /tmp/requirements.txt 9 | 10 | RUN micromamba create -y --file /tmp/environment_tpu.yml \ 11 | && micromamba clean --all --yes \ 12 | && find /opt/conda/ -follow -type f -name '*.pyc' -delete 13 | 14 | FROM debian:bullseye-slim as test-image 15 | 16 | COPY --from=conda /opt/conda/envs/. /opt/conda/envs/ 17 | ENV PATH=/opt/conda/envs/imgx/bin/:$PATH APP_FOLDER=/app 18 | ENV PYTHONPATH=$APP_FOLDER:$PYTHONPATH 19 | ENV TZ=Europe/London 20 | RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone 21 | 22 | WORKDIR $APP_FOLDER 23 | 24 | ARG USER_ID=1000 25 | ARG GROUP_ID=1000 26 | ENV USER=app 27 | ENV GROUP=app 28 | 29 | USER root 30 | RUN apt-get update && apt-get install -y git vim unzip apt-transport-https ca-certificates gnupg curl make python3-opencv 31 | RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list 32 | RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add - 33 | RUN apt-get update && apt-get install google-cloud-cli 34 | RUN git config --global --add safe.directory /${USER}/ImgX 35 | 36 | ENV TF_CUDNN_DETERMINISTIC=1 37 | 38 | FROM test-image as run-image 39 | # The run-image (default) is the same as the dev-image with the some files directly 40 | # copied inside 41 | 42 | WORKDIR /${USER}/ImgX 43 | -------------------------------------------------------------------------------- /docker/environment.yml: -------------------------------------------------------------------------------- 1 | name: imgx 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python=3.9 6 | - pip=23.3.1 7 | - pip: 8 | - tensorflow-cpu==2.14.0 9 | - jax==0.4.20 10 | - jaxlib==0.4.20 11 | - -r requirements.txt 12 | -------------------------------------------------------------------------------- /docker/environment_mac_m1.yml: -------------------------------------------------------------------------------- 1 | name: imgx 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python=3.9 6 | - pip=23.3.1 7 | - pip: 8 | - tensorflow-macos==2.14.0 9 | - tensorflow-metal==1.1.0 10 | - jax==0.4.20 11 | - jaxlib==0.4.20 12 | - -r requirements.txt 13 | -------------------------------------------------------------------------------- /docker/environment_tpu.yml: -------------------------------------------------------------------------------- 1 | name: imgx 2 | channels: 3 | - defaults 4 | - conda-forge 5 | dependencies: 6 | - python=3.9 7 | - pip=23.3.1 8 | - pip: 9 | - --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html 10 | - tensorflow-cpu==2.14.0 11 | - jax[tpu]==0.4.20 12 | - jaxlib==0.4.20 13 | - -r requirements.txt 14 | -------------------------------------------------------------------------------- /docker/requirements.txt: -------------------------------------------------------------------------------- 1 | SimpleITK==2.3.1 2 | chex==0.1.8 3 | coverage==7.3.3 4 | flax==0.7.5 5 | hydra-core==1.3.2 6 | kaggle==1.5.16 7 | matplotlib==3.8.2 8 | nbmake==1.4.6 9 | numpy==1.26.2 10 | opencv-python==4.8.1.78 11 | optax==0.1.7 12 | pandas==2.1.4 13 | pre-commit==3.6.0 14 | protobuf==3.20.3 # https://github.com/tensorflow/datasets/issues/4858 15 | pytest-cov==4.1.0 16 | pytest-mock==3.12.0 17 | pytest-randomly==3.15.0 18 | pytest-split==0.8.1 19 | pytest-xdist==3.5.0 20 | pytest==7.4.3 21 | rdkit-pypi==2022.9.5 22 | rich==13.7.0 23 | ruff==0.1.8 24 | tensorflow-datasets==4.9.3 25 | tomli==2.0.1 26 | torch==2.1.2 # for testing only 27 | wandb==0.16.1 28 | wily==1.25.0 29 | -------------------------------------------------------------------------------- /examples/segmentation/BB_anon_348_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mathpluscode/ImgX-DiffSeg/ce57729e2dcd1d21960ec797bdff2ef5df7e5101/examples/segmentation/BB_anon_348_1.png -------------------------------------------------------------------------------- /examples/segmentation/BB_anon_348_1_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mathpluscode/ImgX-DiffSeg/ce57729e2dcd1d21960ec797bdff2ef5df7e5101/examples/segmentation/BB_anon_348_1_mask.png -------------------------------------------------------------------------------- /examples/segmentation/config.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: muscle_us 3 | loader: 4 | max_num_samples_per_split: -1 5 | patch_shape: 6 | - 480 7 | - 512 8 | patch_overlap: 9 | - 0 10 | - 0 11 | data_augmentation: 12 | max_rotation: 30 13 | max_zoom: 0.2 14 | max_shear: 30 15 | max_shift: 0.3 16 | max_log_gamma: 0.3 17 | v_min: 0.0 18 | v_max: 1.0 19 | p: 0.5 20 | trainer: 21 | max_num_samples: 512000 22 | batch_size: 64 23 | batch_size_per_replica: 8 24 | num_devices_per_replica: 1 25 | patch_size: 26 | - 2 27 | - 2 28 | scale_factor: 29 | - 2 30 | - 2 31 | task: 32 | name: segmentation 33 | model: 34 | _target_: imgx.model.Unet 35 | remat: true 36 | num_spatial_dims: 2 37 | patch_size: 38 | - 2 39 | - 2 40 | scale_factor: 41 | - 2 42 | - 2 43 | num_res_blocks: 2 44 | num_channels: 45 | - 8 46 | - 16 47 | - 32 48 | - 64 49 | out_channels: 2 50 | num_heads: 8 51 | widening_factor: 4 52 | num_transform_layers: 1 53 | dropout: 0.1 54 | loss: 55 | dice: 1.0 56 | cross_entropy: 0.0 57 | focal: 1.0 58 | early_stopping: 59 | metric: mean_binary_dice_score_without_background 60 | mode: max 61 | min_delta: 0.0001 62 | patience: 10 63 | debug: false 64 | seed: 0 65 | half_precision: true 66 | optimizer: 67 | name: adamw 68 | kwargs: 69 | b1: 0.9 70 | b2: 0.999 71 | weight_decay: 1.0e-08 72 | grad_norm: 1.0 73 | lr_schedule: 74 | warmup_steps: 100 75 | decay_steps: 10000 76 | init_value: 1.0e-05 77 | peak_value: 0.0008 78 | end_value: 5.0e-05 79 | logging: 80 | root_dir: null 81 | log_freq: 10 82 | save_freq: 100 83 | wandb: 84 | project: imgx 85 | entity: entity 86 | -------------------------------------------------------------------------------- /examples/segmentation/files/ckpt/checkpoint_1900/checkpoint: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mathpluscode/ImgX-DiffSeg/ce57729e2dcd1d21960ec797bdff2ef5df7e5101/examples/segmentation/files/ckpt/checkpoint_1900/checkpoint -------------------------------------------------------------------------------- /images/diffusion_class_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mathpluscode/ImgX-DiffSeg/ce57729e2dcd1d21960ec797bdff2ef5df7e5101/images/diffusion_class_diagram.png -------------------------------------------------------------------------------- /images/diffusion_training_strategy_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mathpluscode/ImgX-DiffSeg/ce57729e2dcd1d21960ec797bdff2ef5df7e5101/images/diffusion_training_strategy_diagram.png -------------------------------------------------------------------------------- /images/melba_graphic_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mathpluscode/ImgX-DiffSeg/ce57729e2dcd1d21960ec797bdff2ef5df7e5101/images/melba_graphic_abstract.png -------------------------------------------------------------------------------- /imgx/__init__.py: -------------------------------------------------------------------------------- 1 | """A Jax-based DL toolkit for biomedical and bioinformatics applications.""" 2 | EPS = 1.0e-5 # machine error 3 | 4 | # jax device 5 | # one model can be stored across multiple shards/slices 6 | # given 8 devices, it can be grouped into 4x2 7 | # if num_devices_per_replica = 2, then one model is stored across 2 devices 8 | # so the replica_axis would be of size 4 9 | SHARD_AXIS = "shard_axis" 10 | REPLICA_AXIS = "replica_axis" 11 | -------------------------------------------------------------------------------- /imgx/conf/__init__.py: -------------------------------------------------------------------------------- 1 | """Package for config files.""" 2 | -------------------------------------------------------------------------------- /imgx/conf/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - data: muscle_us 3 | - task: gaussian_diff_seg 4 | # config below overwrites the values above 5 | # https://hydra.cc/docs/1.2/upgrades/1.0_to_1.1/default_composition_order/ 6 | - _self_ 7 | 8 | debug: False 9 | seed: 0 10 | half_precision: True 11 | 12 | optimizer: 13 | name: "adamw" 14 | kwargs: 15 | b1: 0.9 16 | b2: 0.999 17 | weight_decay: 1e-08 18 | grad_norm: 1.0 19 | lr_schedule: 20 | warmup_steps: 100 21 | decay_steps: 10_000 22 | init_value: 1e-05 23 | peak_value: 8e-04 24 | end_value: 5e-05 25 | 26 | logging: 27 | root_dir: 28 | log_freq: 10 29 | save_freq: 500 30 | wandb: 31 | project: imgx 32 | entity: wandb_entity 33 | -------------------------------------------------------------------------------- /imgx/conf/data/amos_ct.yaml: -------------------------------------------------------------------------------- 1 | name: amos_ct 2 | 3 | loader: 4 | max_num_samples_per_split: -1 5 | patch_shape: [128, 128, 128] 6 | patch_overlap: [64, 0, 0] # image shape is [192, 128, 128] 7 | data_augmentation: 8 | max_rotation: 30 # degrees 9 | max_zoom: 0.2 # as a fraction of the image size 10 | max_shear: 30 # degrees 11 | max_shift: 0.3 # as a fraction of the image size 12 | max_log_gamma: 0.3 13 | v_min: 0.0 # minimum value for intensity 14 | v_max: 1.0 # maximum value for intensity 15 | p: 0.5 # probability of applying each augmentation 16 | 17 | trainer: 18 | max_num_samples: 100_000 19 | batch_size: 8 # all model replicas are updated every `batch_size` samples 20 | batch_size_per_replica: 1 # each model replicate takes `batch_size_per_replica` samples per step 21 | num_devices_per_replica: 1 # model is split into num_devices_per_replica shards/slices 22 | 23 | patch_size: [2, 2, 2] 24 | scale_factor: [2, 2, 2] 25 | -------------------------------------------------------------------------------- /imgx/conf/data/brats2021_mr.yaml: -------------------------------------------------------------------------------- 1 | name: brats2021_mr 2 | 3 | loader: 4 | max_num_samples_per_split: -1 5 | patch_shape: [128, 128, 128] 6 | patch_overlap: [0, 0, 32] # image shape is [179, 219, 155] 7 | data_augmentation: 8 | max_rotation: 30 # degrees 9 | max_zoom: 0.2 # as a fraction of the image size 10 | max_shear: 30 # degrees 11 | max_shift: 0.3 # as a fraction of the image size 12 | max_log_gamma: 0.3 13 | v_min: 0.0 # minimum value for intensity 14 | v_max: 1.0 # maximum value for intensity 15 | p: 0.5 # probability of applying each augmentation 16 | 17 | trainer: 18 | max_num_samples: 100_000 19 | batch_size: 8 # all model replicas are updated every `batch_size` samples 20 | batch_size_per_replica: 1 # each model replicate takes `batch_size_per_replica` samples per step 21 | num_devices_per_replica: 1 # model is split into num_devices_per_replica shards/slices 22 | 23 | patch_size: [2, 2, 2] 24 | scale_factor: [2, 2, 2] 25 | -------------------------------------------------------------------------------- /imgx/conf/data/male_pelvic_mr.yaml: -------------------------------------------------------------------------------- 1 | name: male_pelvic_mr 2 | 3 | loader: 4 | max_num_samples_per_split: -1 5 | patch_shape: [256, 256, 32] 6 | patch_overlap: [0, 0, 16] # image shape is [256, 256, 48] 7 | data_augmentation: 8 | max_rotation: 30 # degrees 9 | max_zoom: 0.2 # as a fraction of the image size 10 | max_shear: 30 # degrees 11 | max_shift: 0.3 # as a fraction of the image size 12 | max_log_gamma: 0.3 13 | v_min: 0.0 # minimum value for intensity 14 | v_max: 1.0 # maximum value for intensity 15 | p: 0.5 # probability of applying each augmentation 16 | 17 | trainer: 18 | max_num_samples: 100_000 19 | batch_size: 8 # all model replicas are updated every `batch_size` samples 20 | batch_size_per_replica: 1 # each model replicate takes `batch_size_per_replica` samples per step 21 | num_devices_per_replica: 1 # model is split into num_devices_per_replica shards/slices 22 | 23 | patch_size: [2, 2, 1] # do not downsample z axis 24 | scale_factor: [2, 2, 2] 25 | -------------------------------------------------------------------------------- /imgx/conf/data/muscle_us.yaml: -------------------------------------------------------------------------------- 1 | name: muscle_us 2 | 3 | loader: 4 | max_num_samples_per_split: -1 5 | patch_shape: [480, 512] 6 | patch_overlap: [0, 0] 7 | data_augmentation: 8 | max_rotation: 30 # degrees 9 | max_zoom: 0.2 # as a fraction of the image size 10 | max_shear: 30 # degrees 11 | max_shift: 0.3 # as a fraction of the image size 12 | max_log_gamma: 0.3 13 | v_min: 0.0 # minimum value for intensity 14 | v_max: 1.0 # maximum value for intensity 15 | p: 0.5 # probability of applying each augmentation 16 | 17 | trainer: 18 | max_num_samples: 512_000 19 | batch_size: 64 # all model replicas are updated every `batch_size` samples 20 | batch_size_per_replica: 8 # each model replicate takes `batch_size_per_replica` samples per step 21 | num_devices_per_replica: 1 # model is split into num_devices_per_replica shards/slices 22 | 23 | patch_size: [2, 2] 24 | scale_factor: [2, 2] 25 | -------------------------------------------------------------------------------- /imgx/conf/task/gaussian_diff_seg.yaml: -------------------------------------------------------------------------------- 1 | name: "diffusion_segmentation" 2 | 3 | recycling: 4 | use: True 5 | # max means the previous step is num_timesteps - 1 6 | # next means the previous step is min(t+1, num_timesteps - 1) 7 | prev_step: "max" # max or next 8 | reverse_step: False 9 | 10 | self_conditioning: 11 | use: False 12 | probability: 0.5 13 | prev_step: "next" # same or next 14 | 15 | uniform_time_sampling: False # False for importance sampling 16 | 17 | diffusion: 18 | num_timesteps: 1001 19 | num_timesteps_beta: 1001 20 | beta_schedule: "linear" # linear, quadratic, cosine, warmup10, warmup50 21 | beta_start: 0.0001 22 | beta_end: 0.02 23 | model_out_type: "x_start" # x_start, noise 24 | model_var_type: "fixed_small" # fixed_small, fixed_large, learned, learned_range 25 | 26 | sampler: 27 | name: "DDPM" # DDPM, DDIM 28 | num_inference_timesteps: 5 29 | 30 | model: 31 | _target_: imgx.model.Unet 32 | remat: True 33 | num_spatial_dims: 3 34 | patch_size: MISSING # data dependent, will be set after loading config 35 | scale_factor: MISSING # data dependent, will be set after loading config 36 | num_res_blocks: 2 37 | num_channels: [32, 64, 128, 256] 38 | out_channels: MISSING # data dependent, will be set after loading config 39 | num_heads: 8 40 | widening_factor: 4 41 | dropout: 0.1 42 | 43 | loss: 44 | dice: 1.0 45 | cross_entropy: 0.0 46 | focal: 1.0 47 | mse: 0.0 48 | vlb: 0.0 49 | 50 | early_stopping: # used on validation set 51 | metric: "mean_binary_dice_score_without_background" 52 | mode: "max" 53 | min_delta: 0.0001 54 | patience: 10 55 | -------------------------------------------------------------------------------- /imgx/conf/task/seg.yaml: -------------------------------------------------------------------------------- 1 | name: "segmentation" 2 | 3 | model: 4 | _target_: imgx.model.Unet 5 | remat: True 6 | num_spatial_dims: 3 7 | patch_size: MISSING # data dependent, will be set after loading config 8 | scale_factor: MISSING # data dependent, will be set after loading config 9 | num_res_blocks: 2 10 | num_channels: [32, 64, 128, 256] 11 | out_channels: MISSING # data dependent, will be set after loading config 12 | num_heads: 8 13 | widening_factor: 4 14 | num_transform_layers: 1 15 | dropout: 0.1 16 | 17 | loss: 18 | dice: 1.0 19 | cross_entropy: 0.0 20 | focal: 1.0 21 | 22 | early_stopping: # used on validation set 23 | metric: "mean_binary_dice_score_without_background" 24 | mode: "max" 25 | min_delta: 0.0001 26 | patience: 10 27 | -------------------------------------------------------------------------------- /imgx/config.py: -------------------------------------------------------------------------------- 1 | """Module for configuration related functions.""" 2 | 3 | 4 | def flatten_dict(d: dict, parent_key: str = "", sep: str = "_") -> dict: # type:ignore[type-arg] 5 | """Flat a nested dict. 6 | 7 | Args: 8 | d: dict to flat. 9 | parent_key: key of the parent. 10 | sep: separation string. 11 | 12 | Returns: 13 | flatten dict. 14 | """ 15 | items = {} 16 | for k, v in d.items(): 17 | new_key = parent_key + sep + k if parent_key else k 18 | if isinstance(v, dict): 19 | items.update(flatten_dict(d=v, parent_key=new_key, sep=sep)) 20 | else: 21 | items[new_key] = v 22 | return dict(items) 23 | -------------------------------------------------------------------------------- /imgx/data/__init__.py: -------------------------------------------------------------------------------- 1 | """Module to handle data.""" 2 | -------------------------------------------------------------------------------- /imgx/data/augmentation/__init__.py: -------------------------------------------------------------------------------- 1 | """Data augmentation module.""" 2 | from __future__ import annotations 3 | 4 | from collections.abc import Sequence 5 | from typing import Callable 6 | 7 | import jax 8 | from jax import numpy as jnp 9 | 10 | AugmentationFn = Callable[[jax.Array, dict[str, jnp.ndarray]], dict[str, jnp.ndarray]] 11 | 12 | 13 | def chain_aug_fns( 14 | fns: Sequence[AugmentationFn], 15 | ) -> AugmentationFn: 16 | """Combine a list of data augmentation functions. 17 | 18 | Args: 19 | fns: entire config. 20 | 21 | Returns: 22 | A data augmentation function. 23 | """ 24 | 25 | def aug_fn(key: jax.Array, batch: dict[str, jnp.ndarray]) -> dict[str, jnp.ndarray]: 26 | keys = jax.random.split(key, num=len(fns)) 27 | for k, fn in zip(keys, fns): 28 | batch = fn(k, batch) 29 | return batch 30 | 31 | return aug_fn 32 | -------------------------------------------------------------------------------- /imgx/data/augmentation/flip.py: -------------------------------------------------------------------------------- 1 | """Flip augmentation for image and label.""" 2 | from __future__ import annotations 3 | 4 | from functools import partial 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | from jax import lax 9 | from omegaconf import DictConfig 10 | 11 | from imgx.data.augmentation import AugmentationFn 12 | from imgx.data.util import get_batch_size 13 | from imgx.datasets import INFO_MAP 14 | from imgx.datasets.constant import FOREGROUND_RANGE, IMAGE, LABEL 15 | 16 | 17 | def random_flip(x: jnp.ndarray, to_flip: jnp.ndarray) -> jnp.ndarray: 18 | """Flip an array along an axis. 19 | 20 | Args: 21 | x: of shape (d1, ..., dn) or (d1, ..., dn, c) 22 | to_flip: (n, ), with boolean values, True means to flip along that axis. 23 | 24 | Returns: 25 | Flipped array. 26 | """ 27 | for i in range(to_flip.size): 28 | x = lax.select( 29 | to_flip[i], 30 | jnp.flip(x, axis=i), 31 | x, 32 | ) 33 | return x 34 | 35 | 36 | def batch_random_flip( 37 | key: jax.Array, batch: dict[str, jnp.ndarray], num_spatial_dims: int, p: float 38 | ) -> dict[str, jnp.ndarray]: 39 | """Flip an array along an axis. 40 | 41 | Args: 42 | key: jax random key. 43 | batch: dict having images or labels, or foreground_range. 44 | images have shape (batch, d1, ..., dn) or (batch, d1, ..., dn, c) 45 | labels have shape (batch, d1, ..., dn) 46 | batch should not have other keys such as UID. 47 | if foreground_range exists, it's pre-calculated based on label, it's 48 | pre-calculated because nonzero function is not jittable. 49 | num_spatial_dims: number of spatial dimensions. 50 | p: probability to flip for each axis. 51 | 52 | Returns: 53 | Flipped batch. 54 | """ 55 | batch_size = get_batch_size(batch) 56 | to_flip = jax.random.uniform(key=key, shape=(batch_size, num_spatial_dims)) < p 57 | 58 | random_flip_vmap = jax.vmap(random_flip) 59 | flipped_batch = {} 60 | for k, v in batch.items(): 61 | if (LABEL in k) or (IMAGE in k): 62 | flipped_batch[k] = random_flip_vmap(x=v, to_flip=to_flip) 63 | elif k == FOREGROUND_RANGE: 64 | flipped_batch[k] = v 65 | else: 66 | raise ValueError(f"Unknown key {k} in batch.") 67 | return flipped_batch 68 | 69 | 70 | def get_random_flip_augmentation_fn(config: DictConfig) -> AugmentationFn: 71 | """Return a data augmentation function for random flip. 72 | 73 | Args: 74 | config: entire config. 75 | 76 | Returns: 77 | A data augmentation function. 78 | """ 79 | dataset_info = INFO_MAP[config.data.name] 80 | da_config = config.data.loader.data_augmentation 81 | return partial( 82 | batch_random_flip, 83 | num_spatial_dims=len(dataset_info.image_spatial_shape), 84 | p=da_config.p, 85 | ) 86 | -------------------------------------------------------------------------------- /imgx/data/augmentation/flip_test.py: -------------------------------------------------------------------------------- 1 | """Test the flip functions.""" 2 | 3 | 4 | from functools import partial 5 | 6 | import chex 7 | import jax 8 | import numpy as np 9 | from absl.testing import parameterized 10 | from chex._src import fake 11 | 12 | from imgx.data.augmentation.flip import batch_random_flip, random_flip 13 | from imgx.datasets.constant import IMAGE 14 | 15 | 16 | # Set `FLAGS.chex_n_cpu_devices` CPU devices for all tests. 17 | def setUpModule() -> None: # pylint: disable=invalid-name 18 | """Fake multi-devices.""" 19 | fake.set_n_cpu_devices(2) 20 | 21 | 22 | class TestRandomFlip(chex.TestCase): 23 | """Test random_flip.""" 24 | 25 | @chex.all_variants() 26 | @parameterized.named_parameters( 27 | ("1d - flip", np.array([1, 2, 3, 4]), np.array([True]), np.array([4, 3, 2, 1])), 28 | ("1d - no flip", np.array([1, 2, 3, 4]), np.array([False]), np.array([1, 2, 3, 4])), 29 | ( 30 | "2d - no flip", 31 | np.array([[1, 2, 3, 4], [5, 6, 7, 8]]), 32 | np.array([False, False]), 33 | np.array([[1, 2, 3, 4], [5, 6, 7, 8]]), 34 | ), 35 | ( 36 | "2d - flip the first axis", 37 | np.array([[1, 2, 3, 4], [5, 6, 7, 8]]), 38 | np.array([True, False]), 39 | np.array([[5, 6, 7, 8], [1, 2, 3, 4]]), 40 | ), 41 | ( 42 | "2d - flip the second axis", 43 | np.array([[1, 2, 3, 4], [5, 6, 7, 8]]), 44 | np.array([False, True]), 45 | np.array([[4, 3, 2, 1], [8, 7, 6, 5]]), 46 | ), 47 | ( 48 | "2d - flip both axes", 49 | np.array([[1, 2, 3, 4], [5, 6, 7, 8]]), 50 | np.array([True, True]), 51 | np.array([[8, 7, 6, 5], [4, 3, 2, 1]]), 52 | ), 53 | ) 54 | def test_values( 55 | self, 56 | x: np.ndarray, 57 | to_flip: np.ndarray, 58 | expected: np.ndarray, 59 | ) -> None: 60 | """Test output shapes. 61 | 62 | Args: 63 | x: input array. 64 | to_flip: (n, ), with boolean values, True means to flip along that axis. 65 | expected: expected output array. 66 | """ 67 | got = self.variant(random_flip)( 68 | x=x, 69 | to_flip=to_flip, 70 | ) 71 | 72 | chex.assert_trees_all_equal(got, expected) 73 | 74 | @chex.all_variants() 75 | @parameterized.product( 76 | image_shape=[(3, 4, 5, 6), (3, 4, 5), (3, 4), (3,)], 77 | ) 78 | def test_shapes( 79 | self, 80 | image_shape: tuple[int, ...], 81 | ) -> None: 82 | """Test output shapes. 83 | 84 | Args: 85 | image_shape: image spatial shape. 86 | """ 87 | key = jax.random.PRNGKey(0) 88 | key_flip, key_image = jax.random.split(key, 2) 89 | image = jax.random.uniform(key=key_image, shape=image_shape, minval=0, maxval=1) 90 | to_flip = jax.random.uniform(key=key_flip, shape=(len(image_shape),)) < 1 / 2 91 | got = self.variant(random_flip)( 92 | x=image, 93 | to_flip=to_flip, 94 | ) 95 | 96 | chex.assert_shape(got, image_shape) 97 | 98 | 99 | class TestBatchRandomFlip(chex.TestCase): 100 | """Test batch_random_flip.""" 101 | 102 | batch_size = 2 103 | 104 | @chex.all_variants() 105 | @parameterized.product( 106 | image_shape=[(3, 4, 5, 6), (3, 4, 5), (3, 4), (3,)], 107 | p=[0.0, 0.5, 1.0], 108 | ) 109 | def test_shapes( 110 | self, 111 | image_shape: tuple[int, ...], 112 | p: float, 113 | ) -> None: 114 | """Test output shapes. 115 | 116 | Args: 117 | image_shape: image spatial shape. 118 | p: probability to flip for each axis. 119 | """ 120 | key = jax.random.PRNGKey(0) 121 | key_image, key = jax.random.split(key) 122 | num_spatial_dims = len(image_shape) 123 | image = jax.random.uniform( 124 | key=key_image, shape=(self.batch_size, *image_shape), minval=0, maxval=1 125 | ) 126 | batch = {IMAGE: image} 127 | got = self.variant(partial(batch_random_flip, num_spatial_dims=num_spatial_dims, p=p))( 128 | key, batch 129 | ) 130 | 131 | chex.assert_shape(got[IMAGE], (self.batch_size, *image_shape)) 132 | if p == 0: 133 | chex.assert_trees_all_equal(got, batch) 134 | -------------------------------------------------------------------------------- /imgx/data/augmentation/intensity.py: -------------------------------------------------------------------------------- 1 | """Intensity related data augmentation functions.""" 2 | from __future__ import annotations 3 | 4 | from functools import partial 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | from omegaconf import DictConfig 9 | 10 | from imgx.data.augmentation import AugmentationFn 11 | from imgx.data.util import get_batch_size 12 | from imgx.datasets.constant import IMAGE 13 | 14 | 15 | def adjust_gamma( 16 | x: jnp.ndarray, 17 | gamma: jnp.ndarray, 18 | gain: float = 1.0, 19 | ) -> jnp.ndarray: 20 | """Adjust gamma of input images. 21 | 22 | https://github.com/tensorflow/tensorflow/blob/v2.14.0/tensorflow/python/ops/image_ops_impl.py#L2303-L2366 23 | 24 | Args: 25 | x: input image. 26 | gamma: non-negative real number. 27 | gain: the constant multiplier. 28 | 29 | Returns: 30 | Adjusted image. 31 | """ 32 | return gain * x**gamma 33 | 34 | 35 | def batch_random_adjust_gamma( 36 | key: jax.Array, 37 | batch: dict[str, jnp.ndarray], 38 | max_log_gamma: float, 39 | p: float = 0.5, 40 | ) -> dict[str, jnp.ndarray]: 41 | """Perform random gamma adjustment on images in batch. 42 | 43 | https://torchio.readthedocs.io/_modules/torchio/transforms/augmentation/intensity/random_gamma.html 44 | 45 | Args: 46 | key: jax random key. 47 | batch: dict having images or labels, or foreground_range. 48 | max_log_gamma: maximum log gamma. 49 | p: probability of performing the augmentation. 50 | 51 | Returns: 52 | Augmented dict having image and label, shapes are not changed. 53 | """ 54 | batch_size = get_batch_size(batch) 55 | 56 | adjusted_batch = {} 57 | for k, v in batch.items(): 58 | if IMAGE in k: 59 | key_gamma, key_act, key = jax.random.split(key, 3) 60 | log_gamma = jax.random.uniform( 61 | key=key_gamma, 62 | shape=(batch_size,), 63 | minval=-max_log_gamma, 64 | maxval=max_log_gamma, 65 | ) 66 | gamma = jnp.exp(log_gamma) 67 | gamma = jnp.where( 68 | jax.random.uniform(key=key_act, shape=gamma.shape) < p, 69 | gamma, 70 | jnp.ones_like(gamma), 71 | ) 72 | adjusted_batch[k] = jax.vmap(adjust_gamma)(v, gamma) 73 | else: 74 | adjusted_batch[k] = v 75 | return adjusted_batch 76 | 77 | 78 | def get_random_gamma_augmentation_fn(config: DictConfig) -> AugmentationFn: 79 | """Return a data augmentation function for random gamma transformation. 80 | 81 | Args: 82 | config: entire config. 83 | 84 | Returns: 85 | A data augmentation function. 86 | """ 87 | da_config = config.data.loader.data_augmentation 88 | return partial( 89 | batch_random_adjust_gamma, 90 | max_log_gamma=da_config.max_log_gamma, 91 | p=da_config.p, 92 | ) 93 | 94 | 95 | def rescale_intensity( 96 | x: jnp.ndarray, 97 | v_min: float = 0.0, 98 | v_max: float = 1.0, 99 | ) -> jnp.ndarray: 100 | """Adjust intensity linearly to the desired range. 101 | 102 | Args: 103 | x: input image, (batch, *spatial_dims, channel). 104 | v_min: minimum intensity. 105 | v_max: maximum intensity. 106 | 107 | Returns: 108 | Adjusted image. 109 | """ 110 | reduction_axes = tuple(range(x.ndim)[slice(1, -1)]) 111 | x_min = jnp.min(x, axis=reduction_axes, keepdims=True) 112 | x_max = jnp.max(x, axis=reduction_axes, keepdims=True) 113 | x = (x - x_min) / (x_max - x_min) 114 | x = x * (v_max - v_min) + v_min 115 | return x 116 | 117 | 118 | def batch_rescale_intensity( 119 | key: jax.Array, # noqa: ARG001, pylint: disable=unused-argument 120 | batch: dict[str, jnp.ndarray], 121 | v_min: float = 0.0, 122 | v_max: float = 1.0, 123 | ) -> dict[str, jnp.ndarray]: 124 | """Perform intensity scaling on images in batch. 125 | 126 | Args: 127 | key: jax random key. 128 | batch: dict having images or labels, or foreground_range. 129 | v_min: minimum intensity. 130 | v_max: maximum intensity. 131 | 132 | Returns: 133 | Augmented dict having image and label, shapes are not changed. 134 | """ 135 | adjusted_batch = {} 136 | for k, v in batch.items(): 137 | if IMAGE in k: 138 | adjusted_batch[k] = rescale_intensity(v, v_min=v_min, v_max=v_max) 139 | else: 140 | adjusted_batch[k] = v 141 | return adjusted_batch 142 | 143 | 144 | def get_rescale_intensity_fn(config: DictConfig) -> AugmentationFn: 145 | """Return a data augmentation function for intensity scaling. 146 | 147 | Args: 148 | config: entire config. 149 | 150 | Returns: 151 | A data augmentation function. 152 | """ 153 | da_config = config.data.loader.data_augmentation 154 | return partial( 155 | batch_rescale_intensity, 156 | v_min=da_config.v_min, 157 | v_max=da_config.v_max, 158 | ) 159 | -------------------------------------------------------------------------------- /imgx/data/augmentation/intensity_test.py: -------------------------------------------------------------------------------- 1 | """Test function for intensity data augmentation.""" 2 | 3 | import chex 4 | import jax 5 | from absl.testing import parameterized 6 | from chex._src import fake 7 | 8 | from imgx.data.augmentation.intensity import batch_random_adjust_gamma, batch_rescale_intensity 9 | from imgx.datasets.constant import IMAGE 10 | 11 | 12 | # Set `FLAGS.chex_n_cpu_devices` CPU devices for all tests. 13 | def setUpModule() -> None: # pylint: disable=invalid-name 14 | """Fake multi-devices.""" 15 | fake.set_n_cpu_devices(2) 16 | 17 | 18 | class TestRandomAdjustGamma(chex.TestCase): 19 | """Test batch_random_adjust_gamma.""" 20 | 21 | @chex.all_variants() 22 | @parameterized.product( 23 | image_shape=[(8, 12, 6), (8, 12)], 24 | max_log_gamma=[0.0, 0.3], 25 | batch_size=[4, 1], 26 | ) 27 | def test_shapes( 28 | self, 29 | batch_size: int, 30 | max_log_gamma: float, 31 | image_shape: tuple[int, ...], 32 | ) -> None: 33 | """Test output shapes. 34 | 35 | Args: 36 | batch_size: number of samples in batch. 37 | max_log_gamma: maximum log gamma. 38 | image_shape: image spatial shape. 39 | """ 40 | key = jax.random.PRNGKey(0) 41 | image = jax.random.uniform(key=key, shape=(batch_size, *image_shape), minval=0, maxval=1) 42 | batch = {IMAGE: image} 43 | got = self.variant(batch_random_adjust_gamma)( 44 | key=key, 45 | batch=batch, 46 | max_log_gamma=max_log_gamma, 47 | ) 48 | 49 | assert len(got) == 1 50 | chex.assert_shape(got[IMAGE], (batch_size, *image_shape)) 51 | 52 | 53 | class TestRescaleIntensity(chex.TestCase): 54 | """Test batch_rescale_intensity.""" 55 | 56 | @chex.all_variants() 57 | @parameterized.product( 58 | image_shape=[(8, 12, 6), (8, 12)], 59 | v_min=[0.0, 0.3], 60 | v_max=[1.0, 0.5], 61 | batch_size=[4, 1], 62 | ) 63 | def test_shapes( 64 | self, 65 | batch_size: int, 66 | v_min: float, 67 | v_max: float, 68 | image_shape: tuple[int, ...], 69 | ) -> None: 70 | """Test output shapes. 71 | 72 | Args: 73 | batch_size: number of samples in batch. 74 | v_min: minimum intensity. 75 | v_max: maximum intensity. 76 | image_shape: image spatial shape. 77 | """ 78 | key = jax.random.PRNGKey(0) 79 | image = jax.random.uniform(key=key, shape=(batch_size, *image_shape), minval=0, maxval=1) 80 | batch = {IMAGE: image} 81 | got = self.variant(batch_rescale_intensity)( 82 | key=key, 83 | batch=batch, 84 | v_min=v_min, 85 | v_max=v_max, 86 | ) 87 | 88 | assert len(got) == 1 89 | chex.assert_shape(got[IMAGE], (batch_size, *image_shape)) 90 | -------------------------------------------------------------------------------- /imgx/data/util_test.py: -------------------------------------------------------------------------------- 1 | """Tests for image utils of datasets.""" 2 | 3 | 4 | import chex 5 | import numpy as np 6 | import pytest 7 | from chex._src import fake 8 | 9 | from imgx.data.util import get_foreground_range 10 | 11 | 12 | # Set `FLAGS.chex_n_cpu_devices` CPU devices for all tests. 13 | def setUpModule() -> None: # pylint: disable=invalid-name 14 | """Fake multi-devices.""" 15 | fake.set_n_cpu_devices(2) 16 | 17 | 18 | @pytest.mark.parametrize( 19 | ("label", "expected"), 20 | [ 21 | ( 22 | np.array([0, 1, 2, 3]), 23 | np.array([[1, 3]]), 24 | ), 25 | ( 26 | np.array([1, 2, 3, 0]), 27 | np.array([[0, 2]]), 28 | ), 29 | ( 30 | np.array([1, 2, 3, 4]), 31 | np.array([[0, 3]]), 32 | ), 33 | ( 34 | np.array([0, 1, 2, 3, 4, 0, 0]), 35 | np.array([[1, 4]]), 36 | ), 37 | ( 38 | np.array([[0, 1, 2, 3], [0, 1, 2, 3], [0, 0, 0, 0]]), 39 | np.array([[0, 1], [1, 3]]), 40 | ), 41 | ], 42 | ids=[ 43 | "1d-left", 44 | "1d-right", 45 | "1d-none", 46 | "1d-both", 47 | "2d", 48 | ], 49 | ) 50 | def test_get_foreground_range( 51 | label: np.ndarray, 52 | expected: np.ndarray, 53 | ) -> None: 54 | """Test get_translation_range return values. 55 | 56 | Args: 57 | label: label with int values, not one-hot. 58 | expected: expected range. 59 | """ 60 | got = get_foreground_range( 61 | label=label, 62 | ) 63 | chex.assert_trees_all_equal(got, expected) 64 | -------------------------------------------------------------------------------- /imgx/data/warp.py: -------------------------------------------------------------------------------- 1 | """Module for image/lavel warping.""" 2 | from __future__ import annotations 3 | 4 | from functools import partial 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | from jax._src.scipy.ndimage import map_coordinates 9 | 10 | 11 | def get_coordinate_grid(shape: tuple[int, ...]) -> jnp.ndarray: 12 | """Generate a grid with given shape. 13 | 14 | This function is not jittable as the output depends on the value of shapes. 15 | 16 | Args: 17 | shape: shape of the grid, (d1, ..., dn). 18 | 19 | Returns: 20 | grid: grid coordinates, of shape (n, d1, ..., dn). 21 | grid[:, i1, ..., in] = [i1, ..., in] 22 | """ 23 | return jnp.stack( 24 | jnp.meshgrid( 25 | *(jnp.arange(d) for d in shape), 26 | indexing="ij", 27 | ), 28 | axis=0, 29 | dtype=jnp.float32, 30 | ) 31 | 32 | 33 | def batch_grid_sample( 34 | x: jnp.ndarray, 35 | grid: jnp.ndarray, 36 | order: int, 37 | constant_values: float = 0.0, 38 | ) -> jnp.ndarray: 39 | """Apply sampling to input. 40 | 41 | https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html 42 | 43 | Args: 44 | x: shape (batch, d1, ..., dn) or (batch, d1, ..., dn, c). 45 | grid: grid coordinates, of shape (batch, n, d1, ..., dn). 46 | order: interpolation order, 0 for nearest, 1 for linear. 47 | constant_values: constant value for out of bound coordinates. 48 | 49 | Returns: 50 | Same shape as x. 51 | """ 52 | if x.ndim not in [grid.ndim - 1, grid.ndim]: 53 | raise ValueError(f"Input x has shape {x.shape}, grid has shape {grid.shape}.") 54 | 55 | # vmap on batch axis 56 | sample_vmap = jax.vmap( 57 | partial( 58 | map_coordinates, 59 | order=order, 60 | mode="constant", 61 | cval=constant_values, 62 | ), 63 | in_axes=(0, 0), 64 | ) 65 | if x.ndim == grid.ndim: 66 | # vmap on channel axis 67 | ch_axis = x.ndim - 1 68 | sample_vmap = jax.vmap( 69 | sample_vmap, 70 | in_axes=(ch_axis, None), 71 | out_axes=ch_axis, 72 | ) 73 | return sample_vmap(x, grid) 74 | 75 | 76 | def warp_image( 77 | x: jnp.ndarray, 78 | ddf: jnp.ndarray, 79 | order: int, 80 | ) -> jnp.ndarray: 81 | """Warp the image with the deformation field. 82 | 83 | TODO: grid is a constant, can be precomputed. 84 | 85 | Args: 86 | x: shape (batch, d1, ..., dn) or (batch, d1, ..., dn, c). 87 | ddf: deformation field, of shape (batch, d1, ..., dn, n). 88 | order: interpolation order, 0 for nearest, 1 for linear. 89 | 90 | Returns: 91 | warped image, of shape (batch, d1, ..., dn) or (batch, d1, ..., dn, c). 92 | """ 93 | # (batch, n, d1, ..., dn) 94 | grid = get_coordinate_grid(shape=ddf.shape[1:-1]) 95 | grid += jnp.moveaxis(ddf, -1, 1) 96 | return batch_grid_sample(x, grid, order=order) 97 | -------------------------------------------------------------------------------- /imgx/datasets/README.md: -------------------------------------------------------------------------------- 1 | # ImgX Datasets 2 | 3 | A [TFDS](https://www.tensorflow.org/datasets/add_dataset)-based python package for data set 4 | building. 5 | 6 | Current supported data sets are listed below. Use the following commands to (re)build all data sets. 7 | 8 | ```bash 9 | make build_dataset 10 | ``` 11 | 12 | ## Male pelvic MR 13 | 14 | ### Description 15 | 16 | This data set from [Li et al. 2022](https://zenodo.org/record/7013610#.Y1U95-zMKrM) contains 589 17 | T2-weighted labeled images which are split for training, validation and testing respectively. 18 | 19 | ### Download and Build 20 | 21 | Use the following commands at the root of this repository (i.e. under `ImgX/`) to automatically 22 | download and build the data set, which will be built under `~/tensorflow_datasets` folder. 23 | Optionally, add flag `--overwrite` to rebuild/overwrite the data set. 24 | 25 | ```bash 26 | tfds build imgx/datasets/male_pelvic_mr 27 | ``` 28 | 29 | ## AMOS CT 30 | 31 | ### Description 32 | 33 | This data set from [Ji et al. 2022](https://zenodo.org/record/7155725#.ZAN4BuzP2rO) contains 500 CT 34 | labeled images which has been split into 200, 100, and 200 images for training, validation, and test 35 | sets. But test set labels were not released, therefore validation is further split into 10 and 90 36 | images for validation and test sets. 37 | 38 | ### Download and Build 39 | 40 | Use the following commands at the root of this repository (i.e. under `ImgX/`) to automatically 41 | download and build the data set, which will be built under `~/tensorflow_datasets` folder. 42 | Optionally, add flag `--overwrite` to rebuild/overwrite the data set. 43 | 44 | ```bash 45 | tfds build imgx/datasets/amos_ct 46 | ``` 47 | 48 | ## Muscle Ultrasound 49 | 50 | ### Description 51 | 52 | This data set from [Marzola et al. 2021](https://data.mendeley.com/datasets/3jykz7wz8d/1) contains 53 | 3910 labeled images, which has been split into 2531, 666, and 713 images for training, validation, 54 | and test sets. 55 | 56 | ### Download and Build 57 | 58 | Use the following commands at the root of this repository (i.e. under `ImgX/`) to automatically 59 | download and build the data set, which will be built under `~/tensorflow_datasets` folder. 60 | Optionally, add flag `--overwrite` to rebuild/overwrite the data set. 61 | 62 | ```bash 63 | tfds build imgx/datasets/muscle_us 64 | ``` 65 | 66 | ## Brain MR 67 | 68 | ### Description 69 | 70 | This data set from [Baid et al. 2021](https://arxiv.org/abs/2107.02314) contains 1251 labeled images 71 | which are split for training, validation and testing respectively. 72 | 73 | ### Download and Build 74 | 75 | #### Manual Download 76 | 77 | This data set requires manual data downloading from 78 | [Kaggle](https://www.kaggle.com/datasets/dschettler8845/brats-2021-task1). using 79 | [kaggle API](https://www.kaggle.com/docs/api). The 80 | [authentication token](https://www.kaggle.com/docs/api#getting-started-installation-&-authentication) 81 | shall be obtained and stored under `~/.kaggle/kaggle.json`. 82 | 83 | Then, execute the following commands to download and unzip files. Afterward, return to `ImgX/` 84 | folder (`/app/ImgX` for docker). 85 | 86 | ```bash 87 | mkdir -p ~/tensorflow_datasets/downloads/manual/BraTS2021_Kaggle/BraTS2021_Training_Data/ 88 | cd ~/tensorflow_datasets/downloads/manual/BraTS2021_Kaggle/BraTS2021_Training_Data/ 89 | kaggle datasets download -d dschettler8845/brats-2021-task1 90 | unzip brats-2021-task1.zip 91 | tar xf BraTS2021_Training_Data.tar 92 | rm BraTS2021_00495.tar 93 | rm BraTS2021_00621.tar 94 | rm BraTS2021_Training_Data.tar 95 | rm brats-2021-task1.zip 96 | ``` 97 | 98 | This way under `BraTS2021_Kaggle/` exist folders per sample. For example, files corresponding to uid 99 | `BraTS2021_01666` should be located at 100 | `~/tensorflow_datasets/downloads/manual/BraTS2021_Kaggle/BraTS2021_Training_Data/BraTS2021_01666/` 101 | under which there are five files: 102 | 103 | - `BraTS2021_01666_flair.nii.gz`, 104 | - `BraTS2021_01666_t1.nii.gz`, 105 | - `BraTS2021_01666_t1ce.nii.gz`, 106 | - `BraTS2021_01666_t2.nii.gz`, 107 | - `BraTS2021_01666_seg.nii.gz`. 108 | 109 | #### Automatic Build 110 | 111 | Use the following commands at the root of this repository (i.e. under `ImgX/`) to automatically 112 | build the data set, which will be built under `~/tensorflow_datasets` folder. Optionally, add flag 113 | `--overwrite` to rebuild/overwrite the data set. 114 | 115 | ```bash 116 | tfds build imgx/datasets/brats2021_mr 117 | ``` 118 | 119 | ## Automated Cardiac Diagnosis Challenge (ACDC) 120 | 121 | ### Description 122 | 123 | This data set from [Bernard et al. 2018](https://ieeexplore.ieee.org/document/8360453) contains 150 124 | samples. Samples are split into 100 and 50 for training and test sets. Each sample contains 125 | 126 | - a 4D image (a sequence of 3D MR images) 127 | - a 3D image and corresponding segmentation label for end-diastolic (ED) frame 128 | - a 3D image and corresponding segmentation label for end-systolic (ES) frame 129 | 130 | ### Download and Build 131 | 132 | Use the following commands at the root of this repository (i.e. under `ImgX/`) to automatically 133 | download and build the data set, which will be built under `~/tensorflow_datasets` folder. 134 | Optionally, add flag `--overwrite` to rebuild/overwrite the data set. 135 | 136 | ```bash 137 | tfds build imgx/datasets/acdc_mr 138 | ``` 139 | -------------------------------------------------------------------------------- /imgx/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | """Dataset module to build tensorflow datasets.""" 2 | 3 | from imgx.datasets.amos_ct.amos_ct_dataset_builder import AMOS_CT_INFO 4 | from imgx.datasets.brats2021_mr.brats2021_mr_dataset_builder import BRATS2021_MR_INFO 5 | from imgx.datasets.dataset_info import DatasetInfo 6 | from imgx.datasets.male_pelvic_mr.male_pelvic_mr_dataset_builder import MALE_PELVIR_MR_INFO 7 | from imgx.datasets.muscle_us.muscle_us_dataset_builder import MUSCLE_US_INFO 8 | 9 | # supported datasets 10 | MALE_PELVIC_MR = "male_pelvic_mr" 11 | AMOS_CT = "amos_ct" 12 | MUSCLE_US = "muscle_us" 13 | BRATS2021_MR = "brats2021_mr" 14 | 15 | # dataset info 16 | INFO_MAP: dict[str, DatasetInfo] = { 17 | MALE_PELVIC_MR: MALE_PELVIR_MR_INFO, 18 | AMOS_CT: AMOS_CT_INFO, 19 | MUSCLE_US: MUSCLE_US_INFO, 20 | BRATS2021_MR: BRATS2021_MR_INFO, 21 | } 22 | -------------------------------------------------------------------------------- /imgx/datasets/amos_ct/__init__.py: -------------------------------------------------------------------------------- 1 | """AMOS dataset. 2 | 3 | https://arxiv.org/abs/2206.08023 4 | 5 | https://zenodo.org/record/7262581 6 | """ 7 | -------------------------------------------------------------------------------- /imgx/datasets/brats2021_mr/__init__.py: -------------------------------------------------------------------------------- 1 | """BRaTS 2021 MR dataset. 2 | 3 | https://www.kaggle.com/datasets/dschettler8845/brats-2021-task1 4 | """ 5 | -------------------------------------------------------------------------------- /imgx/datasets/constant.py: -------------------------------------------------------------------------------- 1 | """Constants for datasets. 2 | 3 | Cannot be defined in __init__.py because of circular import. 4 | __init__.py imports from each data set, which imports constants from this file. 5 | """ 6 | from pathlib import Path 7 | 8 | # splits 9 | TRAIN_SPLIT = "train" 10 | VALID_SPLIT = "valid" 11 | TEST_SPLIT = "test" 12 | 13 | # data dict keys 14 | UID = "uid" 15 | IMAGE = "image" # in a batch, keys having image are also considered as images 16 | LABEL = "label" # in a batch, keys having label are also considered as labels 17 | 18 | # prediction 19 | LABEL_PRED = "label_pred" 20 | 21 | TFDS_DIR: Path = Path.home() / "tensorflow_datasets" 22 | TFDS_EXTRACTED_DIR: Path = TFDS_DIR / "downloads" / "extracted" 23 | TFDS_MANUAL_DIR: Path = TFDS_DIR / "downloads" / "manual" 24 | 25 | # segmentation task 26 | FOREGROUND_RANGE = "foreground_range" # added during pre-processing in tf for augmentation 27 | -------------------------------------------------------------------------------- /imgx/datasets/dataset_info.py: -------------------------------------------------------------------------------- 1 | """Data set class for datasets.""" 2 | import dataclasses 3 | from pathlib import Path 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | 8 | 9 | @dataclasses.dataclass 10 | class DatasetInfo: 11 | """Data set class for datasets.""" 12 | 13 | name: str 14 | tfds_preprocessed_dir: Path 15 | image_spacing: tuple[float, ...] 16 | image_spatial_shape: tuple[int, ...] 17 | image_channels: int 18 | class_names: tuple[str, ...] # for segmentation label only 19 | classes_are_exclusive: bool # for segmentation label only 20 | 21 | @property 22 | def input_image_shape(self) -> tuple[int, ...]: 23 | """Input shape of image.""" 24 | return (*self.image_spatial_shape, self.image_channels) 25 | 26 | @property 27 | def label_shape(self) -> tuple[int, ...]: 28 | """Shape of label.""" 29 | return self.image_spatial_shape 30 | 31 | @property 32 | def ndim(self) -> int: 33 | """Number of dimensions.""" 34 | return len(self.image_spatial_shape) 35 | 36 | @property 37 | def num_classes(self) -> int: 38 | """Number of classes for segmentation.""" 39 | raise NotImplementedError 40 | 41 | def logits_to_label(self, x: jnp.ndarray, axis: int) -> jnp.ndarray: 42 | """Transform logits to label with integers.""" 43 | raise NotImplementedError 44 | 45 | def label_to_mask( 46 | self, x: jnp.ndarray, axis: int, dtype: jnp.dtype = jnp.float32 47 | ) -> jnp.ndarray: 48 | """Transform label to boolean mask.""" 49 | raise NotImplementedError 50 | 51 | def logits_to_label_with_post_process(self, x: jnp.ndarray, axis: int) -> jnp.ndarray: 52 | """Transform logits to label with post-processing.""" 53 | return self.post_process_label(self.logits_to_label(x, axis=axis)) 54 | 55 | def post_process_label(self, label: jnp.ndarray) -> jnp.ndarray: 56 | """Label post-processing.""" 57 | return label 58 | 59 | 60 | class OneHotLabeledDatasetInfo(DatasetInfo): 61 | """Data set with mutual exclusive labels.""" 62 | 63 | @property 64 | def num_classes(self) -> int: 65 | """Number of classes including background.""" 66 | return len(self.class_names) + 1 67 | 68 | def logits_to_label(self, x: jnp.ndarray, axis: int) -> jnp.ndarray: 69 | """Transform logits to label with integers. 70 | 71 | Args: 72 | x: logits. 73 | axis: axis of num_classes. 74 | 75 | Returns: 76 | Label with integers. 77 | """ 78 | return jnp.argmax(x, axis=axis) 79 | 80 | def label_to_mask( 81 | self, x: jnp.ndarray, axis: int, dtype: jnp.dtype = jnp.float32 82 | ) -> jnp.ndarray: 83 | """Transform label to boolean mask. 84 | 85 | Args: 86 | x: label. 87 | axis: axis of num_classes. 88 | dtype: dtype of output. 89 | 90 | Returns: 91 | One hot mask. 92 | """ 93 | return jax.nn.one_hot( 94 | x=x, 95 | num_classes=self.num_classes, 96 | axis=axis, 97 | dtype=dtype, 98 | ) 99 | -------------------------------------------------------------------------------- /imgx/datasets/male_pelvic_mr/__init__.py: -------------------------------------------------------------------------------- 1 | """Male pelvic MR dataset. 2 | 3 | https://arxiv.org/abs/2209.05160 4 | """ 5 | -------------------------------------------------------------------------------- /imgx/datasets/muscle_us/__init__.py: -------------------------------------------------------------------------------- 1 | """Muscle ultrasound dataset. 2 | 3 | https://www.sciencedirect.com/science/article/pii/S0010482521004170 4 | https://data.mendeley.com/datasets/3jykz7wz8d/1 5 | """ 6 | -------------------------------------------------------------------------------- /imgx/datasets/save.py: -------------------------------------------------------------------------------- 1 | """IO related functions (file cannot be named as io). 2 | 3 | https://stackoverflow.com/questions/26569828/pycharm-py-initialize-cant-initialize-sys-standard-streams 4 | """ 5 | from __future__ import annotations 6 | 7 | from pathlib import Path 8 | 9 | import numpy as np 10 | import pandas as pd 11 | import SimpleITK as sitk # noqa: N813 12 | from absl import logging 13 | from PIL import Image 14 | 15 | 16 | def save_uids( 17 | train_uids: list[str], 18 | valid_uids: list[str], 19 | test_uids: list[str], 20 | out_dir: Path, 21 | ) -> None: 22 | """Save uids to csv files. 23 | 24 | Args: 25 | train_uids: list of training uids. 26 | valid_uids: list of validation uids. 27 | test_uids: list of test uids. 28 | out_dir: directory to save the csv files. 29 | """ 30 | pd.DataFrame({"uid": train_uids}).to_csv(out_dir / "train_uids.csv", index=False) 31 | pd.DataFrame({"uid": valid_uids}).to_csv(out_dir / "valid_uids.csv", index=False) 32 | pd.DataFrame({"uid": test_uids}).to_csv(out_dir / "test_uids.csv", index=False) 33 | logging.info(f"There are {len(train_uids)} training samples.") 34 | logging.info(f"There are {len(valid_uids)} validation samples.") 35 | logging.info(f"There are {len(test_uids)} test samples.") 36 | 37 | 38 | def save_2d_grayscale_image( 39 | image: np.ndarray, 40 | out_path: Path, 41 | ) -> None: 42 | """Save grayscale 2d images. 43 | 44 | Args: 45 | image: (height, width), the values between [0, 1]. 46 | out_path: output path. 47 | """ 48 | out_path.parent.mkdir(parents=True, exist_ok=True) 49 | image = np.asarray(image * 255, dtype="uint8") 50 | Image.fromarray(image, "L").save(str(out_path)) 51 | 52 | 53 | def load_2d_grayscale_image( 54 | image_path: Path, 55 | dtype: np.dtype = np.uint8, 56 | ) -> np.ndarray: 57 | """Load 2d images. 58 | 59 | Args: 60 | image_path: path to the mask. 61 | dtype: data type of the output. 62 | 63 | Returns: 64 | mask: (height, width), the values are between [0, 1]. 65 | """ 66 | mask = Image.open(str(image_path)).convert("L") # value [0, 255] 67 | mask = np.asarray(mask) / 255 # value [0, 1] 68 | mask = np.asarray(mask, dtype=dtype) 69 | return mask 70 | 71 | 72 | def save_3d_image( 73 | image: np.ndarray, 74 | reference_image: sitk.Image, 75 | out_path: Path, 76 | ) -> None: 77 | """Save 3d image. 78 | 79 | Args: 80 | image: (depth, height, width), the values are integers. 81 | reference_image: reference image for copy meta data. 82 | out_path: output path. 83 | """ 84 | out_path.parent.mkdir(parents=True, exist_ok=True) 85 | image = sitk.GetImageFromArray(image) 86 | image.CopyInformation(reference_image) 87 | # output 88 | sitk.WriteImage( 89 | image=image, 90 | fileName=out_path, 91 | useCompression=True, 92 | ) 93 | 94 | 95 | def save_image( 96 | image: np.ndarray, 97 | reference_image: sitk.Image, 98 | out_path: Path, 99 | dtype: np.dtype, 100 | ) -> None: 101 | """Save 2d or 3d image. 102 | 103 | Args: 104 | image: (width, height, depth) or (height, width), 3D is not reversed but 2D is reversed. 105 | reference_image: reference image for copy metadata. 106 | out_path: output path. 107 | dtype: data type of the output. 108 | """ 109 | out_path.parent.mkdir(parents=True, exist_ok=True) 110 | if image.ndim not in [2, 3]: 111 | raise ValueError( 112 | f"Image should be 2D or 3D, but {image.ndim}D is given with shape {image.shape}." 113 | ) 114 | if image.ndim == 2: 115 | save_2d_grayscale_image( 116 | image=image.astype(dtype=dtype), 117 | out_path=out_path, 118 | ) 119 | else: 120 | # (width, height, depth) -> (depth, height, width) 121 | image = np.transpose(image, axes=[2, 1, 0]).astype(dtype=dtype) 122 | save_3d_image( 123 | image=image, 124 | reference_image=reference_image, 125 | out_path=out_path, 126 | ) 127 | 128 | 129 | def save_ddf( 130 | ddf: np.ndarray, 131 | reference_image: sitk.Image, 132 | out_path: Path, 133 | dtype: np.dtype = np.float64, 134 | ) -> None: 135 | """Save ddf for 3d volumes. 136 | 137 | Args: 138 | ddf: (width, height, depth, 3), unit is 1 without spacing. 139 | reference_image: reference image for copy metadata. 140 | out_path: output path. 141 | dtype: data type of the output. 142 | """ 143 | if ddf.ndim != 4: 144 | raise ValueError(f"Mask should be 4D, but {ddf.ndim}D is given.") 145 | out_path.parent.mkdir(parents=True, exist_ok=True) 146 | 147 | # ddf is scaled by spacing 148 | ddf = np.transpose(ddf, axes=[2, 1, 0, 3]).astype(dtype=dtype) 149 | ddf *= np.expand_dims(reference_image.GetSpacing(), axis=list(range(ddf.ndim - 1))) 150 | 151 | ddf_volume = sitk.GetImageFromArray(ddf, isVector=True) 152 | ddf_volume.SetSpacing(reference_image.GetSpacing()) 153 | ddf_volume.CopyInformation(reference_image) 154 | tx = sitk.DisplacementFieldTransform(ddf_volume) 155 | sitk.WriteTransform(tx, out_path) 156 | -------------------------------------------------------------------------------- /imgx/datasets/tests/conftest.py: -------------------------------------------------------------------------------- 1 | """Fixtures for tests.""" 2 | from pathlib import Path 3 | 4 | import pytest 5 | 6 | 7 | @pytest.fixture(scope="session") 8 | def fixture_path() -> Path: 9 | """Directory path containing fixture data. 10 | 11 | Returns: 12 | Folder path containing the data. 13 | """ 14 | return Path(__file__).resolve().parent / "fixtures" 15 | -------------------------------------------------------------------------------- /imgx/datasets/tests/fixtures/BB_anon_1789_3_mask_pred.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mathpluscode/ImgX-DiffSeg/ce57729e2dcd1d21960ec797bdff2ef5df7e5101/imgx/datasets/tests/fixtures/BB_anon_1789_3_mask_pred.png -------------------------------------------------------------------------------- /imgx/datasets/tests/fixtures/BB_anon_1789_3_mask_pred_postprocessed_0.5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mathpluscode/ImgX-DiffSeg/ce57729e2dcd1d21960ec797bdff2ef5df7e5101/imgx/datasets/tests/fixtures/BB_anon_1789_3_mask_pred_postprocessed_0.5.png -------------------------------------------------------------------------------- /imgx/datasets/tests/fixtures/BB_anon_1789_3_mask_pred_postprocessed_0.75.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mathpluscode/ImgX-DiffSeg/ce57729e2dcd1d21960ec797bdff2ef5df7e5101/imgx/datasets/tests/fixtures/BB_anon_1789_3_mask_pred_postprocessed_0.75.png -------------------------------------------------------------------------------- /imgx/datasets/tests/fixtures/BB_anon_425_2_mask_pred.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mathpluscode/ImgX-DiffSeg/ce57729e2dcd1d21960ec797bdff2ef5df7e5101/imgx/datasets/tests/fixtures/BB_anon_425_2_mask_pred.png -------------------------------------------------------------------------------- /imgx/datasets/tests/fixtures/BB_anon_425_2_mask_pred_postprocessed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mathpluscode/ImgX-DiffSeg/ce57729e2dcd1d21960ec797bdff2ef5df7e5101/imgx/datasets/tests/fixtures/BB_anon_425_2_mask_pred_postprocessed.png -------------------------------------------------------------------------------- /imgx/datasets/tests/fixtures/GM_anon_780_3_mask_pred.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mathpluscode/ImgX-DiffSeg/ce57729e2dcd1d21960ec797bdff2ef5df7e5101/imgx/datasets/tests/fixtures/GM_anon_780_3_mask_pred.png -------------------------------------------------------------------------------- /imgx/datasets/tests/fixtures/GM_anon_780_3_mask_pred_postprocessed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mathpluscode/ImgX-DiffSeg/ce57729e2dcd1d21960ec797bdff2ef5df7e5101/imgx/datasets/tests/fixtures/GM_anon_780_3_mask_pred_postprocessed.png -------------------------------------------------------------------------------- /imgx/datasets/util.py: -------------------------------------------------------------------------------- 1 | """Util functions for image. 2 | 3 | Some are adapted from 4 | https://github.com/google-research/scenic/blob/03735eb81f64fd1241c4efdb946ea6de3d326fe1/scenic/dataset_lib/dataset_utils.py 5 | """ 6 | from __future__ import annotations 7 | 8 | import numpy as np 9 | 10 | 11 | def get_center_pad_shape( 12 | current_shape: tuple[int, ...], target_shape: tuple[int, ...] 13 | ) -> tuple[tuple[int, ...], tuple[int, ...]]: 14 | """Get pad sizes for sitk.ConstantPad. 15 | 16 | The padding is added symmetrically. 17 | 18 | Args: 19 | current_shape: current shape of the image. 20 | target_shape: target shape of the image. 21 | 22 | Returns: 23 | pad_lower: shape to pad on the lower side. 24 | pad_upper: shape to pad on the upper side. 25 | """ 26 | pad_lower = [] 27 | pad_upper = [] 28 | for i, size_i in enumerate(current_shape): 29 | pad_i = max(target_shape[i] - size_i, 0) 30 | pad_lower_i = pad_i // 2 31 | pad_upper_i = pad_i - pad_lower_i 32 | pad_lower.append(pad_lower_i) 33 | pad_upper.append(pad_upper_i) 34 | return tuple(pad_lower), tuple(pad_upper) 35 | 36 | 37 | def get_center_crop_shape( 38 | current_shape: tuple[int, ...], target_shape: tuple[int, ...] 39 | ) -> tuple[tuple[int, ...], tuple[int, ...]]: 40 | """Get crop sizes for sitk.Crop. 41 | 42 | The crop is performed symmetrically. 43 | 44 | Args: 45 | current_shape: current shape of the image. 46 | target_shape: target shape of the image. 47 | 48 | Returns: 49 | crop_lower: shape to pad on the lower side. 50 | crop_upper: shape to pad on the upper side. 51 | """ 52 | crop_lower = [] 53 | crop_upper = [] 54 | for i, size_i in enumerate(current_shape): 55 | crop_i = max(size_i - target_shape[i], 0) 56 | crop_lower_i = crop_i // 2 57 | crop_upper_i = crop_i - crop_lower_i 58 | crop_lower.append(crop_lower_i) 59 | crop_upper.append(crop_upper_i) 60 | return tuple(crop_lower), tuple(crop_upper) 61 | 62 | 63 | def try_to_get_center_crop_shape( 64 | label_min: int, label_max: int, current_length: int, target_length: int 65 | ) -> tuple[int, int]: 66 | """Try to crop at the center of label, 1D. 67 | 68 | Args: 69 | label_min: label index minimum, inclusive. 70 | label_max: label index maximum, exclusive. 71 | current_length: current image length. 72 | target_length: target image length. 73 | 74 | Returns: 75 | crop_lower: shape to pad on the lower side. 76 | crop_upper: shape to pad on the upper side. 77 | 78 | Raises: 79 | ValueError: if label min max is out of range. 80 | """ 81 | if label_min < 0 or label_max > current_length: 82 | raise ValueError("Label index out of range.") 83 | 84 | if current_length <= target_length: 85 | # no need of crop 86 | return 0, 0 87 | # attend to perform crop centered at label center 88 | label_center = (label_max - 1 + label_min) / 2.0 89 | bbox_lower = int(np.ceil(label_center - target_length / 2.0)) 90 | bbox_upper = bbox_lower + target_length 91 | # if lower is negative, then have to shift the window to right 92 | bbox_lower = max(bbox_lower, 0) 93 | # if upper is too large, then have to shift the window to left 94 | if bbox_upper > current_length: 95 | bbox_lower -= bbox_upper - current_length 96 | # calculate crop 97 | crop_lower = bbox_lower # bbox index starts at 0 98 | crop_upper = current_length - target_length - crop_lower 99 | return crop_lower, crop_upper 100 | 101 | 102 | def get_center_crop_shape_from_bbox( 103 | bbox_min: tuple[int, ...] | np.ndarray, 104 | bbox_max: tuple[int, ...] | np.ndarray, 105 | current_shape: tuple[int, ...], 106 | target_shape: tuple[int, ...], 107 | ) -> tuple[tuple[int, ...], tuple[int, ...]]: 108 | """Get crop sizes for sitk.Crop from label bounding box. 109 | 110 | The crop is not necessarily performed symmetrically. 111 | 112 | Args: 113 | bbox_min: [start_in_1st_spatial_dim, ...], inclusive, starts at zero. 114 | bbox_max: [end_in_1st_spatial_dim, ...], exclusive, starts at zero. 115 | current_shape: current shape of the image. 116 | target_shape: target shape of the image. 117 | 118 | Returns: 119 | crop_lower: shape to crop on the lower side. 120 | crop_upper: shape to crop on the upper side. 121 | """ 122 | crop_lower = [] 123 | crop_upper = [] 124 | for i, current_length in enumerate(current_shape): 125 | crop_lower_i, crop_upper_i = try_to_get_center_crop_shape( 126 | label_min=bbox_min[i], 127 | label_max=bbox_max[i], 128 | current_length=current_length, 129 | target_length=target_shape[i], 130 | ) 131 | crop_lower.append(crop_lower_i) 132 | crop_upper.append(crop_upper_i) 133 | return tuple(crop_lower), tuple(crop_upper) 134 | -------------------------------------------------------------------------------- /imgx/device.py: -------------------------------------------------------------------------------- 1 | """Module to handle multi-devices.""" 2 | from __future__ import annotations 3 | 4 | import chex 5 | import jax 6 | import jax.numpy as jnp 7 | from jax import lax 8 | 9 | 10 | def broadcast_to_local_devices(value: chex.ArrayTree) -> chex.ArrayTree: 11 | """Broadcasts an object to all local devices. 12 | 13 | Args: 14 | value: value to be broadcast. 15 | 16 | Returns: 17 | broadcast value. 18 | """ 19 | devices = jax.local_devices() 20 | return jax.tree_map(lambda v: jax.device_put_sharded(len(devices) * [v], devices), value) 21 | 22 | 23 | def get_first_replica_values(value: chex.ArrayTree) -> chex.ArrayTree: 24 | """Gets values from the first replica. 25 | 26 | Args: 27 | value: broadcast value. 28 | 29 | Returns: 30 | value of the first replica. 31 | """ 32 | return jax.tree_map(lambda x: x[0], value) 33 | 34 | 35 | def bind_rng_to_host_or_device( 36 | rng: jnp.ndarray, 37 | bind_to: str | None = None, 38 | axis_name: str | tuple[str, ...] | None = None, 39 | ) -> jnp.ndarray: 40 | """Binds a rng to the host or device. 41 | 42 | https://github.com/google-research/scenic/blob/main/scenic/train_lib/train_utils.py#L577 43 | 44 | Must be called from within a pmapped function. Note that when binding to 45 | "device", we also bind the rng to hosts, as we fold_in the rng with 46 | axis_index, which is unique for devices across all hosts. 47 | 48 | Args: 49 | rng: A jax.random.PRNGKey. 50 | bind_to: Must be one of the 'host' or 'device'. None means no binding. 51 | axis_name: The axis of the devices we are binding rng across, necessary 52 | if bind_to is device. 53 | 54 | Returns: 55 | jax.random.PRNGKey specialized to host/device. 56 | """ 57 | if bind_to is None: 58 | return rng 59 | if bind_to == "host": 60 | return jax.random.fold_in(rng, jax.process_index()) 61 | if bind_to == "device": 62 | return jax.random.fold_in(rng, lax.axis_index(axis_name)) 63 | raise ValueError("`bind_to` should be one of the `[None, 'host', 'device']`") 64 | 65 | 66 | def shard( 67 | pytree: chex.ArrayTree, 68 | num_replicas: int, 69 | ) -> chex.ArrayTree: 70 | """Reshapes all arrays in the pytree to add a leading shard dimension. 71 | 72 | We assume that all arrays in the pytree have leading dimension 73 | divisible by num_devices_per_replica. 74 | 75 | Args: 76 | pytree: A pytree of arrays to be sharded. 77 | num_replicas: number of model replicas. 78 | 79 | Returns: 80 | Sharded data. 81 | """ 82 | 83 | def _shard_array(array: jnp.ndarray) -> jnp.ndarray: 84 | return array.reshape((num_replicas, -1) + array.shape[1:]) 85 | 86 | return jax.tree_map(_shard_array, pytree) 87 | 88 | 89 | def unshard(pytree: chex.ArrayTree, device: jax.Device) -> chex.ArrayTree: 90 | """Reshapes arrays from [ndev, bs, ...] to [host_bs, ...]. 91 | 92 | Args: 93 | pytree: A pytree of arrays to be sharded. 94 | device: device to put. 95 | 96 | Returns: 97 | Sharded data. 98 | """ 99 | 100 | def _unshard_array(array: jnp.ndarray) -> jnp.ndarray: 101 | ndev, bs = array.shape[:2] 102 | return array.reshape((ndev * bs,) + array.shape[2:]) 103 | 104 | pytree = jax.device_put(pytree, device) 105 | return jax.tree_map(_unshard_array, pytree) 106 | -------------------------------------------------------------------------------- /imgx/diffusion/README.md: -------------------------------------------------------------------------------- 1 | # Diffusion 2 | 3 | The class diagram is illustrated below, where 4 | 5 | - red blocks are classes having abstract methods. 6 | - yellow blocks are classes used during training, with `sample()` not being implemented. 7 | - green blocks are classes used during inference, with `sample()` implemented. 8 | 9 |
10 | diffusion_class_diagram 11 |
12 | -------------------------------------------------------------------------------- /imgx/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | """Diffusion related functions.""" 2 | from imgx.diffusion.diffusion import Diffusion 3 | 4 | __all__ = [ 5 | "Diffusion", 6 | ] 7 | -------------------------------------------------------------------------------- /imgx/diffusion/diffusion.py: -------------------------------------------------------------------------------- 1 | """Base diffusion class.""" 2 | from __future__ import annotations 3 | 4 | from collections.abc import Sequence 5 | from dataclasses import dataclass 6 | from typing import Callable 7 | 8 | import jax.numpy as jnp 9 | import jax.random 10 | 11 | 12 | @dataclass 13 | class Diffusion: 14 | """Base class for diffusion.""" 15 | 16 | num_timesteps: int 17 | noise_fn: Callable[..., jnp.ndarray] 18 | 19 | def sample_noise(self, key: jax.Array, shape: Sequence[int], dtype: jnp.dtype) -> jnp.ndarray: 20 | """Return a noise of the same shape as input. 21 | 22 | Define this function to avoid defining randon key. 23 | 24 | Args: 25 | key: random key. 26 | shape: array shape. 27 | dtype: data type. 28 | 29 | Returns: 30 | Noise of the same shape and dtype as x. 31 | """ 32 | return self.noise_fn(key=key, shape=shape, dtype=dtype) 33 | 34 | def t_index_to_t(self, t_index: jnp.ndarray) -> jnp.ndarray: 35 | """Convert t_index to t. 36 | 37 | t_index = 0 corresponds to t = 1 / num_timesteps. 38 | t_index = num_timesteps - 1 corresponds to t = 1. 39 | 40 | Args: 41 | t_index: t_index, shape (batch, ). 42 | 43 | Returns: 44 | t: t, shape (batch, ). 45 | """ 46 | return jnp.asarray(t_index + 1, jnp.float32) / self.num_timesteps 47 | 48 | def q_sample( 49 | self, 50 | x_start: jnp.ndarray, 51 | noise: jnp.ndarray, 52 | t_index: jnp.ndarray, 53 | ) -> jnp.ndarray: 54 | """Sample from q(x_t | x_0). 55 | 56 | Args: 57 | x_start: noiseless input. 58 | noise: same shape as x_start. 59 | t_index: storing index values < self.num_timesteps. 60 | 61 | Returns: 62 | Noisy array with same shape as x_start. 63 | """ 64 | raise NotImplementedError 65 | 66 | def predict_xprev_from_xstart_xt( 67 | self, x_start: jnp.ndarray, x_t: jnp.ndarray, t_index: jnp.ndarray 68 | ) -> jnp.ndarray: 69 | """Get x_{t-1} from x_0 and x_t. 70 | 71 | Args: 72 | x_start: noisy input at t, shape (batch, ...). 73 | x_t: noisy input, same shape as x_start. 74 | t_index: storing index values < self.num_timesteps, shape (batch, ). 75 | 76 | Returns: 77 | predicted x_0, same shape as x_prev. 78 | """ 79 | raise NotImplementedError 80 | 81 | def predict_xstart_from_model_out_xt( 82 | self, 83 | model_out: jnp.ndarray, 84 | x_t: jnp.ndarray, 85 | t_index: jnp.ndarray, 86 | ) -> jnp.ndarray: 87 | """Predict x_0 from model output and x_t. 88 | 89 | Args: 90 | model_out: model output. 91 | x_t: noisy input. 92 | t_index: storing index values < self.num_timesteps. 93 | 94 | Returns: 95 | x_start, same shape as x_t. 96 | """ 97 | raise NotImplementedError 98 | 99 | def variational_lower_bound( 100 | self, 101 | model_out: jnp.ndarray, 102 | x_start: jnp.ndarray, 103 | x_t: jnp.ndarray, 104 | t_index: jnp.ndarray, 105 | ) -> tuple[jnp.ndarray, jnp.ndarray]: 106 | """Variational lower-bound, ELBO, smaller is better. 107 | 108 | Args: 109 | model_out: raw model output, may contain additional parameters. 110 | x_start: noiseless input. 111 | x_t: noisy input, same shape as x_start. 112 | t_index: storing index values < self.num_timesteps. 113 | 114 | Returns: 115 | - lower bounds, shape (batch, ). 116 | - model_out with the same shape as x_start. 117 | """ 118 | raise NotImplementedError 119 | 120 | def sample( 121 | self, 122 | key: jax.Array, 123 | model_out: jnp.ndarray, 124 | x_t: jnp.ndarray, 125 | t_index: jnp.ndarray, 126 | ) -> tuple[jnp.ndarray, jnp.ndarray]: 127 | """Sample x_{t-1} ~ p(x_{t-1} | x_t). 128 | 129 | Args: 130 | key: random key. 131 | model_out: model predicted output. 132 | If model estimates variance, the last axis will be split. 133 | x_t: noisy x at time t. 134 | t_index: storing index values < self.num_timesteps. 135 | 136 | Returns: 137 | sample: x_{t-1}, same shape as x_t. 138 | x_start_pred: same shape as x_t. 139 | """ 140 | raise NotImplementedError 141 | 142 | def diffusion_loss( 143 | self, 144 | x_start: jnp.ndarray, 145 | x_t: jnp.ndarray, 146 | t_index: jnp.ndarray, 147 | noise: jnp.ndarray, 148 | model_out: jnp.ndarray, 149 | ) -> tuple[dict[str, jnp.ndarray], jnp.ndarray]: 150 | """Diffusion-specific loss function. 151 | 152 | Args: 153 | x_start: noiseless input. 154 | x_t: noisy input. 155 | t_index: storing index values < self.num_timesteps. 156 | noise: sampled noise, same shape as x_t. 157 | model_out: model output, may contain additional parameters. 158 | 159 | Returns: 160 | scalars: dict of losses, each with shape (batch, ). 161 | model_out: same shape as x_start. 162 | """ 163 | raise NotImplementedError 164 | -------------------------------------------------------------------------------- /imgx/diffusion/gaussian/__init__.py: -------------------------------------------------------------------------------- 1 | """Gaussian based continuous diffusion models.""" 2 | -------------------------------------------------------------------------------- /imgx/diffusion/gaussian/sampler.py: -------------------------------------------------------------------------------- 1 | """Module for sampling.""" 2 | from __future__ import annotations 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | from imgx.diffusion.gaussian.gaussian_diffusion import GaussianDiffusion 8 | from imgx.diffusion.util import expand, extract_and_expand 9 | 10 | 11 | class DDPMSampler(GaussianDiffusion): 12 | """DDPM https://arxiv.org/abs/2006.11239.""" 13 | 14 | def sample( 15 | self, 16 | key: jax.Array, 17 | model_out: jnp.ndarray, 18 | x_t: jnp.ndarray, 19 | t_index: jnp.ndarray, 20 | ) -> tuple[jnp.ndarray, jnp.ndarray]: 21 | """Sample x_{t-1} ~ p(x_{t-1} | x_t) using DDPM. 22 | 23 | https://arxiv.org/abs/2006.11239 24 | 25 | Args: 26 | key: random key. 27 | model_out: model predicted output. 28 | If model estimates variance, the last axis will be split. 29 | x_t: noisy input, shape (batch, ...). 30 | t_index: storing index values < self.num_timesteps, 31 | shape (batch, ) or broadcast-compatible to x_start shape. 32 | 33 | Returns: 34 | sample: x_{t-1}, same shape as x_t. 35 | x_start_pred: same shape as x_t. 36 | """ 37 | x_start_pred, mean, log_variance = self.p_mean_variance( 38 | model_out=model_out, 39 | x_t=x_t, 40 | t_index=t_index, 41 | ) 42 | noise = self.sample_noise(key=key, shape=x_t.shape, dtype=x_t.dtype) 43 | 44 | # no noise when t=0 45 | # mean + exp(log(sigma**2)/2) * noise = mean + sigma * noise 46 | nonzero_mask = jnp.array(t_index != 0, dtype=noise.dtype) 47 | nonzero_mask = expand(nonzero_mask, noise.ndim) 48 | sample = mean + nonzero_mask * jnp.exp(0.5 * log_variance) * noise 49 | 50 | return sample, x_start_pred 51 | 52 | 53 | class DDIMSampler(GaussianDiffusion): 54 | """DDIM https://arxiv.org/abs/2010.02502. 55 | 56 | https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ddim/pipeline_ddim.py 57 | """ 58 | 59 | def sample( 60 | self, 61 | key: jax.Array, 62 | model_out: jnp.ndarray, 63 | x_t: jnp.ndarray, 64 | t_index: jnp.ndarray, 65 | eta: float = 0.0, 66 | ) -> tuple[jnp.ndarray, jnp.ndarray]: 67 | """Sample x_{t-1} ~ p(x_{t-1} | x_t) using DDIM. 68 | 69 | https://arxiv.org/abs/2010.02502 70 | 71 | Args: 72 | key: random key. 73 | model_out: model predicted output. 74 | If model estimates variance, the last axis will be split. 75 | x_t: noisy input, shape (batch, ...). 76 | t_index: storing index values < self.num_timesteps, 77 | shape (batch, ) or broadcast-compatible to x_start shape. 78 | eta: control the noise level in sampling. 79 | 80 | Returns: 81 | sample: x_{t-1}, same shape as x_t. 82 | x_start_pred: same shape as x_t. 83 | """ 84 | # prepare constants 85 | x_start_pred, _ = self.p_mean( 86 | model_out=model_out, 87 | x_t=x_t, 88 | t_index=t_index, 89 | ) 90 | noise = self.predict_noise_from_xstart_xt(x_t=x_t, x_start=x_start_pred, t_index=t_index) 91 | alphas_cumprod_prev = extract_and_expand( 92 | self.alphas_cumprod_prev, t_index=t_index, ndim=x_t.ndim 93 | ) 94 | coeff_start = jnp.sqrt(alphas_cumprod_prev) 95 | log_variance = ( 96 | extract_and_expand( 97 | self.posterior_log_variance_clipped, 98 | t_index=t_index, 99 | ndim=x_t.ndim, 100 | ) 101 | * eta 102 | ) 103 | coeff_noise = jnp.sqrt(1.0 - alphas_cumprod_prev - log_variance**2) 104 | mean = coeff_start * x_start_pred + coeff_noise * noise 105 | 106 | # deterministic for t_index > 0 107 | nonzero_mask = jnp.array(t_index != 0, dtype=x_t.dtype) 108 | nonzero_mask = expand(nonzero_mask, x_t.ndim) 109 | 110 | # sample 111 | noise = self.sample_noise(key=key, shape=x_t.shape, dtype=x_t.dtype) 112 | sample = mean + nonzero_mask * log_variance * noise 113 | return sample, x_start_pred 114 | -------------------------------------------------------------------------------- /imgx/diffusion/gaussian/variance_schedule.py: -------------------------------------------------------------------------------- 1 | """Variance schedule for diffusion models.""" 2 | from __future__ import annotations 3 | 4 | import jax.numpy as jnp 5 | import numpy as np 6 | 7 | 8 | def get_beta_schedule( 9 | num_timesteps: int, 10 | beta_schedule: str, 11 | beta_start: float, 12 | beta_end: float, 13 | ) -> jnp.ndarray: 14 | """Get variance (beta) schedule for q(x_t | x_{t-1}). 15 | 16 | Args: 17 | num_timesteps: number of time steps in total, T. 18 | beta_schedule: schedule for beta. 19 | beta_start: beta for t=0. 20 | beta_end: beta for t=T-1. 21 | 22 | Returns: 23 | Shape (num_timesteps,) array of beta values, for t=0, ..., T-1. 24 | Values are in ascending order. 25 | 26 | Raises: 27 | ValueError: for unknown schedule. 28 | """ 29 | if beta_schedule == "linear": 30 | return jnp.linspace( 31 | beta_start, 32 | beta_end, 33 | num_timesteps, 34 | ) 35 | if beta_schedule == "quadradic": 36 | return ( 37 | jnp.linspace( 38 | beta_start**0.5, 39 | beta_end**0.5, 40 | num_timesteps, 41 | ) 42 | ** 2 43 | ) 44 | if beta_schedule == "cosine": 45 | 46 | def f(t: float) -> float: 47 | """Eq 17 in https://arxiv.org/abs/2102.09672. 48 | 49 | Args: 50 | t: time step with values in [0, 1]. 51 | 52 | Returns: 53 | Cumulative product of alpha. 54 | """ 55 | return np.cos((t + 0.008) / 1.008 * np.pi / 2) ** 2 56 | 57 | betas = [0.0] 58 | alphas_cumprod_prev = 1.0 59 | for i in range(1, num_timesteps): 60 | t = i / (num_timesteps - 1) 61 | alphas_cumprod = f(t) 62 | beta = 1 - alphas_cumprod / alphas_cumprod_prev 63 | betas.append(beta) 64 | return jnp.array(betas) * (beta_end - beta_start) + beta_start 65 | 66 | if beta_schedule == "warmup10": 67 | num_timesteps_warmup = max(num_timesteps // 10, 1) 68 | betas_warmup = ( 69 | jnp.linspace( 70 | beta_start**0.5, 71 | beta_end**0.5, 72 | num_timesteps_warmup, 73 | ) 74 | ** 2 75 | ) 76 | return jnp.concatenate( 77 | [ 78 | betas_warmup, 79 | jnp.ones((num_timesteps - num_timesteps_warmup,)) * beta_end, 80 | ] 81 | ) 82 | if beta_schedule == "warmup50": 83 | num_timesteps_warmup = max(num_timesteps // 2, 1) 84 | betas_warmup = ( 85 | jnp.linspace( 86 | beta_start**0.5, 87 | beta_end**0.5, 88 | num_timesteps_warmup, 89 | ) 90 | ** 2 91 | ) 92 | return jnp.concatenate( 93 | [ 94 | betas_warmup, 95 | jnp.ones((num_timesteps - num_timesteps_warmup,)) * beta_end, 96 | ] 97 | ) 98 | raise ValueError(f"Unknown beta_schedule {beta_schedule}.") 99 | 100 | 101 | def downsample_beta_schedule( 102 | betas: jnp.ndarray, 103 | num_timesteps: int, 104 | num_timesteps_to_keep: int, 105 | ) -> jnp.ndarray: 106 | """Down-sample beta schedule. 107 | 108 | After down-sampling, the first and last values of alphas_cumprod are kept. 109 | 110 | Args: 111 | betas: beta schedule, shape (num_timesteps,). 112 | Values are in ascending order. 113 | num_timesteps: number of time steps in total, T. 114 | num_timesteps_to_keep: number of time steps to keep. 115 | 116 | Returns: 117 | Down-sampled beta schedule, shape (num_timesteps_to_keep,). 118 | """ 119 | if betas.shape != (num_timesteps,): 120 | raise ValueError( 121 | f"betas.shape ({betas.shape}) must be equal to (num_timesteps,)=({num_timesteps},)" 122 | ) 123 | if num_timesteps_to_keep > num_timesteps: 124 | raise ValueError( 125 | f"num_timesteps_to_keep ({num_timesteps_to_keep}) " 126 | f"must be <= num_timesteps ({num_timesteps})" 127 | ) 128 | if (num_timesteps - 1) % (num_timesteps_to_keep - 1) != 0: 129 | raise ValueError( 130 | f"num_timesteps-1={num_timesteps-1} can't be evenly divided by " 131 | f"num_timesteps_to_keep-1={num_timesteps_to_keep-1}." 132 | ) 133 | if num_timesteps_to_keep < 2: 134 | raise ValueError(f"num_timesteps_to_keep ({num_timesteps_to_keep}) must be >= 2.") 135 | if num_timesteps_to_keep == num_timesteps: 136 | return betas 137 | step_scale = (num_timesteps - 1) // (num_timesteps_to_keep - 1) 138 | alphas_cumprod = jnp.cumprod(1.0 - betas) 139 | # (num_timesteps_to_keep,) 140 | alphas_cumprod = alphas_cumprod[::step_scale] 141 | 142 | # recompute betas 143 | # (num_timesteps_to_keep-1,) 144 | alphas = alphas_cumprod[1:] / alphas_cumprod[:-1] 145 | # (num_timesteps_to_keep,) 146 | alphas = jnp.concatenate([alphas_cumprod[:1], alphas]) 147 | betas = 1 - alphas 148 | return betas 149 | -------------------------------------------------------------------------------- /imgx/diffusion/gaussian/variance_schedule_test.py: -------------------------------------------------------------------------------- 1 | """Test Gaussian diffusion related classes and functions.""" 2 | 3 | 4 | import chex 5 | import jax.numpy as jnp 6 | from absl.testing import parameterized 7 | from chex._src import fake 8 | 9 | from imgx.diffusion.gaussian.variance_schedule import downsample_beta_schedule, get_beta_schedule 10 | 11 | 12 | # Set `FLAGS.chex_n_cpu_devices` CPU devices for all tests. 13 | def setUpModule() -> None: # pylint: disable=invalid-name 14 | """Fake multi-devices.""" 15 | fake.set_n_cpu_devices(2) 16 | 17 | 18 | class TestGetBetaSchedule(chex.TestCase): 19 | """Test get_beta_schedule.""" 20 | 21 | @parameterized.product( 22 | num_timesteps=[1, 4], 23 | beta_schedule=[ 24 | "linear", 25 | "quadradic", 26 | "cosine", 27 | "warmup10", 28 | "warmup50", 29 | ], 30 | ) 31 | def test_shapes( 32 | self, 33 | num_timesteps: int, 34 | beta_schedule: str, 35 | ) -> None: 36 | """Test output shape.""" 37 | beta_start = 0.0 38 | beta_end = 0.2 39 | got = get_beta_schedule( 40 | num_timesteps=num_timesteps, 41 | beta_schedule=beta_schedule, 42 | beta_start=beta_start, 43 | beta_end=beta_end, 44 | ) 45 | chex.assert_shape(got, (num_timesteps,)) 46 | 47 | assert got[0] == beta_start 48 | if num_timesteps > 1: 49 | chex.assert_trees_all_close(got[-1], beta_end) 50 | 51 | 52 | class TestDownsampleBetaSchedule(chex.TestCase): 53 | """Test downsample_beta_schedule.""" 54 | 55 | @parameterized.named_parameters( 56 | ("same 1001 steps", 1001, 1001), 57 | ("same 101 steps", 101, 101), 58 | ("downsample 21 to 5", 21, 5), 59 | ("downsample 101 to 5", 101, 5), 60 | ("downsample 11 to 3", 11, 3), 61 | ) 62 | def test_values( 63 | self, 64 | num_timesteps: int, 65 | num_timesteps_to_keep: int, 66 | ) -> None: 67 | """Test output values and shapes.""" 68 | betas = jnp.linspace(1e-4, 0.02, num_timesteps) 69 | alphas_cumprod = jnp.cumprod(1.0 - betas) 70 | got = downsample_beta_schedule(betas, num_timesteps, num_timesteps_to_keep) 71 | alphas_cumprod_got = jnp.cumprod(1.0 - got) 72 | chex.assert_shape(got, (num_timesteps_to_keep,)) 73 | chex.assert_trees_all_close(alphas_cumprod_got[0], alphas_cumprod[0]) 74 | chex.assert_trees_all_close(alphas_cumprod_got[-1], alphas_cumprod[-1]) 75 | -------------------------------------------------------------------------------- /imgx/diffusion/util.py: -------------------------------------------------------------------------------- 1 | """Utility functions for diffusion models.""" 2 | from __future__ import annotations 3 | 4 | import jax.numpy as jnp 5 | 6 | 7 | def extract_and_expand(arr: jnp.ndarray, t_index: jnp.ndarray, ndim: int) -> jnp.ndarray: 8 | """Extract values from a 1D array and expand. 9 | 10 | This function is not jittable. 11 | 12 | Args: 13 | arr: 1D of shape (num_timesteps, ). 14 | t_index: storing index values < self.num_timesteps, 15 | shape (batch, ) or has ndim dimension. 16 | ndim: number of dimensions for an array of shape (batch, ...). 17 | 18 | Returns: 19 | Expanded array of shape (batch, ...), expanded axes have dim 1. 20 | """ 21 | if arr.ndim != 1: 22 | raise ValueError(f"arr must be 1D, got {arr.ndim}D.") 23 | x = arr[t_index] 24 | return expand(x, ndim) 25 | 26 | 27 | def expand(x: jnp.ndarray, ndim: int) -> jnp.ndarray: 28 | """Expand. 29 | 30 | This function is not jittable. 31 | 32 | Args: 33 | x: a 1D or nD array. 34 | ndim: number of dimensions as output. 35 | 36 | Returns: 37 | Expanded array, expanded axes have dim 1. 38 | """ 39 | if x.ndim == 1: 40 | return jnp.expand_dims(x, axis=tuple(range(1, ndim))) 41 | if x.ndim == ndim: 42 | return x 43 | raise ValueError(f"t_index must be 1D or {ndim}D, got {x.ndim}D.") 44 | -------------------------------------------------------------------------------- /imgx/diffusion/util_test.py: -------------------------------------------------------------------------------- 1 | """Test Gaussian diffusion related classes and functions.""" 2 | 3 | 4 | import chex 5 | import jax 6 | import jax.numpy as jnp 7 | from absl.testing import parameterized 8 | from chex._src import fake 9 | 10 | from imgx.diffusion.util import extract_and_expand 11 | 12 | 13 | # Set `FLAGS.chex_n_cpu_devices` CPU devices for all tests. 14 | def setUpModule() -> None: # pylint: disable=invalid-name 15 | """Fake multi-devices.""" 16 | fake.set_n_cpu_devices(2) 17 | 18 | 19 | class TestExtractAndExpand(chex.TestCase): 20 | """Test extract_and_expand.""" 21 | 22 | @chex.variants(without_jit=True, with_device=True, without_device=True) 23 | @parameterized.named_parameters( 24 | ( 25 | "1d", 26 | 1, 27 | ), 28 | ( 29 | "2d", 30 | 2, 31 | ), 32 | ( 33 | "3d", 34 | 3, 35 | ), 36 | ) 37 | def test_shapes( 38 | self, 39 | ndim: int, 40 | ) -> None: 41 | """Test output shape. 42 | 43 | Args: 44 | ndim: number of dimensions. 45 | """ 46 | batch_size = 2 47 | betas = jnp.array([0, 0.2, 0.5, 1.0]) 48 | num_timesteps = len(betas) 49 | rng = jax.random.PRNGKey(0) 50 | t_index = jax.random.randint(rng, shape=(batch_size,), minval=0, maxval=num_timesteps) 51 | got = self.variant(extract_and_expand)(arr=betas, t_index=t_index, ndim=ndim) 52 | expected_shape = (batch_size,) + (1,) * (ndim - 1) 53 | chex.assert_shape(got, expected_shape) 54 | -------------------------------------------------------------------------------- /imgx/experiment.py: -------------------------------------------------------------------------------- 1 | """Experiment interface.""" 2 | from __future__ import annotations 3 | 4 | from collections.abc import Iterator 5 | from pathlib import Path 6 | 7 | import chex 8 | import jax 9 | import jax.numpy as jnp 10 | import numpy as np 11 | import tensorflow as tf 12 | from absl import logging 13 | from omegaconf import DictConfig 14 | 15 | from imgx.datasets import INFO_MAP 16 | from imgx.metric.util import merge_aggregated_metrics 17 | from imgx.train_state import TrainState 18 | 19 | 20 | class Experiment: 21 | """Experiment for supervised training.""" 22 | 23 | def __init__(self, config: DictConfig) -> None: 24 | """Initializes experiment. 25 | 26 | Args: 27 | config: experiment config. 28 | """ 29 | # Do not use accelerators in data pipeline. 30 | try: 31 | tf.config.set_visible_devices([], device_type="GPU") 32 | tf.config.set_visible_devices([], device_type="TPU") 33 | except RuntimeError: 34 | logging.error( 35 | f"Failed to set visible devices, data set may be using GPU/TPUs. " 36 | f"Visible GPU devices: {tf.config.get_visible_devices('GPU')}. " 37 | f"Visible TPU devices: {tf.config.get_visible_devices('TPU')}." 38 | ) 39 | 40 | self.config = config 41 | self.dataset_info = INFO_MAP[self.config.data.name] 42 | self.p_train_step = None # To be defined in train_init 43 | self.p_eval_step = None # To be defined in train_init 44 | 45 | def train_init( 46 | self, batch: dict[str, jnp.ndarray], ckpt_dir: Path | None = None, step: int | None = None 47 | ) -> tuple[TrainState, int]: 48 | """Initialize data loader, loss, networks for training. 49 | 50 | Args: 51 | batch: training data. 52 | ckpt_dir: checkpoint directory to restore from. 53 | step: checkpoint step to restore from, if None use the latest one. 54 | 55 | Returns: 56 | initialized training state. 57 | """ 58 | raise NotImplementedError 59 | 60 | def train_step( 61 | self, train_state: TrainState, batch: dict[str, jnp.ndarray], key: jax.Array 62 | ) -> tuple[TrainState, chex.ArrayTree]: 63 | """Perform a training step. 64 | 65 | Args: 66 | train_state: training state. 67 | batch: training data. 68 | key: random key. 69 | 70 | Returns: 71 | - new training state. 72 | - new random key. 73 | - metric dict. 74 | """ 75 | # key is updated/fold inside pmap function 76 | # to ensure a different key is used per step 77 | train_state, metrics = self.p_train_step( # pylint: disable=not-callable 78 | train_state, 79 | batch, 80 | key, 81 | ) 82 | metrics = merge_aggregated_metrics(metrics) 83 | metrics = jax.tree_map(lambda x: x.item(), metrics) # tensor to values 84 | return train_state, metrics 85 | 86 | def eval_step( 87 | self, 88 | train_state: TrainState, 89 | iterator: Iterator[dict[str, jnp.ndarray]], 90 | num_steps: int, 91 | key: jax.Array, 92 | out_dir: Path | None = None, 93 | ) -> dict[str, jnp.ndarray]: 94 | """Evaluation on entire validation/test data set. 95 | 96 | Args: 97 | train_state: training state. 98 | iterator: data iterator. 99 | num_steps: number of steps for evaluation. 100 | key: random key. 101 | out_dir: output directory, if not None, predictions will be saved. 102 | 103 | Returns: 104 | metric dict. 105 | """ 106 | raise NotImplementedError 107 | 108 | def eval_batch( 109 | self, 110 | train_state: TrainState, 111 | key: jax.Array, 112 | batch: dict[str, jnp.ndarray], 113 | uids: list[str], 114 | device_cpu: jax.Device, 115 | ) -> tuple[list[str], dict[str, np.ndarray], dict[str, np.ndarray]]: 116 | """Evaluate a batch. 117 | 118 | Args: 119 | train_state: training state. 120 | key: random key. 121 | batch: batch data without uid. 122 | uids: uids in the batch, potentially including padded samples. 123 | device_cpu: cpu device. 124 | 125 | Returns: 126 | uids: uids in the batch, excluding padded samples. 127 | metrics: each item has shape (num_samples,). 128 | prediction dict: each item has shape (num_samples, ...). 129 | """ 130 | raise NotImplementedError 131 | -------------------------------------------------------------------------------- /imgx/integration_test.py: -------------------------------------------------------------------------------- 1 | """Test experiments train, valid, and test. 2 | 3 | mocker.patch, https://pytest-mock.readthedocs.io/en/latest/ 4 | """ 5 | import shutil 6 | from tempfile import TemporaryDirectory 7 | 8 | import pytest 9 | from pytest_mock import MockFixture 10 | 11 | from imgx.run_test import main as run_test 12 | from imgx.run_train import main as run_train 13 | from imgx.run_valid import main as run_valid 14 | 15 | 16 | @pytest.mark.integration() 17 | @pytest.mark.parametrize( 18 | "dataset", 19 | ["muscle_us", "amos_ct"], 20 | ) 21 | def test_segmentation_train_valid_test(mocker: MockFixture, dataset: str) -> None: 22 | """Test train, valid, and test. 23 | 24 | Args: 25 | mocker: mocker, a wrapper of unittest.mock. 26 | dataset: dataset name. 27 | """ 28 | with TemporaryDirectory() as temp_dir: 29 | mocker.resetall() 30 | mocker.patch.dict("os.environ", {"WANDB_MODE": "offline"}) 31 | mocker.patch( 32 | "sys.argv", 33 | ["pytest", "debug=true", "task=seg", f"data={dataset}", f"logging.root_dir={temp_dir}"], 34 | ) 35 | run_train() # pylint: disable=no-value-for-parameter 36 | mocker.patch( 37 | "sys.argv", 38 | ["pytest", f"--log_dir={temp_dir}/wandb/latest-run"], 39 | ) 40 | run_valid() 41 | run_test() 42 | shutil.rmtree(temp_dir) 43 | 44 | 45 | @pytest.mark.integration() 46 | @pytest.mark.parametrize( 47 | "dataset", 48 | ["muscle_us", "amos_ct"], 49 | ) 50 | def test_diffusion_segmentation_train_valid_test(mocker: MockFixture, dataset: str) -> None: 51 | """Test train, valid, and test. 52 | 53 | Args: 54 | mocker: mocker, a wrapper of unittest.mock. 55 | dataset: dataset name. 56 | """ 57 | with TemporaryDirectory() as temp_dir: 58 | mocker.resetall() 59 | mocker.patch.dict("os.environ", {"WANDB_MODE": "offline"}) 60 | mocker.patch( 61 | "sys.argv", 62 | [ 63 | "pytest", 64 | "debug=true", 65 | "task=gaussian_diff_seg", 66 | f"data={dataset}", 67 | f"logging.root_dir={temp_dir}", 68 | ], 69 | ) 70 | run_train() # pylint: disable=no-value-for-parameter 71 | mocker.patch( 72 | "sys.argv", 73 | [ 74 | "pytest", 75 | "--num_timesteps=2", 76 | "--sampler=DDPM", 77 | f"--log_dir={temp_dir}/wandb/latest-run", 78 | ], 79 | ) 80 | run_valid() 81 | run_test() 82 | shutil.rmtree(temp_dir) 83 | -------------------------------------------------------------------------------- /imgx/loss/__init__.py: -------------------------------------------------------------------------------- 1 | """Package for loss functions.""" 2 | from imgx.loss.cross_entropy import cross_entropy, focal_loss 3 | from imgx.loss.deformation import bending_energy_loss, gradient_norm_loss, jacobian_loss 4 | from imgx.loss.dice import dice_loss, dice_loss_from_masks 5 | from imgx.loss.similarity import nrmsd_loss, psnr_loss 6 | 7 | __all__ = [ 8 | "cross_entropy", 9 | "focal_loss", 10 | "dice_loss", 11 | "dice_loss_from_masks", 12 | "psnr_loss", 13 | "nrmsd_loss", 14 | "bending_energy_loss", 15 | "gradient_norm_loss", 16 | "jacobian_loss", 17 | ] 18 | -------------------------------------------------------------------------------- /imgx/loss/cross_entropy.py: -------------------------------------------------------------------------------- 1 | """Loss functions for classification.""" 2 | import jax 3 | import jax.numpy as jnp 4 | import optax 5 | from jax import lax 6 | 7 | 8 | def cross_entropy( 9 | logits: jnp.ndarray, 10 | mask_true: jnp.ndarray, 11 | classes_are_exclusive: bool, 12 | ) -> jnp.ndarray: 13 | """Cross entropy, supporting soft label. 14 | 15 | optax.softmax_cross_entropy returns (batch, ...). 16 | optax.sigmoid_binary_cross_entropy returns (batch, ..., num_classes). 17 | 18 | Args: 19 | logits: unscaled prediction, (batch, ..., num_classes). 20 | mask_true: probabilities per class, (batch, ..., num_classes). 21 | classes_are_exclusive: if False, each element can be assigned to multiple classes. 22 | 23 | Returns: 24 | Cross entropy loss value of shape (batch, ). 25 | """ 26 | mask_true = mask_true.astype(logits.dtype) 27 | loss = lax.cond( 28 | classes_are_exclusive, 29 | optax.softmax_cross_entropy, 30 | lambda *args: jnp.sum(optax.sigmoid_binary_cross_entropy(*args), axis=-1), 31 | logits, 32 | mask_true, 33 | ) 34 | return jnp.mean(loss, axis=range(1, loss.ndim)) 35 | 36 | 37 | def softmax_focal_loss( 38 | logits: jnp.ndarray, 39 | mask_true: jnp.ndarray, 40 | gamma: float = 2.0, 41 | ) -> jnp.ndarray: 42 | """Focal loss with one-hot / mutual exclusive classes. 43 | 44 | https://arxiv.org/abs/1708.02002 45 | Implementation is similar to optax.softmax_cross_entropy. 46 | 47 | Args: 48 | logits: unscaled prediction, (batch, ..., num_classes). 49 | mask_true: probabilities per class, (batch, ..., num_classes). 50 | gamma: adjust class imbalance, 0 is equivalent to cross entropy. 51 | 52 | Returns: 53 | Loss of shape (batch, ...). 54 | """ 55 | log_p = jax.nn.log_softmax(logits) 56 | p = jnp.exp(log_p) 57 | return -jnp.sum(((1 - p) ** gamma) * log_p * mask_true, axis=-1) 58 | 59 | 60 | def sigmoid_focal_loss( 61 | logits: jnp.ndarray, 62 | mask_true: jnp.ndarray, 63 | gamma: float = 2.0, 64 | ) -> jnp.ndarray: 65 | """Focal loss with multi-hot / non mutual exclusive classes. 66 | 67 | https://arxiv.org/abs/1708.02002 68 | Implementation is similar to optax.sigmoid_binary_cross_entropy. 69 | 70 | Args: 71 | logits: unscaled prediction, (batch, ..., num_classes). 72 | mask_true: probabilities per class, (batch, ..., num_classes). 73 | gamma: adjust class imbalance, 0 is equivalent to cross entropy. 74 | 75 | Returns: 76 | Focal loss value of shape (batch, ..., num_classes). 77 | """ 78 | log_p = jax.nn.log_sigmoid(logits) 79 | log_not_p = jax.nn.log_sigmoid(-logits) 80 | p = jnp.exp(log_p) 81 | return -((1 - p) ** gamma) * log_p * mask_true - (p**gamma) * log_not_p * (1 - mask_true) 82 | 83 | 84 | def focal_loss( 85 | logits: jnp.ndarray, 86 | mask_true: jnp.ndarray, 87 | classes_are_exclusive: bool, 88 | gamma: float = 2.0, 89 | ) -> jnp.ndarray: 90 | """Focal loss. 91 | 92 | https://arxiv.org/abs/1708.02002 93 | softmax_focal_loss returns (batch, ...). 94 | sigmoid_focal_loss returns (batch, ..., num_classes). 95 | 96 | Args: 97 | logits: unscaled prediction, (batch, ..., num_classes). 98 | mask_true: probabilities per class, (batch, ..., num_classes). 99 | classes_are_exclusive: if False, each element can be assigned to multiple classes. 100 | gamma: adjust class imbalance, 0 is equivalent to cross entropy. 101 | 102 | Returns: 103 | Focal loss value of shape (batch, ). 104 | """ 105 | mask_true = mask_true.astype(logits.dtype) 106 | loss = lax.cond( 107 | classes_are_exclusive, 108 | softmax_focal_loss, 109 | lambda *args: jnp.sum(sigmoid_focal_loss(*args), axis=-1), 110 | logits, 111 | mask_true, 112 | gamma, 113 | ) 114 | return jnp.mean(loss, axis=range(1, loss.ndim)) 115 | -------------------------------------------------------------------------------- /imgx/loss/deformation.py: -------------------------------------------------------------------------------- 1 | """Deformation losses for ddf.""" 2 | from functools import partial 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | from imgx.metric.deformation import gradient, jacobian_det 8 | 9 | 10 | def gradient_norm_loss(x: jnp.ndarray, norm_ord: int, spacing: jnp.ndarray) -> jnp.ndarray: 11 | """Calculate noram of gradients of x using central finite difference. 12 | 13 | Args: 14 | x: shape = (batch, d1, d2, ..., dn, channel). 15 | norm_ord: 1 for L1 or 2 for L2. 16 | spacing: spacing between each pixel/voxel, shape = (n,). 17 | 18 | Returns: 19 | shape = (batch,). 20 | """ 21 | # (batch, d1, d2, ..., dn, channel, n) 22 | grad = gradient(x, spacing) 23 | if norm_ord == 1: 24 | return jnp.mean(jnp.abs(grad), axis=tuple(range(1, grad.ndim))) 25 | if norm_ord == 2: 26 | return jnp.mean(grad**2, axis=tuple(range(1, grad.ndim))) 27 | raise ValueError(f"norm_ord = {norm_ord} is not supported.") 28 | 29 | 30 | def bending_energy_loss(x: jnp.ndarray, spacing: jnp.ndarray) -> jnp.ndarray: 31 | """Calculate bending energey (L2 norm of second order gradient) using central finite difference. 32 | 33 | Args: 34 | x: shape = (batch, d1, d2, ..., dn, channel). 35 | spacing: spacing between each pixel/voxel, shape = (n,). 36 | 37 | Returns: 38 | shape = (batch,). 39 | """ 40 | # (batch, d1, d2, ..., dn, channel, n) 41 | grad_1d = gradient(x, spacing) 42 | # (batch, d1, d2, ..., dn, channel, n, n) 43 | grad_2d = jax.vmap(partial(gradient, spacing=spacing), in_axes=-1, out_axes=-1)(grad_1d) 44 | # (batch,) 45 | return jnp.mean(grad_2d**2, axis=tuple(range(1, grad_2d.ndim))) 46 | 47 | 48 | def jacobian_loss(x: jnp.ndarray, spacing: jnp.ndarray) -> jnp.ndarray: 49 | """Calculate Jacobian loss. 50 | 51 | https://arxiv.org/abs/1907.00068 52 | 53 | If the Jacobian determinant is <0, the transformation is folding, thus not volume preserving. 54 | Do not penalize the Jacobian determinant if it is >0. 55 | 56 | Args: 57 | x: shape = (batch, d1, d2, ..., dn, n). 58 | spacing: spacing between each pixel/voxel, shape = (n,). 59 | 60 | Returns: 61 | shape = (batch,). 62 | """ 63 | # (batch, d1, d2, ..., dn) 64 | det = jacobian_det(x, spacing) 65 | det = jnp.clip(det, a_max=0.0) # negative values are folding 66 | # (batch,) 67 | return jnp.mean(-det, axis=tuple(range(1, det.ndim))) # reverse sign 68 | -------------------------------------------------------------------------------- /imgx/loss/deformation_test.py: -------------------------------------------------------------------------------- 1 | """Test deformation loss functions.""" 2 | from functools import partial 3 | 4 | import chex 5 | import jax.numpy as jnp 6 | import numpy as np 7 | from absl.testing import parameterized 8 | from chex._src import fake 9 | 10 | from imgx.loss.deformation import bending_energy_loss, gradient_norm_loss, jacobian_loss 11 | 12 | 13 | # Set `FLAGS.chex_n_cpu_devices` CPU devices for all tests. 14 | def setUpModule() -> None: # pylint: disable=invalid-name 15 | """Fake multi-devices.""" 16 | fake.set_n_cpu_devices(2) 17 | 18 | 19 | class TestGradientNormLoss(chex.TestCase): 20 | """Test gradient_norm_loss.""" 21 | 22 | batch = 2 23 | 24 | @chex.all_variants() 25 | @parameterized.product( 26 | shape=[(2, 3), (2, 3, 4), (2, 3, 4, 5)], 27 | norm_ord=[1, 2], 28 | ) 29 | def test_shapes( 30 | self, 31 | shape: tuple[int, ...], 32 | norm_ord: int, 33 | ) -> None: 34 | """Test return shapes. 35 | 36 | Args: 37 | shape: input shape. 38 | norm_ord: 1 for L1 or 2 for L2. 39 | """ 40 | x = jnp.ones((self.batch, *shape)) 41 | spacing = jnp.ones((len(shape),)) 42 | got = self.variant(partial(gradient_norm_loss, norm_ord=norm_ord, spacing=spacing))(x) 43 | chex.assert_shape(got, (self.batch,)) 44 | 45 | @chex.all_variants() 46 | @parameterized.named_parameters( 47 | ( 48 | "1d - L1 norm", 49 | np.array([[[-2.3], [1.1], [-0.5]]]), 50 | 1, 51 | # norm [1.7, 0.9, -0.8] 52 | np.array([3.4 / 3]), 53 | ), 54 | ( 55 | "1d - L2 norm", 56 | np.array([[[-2.3], [1.1], [-0.5]]]), 57 | 2, 58 | # norm [1.7, 0.9, -0.8] 59 | np.array([(1.7 * 1.7 + 0.9 * 0.9 + 0.8 * 0.8) / 3]), 60 | ), 61 | ( 62 | "batch 1d - L1 norm", 63 | np.array([[[2.0], [1.0], [0.1]], [[0.0], [-1.0], [-2.0]], [[0.2], [-1.0], [-2.0]]]), 64 | 1, 65 | # norm (singleton axis is removed) 66 | # [[-0.5, -0.95, -0.45], 67 | # [-0.5, -1.0, -0.5], 68 | # [-0.6, -1.1, -0.5]], 69 | np.array([1.9 / 3, 2.0 / 3, 2.2 / 3]), 70 | ), 71 | ( 72 | "2d - L1 norm", 73 | # (1,3,1,3) 74 | np.array([[[[2.0, 1.0, 0.1]], [[0.0, -1.0, -2.0]], [[0.2, -1.0, -2.0]]]]), 75 | 1, 76 | # dx norm (singleton axis is removed) 77 | # [[-1.0, -1.0, -1.05], 78 | # [-0.9, -1.0, -1.05], 79 | # [0.1, 0.0, 0.0]] 80 | # dy norm are all zeros 81 | np.array([6.1 / 18]), 82 | ), 83 | ) 84 | def test_values( 85 | self, 86 | x: jnp.ndarray, 87 | norm_ord: int, 88 | expected: jnp.ndarray, 89 | ) -> None: 90 | """Test return values. 91 | 92 | Args: 93 | x: input. 94 | norm_ord: norm order. 95 | expected: expected output. 96 | """ 97 | spacing = jnp.ones((x.ndim - 2,)) 98 | got = self.variant(partial(gradient_norm_loss, norm_ord=norm_ord, spacing=spacing))(x) 99 | chex.assert_trees_all_close(got, expected) 100 | 101 | 102 | class TestBendingEnergyLoss(chex.TestCase): 103 | """Test bending_energy_loss.""" 104 | 105 | batch = 2 106 | 107 | @chex.all_variants() 108 | @parameterized.product( 109 | shape=[(2, 3), (2, 3, 4), (2, 3, 4, 5)], 110 | ) 111 | def test_shapes( 112 | self, 113 | shape: tuple[int, ...], 114 | ) -> None: 115 | """Test return shapes. 116 | 117 | Args: 118 | shape: input shape. 119 | """ 120 | x = jnp.ones((self.batch, *shape)) 121 | spacing = jnp.ones((x.ndim - 2,)) 122 | got = self.variant(bending_energy_loss)(x, spacing) 123 | chex.assert_shape(got, (self.batch,)) 124 | 125 | @chex.all_variants() 126 | @parameterized.named_parameters( 127 | ( 128 | "1d", 129 | # (1, 3, 1) 130 | np.array([[[-2.3], [1.1], [-0.5]]]), 131 | # dx [1.7, 0.9, -0.8] 132 | # dxx [-0.4, -1.25, -0.85] 133 | np.array([(0.4 * 0.4 + 1.25 * 1.25 + 0.85 * 0.85) / 3]), 134 | ), 135 | ) 136 | def test_values( 137 | self, 138 | x: jnp.ndarray, 139 | expected: jnp.ndarray, 140 | ) -> None: 141 | """Test return values. 142 | 143 | Args: 144 | x: input. 145 | expected: expected output. 146 | """ 147 | spacing = jnp.ones((x.ndim - 2,)) 148 | got = self.variant(bending_energy_loss)(x, spacing) 149 | chex.assert_trees_all_close(got, expected) 150 | 151 | 152 | class TestJacobianLoss(chex.TestCase): 153 | """Test jacobian_loss.""" 154 | 155 | batch = 2 156 | 157 | @chex.all_variants() 158 | @parameterized.product( 159 | shape=[(2, 3), (2, 3, 4), (2, 3, 4, 5)], 160 | ) 161 | def test_shapes( 162 | self, 163 | shape: tuple[int, ...], 164 | ) -> None: 165 | """Test return shapes. 166 | 167 | Args: 168 | shape: input shape. 169 | """ 170 | x = jnp.ones((self.batch, *shape, len(shape))) 171 | spacing = jnp.ones((x.ndim - 2,)) 172 | got = self.variant(jacobian_loss)(x, spacing) 173 | chex.assert_shape(got, (self.batch,)) 174 | -------------------------------------------------------------------------------- /imgx/loss/dice.py: -------------------------------------------------------------------------------- 1 | """Loss functions for image segmentation.""" 2 | import jax 3 | import jax.numpy as jnp 4 | from jax import lax 5 | 6 | 7 | def dice_loss_from_masks( 8 | mask_pred: jnp.ndarray, 9 | mask_true: jnp.ndarray, 10 | ) -> jnp.ndarray: 11 | """Mean dice loss, smaller is better. 12 | 13 | Losses are not calculated on instance-classes, where there is no label. 14 | This is to avoid the need of smoothing and potentially nan gradients. 15 | 16 | Args: 17 | mask_pred: binary masks, (batch, ..., num_classes). 18 | mask_true: binary masks, (batch, ..., num_classes). 19 | 20 | Returns: 21 | Dice loss value of shape (batch, num_classes). 22 | """ 23 | reduce_axis = tuple(range(mask_pred.ndim))[1:-1] 24 | # (batch, num_classes) 25 | numerator = 2.0 * jnp.sum(mask_pred * mask_true, axis=reduce_axis) 26 | denominator = jnp.sum(mask_pred + mask_true, axis=reduce_axis) 27 | not_nan_mask = jnp.sum(mask_true, axis=reduce_axis) > 0 28 | # nan loss are replaced by 0.0 29 | return jnp.where( 30 | condition=not_nan_mask, 31 | x=1.0 - numerator / denominator, 32 | y=jnp.nan, 33 | ) 34 | 35 | 36 | def dice_loss( 37 | logits: jnp.ndarray, 38 | mask_true: jnp.ndarray, 39 | classes_are_exclusive: bool, 40 | ) -> jnp.ndarray: 41 | """Mean dice loss, smaller is better. 42 | 43 | Losses are not calculated on instance-classes, where there is no label. 44 | This is to avoid the need of smoothing and potentially nan gradients. 45 | 46 | Args: 47 | logits: unscaled prediction, (batch, ..., num_classes). 48 | mask_true: binary masks, (batch, ..., num_classes). 49 | classes_are_exclusive: classes are exclusive, i.e. no overlap. 50 | 51 | Returns: 52 | Dice loss value of shape (batch, num_classes). 53 | """ 54 | mask_pred = lax.cond( 55 | classes_are_exclusive, 56 | jax.nn.softmax, 57 | jax.nn.sigmoid, 58 | logits, 59 | ) 60 | return dice_loss_from_masks(mask_pred, mask_true) 61 | -------------------------------------------------------------------------------- /imgx/loss/segmentation.py: -------------------------------------------------------------------------------- 1 | """Vanilla segmentation loss.""" 2 | from __future__ import annotations 3 | 4 | import jax.numpy as jnp 5 | from omegaconf import DictConfig 6 | 7 | from imgx.datasets.dataset_info import DatasetInfo 8 | from imgx.loss import cross_entropy, dice_loss, focal_loss 9 | from imgx.metric import class_proportion, class_volume 10 | 11 | 12 | def segmentation_loss( 13 | logits: jnp.ndarray, 14 | label: jnp.ndarray, 15 | dataset_info: DatasetInfo, 16 | loss_config: DictConfig, 17 | ) -> tuple[jnp.ndarray, dict[str, jnp.ndarray]]: 18 | """Calculate segmentation loss with auxiliary losses and return metrics. 19 | 20 | Args: 21 | logits: unnormalised logits of shape (batch, ..., num_classes). 22 | label: label of shape (batch, ...). 23 | dataset_info: dataset info with helper functions. 24 | loss_config: have weights of diff losses. 25 | 26 | Returns: 27 | - calculated loss, of shape (batch,). 28 | - metrics, values of shape (batch,). 29 | """ 30 | spacing = jnp.array(dataset_info.image_spacing) 31 | mask_true = dataset_info.label_to_mask(label, axis=-1) 32 | metrics = {} 33 | 34 | # (batch, num_classes) 35 | class_prop_batch_cls = class_proportion(mask_true) 36 | for i in range(dataset_info.num_classes): 37 | metrics[f"class_{i}_proportion_true"] = class_prop_batch_cls[:, i] 38 | class_volume_batch_cls = class_volume(mask_true, spacing) 39 | for i in range(dataset_info.num_classes): 40 | metrics[f"class_{i}_volume_true"] = class_volume_batch_cls[:, i] 41 | 42 | # total loss 43 | loss_batch = jnp.zeros((logits.shape[0],), dtype=logits.dtype) 44 | if loss_config.get("dice", 0.0) > 0: 45 | # (batch, num_classes) 46 | dice_loss_batch_cls = dice_loss( 47 | logits=logits, 48 | mask_true=mask_true, 49 | classes_are_exclusive=dataset_info.classes_are_exclusive, 50 | ) 51 | # (batch, ) 52 | # without background 53 | # mask out non-existing classes 54 | dice_loss_batch = jnp.mean( 55 | dice_loss_batch_cls[:, 1:], axis=-1, where=class_prop_batch_cls[:, 1:] > 0 56 | ) 57 | metrics["dice_loss"] = dice_loss_batch 58 | for i in range(dice_loss_batch_cls.shape[-1]): 59 | metrics[f"dice_loss_class_{i}"] = dice_loss_batch_cls[:, i] 60 | loss_batch += dice_loss_batch * loss_config["dice"] 61 | 62 | if loss_config.get("cross_entropy", 0.0) > 0: 63 | # (batch, ) 64 | ce_loss_batch = cross_entropy( 65 | logits=logits, 66 | mask_true=mask_true, 67 | classes_are_exclusive=dataset_info.classes_are_exclusive, 68 | ) 69 | metrics["cross_entropy_loss"] = ce_loss_batch 70 | loss_batch += ce_loss_batch * loss_config["cross_entropy"] 71 | 72 | if loss_config.get("focal", 0.0) > 0: 73 | # (batch, ) 74 | focal_loss_batch = focal_loss( 75 | logits=logits, 76 | mask_true=mask_true, 77 | classes_are_exclusive=dataset_info.classes_are_exclusive, 78 | ) 79 | metrics["focal_loss"] = focal_loss_batch 80 | loss_batch += focal_loss_batch * loss_config["focal"] 81 | metrics["total_loss"] = loss_batch 82 | return loss_batch, metrics 83 | -------------------------------------------------------------------------------- /imgx/loss/segmentation_test.py: -------------------------------------------------------------------------------- 1 | """Test segmentation loss.""" 2 | from functools import partial 3 | 4 | import chex 5 | import jax.numpy as jnp 6 | from absl.testing import parameterized 7 | from chex._src import fake 8 | from omegaconf import DictConfig 9 | 10 | from imgx.datasets import INFO_MAP 11 | from imgx.loss.segmentation import segmentation_loss 12 | 13 | 14 | # Set `FLAGS.chex_n_cpu_devices` CPU devices for all tests. 15 | def setUpModule() -> None: # pylint: disable=invalid-name 16 | """Fake multi-devices.""" 17 | fake.set_n_cpu_devices(2) 18 | 19 | 20 | class TestSegmentationLoss(chex.TestCase): 21 | """Test segmentation_loss.""" 22 | 23 | batch_size = 2 24 | 25 | @chex.all_variants() 26 | @parameterized.product( 27 | dataset_name=sorted(INFO_MAP.keys()), 28 | loss_config=[ 29 | { 30 | "cross_entropy": 1.0, 31 | "dice": 1.0, 32 | "focal": 1.0, 33 | }, 34 | { 35 | "dice": 1.0, 36 | }, 37 | ], 38 | ) 39 | def test_shapes( 40 | self, 41 | dataset_name: str, 42 | loss_config: dict[str, float], 43 | ) -> None: 44 | """Test return shapes. 45 | 46 | Args: 47 | dataset_name: dataset name. 48 | loss_config: loss config. 49 | """ 50 | dataset_info = INFO_MAP[dataset_name] 51 | shape = dataset_info.image_spatial_shape 52 | shape = tuple(max(x // 16, 2) for x in shape) # reduce shape to speed up test 53 | logits = jnp.ones( 54 | (self.batch_size, *shape, dataset_info.num_classes), 55 | dtype=jnp.float32, 56 | ) 57 | label = jnp.ones((self.batch_size, *shape), dtype=jnp.int32) 58 | 59 | got_loss_batch, got_metrics = self.variant( 60 | partial( 61 | segmentation_loss, 62 | dataset_info=dataset_info, 63 | loss_config=DictConfig(loss_config), 64 | ) 65 | )(logits, label) 66 | chex.assert_shape(got_loss_batch, (self.batch_size,)) 67 | for v in got_metrics.values(): 68 | chex.assert_shape(v, (self.batch_size,)) 69 | -------------------------------------------------------------------------------- /imgx/loss/similarity.py: -------------------------------------------------------------------------------- 1 | """Image similarity loss functions.""" 2 | import jax.numpy as jnp 3 | 4 | from imgx import EPS 5 | from imgx.metric import nrmsd, psnr 6 | 7 | 8 | def psnr_loss( 9 | image1: jnp.ndarray, 10 | image2: jnp.ndarray, 11 | value_range: float = 1.0, 12 | eps: float = EPS, 13 | ) -> jnp.ndarray: 14 | """Peak signal-to-noise ratio (PSNR) loss. 15 | 16 | Args: 17 | image1: image of shape (batch, ..., channels). 18 | image2: image of shape (batch, ..., channels). 19 | value_range: value range of input images. 20 | eps: epsilon, if two images are identical, MSE=0. 21 | 22 | Returns: 23 | PSNR loss of shape (batch,). 24 | """ 25 | return -psnr( 26 | image1=image1, 27 | image2=image2, 28 | value_range=value_range, 29 | eps=eps, 30 | ) 31 | 32 | 33 | def nrmsd_loss( 34 | image_pred: jnp.ndarray, 35 | image_true: jnp.ndarray, 36 | eps: float = EPS, 37 | ) -> jnp.ndarray: 38 | """Normalized root-mean-square-deviation (NRMSD) loss. 39 | 40 | Args: 41 | image_pred: predicted image of shape (batch, ..., channels). 42 | image_true: ground truth image of shape (batch, ..., channels). 43 | eps: epsilon, if two images are identical, MSE=0. 44 | 45 | Returns: 46 | NRMSD loss of shape (batch,). 47 | """ 48 | return nrmsd( 49 | image_pred=image_pred, 50 | image_true=image_true, 51 | eps=eps, 52 | ) 53 | -------------------------------------------------------------------------------- /imgx/metric/__init__.py: -------------------------------------------------------------------------------- 1 | """Module for metrics.""" 2 | from imgx.metric.area import class_proportion, class_volume 3 | from imgx.metric.centroid import centroid_distance 4 | from imgx.metric.deformation import jacobian_det 5 | from imgx.metric.dice import dice_score, iou, stability 6 | from imgx.metric.similarity import nrmsd, psnr, ssim 7 | from imgx.metric.smoothing import gaussian_smooth_label, smooth_label 8 | from imgx.metric.surface_distance import ( 9 | aggregated_surface_distance, 10 | average_surface_distance, 11 | hausdorff_distance, 12 | normalized_surface_dice, 13 | normalized_surface_dice_from_distances, 14 | ) 15 | 16 | __all__ = [ 17 | "dice_score", 18 | "iou", 19 | "average_surface_distance", 20 | "aggregated_surface_distance", 21 | "normalized_surface_dice", 22 | "normalized_surface_dice_from_distances", 23 | "hausdorff_distance", 24 | "centroid_distance", 25 | "class_proportion", 26 | "class_volume", 27 | "stability", 28 | "ssim", 29 | "psnr", 30 | "nrmsd", 31 | "jacobian_det", 32 | "gaussian_smooth_label", 33 | "smooth_label", 34 | ] 35 | -------------------------------------------------------------------------------- /imgx/metric/area.py: -------------------------------------------------------------------------------- 1 | """Metrics to measure foreground area.""" 2 | import jax 3 | import numpy as np 4 | from jax import numpy as jnp 5 | 6 | MM3_TO_ML = 0.001 7 | 8 | 9 | def class_proportion(mask: jnp.ndarray) -> jnp.ndarray: 10 | """Calculate proportion per class. 11 | 12 | This metric does not consider spacing. 13 | 14 | Args: 15 | mask: shape = (batch, d1, ..., dn, num_classes). 16 | 17 | Returns: 18 | Proportion, shape = (batch, num_classes). 19 | """ 20 | reduce_axes = tuple(range(1, mask.ndim - 1)) 21 | volume = jnp.float32(np.prod(mask.shape[1:-1])) 22 | sqrt_volume = jnp.sqrt(volume) 23 | mask = jnp.float32(mask) 24 | # attempt to avoid over/underflow 25 | return jnp.sum(mask / sqrt_volume, axis=reduce_axes) / sqrt_volume 26 | 27 | 28 | def get_volume(label: jnp.ndarray, spacing: jnp.ndarray) -> jnp.ndarray: 29 | """Calculate volume from binary label/mask. 30 | 31 | This metric considers spacing. 32 | 33 | Args: 34 | label: binary label, of shape (batch, d1, ..., dn). 35 | spacing: (n,), spacing of each dimension, in mm. 36 | 37 | Returns: 38 | volume: volume in ml, (batch,). 39 | """ 40 | volume_per_voxel = jnp.prod(spacing) * MM3_TO_ML # ml 41 | return jnp.sum(label, axis=list(range(1, label.ndim))) * volume_per_voxel 42 | 43 | 44 | def class_volume(mask: jnp.ndarray, spacing: jnp.ndarray) -> jnp.ndarray: 45 | """Calculate volume per class. 46 | 47 | This metric does consider spacing. 48 | 49 | Args: 50 | mask: shape = (batch, d1, ..., dn, num_classes). 51 | spacing: (n,), spacing of each dimension, in mm. 52 | 53 | Returns: 54 | volume: volume in ml, (batch, num_classes). 55 | """ 56 | return jax.vmap(get_volume, in_axes=(-1, None), out_axes=-1)(mask, spacing) 57 | -------------------------------------------------------------------------------- /imgx/metric/area_test.py: -------------------------------------------------------------------------------- 1 | """Test area functions.""" 2 | 3 | import chex 4 | import numpy as np 5 | from absl.testing import parameterized 6 | from chex._src import fake 7 | 8 | from imgx.metric.area import class_proportion, class_volume, get_volume 9 | 10 | 11 | # Set `FLAGS.chex_n_cpu_devices` CPU devices for all tests. 12 | def setUpModule() -> None: # pylint: disable=invalid-name 13 | """Fake multi-devices.""" 14 | fake.set_n_cpu_devices(2) 15 | 16 | 17 | class TestClassProportion(chex.TestCase): 18 | """Test class_proportion.""" 19 | 20 | @chex.all_variants() 21 | @parameterized.named_parameters( 22 | ( 23 | "1d-1class", 24 | np.asarray([False, True, True, False])[..., None], 25 | np.asarray([0.5])[..., None], 26 | ), 27 | ( 28 | "1d-1class-empty", 29 | np.asarray([False, False, False, False])[..., None], 30 | np.asarray([0.0])[..., None], 31 | ), 32 | ( 33 | "1d-2classes", 34 | np.asarray([[False, True], [True, True], [True, False], [False, False]]), 35 | np.asarray([[0.5, 0.5]]), 36 | ), 37 | ( 38 | "2d-1class", 39 | np.array( 40 | [ 41 | [False, False, True, False, False], 42 | [False, True, False, True, False], 43 | [False, False, False, False, False], 44 | [False, False, False, False, False], 45 | ] 46 | )[..., None], 47 | np.asarray([3.0 / 20.0])[..., None], 48 | ), 49 | ) 50 | def test_values(self, mask: np.ndarray, expected: np.ndarray) -> None: 51 | """Test exact values. 52 | 53 | Args: 54 | mask: shape = (batch, d1, ..., dn, num_classes). 55 | expected: expected coordinates. 56 | """ 57 | got = self.variant(class_proportion)( 58 | mask=mask[None, ...], 59 | ) 60 | chex.assert_trees_all_close(got, expected) 61 | 62 | 63 | class TestGetVolume(chex.TestCase): 64 | """Test get_volume.""" 65 | 66 | @chex.all_variants() 67 | @parameterized.product( 68 | image_shape=[(8, 12, 6), (8, 12)], 69 | batch_size=[4, 1], 70 | ) 71 | def test_shapes(self, image_shape: tuple[int, ...], batch_size: int) -> None: 72 | """Test output shapes. 73 | 74 | Args: 75 | image_shape: spatial shape of images. 76 | batch_size: number of samples in batch. 77 | """ 78 | mask = np.zeros((batch_size, *image_shape), dtype=np.bool_) 79 | spacing = np.ones(len(image_shape)) 80 | got = self.variant(get_volume)(mask, spacing) 81 | chex.assert_shape(got, (batch_size,)) 82 | 83 | 84 | class TestClassVolume(chex.TestCase): 85 | """Test class_volume.""" 86 | 87 | @chex.all_variants() 88 | @parameterized.product( 89 | image_shape=[(8, 12, 6), (8, 12)], 90 | num_classes=[1, 2, 3], 91 | batch_size=[4, 1], 92 | ) 93 | def test_shapes(self, image_shape: tuple[int, ...], num_classes: int, batch_size: int) -> None: 94 | """Test output shapes. 95 | 96 | Args: 97 | image_shape: spatial shape of images. 98 | num_classes: number of classes. 99 | batch_size: number of samples in batch. 100 | """ 101 | mask = np.zeros((batch_size, *image_shape, num_classes), dtype=np.bool_) 102 | spacing = np.ones(len(image_shape)) 103 | got = self.variant(class_volume)(mask, spacing) 104 | chex.assert_shape(got, (batch_size, num_classes)) 105 | -------------------------------------------------------------------------------- /imgx/metric/centroid.py: -------------------------------------------------------------------------------- 1 | """Metric centroid distance.""" 2 | from __future__ import annotations 3 | 4 | import jax.numpy as jnp 5 | 6 | 7 | def get_centroid( 8 | mask: jnp.ndarray, 9 | grid: jnp.ndarray, 10 | ) -> tuple[jnp.ndarray, jnp.ndarray]: 11 | """Calculate the centroid of the mask. 12 | 13 | Args: 14 | mask: boolean mask of shape = (batch, d1, ..., dn, num_classes) 15 | grid: shape = (n, d1, ..., dn) 16 | 17 | Returns: 18 | centroid of shape = (batch, n, num_classes). 19 | nan mask of shape = (batch, num_classes). 20 | """ 21 | mask_reduce_axes = tuple(range(1, mask.ndim - 1)) 22 | grid_reduce_axes = tuple(range(2, mask.ndim)) 23 | # (batch, n, d1, ..., dn, num_classes) 24 | masked_grid = jnp.expand_dims(mask, axis=1) * jnp.expand_dims(grid, axis=(0, -1)) 25 | # (batch, n, num_classes) 26 | numerator = jnp.sum(masked_grid, axis=grid_reduce_axes) 27 | # (batch, num_classes) 28 | summed_mask = jnp.sum(mask, axis=mask_reduce_axes) 29 | # (batch, 1, num_classes) 30 | denominator = summed_mask[:, None, :] 31 | # if mask is not empty return real centroid, else nan 32 | centroid = jnp.where(condition=denominator > 0, x=numerator / denominator, y=jnp.nan) 33 | return centroid, jnp.array(summed_mask == 0, dtype=jnp.bool_) 34 | 35 | 36 | def centroid_distance( 37 | mask_true: jnp.ndarray, 38 | mask_pred: jnp.ndarray, 39 | grid: jnp.ndarray, 40 | spacing: jnp.ndarray | None = None, 41 | ) -> jnp.ndarray: 42 | """Calculate the L2-distance between two centroids. 43 | 44 | Args: 45 | mask_true: shape = (batch, d1, ..., dn, num_classes). 46 | mask_pred: shape = (batch, d1, ..., dn, num_classes). 47 | grid: shape = (n, d1, ..., dn). 48 | spacing: spacing of pixel/voxels along each dimension, (n,). 49 | 50 | Returns: 51 | distance, shape = (batch, num_classes). 52 | """ 53 | # centroid (batch, n, num_classes) nan_mask (batch, num_classes) 54 | centroid_true, nan_mask_true = get_centroid( 55 | mask=mask_true, 56 | grid=grid, 57 | ) 58 | centroid_pred, nan_mask_pred = get_centroid( 59 | mask=mask_pred, 60 | grid=grid, 61 | ) 62 | nan_mask = nan_mask_true | nan_mask_pred 63 | if spacing is not None: 64 | centroid_true = jnp.where( 65 | condition=nan_mask[:, None, :], 66 | x=jnp.nan, 67 | y=centroid_true * spacing[None, :, None], 68 | ) 69 | centroid_pred = jnp.where( 70 | condition=nan_mask[:, None, :], 71 | x=jnp.nan, 72 | y=centroid_pred * spacing[None, :, None], 73 | ) 74 | 75 | # return nan if the centroid cannot be defined for one sample with one class 76 | return jnp.where( 77 | condition=nan_mask, 78 | x=jnp.nan, 79 | y=jnp.linalg.norm(centroid_true - centroid_pred, axis=1), 80 | ) 81 | -------------------------------------------------------------------------------- /imgx/metric/deformation.py: -------------------------------------------------------------------------------- 1 | """Deformation metrics for ddf.""" 2 | from __future__ import annotations 3 | 4 | from jax import numpy as jnp 5 | 6 | 7 | def gradient_along_axis(x: jnp.ndarray, axis: int, spacing: float | jnp.ndarray) -> jnp.ndarray: 8 | """Calculate gradients on one axis of using central finite difference. 9 | 10 | https://en.wikipedia.org/wiki/Finite_difference 11 | dx[i] = (x[i+1] - x[i-1]) / 2 12 | The edge values are padded. 13 | 14 | Args: 15 | x: shape = (d1, d2, ..., dn). 16 | axis: axis to calculate gradient. 17 | spacing: spacing between each pixel/voxel. 18 | 19 | Returns: 20 | shape = (d1, d2, ..., dn). 21 | """ 22 | # repeat edge values 23 | # (d1, ..., di+2, ..., dn) 24 | x = jnp.pad( 25 | x, 26 | pad_width=[(0, 0)] * axis + [(1, 1)] + [(0, 0)] * (x.ndim - axis - 1), 27 | mode="edge", 28 | ) 29 | # x[i-1] 30 | # (d1, ..., di, ..., dn) 31 | indices = jnp.arange(0, x.shape[axis] - 2) 32 | x_prev = jnp.take(x, indices=indices, axis=axis, indices_are_sorted=True) 33 | # x[i+1] 34 | # (d1, ..., di, ..., dn) 35 | indices = jnp.arange(2, x.shape[axis]) 36 | x_next = jnp.take(x, indices=indices, axis=axis, indices_are_sorted=True) 37 | return (x_next - x_prev) / 2 / spacing 38 | 39 | 40 | def gradient(x: jnp.ndarray, spacing: jnp.ndarray) -> jnp.ndarray: 41 | """Calculate gradients per axis of using central finite difference. 42 | 43 | Args: 44 | x: shape = (batch, d1, d2, ..., dn, channel). 45 | spacing: spacing between each pixel/voxel, shape = (n,). 46 | 47 | Returns: 48 | shape = (batch, d1, d2, ..., dn, channel, n). 49 | """ 50 | return jnp.stack( 51 | [gradient_along_axis(x, axis, spacing[axis - 1]) for axis in range(1, x.ndim - 1)], axis=-1 52 | ) 53 | 54 | 55 | def jacobian_det(x: jnp.ndarray, spacing: jnp.ndarray) -> jnp.ndarray: 56 | """Calculate Jacobian matrix of ddf. 57 | 58 | https://arxiv.org/abs/1907.00068 59 | 60 | Args: 61 | x: shape = (batch, d1, d2, ..., dn, n). 62 | spacing: spacing between each pixel/voxel, shape = (n,). 63 | 64 | Returns: 65 | shape = (batch, d1, d2, ..., dn). 66 | """ 67 | n = x.shape[-1] 68 | # (batch, d1, d2, ..., dn, n, n) 69 | grad = gradient(x, spacing) 70 | # (1, 1, ..., 1, n, n) 71 | eye = jnp.expand_dims(jnp.eye(n), axis=range(n + 1)) 72 | # (batch, d1, d2, ..., dn) 73 | return jnp.linalg.det(grad + eye) 74 | -------------------------------------------------------------------------------- /imgx/metric/dice.py: -------------------------------------------------------------------------------- 1 | """Metric functions for image segmentation.""" 2 | import jax.numpy as jnp 3 | 4 | 5 | def dice_score( 6 | mask_pred: jnp.ndarray, 7 | mask_true: jnp.ndarray, 8 | ) -> jnp.ndarray: 9 | """Soft Dice score, larger is better. 10 | 11 | Args: 12 | mask_pred: soft mask with probabilities, (batch, ..., num_classes). 13 | mask_true: one hot targets, (batch, ..., num_classes). 14 | 15 | Returns: 16 | Dice score of shape (batch, num_classes). 17 | """ 18 | # additions between bools results in errors 19 | if mask_pred.dtype == jnp.bool_: 20 | mask_pred = mask_pred.astype(jnp.float32) 21 | if mask_true.dtype == jnp.bool_: 22 | mask_true = mask_true.astype(jnp.float32) 23 | 24 | reduce_axis = tuple(range(mask_pred.ndim))[1:-1] 25 | numerator = 2.0 * jnp.sum(mask_pred * mask_true, axis=reduce_axis) 26 | denominator = jnp.sum(mask_pred + mask_true, axis=reduce_axis) 27 | return jnp.where( 28 | condition=denominator > 0, 29 | x=numerator / denominator, 30 | y=jnp.nan, 31 | ) 32 | 33 | 34 | def iou( 35 | mask_pred: jnp.ndarray, 36 | mask_true: jnp.ndarray, 37 | ) -> jnp.ndarray: 38 | """IOU (Intersection Over Union), or Jaccard index. 39 | 40 | Args: 41 | mask_pred: binary mask of predictions, (batch, ..., num_classes). 42 | mask_true: one hot targets, (batch, ..., num_classes). 43 | 44 | Returns: 45 | IoU of shape (batch, num_classes). 46 | """ 47 | # additions between bools results in errors 48 | if mask_pred.dtype == jnp.bool_: 49 | mask_pred = mask_pred.astype(jnp.float32) 50 | if mask_true.dtype == jnp.bool_: 51 | mask_true = mask_true.astype(jnp.float32) 52 | 53 | reduce_axis = tuple(range(mask_pred.ndim))[1:-1] 54 | numerator = jnp.sum(mask_pred * mask_true, axis=reduce_axis) 55 | sum_mask = jnp.sum(mask_pred + mask_true, axis=reduce_axis) 56 | denominator = sum_mask - numerator 57 | return jnp.where(condition=sum_mask > 0, x=numerator / denominator, y=jnp.nan) 58 | 59 | 60 | def stability( 61 | logits: jnp.ndarray, 62 | threshold: float = 0.0, 63 | threshold_offset: float = 1.0, 64 | ) -> jnp.ndarray: 65 | """Calculate stability of predictions. 66 | 67 | https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/utils/amg.py 68 | 69 | Args: 70 | logits: shape = (batch, ..., num_classes). 71 | threshold: threshold for prediction. 72 | threshold_offset: offset for threshold. 73 | 74 | Returns: 75 | Stability of shape (batch, num_classes). 76 | 77 | Raises: 78 | ValueError: if threshold_offset is negative. 79 | """ 80 | if threshold_offset < 0: 81 | raise ValueError(f"threshold_offset must be non-negative, got {threshold_offset}.") 82 | # logits max values is 0 83 | logits -= jnp.mean(logits, axis=-1, keepdims=True) 84 | mask_high_threshold = logits >= (threshold + threshold_offset) 85 | mask_low_threshold = logits >= (threshold - threshold_offset) 86 | return iou(mask_high_threshold, mask_low_threshold) 87 | -------------------------------------------------------------------------------- /imgx/metric/distribution.py: -------------------------------------------------------------------------------- 1 | """Metric functions for probability distributions.""" 2 | import jax.numpy as jnp 3 | 4 | 5 | def normal_kl( 6 | p_mean: jnp.ndarray, 7 | p_log_variance: jnp.ndarray, 8 | q_mean: jnp.ndarray, 9 | q_log_variance: jnp.ndarray, 10 | ) -> jnp.ndarray: 11 | r"""Compute the KL divergence between two 1D normal distributions. 12 | 13 | KL[p||q] = \int p \log (p / q) dx 14 | 15 | Although the inputs are arrays, each value is considered independently. 16 | This function is not symmetric. 17 | 18 | Input array shapes should be broadcast-compatible. 19 | 20 | Args: 21 | p_mean: mean of distribution p. 22 | p_log_variance: log variance of distribution p. 23 | q_mean: mean of distribution q. 24 | q_log_variance: log variance of distribution q. 25 | 26 | Returns: 27 | KL divergence. 28 | """ 29 | return 0.5 * ( 30 | -1.0 31 | + q_log_variance 32 | - p_log_variance 33 | + jnp.exp(p_log_variance - q_log_variance) 34 | + ((p_mean - q_mean) ** 2) * jnp.exp(-q_log_variance) 35 | ) 36 | 37 | 38 | def approx_standard_normal_cdf(x: jnp.ndarray) -> jnp.ndarray: 39 | """Approximate cumulative distribution function of standard normal. 40 | 41 | if x ~ Normal(mean, var), then cdf(z) = p(x <= z) 42 | 43 | https://www.aimspress.com/article/doi/10.3934/math.2022648#b13 44 | https://www.jstor.org/stable/2346872 45 | 46 | Args: 47 | x: array of any shape with any float values. 48 | 49 | Returns: 50 | CDF estimation. 51 | """ 52 | return 0.5 * (1.0 + jnp.tanh(jnp.sqrt(2.0 / jnp.pi) * (x + 0.044715 * x**3))) 53 | 54 | 55 | def discretized_gaussian_log_likelihood( 56 | x: jnp.ndarray, 57 | mean: jnp.ndarray, 58 | log_variance: jnp.ndarray, 59 | x_delta: float = 1.0 / 255.0, 60 | x_bound: float = 0.999, 61 | ) -> jnp.ndarray: 62 | """Log-likelihood of a normal distribution discretizing to an image. 63 | 64 | p(y=x) is approximated by p(y <= x+delta) - p(y <= x-delta). 65 | 66 | Args: 67 | x: target image, with value in [-1, 1]. 68 | mean: normal distribution mean. 69 | log_variance: log of distribution variance. 70 | x_delta: discretization step, used to estimate probability. 71 | x_bound: values with abs > x_bound are calculated differently. 72 | 73 | Returns: 74 | Discretized log likelihood over 2*delta. 75 | """ 76 | log_scales = 0.5 * log_variance 77 | centered_x = x - mean 78 | inv_stdv = jnp.exp(-log_scales) 79 | 80 | # let y be a variable 81 | # cdf(z+delta) = p(y <= z+delta) 82 | plus_in = inv_stdv * (centered_x + x_delta) # z 83 | cdf_plus = approx_standard_normal_cdf(plus_in) 84 | # log( p(y <= z+delta) ) 85 | log_cdf_plus = jnp.log(cdf_plus.clip(min=1e-12)) 86 | 87 | # cdf(z-delta) = p(y <= z-delta) 88 | minus_in = inv_stdv * (centered_x - x_delta) 89 | cdf_minus = approx_standard_normal_cdf(minus_in) 90 | # log( 1-p(y <= z-delta) ) = log( p(y > z-delta) ) 91 | log_one_minus_cdf_minus = jnp.log((1.0 - cdf_minus).clip(min=1e-12)) 92 | 93 | # p(z-delta < y <= z+delta) 94 | cdf_delta = cdf_plus - cdf_minus 95 | log_cdf_delta = jnp.log(cdf_delta.clip(min=1e-12)) 96 | 97 | # if x < -0.999, log( p(y <= z+delta) ) 98 | # if x > 0.999, log( p(y > z-delta) ) 99 | # if -0.999 <= x <= 0.999, log( p(z-delta < y <= z+delta) ) 100 | return jnp.where( 101 | x < -x_bound, 102 | log_cdf_plus, 103 | jnp.where(x > x_bound, log_one_minus_cdf_minus, log_cdf_delta), 104 | ) 105 | -------------------------------------------------------------------------------- /imgx/metric/segmentation_test.py: -------------------------------------------------------------------------------- 1 | """Test functions in imgx.metric.util.""" 2 | 3 | 4 | import chex 5 | import jax 6 | import numpy as np 7 | from chex._src import fake 8 | 9 | from imgx.metric.segmentation import ( 10 | get_jit_segmentation_metrics, 11 | get_non_jit_segmentation_metrics, 12 | get_non_jit_segmentation_metrics_per_step, 13 | ) 14 | 15 | 16 | # Set `FLAGS.chex_n_cpu_devices` CPU devices for all tests. 17 | def setUpModule() -> None: # pylint: disable=invalid-name 18 | """Fake multi-devices.""" 19 | fake.set_n_cpu_devices(2) 20 | 21 | 22 | class TestGetSegmentationMetrics(chex.TestCase): 23 | """Test get_segmentation_metrics.""" 24 | 25 | batch = 2 26 | num_classes = 3 27 | spatial_shape = (4, 5, 6) 28 | spacing = np.array((0.2, 0.5, 1.0)) 29 | mask_shape = (batch, *spatial_shape, num_classes) 30 | 31 | @chex.all_variants() 32 | def test_jit_shapes(self) -> None: 33 | """Test shapes.""" 34 | key = jax.random.PRNGKey(0) 35 | key_pred, key_true = jax.random.split(key) 36 | mask_pred = jax.random.uniform(key_pred, shape=self.mask_shape) 37 | mask_true = jax.random.uniform(key_true, shape=self.mask_shape) 38 | 39 | got = self.variant(get_jit_segmentation_metrics)(mask_pred, mask_true, self.spacing) 40 | for _, v in got.items(): 41 | chex.assert_shape(v, (self.batch,)) 42 | 43 | @chex.variants(without_jit=True, with_device=True, without_device=True) 44 | def test_nonjit_shapes(self) -> None: 45 | """Test shapes.""" 46 | key = jax.random.PRNGKey(0) 47 | key_pred, key_true = jax.random.split(key) 48 | mask_pred = jax.random.uniform(key_pred, shape=self.mask_shape) 49 | mask_true = jax.random.uniform(key_true, shape=self.mask_shape) 50 | 51 | got = self.variant(get_non_jit_segmentation_metrics)(mask_pred, mask_true, self.spacing) 52 | for _, v in got.items(): 53 | chex.assert_shape(v, (self.batch,)) 54 | 55 | @chex.variants(without_jit=True, with_device=True, without_device=True) 56 | def test_nonjit_per_step_shapes(self) -> None: 57 | """Test shapes.""" 58 | num_steps = 2 59 | key = jax.random.PRNGKey(0) 60 | key_pred, key_true = jax.random.split(key) 61 | mask_pred = jax.random.uniform(key_pred, shape=(*self.mask_shape, num_steps)) 62 | mask_true = jax.random.uniform(key_true, shape=self.mask_shape) 63 | 64 | got = self.variant(get_non_jit_segmentation_metrics_per_step)( 65 | mask_pred, mask_true, self.spacing 66 | ) 67 | for _, v in got.items(): 68 | chex.assert_shape(v, (self.batch, num_steps)) 69 | -------------------------------------------------------------------------------- /imgx/metric/similarity.py: -------------------------------------------------------------------------------- 1 | """Image similarity metrics.""" 2 | 3 | import jax.numpy as jnp 4 | 5 | from imgx import EPS 6 | from imgx.metric.smoothing import get_conv 7 | 8 | 9 | def ssim( 10 | image1: jnp.ndarray, 11 | image2: jnp.ndarray, 12 | max_val: float = 1.0, 13 | kernel_sigma: float = 1.5, 14 | kernel_size: int = 11, 15 | kernel_type: str = "gaussian", 16 | k1: float = 0.01, 17 | k2: float = 0.03, 18 | ) -> jnp.ndarray: 19 | """Calculate Structural similarity index metric (SSIM). 20 | 21 | https://en.wikipedia.org/wiki/Structural_similarity 22 | https://github.com/Project-MONAI/MONAI/blob/ccd32ca5e9e84562d2f388b45b6724b5c77c1f57/monai/metrics/regression.py#L240 23 | 24 | SSIM is calculated per window. The window size is specified by `window_size`. 25 | The window is convolved with a Gaussian kernel specified by `kernel_sigma`. 26 | 27 | SSIM as loss is not stable and may cause NaNs gradients. 28 | https://github.com/tensorflow/tensorflow/issues/50400 29 | https://github.com/tensorflow/tensorflow/issues/57353 30 | 31 | Args: 32 | image1: image of shape (batch, ..., channels). 33 | image2: image of shape (batch, ..., channels). 34 | max_val: maximum value of input images, minimum is assumed to be zero. 35 | kernel_sigma: sigma for Gaussian kernel. 36 | kernel_size: size for kernel. 37 | kernel_type: type of kernel, "gaussian" or "uniform". 38 | k1: stability constant for luminance. 39 | k2: stability constant for contrast. 40 | 41 | Returns: 42 | SSIM of shape (batch,). 43 | """ 44 | num_spatial_dims = image1.ndim - 2 45 | conv = get_conv( 46 | num_spatial_dims=num_spatial_dims, 47 | kernel_sigma=kernel_sigma, 48 | kernel_size=kernel_size, 49 | kernel_type=kernel_type, 50 | padding="VALID", 51 | ) 52 | 53 | # (batch, ..., channels) 54 | # the spatial dims are reduced by the conv 55 | mean1 = conv(image1) 56 | mean2 = conv(image2) 57 | mean11 = conv(image1 * image1) 58 | mean12 = conv(image1 * image2) 59 | mean22 = conv(image2 * image2) 60 | var1 = mean11 - mean1 * mean1 61 | var2 = mean22 - mean2 * mean2 62 | covar12 = mean12 - mean1 * mean2 63 | 64 | c1 = (k1 * max_val) ** 2 65 | c2 = (k2 * max_val) ** 2 66 | numerator = (2 * mean1 * mean2 + c1) * (2 * covar12 + c2) 67 | denominator = (mean1**2 + mean2**2 + c1) * (var1 + var2 + c2) 68 | 69 | # (batch,) 70 | return jnp.mean(numerator / denominator, axis=tuple(range(1, image1.ndim))) 71 | 72 | 73 | def psnr( 74 | image1: jnp.ndarray, 75 | image2: jnp.ndarray, 76 | value_range: float = 1.0, 77 | eps: float = EPS, 78 | ) -> jnp.ndarray: 79 | """Calculate Peak signal-to-noise ratio (PSNR) metric. 80 | 81 | https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio 82 | https://github.com/Project-MONAI/MONAI/blob/ccd32ca5e9e84562d2f388b45b6724b5c77c1f57/monai/metrics/regression.py#L186 83 | 84 | Args: 85 | image1: image of shape (batch, ..., channels). 86 | image2: image of shape (batch, ..., channels). 87 | value_range: value range of input images. 88 | eps: epsilon, if two images are identical, MSE=0. 89 | 90 | Returns: 91 | PSNR of shape (batch,). 92 | """ 93 | mse = jnp.mean((image1 - image2) ** 2, axis=range(1, image1.ndim)) 94 | mse = jnp.maximum(mse, eps) 95 | return 20 * jnp.log10(value_range) - 10 * jnp.log10(mse) 96 | 97 | 98 | def nrmsd( 99 | image_pred: jnp.ndarray, 100 | image_true: jnp.ndarray, 101 | eps: float = EPS, 102 | ) -> jnp.ndarray: 103 | """Calculate normalized root-mean-square-deviation (NRMSD) metric. 104 | 105 | The normalization is performed by the mean of ground truth image. 106 | https://en.wikipedia.org/wiki/Root-mean-square_deviation 107 | 108 | Args: 109 | image_pred: predicted image of shape (batch, ..., channels). 110 | image_true: ground truth image of shape (batch, ..., channels). 111 | eps: epsilon, if two images are identical, MSE=0. 112 | 113 | Returns: 114 | NRMSD of shape (batch,). 115 | """ 116 | mse = jnp.mean((image_pred - image_true) ** 2, axis=range(1, image_true.ndim)) 117 | rmsd = jnp.sqrt(mse) 118 | denominator = jnp.maximum(jnp.mean(image_true, axis=range(1, image_true.ndim)), eps) 119 | return rmsd / denominator 120 | -------------------------------------------------------------------------------- /imgx/metric/similarity_test.py: -------------------------------------------------------------------------------- 1 | """Test similarity functions.""" 2 | 3 | from __future__ import annotations 4 | 5 | import chex 6 | import jax.numpy as jnp 7 | import numpy as np 8 | from absl.testing import parameterized 9 | from chex._src import fake 10 | 11 | from imgx.metric.similarity import nrmsd, psnr, ssim 12 | 13 | 14 | def setUpModule() -> None: # pylint: disable=invalid-name 15 | """Fake multi-devices.""" 16 | fake.set_n_cpu_devices(2) 17 | 18 | 19 | class TestSSIM(chex.TestCase): 20 | """Test SSIM.""" 21 | 22 | @parameterized.product( 23 | ( 24 | {"image_shape": (13,), "in_channels": 1}, 25 | {"image_shape": (11, 13), "in_channels": 1}, 26 | {"image_shape": (11, 13), "in_channels": 3}, 27 | {"image_shape": (11, 13, 15), "in_channels": 2}, 28 | ), 29 | kernel_type=("gaussian", "uniform"), 30 | ) 31 | def test_shapes( 32 | self, 33 | image_shape: tuple[int, ...], 34 | in_channels: int, 35 | kernel_type: str, 36 | ) -> None: 37 | """Test return shapes. 38 | 39 | Args: 40 | image_shape: image shapes. 41 | in_channels: number of input channels. 42 | kernel_type: type of kernel, "gaussian" or "uniform". 43 | """ 44 | batch_size = 2 45 | 46 | image = jnp.ones((batch_size, *image_shape, in_channels)) 47 | got = ssim(image, image, kernel_type=kernel_type) 48 | 49 | chex.assert_shape(got, (batch_size,)) 50 | chex.assert_trees_all_close(got, jnp.ones((batch_size,))) 51 | 52 | @parameterized.named_parameters( 53 | ( 54 | "1d", 55 | # x 56 | # mean(x)=0.6 57 | # mean(x**2)=0.6 58 | # var(x)=0.24 59 | np.array([0.0, 1.0, 0.0, 1.0, 1.0]), 60 | # mean(y)=0.4 61 | # mean(y**2)=0.252 62 | # var(y)=0.092 63 | # mean(xy)=0.36 64 | # covar(x,y)=0.12 65 | np.array([0.2, 0.8, 0, 0.3, 0.7]), 66 | 5, 67 | "uniform", 68 | (2 * 0.6 * 0.4 + 0.0001) 69 | * (2 * 0.12 + 0.0009) 70 | / ((0.6**2 + 0.4**2 + 0.0001) * (0.24 + 0.092 + 0.0009)), 71 | ), 72 | ) 73 | def test_values( 74 | self, 75 | image1: np.ndarray, 76 | image2: np.ndarray, 77 | kernel_size: int, 78 | kernel_type: str, 79 | expected: float, 80 | ) -> None: 81 | """Test return values. 82 | 83 | Args: 84 | image1: image 1. 85 | image2: image 2. 86 | kernel_size: kernel size. 87 | kernel_type: type of kernel, "gaussian" or "uniform". 88 | expected: expected SSIM value. 89 | """ 90 | image1 = image1[None, ..., None] 91 | image2 = image2[None, ..., None] 92 | got = ssim( 93 | jnp.array(image1), jnp.array(image2), kernel_size=kernel_size, kernel_type=kernel_type 94 | ) 95 | chex.assert_trees_all_close(got, jnp.array([expected])) 96 | 97 | 98 | class TestPSNR(chex.TestCase): 99 | """Test PSNR.""" 100 | 101 | @parameterized.product( 102 | ( 103 | {"image_shape": (13,), "in_channels": 1}, 104 | {"image_shape": (11, 13), "in_channels": 1}, 105 | {"image_shape": (11, 13), "in_channels": 3}, 106 | {"image_shape": (11, 13, 15), "in_channels": 2}, 107 | ), 108 | ) 109 | def test_shapes( 110 | self, 111 | image_shape: tuple[int, ...], 112 | in_channels: int, 113 | ) -> None: 114 | """Test return shapes. 115 | 116 | Args: 117 | image_shape: image shapes. 118 | in_channels: number of input channels. 119 | kernel_type: type of kernel, "gaussian" or "uniform". 120 | """ 121 | batch_size = 2 122 | 123 | image = jnp.ones((batch_size, *image_shape, in_channels)) 124 | got = psnr(image, image) 125 | 126 | chex.assert_shape(got, (batch_size,)) 127 | chex.assert_tree_all_finite(got) 128 | 129 | 130 | class TestNRMSD(chex.TestCase): 131 | """Test NRMSD.""" 132 | 133 | @parameterized.product( 134 | ( 135 | {"image_shape": (13,), "in_channels": 1}, 136 | {"image_shape": (11, 13), "in_channels": 1}, 137 | {"image_shape": (11, 13), "in_channels": 3}, 138 | {"image_shape": (11, 13, 15), "in_channels": 2}, 139 | ), 140 | ) 141 | def test_shapes( 142 | self, 143 | image_shape: tuple[int, ...], 144 | in_channels: int, 145 | ) -> None: 146 | """Test return shapes. 147 | 148 | Args: 149 | image_shape: image shapes. 150 | in_channels: number of input channels. 151 | kernel_type: type of kernel, "gaussian" or "uniform". 152 | """ 153 | batch_size = 2 154 | 155 | image = jnp.zeros((batch_size, *image_shape, in_channels)) 156 | got = nrmsd(image, image) 157 | 158 | chex.assert_shape(got, (batch_size,)) 159 | chex.assert_tree_all_finite(got) 160 | -------------------------------------------------------------------------------- /imgx/metric/smoothing.py: -------------------------------------------------------------------------------- 1 | """Label/image smoothing functions.""" 2 | from typing import Callable 3 | 4 | import jax 5 | from jax import lax 6 | from jax import numpy as jnp 7 | 8 | 9 | def gaussian_kernel( 10 | num_spatial_dims: int, 11 | kernel_sigma: float, 12 | kernel_size: int, 13 | ) -> jnp.ndarray: 14 | """Gaussian kernel for convolution. 15 | 16 | Args: 17 | num_spatial_dims: number of spatial dimensions. 18 | kernel_sigma: sigma for Gaussian kernel. 19 | kernel_size: size for Gaussian kernel. 20 | 21 | Returns: 22 | Gaussian kernel of shape (window_size, ..., window_size), ndim=num_spatial_dims. 23 | """ 24 | if kernel_size // 2 == 0: 25 | raise ValueError(f"kernel_size = {kernel_size} must be odd.") 26 | # (kernel_size,) 27 | dist = jnp.arange((1 - kernel_size) / 2, (1 + kernel_size) / 2) 28 | kernel_1d = jnp.exp(-jnp.power(dist / kernel_sigma, 2) / 2) 29 | kernel_1d /= kernel_1d.sum() 30 | if num_spatial_dims == 1: 31 | return kernel_1d 32 | 33 | kernel_nd = jnp.ones((kernel_size,) * num_spatial_dims) 34 | for i in range(num_spatial_dims): 35 | kernel_nd *= jnp.expand_dims(kernel_1d, axis=[j for j in range(num_spatial_dims) if j != i]) 36 | return kernel_nd 37 | 38 | 39 | def get_conv( 40 | num_spatial_dims: int, 41 | kernel_sigma: float, 42 | kernel_size: int, 43 | kernel_type: str, 44 | padding: str, 45 | ) -> Callable[[jnp.ndarray], jnp.ndarray]: 46 | """Get Gaussian convolution function. 47 | 48 | The function performs convolution on the spatial dimensions for each feature channel. 49 | 50 | Args: 51 | num_spatial_dims: number of spatial dimensions. 52 | kernel_sigma: sigma for Gaussian kernel. 53 | kernel_size: size for Gaussian kernel. 54 | kernel_type: type of kernel, "gaussian" or "uniform". 55 | padding: padding type, "SAME" or "VALID". 56 | 57 | Returns: 58 | Gaussian convolution function that takes input of shape (batch, ..., channels). 59 | """ 60 | if num_spatial_dims > 4: 61 | raise ValueError(f"num_spatial_dims = {num_spatial_dims} must be <= 4.") 62 | if kernel_type == "gaussian": 63 | # (window_size, ..., window_size), ndim=num_spatial_dims 64 | kernel = gaussian_kernel( 65 | num_spatial_dims, 66 | kernel_sigma=kernel_sigma, 67 | kernel_size=kernel_size, 68 | ) 69 | elif kernel_type == "uniform": 70 | kernel = jnp.ones((kernel_size,) * num_spatial_dims) / kernel_size**num_spatial_dims 71 | else: 72 | raise ValueError(f"kernel_type = {kernel_type} must be 'gaussian' or 'uniform'.") 73 | # (1,1,window_size, ..., window_size), ndim=num_spatial_dims+2 74 | kernel = kernel[None, None, ...] 75 | 76 | spatial_dilation = "WHDT"[:num_spatial_dims] 77 | lhs_spec = f"N{spatial_dilation}C" 78 | rhs_spec = f"OI{spatial_dilation}" 79 | out_spec = f"N{spatial_dilation}C" 80 | 81 | def conv(x: jnp.ndarray) -> jnp.ndarray: 82 | """Convolution function. 83 | 84 | Args: 85 | x: (batch, ...) 86 | 87 | Returns: 88 | (batch, ...) 89 | """ 90 | return lax.conv_general_dilated( 91 | lhs=x[..., None], 92 | rhs=kernel.astype(x.dtype), 93 | dimension_numbers=(lhs_spec, rhs_spec, out_spec), 94 | window_strides=(1,) * num_spatial_dims, 95 | padding=padding, 96 | )[..., 0] 97 | 98 | return jax.vmap(conv, in_axes=-1, out_axes=-1) 99 | 100 | 101 | def smooth_label( 102 | mask: jnp.ndarray, classes_are_exclusive: bool, label_smoothing: float = 0.0 103 | ) -> jnp.ndarray: 104 | """Label smoothing. 105 | 106 | If classes are exclusive, the even distribution has p = 1 / num_classes per class. 107 | If classes are not exclusive, the even distribution has p = 0.5 per class. 108 | 109 | Args: 110 | mask: probabilities per class, (batch, ..., num_classes). 111 | classes_are_exclusive: if False, each element can be assigned to multiple classes. 112 | label_smoothing: label smoothing factor between 0 and 1, 0.0 means no smoothing. 113 | 114 | Returns: 115 | Label smoothed mask. 116 | """ 117 | p_even = lax.select(classes_are_exclusive, 1.0 / mask.shape[-1], 0.5) 118 | return mask * (1 - label_smoothing) + p_even * label_smoothing 119 | 120 | 121 | def gaussian_smooth_label( 122 | mask: jnp.ndarray, 123 | classes_are_exclusive: bool, 124 | kernel_sigma: float = 1.5, 125 | kernel_size: int = 3, 126 | ) -> jnp.ndarray: 127 | """Label smoothing using Gaussian kernels. 128 | 129 | The smoothing is performed on the spatial dimensions per class. 130 | https://arxiv.org/abs/2104.05788 131 | 132 | Args: 133 | mask: probabilities per class, (batch, ..., num_classes). 134 | classes_are_exclusive: if False, each element can be assigned to multiple classes. 135 | kernel_sigma: sigma for Gaussian kernel. 136 | kernel_size: size for kernel. 137 | 138 | Returns: 139 | Label smoothed mask. 140 | """ 141 | mask = get_conv( 142 | num_spatial_dims=mask.ndim - 2, 143 | kernel_sigma=kernel_sigma, 144 | kernel_size=kernel_size, 145 | kernel_type="gaussian", 146 | padding="SAME", 147 | )(mask) 148 | # if classes are exclusive, the sum should be 1 149 | mask = lax.cond( 150 | classes_are_exclusive, 151 | lambda x: x / jnp.sum(x, axis=-1, keepdims=True), 152 | lambda x: x, 153 | mask, 154 | ) 155 | return mask 156 | -------------------------------------------------------------------------------- /imgx/metric/util.py: -------------------------------------------------------------------------------- 1 | """Util functions.""" 2 | from __future__ import annotations 3 | 4 | import chex 5 | import jax.numpy as jnp 6 | 7 | 8 | def aggregate_metrics(metrics: dict[str, chex.ArrayTree]) -> dict[str, chex.ArrayTree]: 9 | """Calculate min/mean/max statistics for each key. 10 | 11 | Args: 12 | metrics: metric dict. 13 | 14 | Returns: 15 | aggregated metric dict. 16 | """ 17 | agg_metrics = {} 18 | for key, value in metrics.items(): 19 | agg_metrics[f"mean_{key}"] = jnp.nanmean(value) 20 | agg_metrics[f"min_{key}"] = jnp.nanmin(value) 21 | agg_metrics[f"max_{key}"] = jnp.nanmax(value) 22 | return agg_metrics 23 | 24 | 25 | def aggregate_metrics_for_diffusion( 26 | metrics: dict[str, jnp.ndarray], 27 | t_index: jnp.ndarray, 28 | ) -> dict[str, jnp.ndarray]: 29 | """Aggregate metrics for diffusion. 30 | 31 | Args: 32 | metrics: dict of metrics, each of shape (batch,). 33 | t_index: time of shape (batch, ...), values in [0, num_timesteps). 34 | 35 | Returns: 36 | aggregated metrics. 37 | """ 38 | mask_t_min = t_index == jnp.min(t_index) 39 | mask_t_max = t_index == jnp.max(t_index) 40 | metrics_diff = {} 41 | for key, value in metrics.items(): 42 | metrics_diff[f"{key}_t_min"] = jnp.nanmean(value, where=mask_t_min) 43 | metrics_diff[f"{key}_t_max"] = jnp.nanmean(value, where=mask_t_max) 44 | return metrics_diff 45 | 46 | 47 | def merge_aggregated_metrics(metrics: dict[str, chex.ArrayTree]) -> dict[str, chex.ArrayTree]: 48 | """Merge/aggregate aggregated metrics. 49 | 50 | Args: 51 | metrics: metric dict containing aggregated metrics. 52 | 53 | Returns: 54 | Merged aggregated metric dict. 55 | """ 56 | min_metrics = {} 57 | max_metrics = {} 58 | mean_metrics = {} 59 | for k in metrics: 60 | if k.startswith("min_"): 61 | min_metrics[k] = jnp.nanmin(metrics[k]) 62 | elif k.startswith("max_"): 63 | max_metrics[k] = jnp.nanmax(metrics[k]) 64 | else: 65 | mean_metrics[k] = jnp.nanmean(metrics[k]) 66 | return { 67 | **min_metrics, 68 | **max_metrics, 69 | **mean_metrics, 70 | } 71 | 72 | 73 | def flatten_diffusion_metrics(metrics: dict[str, jnp.ndarray]) -> dict[str, jnp.ndarray]: 74 | """Flatten metrics dict for diffusion models. 75 | 76 | Args: 77 | metrics: dict of metrics, each value of shape (batch, num_steps). 78 | 79 | Returns: 80 | metrics: dict of metrics, each value of shape (batch, ). 81 | """ 82 | metrics_flatten = {} 83 | for k, v in metrics.items(): 84 | for i in range(v.shape[-1]): 85 | metrics_flatten[f"{k}_step_{i}"] = v[..., i] 86 | metrics_flatten[k] = v[..., -1] 87 | return metrics_flatten 88 | -------------------------------------------------------------------------------- /imgx/model/__init__.py: -------------------------------------------------------------------------------- 1 | """Package for models.""" 2 | from imgx.model.unet.unet import Unet # noqa: F401 3 | 4 | SUPPORTED_VISION_MODELS = [ 5 | "Unet", 6 | ] 7 | 8 | __all__ = SUPPORTED_VISION_MODELS 9 | -------------------------------------------------------------------------------- /imgx/model/attention/__init__.py: -------------------------------------------------------------------------------- 1 | """Attention and transformer related modules.""" 2 | from imgx.model.attention.efficient_attention import dot_product_attention_with_qkv_chunks 3 | from imgx.model.attention.transformer import TransformerEncoder 4 | 5 | __all__ = [ 6 | "dot_product_attention_with_qkv_chunks", 7 | "TransformerEncoder", 8 | ] 9 | -------------------------------------------------------------------------------- /imgx/model/attention/transformer.py: -------------------------------------------------------------------------------- 1 | """Generic transformer related functions. 2 | 3 | https://github.com/deepmind/dm-haiku/blob/main/examples/transformer/model.py 4 | https://github.com/google-research/vision_transformer/blob/main/vit_jax/models_vit.py 5 | """ 6 | from __future__ import annotations 7 | 8 | import flax.linen as nn 9 | import jax.numpy as jnp 10 | import numpy as np 11 | 12 | from imgx.model.attention.efficient_attention import dot_product_attention_with_qkv_chunks 13 | from imgx.model.basic import MLP 14 | 15 | 16 | class TransformerEncoder(nn.Module): 17 | """A transformer encoder/decoder. 18 | 19 | The architecture is from 20 | https://github.com/google-deepmind/dm-haiku/blob/main/examples/transformer/model.py 21 | """ 22 | 23 | num_heads: int 24 | num_layers: int = 1 25 | autoregressive: bool = False 26 | widening_factor: int = 4 27 | add_position_embedding: bool = False 28 | dropout: float = 0.0 29 | remat: bool = True # if True also uses efficient attention 30 | dtype: jnp.dtype = jnp.float32 31 | 32 | @nn.compact 33 | def __call__( 34 | self, 35 | is_train: bool, 36 | x: jnp.ndarray, 37 | mask: jnp.ndarray | None = None, 38 | ) -> jnp.ndarray: 39 | """Transformer encoder forward pass. 40 | 41 | Args: 42 | is_train: whether in training mode. 43 | x: shape (batch, ..., model_size). 44 | mask: shape (batch, ...) or None. 45 | Tokens with False values are ignored. 46 | 47 | Returns: 48 | (batch, ..., model_size). 49 | """ 50 | batch, *spatial_shape, channels = x.shape 51 | if len(spatial_shape) > 1: 52 | # x: shape(batch, seq_len, model_size). 53 | # mask: shape(batch, seq_len) or None. 54 | x = x.reshape((batch, -1, channels)) 55 | if mask is not None: 56 | mask = mask.reshape((batch, -1)) 57 | 58 | kernel_init = nn.initializers.variance_scaling( 59 | scale=2 / self.num_layers, 60 | mode="fan_in", 61 | distribution="truncated_normal", 62 | ) 63 | 64 | # define classes 65 | attn_cls = nn.MultiHeadDotProductAttention 66 | attention_fn = nn.dot_product_attention 67 | if self.remat: 68 | attn_cls = nn.remat(attn_cls) 69 | attention_fn = dot_product_attention_with_qkv_chunks 70 | 71 | _, seq_len, model_size = x.shape 72 | 73 | if model_size % self.widening_factor != 0: 74 | raise ValueError( 75 | f"Model size {model_size} is not divisible by widening factor " 76 | f"{self.widening_factor}" 77 | ) 78 | 79 | # compute mask if provided 80 | if mask is not None: 81 | if self.autoregressive: 82 | # compute causal mask for autoregressive sequence modelling. 83 | # (1, 1, seq_len, seq_len) 84 | causal_mask = np.tril(np.ones((1, 1, seq_len, seq_len))) 85 | # (batch, 1, seq_len, seq_len) 86 | mask = mask[:, None, None, :] * causal_mask 87 | else: 88 | # (batch, 1, 1, seq_len) * (batch, 1, seq_len, 1) 89 | # -> (batch, 1, seq_len, seq_len) 90 | mask = mask[:, None, None, :] * mask[:, None, :, None] 91 | 92 | # embed the input tokens and positions. 93 | if self.add_position_embedding: 94 | positional_embeddings = self.param( 95 | "transformer_positional_embeddings", 96 | nn.initializers.truncated_normal(stddev=0.02), 97 | (1, seq_len, model_size), 98 | ) 99 | x += positional_embeddings 100 | x = nn.Dropout(rate=self.dropout, deterministic=not is_train)(x) 101 | 102 | h = x 103 | for _ in range(self.num_layers): 104 | h_attn = nn.LayerNorm(dtype=self.dtype)(h) 105 | h_attn = attn_cls( 106 | num_heads=self.num_heads, 107 | # head_dim = qkv_features // num_heads 108 | qkv_features=model_size // self.widening_factor * self.num_heads, 109 | out_features=model_size, 110 | attention_fn=attention_fn, 111 | kernel_init=kernel_init, 112 | dtype=self.dtype, 113 | )(inputs_q=h_attn, inputs_k=h_attn, inputs_v=h_attn, mask=mask) 114 | h_attn = nn.Dropout(rate=self.dropout, deterministic=not is_train)(h_attn) 115 | h += h_attn 116 | 117 | h_dense = nn.LayerNorm(dtype=self.dtype)(h) 118 | h_dense = MLP( 119 | emb_size=model_size * self.widening_factor, 120 | output_size=model_size, 121 | dtype=self.dtype, 122 | kernel_init=kernel_init, 123 | remat=self.remat, 124 | )(h_dense) 125 | h_dense = nn.Dropout(rate=self.dropout, deterministic=not is_train)(h_dense) 126 | h += h_dense 127 | 128 | h = nn.LayerNorm(dtype=self.dtype)(h) 129 | 130 | if len(spatial_shape) > 1: 131 | # (batch, spatial_shape, model_size). 132 | h = h.reshape((batch, *spatial_shape, model_size)) 133 | 134 | return h 135 | -------------------------------------------------------------------------------- /imgx/model/basic.py: -------------------------------------------------------------------------------- 1 | """Basic functions and modules.""" 2 | from __future__ import annotations 3 | 4 | from typing import Callable 5 | 6 | import flax.linen as nn 7 | import jax 8 | import jax.numpy as jnp 9 | 10 | 11 | class Identity(nn.Module): 12 | """Identity module.""" 13 | 14 | dtype: jnp.dtype = jnp.float32 15 | 16 | @nn.compact 17 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 18 | """Forward pass. 19 | 20 | Args: 21 | x: input. 22 | 23 | Returns: 24 | input. 25 | """ 26 | return x 27 | 28 | 29 | class InstanceNorm(nn.Module): 30 | """Instance norm. 31 | 32 | The norm is calculated on axes excluding batch and features. 33 | """ 34 | 35 | dtype: jnp.dtype = jnp.float32 36 | 37 | @nn.compact 38 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 39 | """Forward pass. 40 | 41 | Args: 42 | x: input with batch axis, (batch, ..., channel). 43 | 44 | Returns: 45 | Normalised input. 46 | """ 47 | reduction_axes = tuple(range(x.ndim)[slice(1, -1)]) 48 | return nn.LayerNorm( 49 | reduction_axes=reduction_axes, 50 | )(x) 51 | 52 | 53 | def sinusoidal_positional_embedding( 54 | x: jnp.ndarray, 55 | dim: int, 56 | max_period: int = 10000, 57 | dtype: jnp.dtype = jnp.float32, 58 | ) -> jnp.ndarray: 59 | """Create sinusoidal timestep embeddings. 60 | 61 | Half defined by sin, half by cos. 62 | For position x, the embeddings are (for i = 0,...,half_dim-1) 63 | sin(x / (max_period ** (i/half_dim))) 64 | cos(x / (max_period ** (i/half_dim))) 65 | 66 | Args: 67 | x: (..., ), with values in [0, 1]. 68 | dim: embedding dimension, assume to be evenly divided by two. 69 | max_period: controls the minimum frequency of the embeddings. 70 | dtype: dtype of the embeddings. 71 | 72 | Returns: 73 | Embedding of size (..., dim). 74 | """ 75 | ndim_x = len(x.shape) 76 | if dim % 2 != 0: 77 | raise ValueError(f"dim must be evenly divided by two, got {dim}.") 78 | half_dim = dim // 2 79 | # (half_dim,) 80 | freq = jnp.arange(0, half_dim, dtype=dtype) 81 | freq = jnp.exp(-jnp.log(max_period) * freq / half_dim) 82 | # (..., half_dim) 83 | freq = jnp.expand_dims(freq, axis=tuple(range(ndim_x))) 84 | args = x[..., None] * max_period * freq 85 | # (..., dim) 86 | return jnp.concatenate([jnp.cos(args), jnp.sin(args)], axis=-1) 87 | 88 | 89 | class MLP(nn.Module): 90 | """Two-layer MLP.""" 91 | 92 | emb_size: int 93 | output_size: int 94 | activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.gelu 95 | kernel_init: Callable[ 96 | [jax.Array, jnp.shape, jnp.dtype], jnp.ndarray 97 | ] = nn.initializers.lecun_normal() 98 | remat: bool = True 99 | dtype: jnp.dtype = jnp.float32 100 | 101 | @nn.compact 102 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 103 | """Forward pass. 104 | 105 | Args: 106 | x: shape (..., in_size) 107 | 108 | Returns: 109 | shape (..., out_size) 110 | """ 111 | dense_cls = nn.remat(nn.Dense) if self.remat else nn.Dense 112 | x = dense_cls( 113 | self.emb_size, 114 | kernel_init=self.kernel_init, 115 | dtype=self.dtype, 116 | )(x) 117 | x = self.activation(x) 118 | x = dense_cls( 119 | self.output_size, 120 | kernel_init=self.kernel_init, 121 | dtype=self.dtype, 122 | )(x) 123 | return x 124 | -------------------------------------------------------------------------------- /imgx/model/basic_test.py: -------------------------------------------------------------------------------- 1 | """Test basic functions for model.""" 2 | 3 | 4 | from functools import partial 5 | 6 | import chex 7 | import jax 8 | import jax.numpy as jnp 9 | from absl.testing import parameterized 10 | from chex._src import fake 11 | 12 | from imgx.model.basic import MLP, InstanceNorm, sinusoidal_positional_embedding 13 | 14 | 15 | # Set `FLAGS.chex_n_cpu_devices` CPU devices for all tests. 16 | def setUpModule() -> None: # pylint: disable=invalid-name 17 | """Fake multi-devices.""" 18 | fake.set_n_cpu_devices(2) 19 | 20 | 21 | class TestInstanceNorm(chex.TestCase): 22 | """Test the function sinusoidal_positional_embedding.""" 23 | 24 | @chex.all_variants() 25 | @parameterized.named_parameters( 26 | ( 27 | "1d", 28 | (2,), 29 | ), 30 | ( 31 | "2d", 32 | (2, 3), 33 | ), 34 | ) 35 | def test_shapes( 36 | self, 37 | in_shape: tuple[int, ...], 38 | ) -> None: 39 | """Test output shapes under different device condition. 40 | 41 | Args: 42 | in_shape: input shape. 43 | """ 44 | rng = {"params": jax.random.PRNGKey(0)} 45 | norm = InstanceNorm() 46 | x = jax.random.uniform( 47 | jax.random.PRNGKey(0), 48 | shape=in_shape, 49 | ) 50 | out, _ = self.variant(norm.init_with_output)(rng, x) 51 | chex.assert_shape(out, in_shape) 52 | 53 | 54 | class TestSinusoidalPositionalEmbedding(chex.TestCase): 55 | """Test the function sinusoidal_positional_embedding.""" 56 | 57 | @chex.all_variants() 58 | @parameterized.named_parameters( 59 | ("1d case 1", (2,), 4, 5), 60 | ( 61 | "1d case 2", 62 | (2,), 63 | 8, 64 | 10000, 65 | ), 66 | ( 67 | "2d", 68 | (2, 3), 69 | 8, 70 | 10000, 71 | ), 72 | ) 73 | def test_shapes(self, in_shape: tuple[int, ...], dim: int, max_period: int) -> None: 74 | """Test output shapes under different device condition. 75 | 76 | Args: 77 | in_shape: input shape. 78 | dim: embedding dimension, assume to be evenly divided by two. 79 | max_period: controls the minimum frequency of the embeddings. 80 | """ 81 | rng = jax.random.PRNGKey(0) 82 | x = jax.random.uniform( 83 | rng, 84 | shape=in_shape, 85 | ) 86 | out = self.variant( 87 | partial(sinusoidal_positional_embedding, dim=dim, max_period=max_period) 88 | )(x) 89 | chex.assert_shape(out, (*in_shape, dim)) 90 | 91 | 92 | class TestMLP(chex.TestCase): 93 | """Test MLP.""" 94 | 95 | emb_size: int = 4 96 | output_size: int = 8 97 | 98 | @chex.all_variants() 99 | @parameterized.product( 100 | in_shape=[(3, 4, 5), (3, 4), (3,)], 101 | remat=[True, False], 102 | ) 103 | def test_shapes( 104 | self, 105 | in_shape: tuple[int, ...], 106 | remat: bool, 107 | ) -> None: 108 | """Test output shapes.""" 109 | rng = {"params": jax.random.PRNGKey(0)} 110 | mlp = MLP( 111 | emb_size=self.emb_size, 112 | output_size=self.output_size, 113 | remat=remat, 114 | ) 115 | out, _ = self.variant(mlp.init_with_output)(rng, jnp.ones(in_shape)) 116 | chex.assert_shape(out, (*in_shape[:-1], self.output_size)) 117 | -------------------------------------------------------------------------------- /imgx/model/conv_test.py: -------------------------------------------------------------------------------- 1 | """Test conv layers.""" 2 | 3 | import chex 4 | import jax 5 | from absl.testing import parameterized 6 | from chex._src import fake 7 | 8 | from imgx.model.conv import ConvDownSample, ConvResBlock, ConvUpSample 9 | 10 | 11 | # Set `FLAGS.chex_n_cpu_devices` CPU devices for all tests. 12 | def setUpModule() -> None: # pylint: disable=invalid-name 13 | """Fake multi-devices.""" 14 | fake.set_n_cpu_devices(2) 15 | 16 | 17 | class TestConvDownSample(chex.TestCase): 18 | """Test ConvDownSample.""" 19 | 20 | batch = 2 21 | 22 | @chex.all_variants() 23 | @parameterized.named_parameters( 24 | ( 25 | "1d", 26 | (12,), 27 | (2,), 28 | (6,), 29 | ), 30 | ( 31 | "2d", 32 | (12, 13), 33 | (2, 2), 34 | (6, 7), 35 | ), 36 | ( 37 | "2d - different scale factors", 38 | (12, 13), 39 | (4, 2), 40 | (3, 7), 41 | ), 42 | ( 43 | "3d - large scale factor", 44 | (2, 4, 8), 45 | (4, 4, 4), 46 | (1, 1, 2), 47 | ), 48 | ( 49 | "3d", 50 | (12, 13, 14), 51 | (2, 2, 2), 52 | (6, 7, 7), 53 | ), 54 | ) 55 | def test_shapes( 56 | self, 57 | in_shape: tuple[int, ...], 58 | scale_factor: tuple[int, ...], 59 | out_shape: tuple[int, ...], 60 | ) -> None: 61 | """Test output shapes under different device condition. 62 | 63 | Args: 64 | in_shape: input shape, without batch, channel. 65 | scale_factor: downsample factor. 66 | out_shape: output shape, without batch, channel. 67 | """ 68 | in_channels = 1 69 | out_channels = 1 70 | rng = {"params": jax.random.PRNGKey(0)} 71 | conv = ConvDownSample( 72 | out_channels=out_channels, 73 | scale_factor=scale_factor, 74 | ) 75 | x = jax.random.uniform( 76 | jax.random.PRNGKey(0), 77 | shape=(self.batch, *in_shape, in_channels), 78 | ) 79 | out, _ = self.variant(conv.init_with_output)(rng, x) 80 | chex.assert_shape(out, (self.batch, *out_shape, out_channels)) 81 | 82 | 83 | class TestConvUpSample(chex.TestCase): 84 | """Test ConvUpSample.""" 85 | 86 | batch = 2 87 | 88 | @chex.all_variants() 89 | @parameterized.named_parameters( 90 | ( 91 | "1d", 92 | (3,), 93 | (2,), 94 | (6,), 95 | ), 96 | ( 97 | "2d", 98 | (3, 4), 99 | (2, 2), 100 | (6, 8), 101 | ), 102 | ( 103 | "2d - different scale factors", 104 | (3, 4), 105 | (4, 2), 106 | (12, 8), 107 | ), 108 | ( 109 | "3d", 110 | (2, 3, 4), 111 | (2, 2, 2), 112 | (4, 6, 8), 113 | ), 114 | ) 115 | def test_shapes( 116 | self, 117 | in_shape: tuple[int, ...], 118 | scale_factor: tuple[int, ...], 119 | out_shape: tuple[int, ...], 120 | ) -> None: 121 | """Test output shapes under different device condition. 122 | 123 | Args: 124 | in_shape: input shape, without batch, channel. 125 | scale_factor: up-sampler factor. 126 | out_shape: output shape, without batch, channel. 127 | """ 128 | in_channels = 1 129 | out_channels = 1 130 | rng = {"params": jax.random.PRNGKey(0)} 131 | conv = ConvUpSample( 132 | out_channels=out_channels, 133 | scale_factor=scale_factor, 134 | ) 135 | x = jax.random.uniform( 136 | jax.random.PRNGKey(0), 137 | shape=(self.batch, *in_shape, in_channels), 138 | ) 139 | out, _ = self.variant(conv.init_with_output)(rng, x) 140 | chex.assert_shape(out, (self.batch, *out_shape, out_channels)) 141 | 142 | 143 | class TestConvResBlock(chex.TestCase): 144 | """Test ConvResBlock.""" 145 | 146 | batch = 2 147 | 148 | @chex.all_variants() 149 | @parameterized.product( 150 | in_shape=[(12,), (12, 13), (12, 13, 14)], 151 | has_t=[True, False], 152 | dropout=[0.0, 0.5, 1.0], 153 | is_train=[True, False], 154 | remat=[True, False], 155 | ) 156 | def test_shapes( 157 | self, 158 | in_shape: tuple[int, ...], 159 | has_t: bool, 160 | dropout: float, 161 | is_train: bool, 162 | remat: bool, 163 | ) -> None: 164 | """Test output shapes. 165 | 166 | Args: 167 | in_shape: input shape, without batch, channel. 168 | has_t: whether has time embedding. 169 | dropout: dropout rate. 170 | is_train: whether in training mode. 171 | remat: remat or not. 172 | """ 173 | kernel_size = (3,) * len(in_shape) 174 | in_channels = 1 175 | t_channels = 2 176 | out_channels = 1 177 | rng = {"params": jax.random.PRNGKey(0), "dropout": jax.random.PRNGKey(1)} 178 | conv = ConvResBlock( 179 | out_channels=out_channels, 180 | kernel_size=kernel_size, 181 | dropout=dropout, 182 | remat=remat, 183 | ) 184 | x = jax.random.uniform( 185 | jax.random.PRNGKey(0), 186 | shape=(self.batch, *in_shape, in_channels), 187 | ) 188 | t_emb = None 189 | if has_t: 190 | t_emb = jax.random.uniform( 191 | jax.random.PRNGKey(0), 192 | shape=(self.batch, t_channels), 193 | ) 194 | out, _ = self.variant(conv.init_with_output, static_argnums=(1,))(rng, is_train, x, t_emb) 195 | chex.assert_shape(out, (self.batch, *in_shape, out_channels)) 196 | -------------------------------------------------------------------------------- /imgx/model/slice.py: -------------------------------------------------------------------------------- 1 | """Functions for slicing images.""" 2 | from __future__ import annotations 3 | 4 | import jax.numpy as jnp 5 | 6 | 7 | def merge_spatial_dim_into_batch(x: jnp.ndarray, num_spatial_dims: int) -> jnp.ndarray: 8 | """Merge spatial dimensions into batch dimension. 9 | 10 | Args: 11 | x: array with original shape (batch, ..., in_channels). 12 | num_spatial_dims: target number of spatial dimensions. 13 | 14 | Returns: 15 | array with ndim=num_spatial_dims+2, 16 | shape = (extended_batch, ..., in_channels). 17 | """ 18 | # e.g. if x.shape = (batch, h, w, d, in_channels) 19 | # then x.ndim == 5, num_spatial_dims = 2 20 | # axes = (0, 3, 1, 2, 4) 21 | axes = ( 22 | 0, 23 | *range(num_spatial_dims + 1, x.ndim - 1), 24 | *range(1, num_spatial_dims + 1), 25 | x.ndim - 1, 26 | ) 27 | # move extra dims to front 28 | # e.g. (batch, h, w, d, in_channels) -> (batch, d, h, w, in_channels) 29 | x = jnp.transpose(x, axes) 30 | # e.g. (batch, d, h, w, in_channels) -> (batch*d, h, w, in_channels) 31 | return jnp.reshape(x, (-1, *x.shape[x.ndim - num_spatial_dims - 1 :])) 32 | 33 | 34 | def split_spatial_dim_from_batch( 35 | x: jnp.ndarray, 36 | num_spatial_dims: int, 37 | batch_size: int, 38 | spatial_shape: tuple[int, ...], 39 | ) -> jnp.ndarray: 40 | """Remove spatial dimensions from batch axis. 41 | 42 | Args: 43 | x: array with merged shape (batch, ..., in_channels), 44 | x.ndim=num_spatial_dims+2. 45 | num_spatial_dims: current number of spatial dimensions. 46 | batch_size: batch size. 47 | spatial_shape: original spatial shape. 48 | 49 | Returns: 50 | array with original shape (batch, ..., in_channels). 51 | """ 52 | # e.g. (batch*d, h, w, out_channels) -> (batch, d, h, w, out_channels) 53 | x = jnp.reshape(x, (batch_size, *spatial_shape[num_spatial_dims:], *x.shape[1:])) 54 | 55 | # e.g. if x.shape = (batch, d, h, w, out_channels) 56 | # then x.ndim == 5, num_spatial_dims = 2 57 | # axes = (0, 3, 1, 2, 4) 58 | axes = ( 59 | 0, 60 | *range(x.ndim - 1 - num_spatial_dims, x.ndim - 1), 61 | *range(1, x.ndim - 1 - num_spatial_dims), 62 | x.ndim - 1, 63 | ) 64 | # e.g. (batch, d, h, w, out_channels) -> (batch, h, w, d, out_channels) 65 | return jnp.transpose(x, axes) 66 | -------------------------------------------------------------------------------- /imgx/model/slice_test.py: -------------------------------------------------------------------------------- 1 | """Test slicing functions.""" 2 | 3 | from functools import partial 4 | 5 | import chex 6 | import jax.numpy as jnp 7 | from absl.testing import parameterized 8 | from chex._src import fake 9 | 10 | from imgx.model.slice import merge_spatial_dim_into_batch, split_spatial_dim_from_batch 11 | 12 | 13 | # Set `FLAGS.chex_n_cpu_devices` CPU devices for all tests. 14 | def setUpModule() -> None: # pylint: disable=invalid-name 15 | """Fake multi-devices.""" 16 | fake.set_n_cpu_devices(2) 17 | 18 | 19 | class TestMergeSplitSpatialDims(chex.TestCase): 20 | """Test merge_spatial_dim_into_batch and split_spatial_dim_from_batch.""" 21 | 22 | @chex.all_variants() 23 | @parameterized.named_parameters( 24 | ("2d-1", (2, 3, 4, 5), 1, (8, 3, 5)), 25 | ("2d-2", (2, 3, 4, 5), 2, (2, 3, 4, 5)), 26 | ("3d-1", (2, 3, 4, 5, 6), 1, (40, 3, 6)), 27 | ("3d-2", (2, 3, 4, 5, 6), 2, (10, 3, 4, 6)), 28 | ("3d-3", (2, 3, 4, 5, 6), 3, (2, 3, 4, 5, 6)), 29 | ) 30 | def test_merge_spatial_dim_into_batch( 31 | self, 32 | in_shape: tuple[int, ...], 33 | num_spatial_dims: int, 34 | expected_shape: tuple[int, ...], 35 | ) -> None: 36 | """Test merge_spatial_dim_into_batch. 37 | 38 | Args: 39 | in_shape: input shape. 40 | num_spatial_dims: number of spatial dimensions. 41 | expected_shape: expected output shape. 42 | """ 43 | x = jnp.ones(in_shape) 44 | x = self.variant(partial(merge_spatial_dim_into_batch, num_spatial_dims=num_spatial_dims))( 45 | x 46 | ) 47 | chex.assert_shape(x, expected_shape) 48 | 49 | @chex.all_variants() 50 | @parameterized.named_parameters( 51 | ("2d-1", (8, 3, 5), 1, 2, (3, 4), (2, 3, 4, 5)), 52 | ("2d-2", (2, 3, 4, 5), 2, 2, (3, 4), (2, 3, 4, 5)), 53 | ("3d-1", (40, 3, 6), 1, 2, (3, 4, 5), (2, 3, 4, 5, 6)), 54 | ("3d-2", (10, 3, 4, 6), 2, 2, (3, 4, 5), (2, 3, 4, 5, 6)), 55 | ("3d-3", (2, 3, 4, 5, 6), 3, 2, (3, 4, 5), (2, 3, 4, 5, 6)), 56 | ) 57 | def test_split_spatial_dim_from_batch( 58 | self, 59 | in_shape: tuple[int, ...], 60 | num_spatial_dims: int, 61 | batch_size: int, 62 | spatial_shape: tuple[int, ...], 63 | expected_shape: tuple[int, ...], 64 | ) -> None: 65 | """Test split_spatial_dim_from_batch. 66 | 67 | Args: 68 | in_shape: input shape, with certain spatial axes merged. 69 | num_spatial_dims: number of spatial dimensions. 70 | batch_size: batch size. 71 | spatial_shape: spatial shape. 72 | expected_shape: expected output shape. 73 | """ 74 | x = jnp.ones(in_shape) 75 | x = self.variant( 76 | partial( 77 | split_spatial_dim_from_batch, 78 | num_spatial_dims=num_spatial_dims, 79 | spatial_shape=spatial_shape, 80 | batch_size=batch_size, 81 | ) 82 | )(x) 83 | chex.assert_shape(x, expected_shape) 84 | -------------------------------------------------------------------------------- /imgx/model/unet/__init__.py: -------------------------------------------------------------------------------- 1 | """Module to build a UNet model.""" 2 | -------------------------------------------------------------------------------- /imgx/model/unet/bottom_encoder.py: -------------------------------------------------------------------------------- 1 | """Image encoder for unet.""" 2 | from __future__ import annotations 3 | 4 | import flax.linen as nn 5 | import jax.numpy as jnp 6 | 7 | from imgx.model.attention import TransformerEncoder 8 | from imgx.model.conv import ConvResBlock 9 | 10 | 11 | class BottomImageEncoderUnet(nn.Module): 12 | """Image encoder module with convolutions for unet.""" 13 | 14 | kernel_size: tuple[int, ...] # convolution layer kernel size 15 | dropout: float = 0.0 # for resnet block 16 | num_heads: int = 8 # for multi head attention 17 | num_layers: int = 1 # for transformer encoder 18 | widening_factor: int = 4 # for key size in MHA 19 | remat: bool = True # reduces memory cost at cost of compute speed 20 | dtype: jnp.dtype = jnp.float32 21 | 22 | @nn.compact 23 | def __call__( 24 | self, 25 | is_train: bool, 26 | image_emb: jnp.ndarray, 27 | t_emb: jnp.ndarray | None, 28 | ) -> jnp.ndarray: 29 | """Encoder the image. 30 | 31 | Args: 32 | is_train: whether in training mode. 33 | image_emb: shape (batch, *spatial_shape, model_size). 34 | t_emb: time embedding, (batch, t_channels). 35 | 36 | Returns: 37 | image_emb: (batch, *spatial_shape, model_size). 38 | """ 39 | model_size = image_emb.shape[-1] 40 | 41 | # conv before attention 42 | # image_emb (batch, *spatial_shape, image_emb_size) 43 | image_emb = ConvResBlock( 44 | out_channels=model_size, 45 | kernel_size=self.kernel_size, 46 | dropout=self.dropout, 47 | remat=self.remat, 48 | )(is_train, image_emb, t_emb) 49 | 50 | # attention 51 | image_emb = TransformerEncoder( 52 | num_heads=self.num_heads, 53 | widening_factor=self.widening_factor, 54 | dropout=self.dropout, 55 | remat=self.remat, 56 | dtype=self.dtype, 57 | )(is_train, image_emb) 58 | 59 | # conv after attention 60 | # image_emb (batch, *spatial_shape, image_emb_size) 61 | image_emb = ConvResBlock( 62 | out_channels=model_size, 63 | kernel_size=self.kernel_size, 64 | dropout=self.dropout, 65 | remat=self.remat, 66 | )(is_train, image_emb, t_emb) 67 | 68 | return image_emb 69 | -------------------------------------------------------------------------------- /imgx/model/unet/bottom_encoder_test.py: -------------------------------------------------------------------------------- 1 | """Test Unet bottom encoder related classes and functions.""" 2 | 3 | import chex 4 | import jax 5 | import jax.numpy as jnp 6 | from absl.testing import parameterized 7 | from chex._src import fake 8 | 9 | from imgx.model.unet.bottom_encoder import BottomImageEncoderUnet 10 | 11 | 12 | # Set `FLAGS.chex_n_cpu_devices` CPU devices for all tests. 13 | def setUpModule() -> None: # pylint: disable=invalid-name 14 | """Fake multi-devices.""" 15 | fake.set_n_cpu_devices(2) 16 | 17 | 18 | class TestBottomEncoderUnet(chex.TestCase): 19 | """Test the class BottomEncoderUnet.""" 20 | 21 | batch_size = 2 22 | model_size = 16 23 | num_heads = 2 24 | t_size = 3 25 | 26 | @chex.all_variants() 27 | @parameterized.product( 28 | spatial_shape=[(5, 6), (5, 6, 7)], 29 | with_time=[True, False], 30 | is_train=[True, False], 31 | ) 32 | def test_output_shape( 33 | self, 34 | spatial_shape: tuple[int, ...], 35 | with_time: bool, 36 | is_train: bool, 37 | ) -> None: 38 | """Test output shape.""" 39 | rng = {"params": jax.random.PRNGKey(0)} 40 | kernel_size = (3,) * len(spatial_shape) 41 | encoder = BottomImageEncoderUnet(num_heads=self.num_heads, kernel_size=kernel_size) 42 | image_emb = jnp.ones((self.batch_size, *spatial_shape, self.model_size)) 43 | t_emb = jnp.ones((self.batch_size, self.t_size)) if with_time else None 44 | out, _ = self.variant(encoder.init_with_output, static_argnums=(1,))( 45 | rng, is_train, image_emb, t_emb 46 | ) 47 | chex.assert_shape(out, (self.batch_size, *spatial_shape, self.model_size)) 48 | -------------------------------------------------------------------------------- /imgx/model/unet/downsample_encoder.py: -------------------------------------------------------------------------------- 1 | """Downsample encoder for unet.""" 2 | from __future__ import annotations 3 | 4 | import flax.linen as nn 5 | import jax.numpy as jnp 6 | 7 | from imgx.model.conv import ConvDownSample, ConvNormAct, ConvResBlock 8 | 9 | 10 | class DownsampleEncoder(nn.Module): 11 | """Down-sample encoder module with convolutions for unet.""" 12 | 13 | num_channels: tuple[int, ...] # channel at each depth, including the bottom 14 | patch_size: tuple[int, ...] # first down sampling layer 15 | scale_factor: tuple[int, ...] # spatial down-sampling/up-sampling 16 | kernel_size: tuple[int, ...] # convolution layer kernel size 17 | num_res_blocks: int = 2 # number of residual blocks 18 | dropout: float = 0.0 # for resnet block 19 | remat: bool = True # reduces memory cost at cost of compute speed 20 | dtype: jnp.dtype = jnp.float32 21 | 22 | @nn.compact 23 | def __call__( 24 | self, 25 | is_train: bool, 26 | x: jnp.ndarray, 27 | t_emb: jnp.ndarray | None, 28 | ) -> list[jnp.ndarray]: 29 | """Encoder the image. 30 | 31 | If batch_size = 2, image_shape = (256, 256, 32), num_channels = (1,2,4) 32 | with num_res_blocks = 2, patch_size = 4 33 | the embeddings' shape are: 34 | (2, 256, 256, 32, 1), from first residual block before for loop 35 | (2, 256, 256, 32, 1), from residual block, i=0 36 | (2, 256, 256, 32, 1), from residual block, i=0 37 | (2, 64, 64, 8, 2), from down-sampling block, i=0 38 | (2, 64, 64, 8, 2), from residual block, i=1 39 | (2, 64, 64, 8, 2), from residual block, i=1 40 | (2, 32, 32, 4, 4), from down-sampling block, i=1 41 | (2, 32, 32, 4, 4), from residual block, i=2 42 | (2, 32, 32, 4, 4), from residual block, i=2 43 | 44 | Args: 45 | is_train: whether in training mode. 46 | x: array of shape (batch, *spatial_shape, in_channels). 47 | t_emb: array of shape (batch, t_channels). 48 | 49 | Returns: 50 | List of embeddings from each layer. 51 | """ 52 | conv_down_sample_cls = nn.remat(ConvDownSample) if self.remat else ConvDownSample 53 | 54 | # encoder raw input 55 | x = ConvNormAct( 56 | out_channels=self.num_channels[0], 57 | kernel_size=self.kernel_size, 58 | remat=self.remat, 59 | )(x) 60 | 61 | # encoding 62 | embeddings = [x] 63 | for i, ch in enumerate(self.num_channels): 64 | # residual blocks 65 | # spatial shape get halved by 2**i 66 | for _ in range(self.num_res_blocks): 67 | x = ConvResBlock( 68 | out_channels=ch, 69 | kernel_size=self.kernel_size, 70 | dropout=self.dropout, 71 | remat=self.remat, 72 | )(is_train, x, t_emb) 73 | embeddings.append(x) 74 | 75 | # down-sampling for non-bottom layers 76 | # spatial shape get halved by 2**(i+1) 77 | if i < len(self.num_channels) - 1: 78 | x = conv_down_sample_cls( 79 | out_channels=self.num_channels[i + 1], 80 | scale_factor=self.patch_size if i == 0 else self.scale_factor, 81 | )(x) 82 | embeddings.append(x) 83 | 84 | return embeddings 85 | -------------------------------------------------------------------------------- /imgx/model/unet/upsample_decoder.py: -------------------------------------------------------------------------------- 1 | """Upsample encoder for unet.""" 2 | from __future__ import annotations 3 | 4 | import flax.linen as nn 5 | import jax 6 | import jax.numpy as jnp 7 | from jax import lax 8 | 9 | from imgx.model.conv import ConvResBlock, ConvUpSample 10 | 11 | 12 | class UpsampleDecoder(nn.Module): 13 | """Upsample decoder module with convolutions for unet.""" 14 | 15 | out_channels: int 16 | num_channels: tuple[int, ...] # channel at each depth, including the bottom 17 | patch_size: tuple[int, ...] # first down sampling layer 18 | scale_factor: tuple[int, ...] # spatial down-sampling/up-sampling 19 | kernel_size: tuple[int, ...] # convolution layer kernel size 20 | num_res_blocks: int = 2 # number of residual blocks 21 | dropout: float = 0.0 # for resnet block 22 | out_kernel_init: jax.nn.initializers.Initializer = nn.linear.default_kernel_init 23 | remat: bool = True # remat reduces memory cost at cost of compute speed 24 | dtype: jnp.dtype = jnp.float32 25 | 26 | @nn.compact 27 | def __call__( 28 | self, 29 | is_train: bool, 30 | embeddings: list[jnp.ndarray], 31 | t_emb: jnp.ndarray | None, 32 | ) -> jnp.ndarray | list[jnp.ndarray]: 33 | """Decode the embedding and perform prediction. 34 | 35 | Args: 36 | is_train: whether in training mode. 37 | embeddings: list of embeddings from each layer. 38 | Starting with the first layer. 39 | t_emb: array of shape (batch, t_channels). 40 | 41 | Returns: 42 | Logits (batch, ..., out_channels). 43 | """ 44 | if len(embeddings) != len(self.num_channels) * (self.num_res_blocks + 1) + 1: 45 | raise ValueError("MaskDecoderConvUnet input length does not match") 46 | num_spatial_dims = len(self.kernel_size) 47 | conv_up_sample_cls = nn.remat(ConvUpSample) if self.remat else ConvUpSample 48 | conv_cls = nn.remat(nn.Conv) if self.remat else nn.Conv 49 | 50 | # spatial shape get halved by 2**(len(self.num_channels)-1) 51 | # channel = self.num_channels[-1] 52 | x = embeddings.pop() 53 | 54 | for i, ch in enumerate(self.num_channels[::-1]): 55 | # spatial shape get halved by 2**(len(self.num_channels)-1-i) 56 | # channel = ch 57 | for _ in range(self.num_res_blocks + 1): 58 | # add skipped 59 | # use addition instead of concatenation to reduce memory cost 60 | skipped = embeddings.pop() 61 | x += skipped 62 | 63 | # conv 64 | x = ConvResBlock( 65 | out_channels=ch, 66 | kernel_size=self.kernel_size, 67 | dropout=self.dropout, 68 | remat=self.remat, 69 | )(is_train, x, t_emb) 70 | 71 | if i < len(self.num_channels) - 1: 72 | # up-sampling 73 | # skipped.shape <= up-scaled shape 74 | # as padding may be added when down-sampling 75 | skipped_shape = embeddings[-1].shape[1:-1] 76 | # deconv and pad to make emb of same shape as skipped 77 | x = conv_up_sample_cls( 78 | out_channels=self.num_channels[-i - 2], 79 | scale_factor=self.patch_size 80 | if i == len(self.num_channels) - 2 81 | else self.scale_factor, 82 | )(x) 83 | x = lax.dynamic_slice( 84 | x, 85 | start_indices=(0,) * (num_spatial_dims + 2), 86 | slice_sizes=(x.shape[0], *skipped_shape, x.shape[-1]), 87 | ) 88 | out = conv_cls( 89 | features=self.out_channels, 90 | kernel_size=(1,) * num_spatial_dims, 91 | kernel_init=self.out_kernel_init, 92 | )(x) 93 | return out 94 | -------------------------------------------------------------------------------- /imgx/run_valid.py: -------------------------------------------------------------------------------- 1 | """Script to launch evaluation on validation tests.""" 2 | import argparse 3 | import json 4 | from pathlib import Path 5 | 6 | import jax 7 | from absl import logging 8 | from flax import jax_utils 9 | from flax.training import common_utils 10 | from omegaconf import DictConfig, OmegaConf 11 | 12 | from imgx.data.iterator import get_image_tfds_dataset 13 | from imgx.run_train import build_experiment 14 | 15 | logging.set_verbosity(logging.INFO) 16 | 17 | 18 | def get_checkpoint_steps( 19 | log_dir: Path, 20 | ) -> list[int]: 21 | """Get the steps of all available checkpoints. 22 | 23 | Args: 24 | log_dir: Directory of entire log. 25 | 26 | Returns: 27 | A list of available steps. 28 | 29 | Raises: 30 | ValueError: if any file not found. 31 | """ 32 | ckpt_dir = log_dir / "files" / "ckpt" 33 | steps = [] 34 | for step_dir in ckpt_dir.glob("checkpoint_*/"): 35 | if not step_dir.is_dir(): 36 | continue 37 | ckpt_path = step_dir / "checkpoint" 38 | if not ckpt_path.exists(): 39 | continue 40 | steps.append(int(step_dir.stem.split("_")[-1])) 41 | return sorted(steps, reverse=True) 42 | 43 | 44 | def parse_args() -> argparse.Namespace: 45 | """Parse arguments.""" 46 | parser = argparse.ArgumentParser(description=__doc__) 47 | parser.add_argument( 48 | "--log_dir", 49 | type=Path, 50 | help="Folder of wandb.", 51 | default=None, 52 | ) 53 | parser.add_argument( 54 | "--num_timesteps", 55 | type=int, 56 | help="Number of sampling steps for diffusion_segmentation.", 57 | default=-1, 58 | ) 59 | parser.add_argument( 60 | "--sampler", 61 | type=str, 62 | help="Sampling algorithm for diffusion_segmentation.", 63 | default="", 64 | choices=["", "DDPM", "DDIM"], 65 | ) 66 | args = parser.parse_args() 67 | 68 | return args 69 | 70 | 71 | def load_and_parse_config( 72 | log_dir: Path, 73 | num_timesteps: int, 74 | sampler: str, 75 | ) -> DictConfig: 76 | """Load and parse config. 77 | 78 | Args: 79 | log_dir: Directory of entire log. 80 | num_timesteps: Number of sampling steps for diffusion_segmentation. 81 | sampler: Sampling algorithm for diffusion_segmentation. 82 | 83 | Returns: 84 | Loaded config. 85 | """ 86 | config = OmegaConf.load(log_dir / "files" / "config_backup.yaml") 87 | if config.task.name == "diffusion_segmentation": 88 | if num_timesteps <= 0: 89 | raise ValueError("num_timesteps required for diffusion.") 90 | config.task.sampler.num_inference_timesteps = num_timesteps 91 | logging.info(f"Sampling {num_timesteps} steps.") 92 | if not sampler: 93 | raise ValueError("sampler required for diffusion.") 94 | config.task.sampler.name = sampler 95 | logging.info(f"Using sampler {sampler}.") 96 | return config 97 | 98 | 99 | def main() -> None: 100 | """Main function.""" 101 | args = parse_args() 102 | logging.info(f"Local devices are: {jax.local_devices()}") 103 | 104 | # load config 105 | config = load_and_parse_config( 106 | log_dir=args.log_dir, num_timesteps=args.num_timesteps, sampler=args.sampler 107 | ) 108 | 109 | # find all available checkpoints 110 | steps = get_checkpoint_steps(log_dir=args.log_dir) 111 | 112 | key = jax.random.PRNGKey(config.seed) 113 | key = common_utils.shard_prng_key(key) # each replica has a different key 114 | 115 | # init data 116 | dataset = get_image_tfds_dataset( 117 | dataset_name=config.data.name, 118 | config=config, 119 | ) 120 | train_iter = dataset.train_iter 121 | valid_iter = dataset.valid_iter 122 | platform = jax.local_devices()[0].platform 123 | if platform not in ["cpu", "tpu"]: 124 | train_iter = jax_utils.prefetch_to_device(train_iter, 2) 125 | valid_iter = jax_utils.prefetch_to_device(valid_iter, 2) 126 | 127 | # evaluate 128 | ckpt_dir = args.log_dir / "files" / "ckpt" 129 | run = build_experiment(config=config) 130 | for step in steps: 131 | logging.info(f"Starting valid split evaluation for step {step}.") 132 | 133 | # load checkpoint 134 | batch = next(train_iter) 135 | train_state, _ = run.train_init(batch=batch, ckpt_dir=ckpt_dir, step=step) 136 | 137 | # evaluation 138 | val_metrics = run.eval_step( 139 | train_state=train_state, iterator=valid_iter, num_steps=dataset.num_valid_steps, key=key 140 | ) 141 | 142 | # save metrics 143 | out_dir = ckpt_dir / f"checkpoint_{step}" 144 | if config.task.name == "diffusion_segmentation": 145 | out_dir = out_dir / config.task.sampler.name 146 | out_dir.mkdir(parents=True, exist_ok=True) 147 | with open(out_dir / "mean_metrics.json", "w", encoding="utf-8") as f: 148 | json.dump(val_metrics, f, sort_keys=True, indent=4) 149 | 150 | 151 | if __name__ == "__main__": 152 | main() 153 | -------------------------------------------------------------------------------- /imgx/task/__init__.py: -------------------------------------------------------------------------------- 1 | """Module for different learning tasks.""" 2 | -------------------------------------------------------------------------------- /imgx/task/diffusion_segmentation/__init__.py: -------------------------------------------------------------------------------- 1 | """Diffusion-based segmentation task.""" 2 | -------------------------------------------------------------------------------- /imgx/task/diffusion_segmentation/diffusion.py: -------------------------------------------------------------------------------- 1 | """Module for diffusion segmentation.""" 2 | from __future__ import annotations 3 | 4 | from dataclasses import dataclass 5 | 6 | import jax.numpy as jnp 7 | 8 | from imgx.diffusion.diffusion import Diffusion 9 | 10 | 11 | @dataclass 12 | class DiffusionSegmentation(Diffusion): 13 | """Base class for segmentation.""" 14 | 15 | def mask_to_x(self, mask: jnp.ndarray) -> jnp.ndarray: 16 | """Convert mask to x. 17 | 18 | Args: 19 | mask: boolean segmentation mask. 20 | 21 | Returns: 22 | array in diffusion space. 23 | """ 24 | raise NotImplementedError 25 | 26 | def x_to_mask(self, x: jnp.ndarray) -> jnp.ndarray: 27 | """Convert x to mask. 28 | 29 | Args: 30 | x: array in diffusion space. 31 | 32 | Returns: 33 | boolean segmentation mask. 34 | """ 35 | raise NotImplementedError 36 | 37 | def x_to_logits(self, x: jnp.ndarray) -> jnp.ndarray: 38 | """Convert x into model output space, which is logits. 39 | 40 | Args: 41 | x: array in diffusion space. 42 | 43 | Returns: 44 | unnormalised logits. 45 | """ 46 | raise NotImplementedError 47 | 48 | def model_out_to_logits_start( 49 | self, model_out: jnp.ndarray, x_t: jnp.ndarray, t_index: jnp.ndarray 50 | ) -> jnp.ndarray: 51 | """Convert model outputs to logits at time 0, noiseless. 52 | 53 | Args: 54 | model_out: model outputs. 55 | x_t: noisy x at time t. 56 | t_index: storing index values < self.num_timesteps. 57 | 58 | Returns: 59 | logits. 60 | """ 61 | raise NotImplementedError 62 | -------------------------------------------------------------------------------- /imgx/task/diffusion_segmentation/gaussian_diffusion_test.py: -------------------------------------------------------------------------------- 1 | """Test Gaussian diffusion related classes and functions.""" 2 | 3 | 4 | import chex 5 | from chex._src import fake 6 | 7 | from imgx.task.diffusion_segmentation.gaussian_diffusion import GaussianDiffusionSegmentation 8 | 9 | 10 | # Set `FLAGS.chex_n_cpu_devices` CPU devices for all tests. 11 | def setUpModule() -> None: # pylint: disable=invalid-name 12 | """Fake multi-devices.""" 13 | fake.set_n_cpu_devices(2) 14 | 15 | 16 | class TestGaussianDiffusionSegmentation(chex.TestCase): 17 | """Test the class GaussianDiffusion.""" 18 | 19 | batch_size = 2 20 | 21 | # unet 22 | in_channels = 1 23 | num_classes = 2 24 | num_channels = (1, 2) 25 | 26 | num_timesteps = 5 27 | num_timesteps_beta = 1001 28 | beta_schedule = "linear" 29 | beta_start = 0.0001 30 | beta_end = 0.02 31 | 32 | def test_attributes( 33 | self, 34 | ) -> None: 35 | """Test attribute shape.""" 36 | gd = GaussianDiffusionSegmentation.create( 37 | classes_are_exclusive=True, 38 | num_timesteps=self.num_timesteps, 39 | num_timesteps_beta=self.num_timesteps_beta, 40 | beta_schedule=self.beta_schedule, 41 | beta_start=self.beta_start, 42 | beta_end=self.beta_end, 43 | model_out_type="x_start", 44 | model_var_type="fixed_large", 45 | ) 46 | 47 | chex.assert_shape(gd.betas, (self.num_timesteps,)) 48 | chex.assert_shape(gd.alphas_cumprod, (self.num_timesteps,)) 49 | chex.assert_shape(gd.alphas_cumprod_prev, (self.num_timesteps,)) 50 | chex.assert_shape(gd.alphas_cumprod_next, (self.num_timesteps,)) 51 | chex.assert_shape(gd.sqrt_alphas_cumprod, (self.num_timesteps,)) 52 | chex.assert_shape(gd.sqrt_one_minus_alphas_cumprod, (self.num_timesteps,)) 53 | chex.assert_shape(gd.log_one_minus_alphas_cumprod, (self.num_timesteps,)) 54 | chex.assert_shape(gd.sqrt_recip_alphas_cumprod, (self.num_timesteps,)) 55 | chex.assert_shape(gd.sqrt_recip_alphas_cumprod_minus_one, (self.num_timesteps,)) 56 | chex.assert_shape(gd.posterior_mean_coeff_start, (self.num_timesteps,)) 57 | chex.assert_shape(gd.posterior_mean_coeff_t, (self.num_timesteps,)) 58 | chex.assert_shape(gd.posterior_variance, (self.num_timesteps,)) 59 | chex.assert_shape(gd.posterior_log_variance_clipped, (self.num_timesteps,)) 60 | -------------------------------------------------------------------------------- /imgx/task/diffusion_segmentation/save.py: -------------------------------------------------------------------------------- 1 | """Segmentation related io (file cannot be named as io). 2 | 3 | https://stackoverflow.com/questions/26569828/pycharm-py-initialize-cant-initialize-sys-standard-streams 4 | """ 5 | from __future__ import annotations 6 | 7 | from pathlib import Path 8 | 9 | import numpy as np 10 | 11 | from imgx.task.segmentation.save import save_segmentation_prediction 12 | 13 | 14 | def save_diffusion_segmentation_prediction( 15 | label_pred: np.ndarray, 16 | uids: list[str], 17 | out_dir: Path | None, 18 | tfds_dir: Path, 19 | reference_suffix: str = "mask_preprocessed", 20 | output_suffix: str = "mask_pred", 21 | ) -> None: 22 | """Save segmentation predictions. 23 | 24 | Args: 25 | label_pred: (num_samples, ..., num_timesteps), the values are integers. 26 | uids: (num_samples,). 27 | out_dir: output directory. 28 | tfds_dir: directory saving preprocessed images and labels. 29 | reference_suffix: suffix of reference image. 30 | output_suffix: suffix of output image. 31 | """ 32 | if out_dir is None: 33 | return 34 | num_timesteps = label_pred.shape[-1] 35 | for i in range(num_timesteps): 36 | save_segmentation_prediction( 37 | label_pred=label_pred[..., i], 38 | uids=uids, 39 | out_dir=out_dir / f"step_{i}", 40 | tfds_dir=tfds_dir, 41 | reference_suffix=reference_suffix, 42 | output_suffix=output_suffix, 43 | ) 44 | -------------------------------------------------------------------------------- /imgx/task/diffusion_segmentation/train_state.py: -------------------------------------------------------------------------------- 1 | """Training state and checkpoints. 2 | 3 | https://github.com/google/flax/blob/main/examples/imagenet/train.py 4 | """ 5 | from __future__ import annotations 6 | 7 | from typing import Callable 8 | 9 | import chex 10 | import flax.linen as nn 11 | import jax 12 | import jax.numpy as jnp 13 | from absl import logging 14 | from flax.training import dynamic_scale as dynamic_scale_lib 15 | from omegaconf import DictConfig 16 | 17 | from imgx.train_state import TrainState as BaseTrainState 18 | from imgx.train_state import init_optimizer 19 | 20 | 21 | class TrainState(BaseTrainState): 22 | """Train state. 23 | 24 | If using nn.BatchNorm, batch_stats needs to be tracked. 25 | https://flax.readthedocs.io/en/latest/guides/batch_norm.html 26 | https://github.com/google/flax/blob/main/examples/imagenet/train.py 27 | """ 28 | 29 | loss_count_hist: jnp.ndarray # mutable 30 | loss_sq_hist: jnp.ndarray # mutable 31 | 32 | 33 | def create_train_state( 34 | key: jax.Array, 35 | batch: dict[str, jnp.ndarray], 36 | model: nn.Module, 37 | config: DictConfig, 38 | initialized: Callable[[jax.Array, chex.ArrayTree, nn.Module], chex.ArrayTree], 39 | ) -> TrainState: 40 | """Create initial training state. 41 | 42 | Args: 43 | key: random key. 44 | batch: batch data for determining input shapes. 45 | model: model. 46 | config: entire configuration. 47 | initialized: function to get initialized model parameters. 48 | 49 | Returns: 50 | initial training state. 51 | """ 52 | dynamic_scale = None 53 | platform = jax.local_devices()[0].platform 54 | if config.half_precision and platform == "gpu": 55 | dynamic_scale = dynamic_scale_lib.DynamicScale() 56 | 57 | params = initialized(key, batch, model) 58 | 59 | # count params 60 | params_count = sum(x.size for x in jax.tree_util.tree_leaves(params)) 61 | logging.info(f"The model has {params_count:,} parameters.") 62 | 63 | # diffusion related 64 | num_timesteps = config.task.diffusion.num_timesteps 65 | loss_count_hist = jnp.zeros((num_timesteps,), dtype=jnp.int32) 66 | loss_sq_hist = jnp.zeros((num_timesteps,), dtype=jnp.float32) 67 | 68 | tx = init_optimizer(config=config) 69 | state = TrainState.create( 70 | apply_fn=model.apply, 71 | params=params, 72 | tx=tx, 73 | dynamic_scale=dynamic_scale, 74 | loss_count_hist=loss_count_hist, 75 | loss_sq_hist=loss_sq_hist, 76 | ) 77 | return state 78 | -------------------------------------------------------------------------------- /imgx/task/segmentation/__init__.py: -------------------------------------------------------------------------------- 1 | """Segmentation task.""" 2 | -------------------------------------------------------------------------------- /imgx/task/segmentation/save.py: -------------------------------------------------------------------------------- 1 | """Segmentation related io (file cannot be named as io). 2 | 3 | https://stackoverflow.com/questions/26569828/pycharm-py-initialize-cant-initialize-sys-standard-streams 4 | """ 5 | from __future__ import annotations 6 | 7 | from pathlib import Path 8 | 9 | import numpy as np 10 | import SimpleITK as sitk # noqa: N813 11 | 12 | from imgx.datasets.save import save_image 13 | 14 | 15 | def save_segmentation_prediction( 16 | label_pred: np.ndarray, 17 | uids: list[str], 18 | out_dir: Path | None, 19 | tfds_dir: Path, 20 | reference_suffix: str = "mask_preprocessed", 21 | output_suffix: str = "mask_pred", 22 | ) -> None: 23 | """Save segmentation predictions. 24 | 25 | Args: 26 | label_pred: (num_samples, ...), the values are integers. 27 | uids: (num_samples,). 28 | out_dir: output directory. 29 | tfds_dir: directory saving preprocessed images and labels. 30 | reference_suffix: suffix of reference image. 31 | output_suffix: suffix of output image. 32 | """ 33 | if out_dir is None: 34 | return 35 | if label_pred.ndim == 3 and np.max(label_pred) > 1: 36 | raise ValueError( 37 | f"Prediction values should be 0 or 1, but " 38 | f"max value is {np.max(label_pred)}. " 39 | f"Multi-class segmentation for 2D images are not supported." 40 | ) 41 | if label_pred.ndim not in [3, 4]: 42 | raise ValueError( 43 | f"Prediction should be 3D or 4D with num_samples axis, but {label_pred.ndim}D is given." 44 | ) 45 | file_suffix = "nii.gz" if label_pred.ndim == 4 else "png" 46 | out_dir.mkdir(parents=True, exist_ok=True) 47 | for i, uid in enumerate(uids): 48 | reference_image = sitk.ReadImage(tfds_dir / f"{uid}_{reference_suffix}.{file_suffix}") 49 | save_image( 50 | image=label_pred[i, ...], 51 | reference_image=reference_image, 52 | out_path=out_dir / f"{uid}_{output_suffix}.{file_suffix}", 53 | dtype=np.uint8, 54 | ) 55 | -------------------------------------------------------------------------------- /imgx/task/util.py: -------------------------------------------------------------------------------- 1 | """Shared utility functions.""" 2 | from __future__ import annotations 3 | 4 | from jax import numpy as jnp 5 | 6 | 7 | def decode_uids(uids: jnp.ndarray) -> list[str]: 8 | """Decode uids. 9 | 10 | Args: 11 | uids: uids in bytes or int. 12 | 13 | Returns: 14 | decoded uids. 15 | """ 16 | decoded = [] 17 | for x in uids.tolist(): 18 | if isinstance(x, bytes): 19 | decoded.append(x.decode("utf-8")) 20 | elif x == 0: 21 | # the batch was not complete, padded with zero 22 | decoded.append("") 23 | else: 24 | raise ValueError(f"uid {x} is not supported.") 25 | return decoded 26 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # package 2 | [build-system] 3 | requires = ["setuptools", "setuptools-scm"] 4 | build-backend = "setuptools.build_meta" 5 | 6 | [project] 7 | name = "imgx" 8 | authors = [ 9 | {name = "Yunguan Fu", email = "yunguan.fu.18@ucl.ac.uk"}, 10 | ] 11 | description = "A Jax-based deep learning toolkit for biomedical applications." 12 | requires-python = ">=3.9" 13 | license = {text = "Apache-2.0"} 14 | version = "0.3.2" 15 | 16 | [project.scripts] 17 | imgx_train="imgx.run_train:main" 18 | imgx_valid="imgx.run_valid:main" 19 | imgx_test="imgx.run_test:main" 20 | 21 | [tool.setuptools] 22 | packages = ["imgx"] 23 | package-dir = {"imgx"="./imgx"} 24 | 25 | # pytest 26 | [tool.pytest.ini_options] 27 | markers = [ 28 | "slow", # slow unit tests 29 | "integration", # integration tests 30 | ] 31 | 32 | # pre-commit 33 | [tool.isort] 34 | py_version=39 35 | known_third_party = ["SimpleITK","absl","chex","hydra","jax","numpy","omegaconf","optax", 36 | "pandas","pytest","ray","rdkit","setuptools","tensorflow","tensorflow_datasets","tree","wandb"] 37 | multi_line_output = 3 38 | force_grid_wrap = 0 39 | line_length = 100 40 | include_trailing_comma = true 41 | use_parentheses = true 42 | 43 | [tool.mypy] 44 | python_version = "3.9" 45 | warn_unused_configs = true 46 | ignore_missing_imports = true 47 | disable_error_code = ["misc","attr-defined","call-arg"] 48 | show_error_codes = true 49 | files = "**/*.py" 50 | 51 | [tool.ruff] 52 | line-length = 100 53 | # Enable Pyflakes `E` and `F` codes by default. 54 | select = [ 55 | "F", # Pyflakes 56 | "E", "W", # pycodestyle 57 | "UP", # pyupgrade 58 | "N", # pep8-naming 59 | # flake8 60 | "YTT", 61 | "ANN", 62 | "S", 63 | "BLE", 64 | "B", 65 | "A", 66 | "C4", 67 | "T10", 68 | "EM", 69 | "ISC", 70 | "ICN", 71 | "T20", 72 | "PT", 73 | "Q", 74 | "RET", 75 | "SIM", 76 | "TID", 77 | "ARG", 78 | "DTZ", 79 | "PIE", 80 | "PGH", # pygrep-hooks 81 | "RUF", # ruff 82 | "PLC", "PLE", "PLR", "PLW", # Pylint 83 | ] 84 | ignore = [ 85 | "ANN002", # MissingTypeArgs 86 | "ANN003", # MissingTypeKwargs 87 | "ANN101", # MissingTypeSelf 88 | "EM101", # Exception must not use a string literal, assign to variable first 89 | "EM102", # Exception must not use an f-string literal, assign to variable first 90 | "RET504", # Unnecessary variable assignment before `return` statement 91 | "S301", # `pickle` and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue 92 | "PLR0913", # Too many arguments to function call 93 | "PLR0915", # Too many statements 94 | "PLE0605", # Invalid format for `__all__`, must be `tuple` or `list` 95 | "PLR0912", # Too many branches 96 | ] 97 | # Exclude a variety of commonly ignored directories. 98 | exclude = [ 99 | ".bzr", 100 | ".direnv", 101 | ".eggs", 102 | ".git", 103 | ".hg", 104 | ".mypy_cache", 105 | ".nox", 106 | ".pants.d", 107 | ".ruff_cache", 108 | ".svn", 109 | ".tox", 110 | ".venv", 111 | "__pypackages__", 112 | "_build", 113 | "buck-out", 114 | "build", 115 | "dist", 116 | "node_modules", 117 | "venv", 118 | ] 119 | # Allow unused variables when underscore-prefixed. 120 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 121 | # Assume Python 3.9 122 | target-version = "py39" 123 | 124 | [tool.ruff.mccabe] 125 | # Unlike Flake8, default to a complexity level of 10. 126 | max-complexity = 10 127 | 128 | [tool.ruff.per-file-ignores] 129 | "test_*.py" = ["S101"] 130 | "*_test.py" = ["S101"] 131 | 132 | [tool.ruff.pydocstyle] 133 | # Use Google-style docstrings. 134 | convention = "google" 135 | 136 | [tool.ruff.pylint] 137 | allow-magic-value-types = ["int", "str"] 138 | --------------------------------------------------------------------------------