├── tests ├── __init__,py └── test_first_run.py ├── ganslate ├── __init__.py ├── nn │ ├── __init__.py │ ├── gans │ │ ├── __init__.py │ │ ├── paired │ │ │ └── __init__.py │ │ └── unpaired │ │ │ └── __init__.py │ ├── losses │ │ ├── __init__.py │ │ ├── utils │ │ │ ├── __init__.py │ │ │ └── ssim.py │ │ ├── pix2pix_losses.py │ │ ├── cut_losses.py │ │ └── cyclegan_losses.py │ ├── generators │ │ ├── resnet │ │ │ ├── __init__.py │ │ │ ├── resnet2d.py │ │ │ └── resnet3d.py │ │ ├── unet │ │ │ └── __init__.py │ │ ├── vnet │ │ │ └── __init__.py │ │ └── __init__.py │ ├── discriminators │ │ ├── patchgan │ │ │ ├── __init__.py │ │ │ ├── patchgan3d.py │ │ │ ├── patchgan2d.py │ │ │ ├── selfattention_patchgan3d.py │ │ │ └── multiscale_patchgan3d.py │ │ └── __init__.py │ ├── invertible.py │ ├── attention.py │ ├── separable.py │ └── utils.py ├── configs │ ├── __init__.py │ ├── inference.py │ ├── config.py │ ├── training.py │ ├── validation_testing.py │ └── utils.py ├── engines │ ├── __init__.py │ ├── utils.py │ ├── inferer.py │ └── base.py ├── utils │ ├── __init__.py │ ├── cli │ │ ├── __init__.py │ │ ├── scripts │ │ │ ├── __init__.py │ │ │ └── download_datasets.py │ │ ├── cookiecutter_templates │ │ │ ├── __init__.py │ │ │ ├── new_project │ │ │ │ ├── __init__.py │ │ │ │ ├── {{ cookiecutter.project_name }} │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── datasets │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── train_dataset.py │ │ │ │ │ │ ├── infer_dataset.py │ │ │ │ │ │ └── val_test_dataset.py │ │ │ │ │ ├── architectures │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ └── template_architecture.py │ │ │ │ │ ├── README.md │ │ │ │ │ └── experiments │ │ │ │ │ │ └── template_experiment.yaml │ │ │ │ └── cookiecutter.json │ │ │ └── your_first_run │ │ │ │ ├── __init__.py │ │ │ │ ├── {{ cookiecutter.project_name }} │ │ │ │ ├── __init__.py │ │ │ │ └── default.yaml │ │ │ │ └── cookiecutter.json │ │ └── interface.py │ ├── metrics │ │ ├── __init__.py │ │ └── train_metrics.py │ ├── trackers │ │ ├── __init__.py │ │ ├── tensorboard.py │ │ ├── inference.py │ │ ├── wandb.py │ │ ├── base.py │ │ └── training.py │ ├── csv_saver.py │ ├── sliding_window_inferer.py │ └── environment.py └── data │ ├── utils │ ├── __init__.py │ ├── ops.py │ ├── fov_truncate.py │ ├── normalization.py │ └── image_pool.py │ ├── __init__.py │ ├── paired_image_dataset.py │ ├── samplers.py │ └── unpaired_image_dataset.py ├── projects ├── README.md ├── horse2zebra │ ├── __init__.py │ └── experiments │ │ └── default.yaml ├── cityscapes_label2photo │ ├── __init__.py │ ├── jobscript.sh │ └── experiments │ │ ├── pix2pix.yaml │ │ └── cyclegan.yaml ├── cleargrasp_depth_estimation │ ├── __init__.py │ ├── datasets │ │ ├── __init__.py │ │ └── old │ │ │ └── __init__.py │ ├── modules │ │ ├── __init__.py │ │ ├── old │ │ │ ├── cyclegan_multimodal_v2.py │ │ │ └── cyclegan_multimodal_v1_structure.py │ │ └── cyclegan_losses_for_v3.py │ ├── jobscript.sh │ └── experiments │ │ ├── old │ │ ├── pix2pix.yaml │ │ ├── cyclegan_multimodal_v2.yaml │ │ ├── cyclegan_multimodal_v1_structure.yaml │ │ ├── cyclegan_multimodal_v1.yaml │ │ └── cyclegan_multimodal_v3.yaml │ │ ├── pix2pix_new.yaml │ │ ├── cyclegan_naive.yaml │ │ └── cyclegan_balanced.yaml ├── maastro_hx4_pet_translation │ ├── __init__.py │ ├── datasets │ │ ├── __init__.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ └── basic.py │ ├── modules │ │ ├── __init__.py │ │ └── hx4_cyclegan_balanced_losses.py │ ├── jobscript.sh │ └── experiments │ │ ├── pix2pix.yaml │ │ ├── cyclegan_naive.yaml │ │ └── cyclegan_balanced.yaml └── brats_mri_sequence_translation │ ├── __init__.py │ ├── datasets │ ├── __init__.py │ ├── val_test_dataset.py │ └── train_dataset.py │ └── experiments │ ├── cut.yaml │ ├── revgan.yaml │ └── cyclegan.yaml ├── docs ├── api │ └── cli.md ├── requirements.txt ├── package_overview │ ├── 6_trackers.md │ ├── 1_cli.md │ ├── 2_projects.md │ └── 5_engines.md ├── community │ └── contributing.md ├── imgs │ ├── uml-ganslate_engines.png │ └── your_first_run_docker.png ├── tutorials_basic │ └── 1_first_run.md ├── index.md ├── installation.md └── tutorials_advanced │ └── 2_custom_G_and_D_architecture.md ├── .gitattributes ├── CONTRIBUTING.md ├── MANIFEST.in ├── setup.py ├── CITATION.cff ├── pyproject.toml ├── docker └── Dockerfile ├── mkdocs.yml ├── .github └── workflows │ ├── yapf_autoformat.yml │ └── ganslate_testing_suite.yml ├── .gitignore ├── tools └── analyzers │ └── wandb │ ├── utils.py │ └── wandb_analyzer.py ├── setup.cfg └── README.md /tests/__init__,py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ganslate/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ganslate/nn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ganslate/configs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ganslate/engines/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ganslate/nn/gans/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ganslate/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ganslate/data/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ganslate/nn/losses/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ganslate/utils/cli/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /projects/README.md: -------------------------------------------------------------------------------- 1 | # Placeholder -------------------------------------------------------------------------------- /projects/horse2zebra/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/api/cli.md: -------------------------------------------------------------------------------- 1 | # CLI 2 | 3 | 4 | -------------------------------------------------------------------------------- /ganslate/nn/losses/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ganslate/utils/cli/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ganslate/utils/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ganslate/utils/trackers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | mkdocs==1.1.2 2 | -------------------------------------------------------------------------------- /ganslate/nn/generators/resnet/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ganslate/nn/generators/unet/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ganslate/nn/generators/vnet/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /projects/cityscapes_label2photo/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-vendored 2 | -------------------------------------------------------------------------------- /ganslate/nn/discriminators/patchgan/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /projects/cleargrasp_depth_estimation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /projects/maastro_hx4_pet_translation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ganslate/utils/cli/cookiecutter_templates/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /projects/brats_mri_sequence_translation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /projects/cleargrasp_depth_estimation/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /projects/cleargrasp_depth_estimation/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /projects/maastro_hx4_pet_translation/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /projects/maastro_hx4_pet_translation/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/package_overview/6_trackers.md: -------------------------------------------------------------------------------- 1 | # Logging and Visualization: -------------------------------------------------------------------------------- /projects/brats_mri_sequence_translation/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /projects/cleargrasp_depth_estimation/datasets/old/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /projects/maastro_hx4_pet_translation/datasets/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | Wish to contribute? Here, you can find guidelines. -------------------------------------------------------------------------------- /ganslate/utils/cli/cookiecutter_templates/new_project/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ganslate/utils/cli/cookiecutter_templates/your_first_run/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ganslate/utils/cli/cookiecutter_templates/new_project/{{ cookiecutter.project_name }}/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/community/contributing.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Contribution information will be added shortly! -------------------------------------------------------------------------------- /ganslate/utils/cli/cookiecutter_templates/your_first_run/{{ cookiecutter.project_name }}/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ganslate/nn/gans/paired/__init__.py: -------------------------------------------------------------------------------- 1 | from .pix2pix import Pix2PixConditionalGAN, Pix2PixConditionalGANConfig -------------------------------------------------------------------------------- /ganslate/utils/cli/cookiecutter_templates/new_project/{{ cookiecutter.project_name }}/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ganslate/utils/cli/cookiecutter_templates/new_project/{{ cookiecutter.project_name }}/architectures/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/imgs/uml-ganslate_engines.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ganslate-team/ganslate/HEAD/docs/imgs/uml-ganslate_engines.png -------------------------------------------------------------------------------- /docs/imgs/your_first_run_docker.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ganslate-team/ganslate/HEAD/docs/imgs/your_first_run_docker.png -------------------------------------------------------------------------------- /ganslate/utils/cli/cookiecutter_templates/new_project/{{ cookiecutter.project_name }}/README.md: -------------------------------------------------------------------------------- 1 | # {{ cookiecutter.project_name }} 2 | 3 | 4 | -------------------------------------------------------------------------------- /ganslate/nn/gans/unpaired/__init__.py: -------------------------------------------------------------------------------- 1 | from .cut import CUT, CUTConfig 2 | from .cyclegan import CycleGAN, CycleGANConfig 3 | from .revgan import RevGAN, RevGANConfig -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include ganslate *.sh 2 | recursive-include ganslate/utils/cli/cookiecutter_templates * 3 | recursive-exclude ganslate/utils/cli/cookiecutter_templates __pycache__/* -------------------------------------------------------------------------------- /ganslate/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .unpaired_image_dataset import UnpairedImageDataset, UnpairedImageDatasetConfig 2 | from .paired_image_dataset import PairedImageDataset, PairedImageDatasetConfig -------------------------------------------------------------------------------- /ganslate/utils/csv_saver.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | class Saver: 4 | def __init__(self) -> None: 5 | self.df = pd.DataFrame() 6 | 7 | def add(self, row): 8 | self.df = self.df.append(row, ignore_index=True) 9 | 10 | def write(self, path): 11 | self.df.to_csv(path) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from pkg_resources import VersionConflict, require 4 | from setuptools import setup 5 | 6 | try: 7 | require('setuptools>=38.3') 8 | except VersionConflict: 9 | print("Error: version of setuptools is too old (<38.3)!") 10 | sys.exit(1) 11 | 12 | if __name__ == "__main__": 13 | setup() -------------------------------------------------------------------------------- /ganslate/utils/cli/cookiecutter_templates/your_first_run/cookiecutter.json: -------------------------------------------------------------------------------- 1 | { 2 | "project_name": "facades_project", 3 | "number_of_iterations": 100000, 4 | "batch_size": 1, 5 | "logging_frequency": 500, 6 | "checkpointing_frequency": 2000, 7 | "generator_model": ["Resnet2D", "Unet2D"], 8 | "cycle_consistency_ssim_percentage": 0.84, 9 | "path": ".", 10 | "enable_cuda": true 11 | } -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.1.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: Hadzic 5 | given-names: Ibrahim 6 | - family-names: Pai 7 | given-names: Suraj 8 | - family-names: Rao 9 | given-names: Chinmay 10 | - family-names: Teuwen 11 | given-names: Jonas 12 | title: ganslate-team/ganslate: Initial public release 13 | version: v0.1.0 14 | date-released: 2017-12-1 15 | -------------------------------------------------------------------------------- /ganslate/nn/discriminators/__init__.py: -------------------------------------------------------------------------------- 1 | from .patchgan.patchgan2d import PatchGAN2D, PatchGAN2DConfig 2 | from .patchgan.patchgan3d import PatchGAN3D, PatchGAN3DConfig 3 | from .patchgan.multiscale_patchgan3d import (MultiScalePatchGAN3D, 4 | MultiScalePatchGAN3DConfig) 5 | from .patchgan.selfattention_patchgan3d import (SelfAttentionPatchGAN3D, 6 | SelfAttentionPatchGAN3DConfig) 7 | -------------------------------------------------------------------------------- /ganslate/configs/inference.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from dataclasses import dataclass 3 | 4 | from ganslate.configs import base, validation_testing 5 | 6 | 7 | @dataclass 8 | class InferenceConfig(base.BaseEngineConfig): 9 | is_deployment: bool = False 10 | dataset: Optional[base.BaseDatasetConfig] = None 11 | sliding_window: Optional[validation_testing.SlidingWindowConfig] = None 12 | checkpointing: base.CheckpointingConfig = base.CheckpointingConfig() 13 | -------------------------------------------------------------------------------- /ganslate/nn/generators/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet.resnet2d import Resnet2D, Resnet2DConfig 2 | from .resnet.resnet3d import Resnet3D, Resnet3DConfig 3 | from .resnet.piresnet3d import Piresnet3D, Piresnet3DConfig 4 | 5 | from .vnet.vnet2d import Vnet2D, Vnet2DConfig 6 | from .vnet.vnet3d import Vnet3D, Vnet3DConfig 7 | from .vnet.selfattention_vnet3d import SelfAttentionVnet3D, SelfAttentionVnet3DConfig 8 | 9 | from .unet.unet2d import Unet2D, Unet2DConfig 10 | from .unet.unet3d import Unet3D, Unet3DConfig -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=42", 4 | "wheel" 5 | ] 6 | build-backend = "setuptools.build_meta" 7 | [tool.bumpver] 8 | current_version = "0.1.1" 9 | version_pattern = "MAJOR.MINOR.PATCH" 10 | commit_message = "bump version to {new_version}" 11 | commit = true 12 | tag = true 13 | push = true 14 | 15 | [tool.bumpver.file_patterns] 16 | "pyproject.toml" = [ 17 | 'current_version = "{version}"', 18 | ] 19 | "setup.cfg" = [ 20 | "version = {version}" 21 | ] 22 | -------------------------------------------------------------------------------- /ganslate/utils/cli/cookiecutter_templates/new_project/cookiecutter.json: -------------------------------------------------------------------------------- 1 | { 2 | "project_name": null, 3 | "dataset_name": "{{ cookiecutter.project_name | replace(' ', '') | capitalize}}", 4 | "cyclegan_name": "{{ cookiecutter.project_name | replace(' ', '') | capitalize}}", 5 | "number_of_iterations": 117700, 6 | "batch_size": 1, 7 | "logging_frequency": 500, 8 | "checkpointing_frequency": 2000, 9 | "generator_model": ["Resnet2D", "Unet2D"], 10 | "cycle_consistency_ssim_percentage": 0.84, 11 | "path": "." 12 | } -------------------------------------------------------------------------------- /ganslate/data/utils/ops.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def pad(volume, target_shape): 5 | assert len(target_shape) == len(volume.shape) 6 | # By default no padding 7 | pad_width = [(0, 0) for _ in range(len(target_shape))] 8 | 9 | for dim in range(len(target_shape)): 10 | if target_shape[dim] > volume.shape[dim]: 11 | pad_total = target_shape[dim] - volume.shape[dim] 12 | pad_per_side = pad_total // 2 13 | pad_width[dim] = (pad_per_side, pad_total % 2 + pad_per_side) 14 | 15 | return np.pad(volume, pad_width, 'constant', constant_values=volume.min()) -------------------------------------------------------------------------------- /ganslate/nn/losses/pix2pix_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import ganslate.nn.losses.utils.ssim as ssim 3 | 4 | import logging 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | class Pix2PixLoss: 9 | """Defines "pixel-to-pixel" loss (applied voxel-to-voxel for 3D omages) 10 | L1 distance between fake_B and real_B images 11 | """ 12 | 13 | def __init__(self, conf): 14 | self.lambda_pix2pix = conf.train.gan.optimizer.lambda_pix2pix 15 | self.criterion = torch.nn.L1Loss() 16 | 17 | def __call__(self, fake_B, real_B): 18 | pix2pix_loss = self.criterion(fake_B, real_B) 19 | return self.lambda_pix2pix * pix2pix_loss 20 | -------------------------------------------------------------------------------- /ganslate/configs/config.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Any 2 | from dataclasses import dataclass 3 | 4 | from ganslate.configs.training import TrainConfig 5 | from ganslate.configs.validation_testing import ValidationConfig, TestConfig 6 | from ganslate.configs.inference import InferenceConfig 7 | 8 | 9 | @dataclass 10 | class Config: 11 | # Enables importing project-specific classes located in the project's dir 12 | project: Optional[Any] = None 13 | # Modes handled internally 14 | mode: str = "train" 15 | 16 | train: TrainConfig = TrainConfig() 17 | val: Optional[ValidationConfig] = None 18 | test: Optional[TestConfig] = None 19 | infer: Optional[InferenceConfig] = None 20 | -------------------------------------------------------------------------------- /ganslate/engines/utils.py: -------------------------------------------------------------------------------- 1 | from ganslate.engines.trainer import Trainer 2 | from ganslate.engines.validator_tester import Tester 3 | from ganslate.engines.inferer import Inferer 4 | from ganslate.utils import communication, environment 5 | from ganslate.utils.builders import build_conf 6 | 7 | 8 | ENGINES = { 9 | 'train': Trainer, 10 | 'test': Tester, 11 | 'infer': Inferer 12 | } 13 | 14 | def init_engine(mode, omegaconf_args): 15 | assert mode in ENGINES.keys() 16 | 17 | # inits distributed mode if ran with torch.distributed.launch 18 | communication.init_distributed() 19 | environment.setup_threading() 20 | 21 | conf = build_conf(omegaconf_args) 22 | return ENGINES[mode](conf) 23 | -------------------------------------------------------------------------------- /docs/tutorials_basic/1_first_run.md: -------------------------------------------------------------------------------- 1 | # Your First Run (facades) 2 | 3 | For both types of install, running the basic facades example is the same. 4 | 5 | Once the installation is complete and you can access the [CLI as shown](using_cli.md), run 6 | ```console 7 | ganslate your-first-run 8 | ``` 9 | On running this, a few options will show up that can be customized. You may also leave it at its default values. Once the prompts 10 | are completed, you will have a folder generated with a demo `facades` project in the path you specified. 11 | 12 | ### Training 13 | Next, you can run the training using the command below, 14 | 15 | ```console 16 | ganslate train config=/default.yaml 17 | ``` 18 | 19 | !!! note 20 | If you have more than one GPU, then you can either run the training in distributed mode or [set CUDA_VISIBLE_DEVICES environment variable](https://developer.nvidia.com/blog/cuda-pro-tip-control-gpu-visibility-cuda_visible_devices/) to use only single GPUs. -------------------------------------------------------------------------------- /docs/package_overview/1_cli.md: -------------------------------------------------------------------------------- 1 | # Using the command line interface 2 | 3 | The command line interface for `ganslate` offers a very simple way to interact with various functionalities. 4 | 5 | After installing the package, you can type 6 | ``` 7 | ganslate 8 | ``` 9 | 10 | to explore the various features available in the CLI 11 | 12 | These are the various options available, 13 | 14 | ``` 15 | Usage: ganslate [OPTIONS] COMMAND [ARGS]... 16 | 17 | ganslate - GAN image-to-image translation framework made simple and 18 | extensible. 19 | 20 | Options: 21 | --help Show this message and exit. 22 | 23 | Commands: 24 | download-dataset Download a dataset. 25 | download-project Download a project. 26 | infer Do inference with a trained model. 27 | install-nvidia-apex Install Nvidia Apex for mixed precision support. 28 | new-project Initialize a new project. 29 | test Test a trained model. 30 | train Train a model. 31 | ``` -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:1.9.0-cuda10.2-cudnn7-devel 2 | 3 | RUN apt-get -qq update 4 | # libsm6 and libxext6 are needed for cv2 5 | RUN apt-get update && apt-get install -y libxext6 libglib2.0-0 libsm6 build-essential sudo \ 6 | libgl1-mesa-glx git wget rsync tmux nano dcmtk fftw3-dev liblapacke-dev libpng-dev libopenblas-dev jq && \ 7 | rm -rf /var/lib/apt/lists/* 8 | 9 | RUN adduser --disabled-password --gecos '' ganslate_user 10 | USER ganslate_user 11 | 12 | # Setup apex for mixed precision 13 | WORKDIR /tmp 14 | RUN git clone https://github.com/NVIDIA/apex \ 15 | && cd apex \ 16 | && pip install -v --disable-pip-version-check --no-cache-dir ./ \ 17 | && cd .. 18 | 19 | USER root 20 | RUN mkdir /data && chmod 777 /data 21 | USER ganslate_user 22 | 23 | WORKDIR /home/ganslate_user 24 | 25 | # Install the ganslate package #TODO: Replace with final package link 26 | RUN pip install --no-warn-script-location -i https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple ganslate==0.1.4 27 | 28 | # Script are installed in ~/.local/bin, add it to PATH 29 | ENV PATH "~/.local/bin:$PATH" 30 | 31 | ENTRYPOINT /bin/bash 32 | -------------------------------------------------------------------------------- /tests/test_first_run.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from ganslate.utils.cli.interface import setup_first_run 4 | from ganslate.engines.utils import init_engine 5 | 6 | # test_with_pytest.py 7 | def test_first_run(): 8 | """[summary] 9 | Generate setup for first run and check if the directory and files are 10 | created. 11 | """ 12 | setup_first_run(".", True, extra_context={"number_of_iterations": 2, 13 | "project_name": "first_run_test", 14 | "logging_frequency": 1, 15 | "enable_cuda": False 16 | }) 17 | 18 | generated_project_dir = Path("first_run_test") 19 | assert generated_project_dir.is_dir() 20 | assert (generated_project_dir / "facades" / "train" / "A" ).is_dir() 21 | assert (generated_project_dir / "facades" / "train" / "B" ).is_dir() 22 | 23 | 24 | def test_training(): 25 | """[summary] 26 | Run 10 iterations of dummy training and see if it works. 27 | """ 28 | assert init_engine('train', ["config=first_run_test/default.yaml"]).run() is None 29 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: ganslate 2 | 3 | 4 | nav: 5 | 6 | - Home: index.md 7 | - Installation: installation.md 8 | 9 | - Package Overview: 10 | - Commandline Interface: package_overview/1_cli.md 11 | - ganslate Projects: package_overview/2_projects.md 12 | - Datasets: package_overview/3_datasets.md 13 | - Model Architectures and Loss Functions: package_overview/4_architectures.md 14 | - Engines: package_overview/5_engines.md 15 | - Logging and Visualization: package_overview/6_trackers.md 16 | - Configuration: package_overview/7_configuration.md 17 | 18 | - Basic Tutorials: 19 | - First Run: tutorials_basic/1_first_run.md 20 | - Your New Project: tutorials_basic/2_new_project.md 21 | 22 | - Advanced Tutorials: 23 | - Custom GAN Architectures: tutorials_advanced/1_custom_gan_architecture.md 24 | - Custom Generator and Discriminator Architectures: tutorials_advanced/2_custom_G_and_D_architectures.md 25 | 26 | - API: api/* 27 | 28 | - Community: 29 | - Contributing: community/contributing.md 30 | 31 | 32 | theme: readthedocs 33 | 34 | 35 | markdown_extensions: 36 | - admonition -------------------------------------------------------------------------------- /docs/package_overview/2_projects.md: -------------------------------------------------------------------------------- 1 | # `ganslate` Projects 2 | 3 | In `ganslate`, a _project_ refers to a collection of all custom code and configuration files pertaining to your specific task. The project directory is expected to have a certain structure that isolates logically different parts of the project, such as data pipeline, GAN implementation, and configuration. The directory structure is as follows 4 | 5 | ```text 6 | 7 | | 8 | |- datasets 9 | | |- custom_train_dataset.py 10 | | |- custom_val_test_dataset.py 11 | | 12 | |- architectures 13 | | |- custom_gan.py 14 | | 15 | |- experiments 16 | | |- exp1_config.yaml 17 | | 18 | |- __init__.py 19 | | 20 | |- README.md 21 | ``` 22 | 23 | The `__init__.py` file initializes your project directory as Python module which is necessary for `ganslate`'s configuration system to correctly function. (See [configuration](./7_configuration.md) for details). The `README.md` file could contain a description of your task. 24 | 25 | `ganslate` provides a Cookiecutter template which can automatically generate an empty project for you. The tutorial [Your First Project](../tutorials_basic/2_new_project.md) provides detailed instructions on how to create and operate your own project. -------------------------------------------------------------------------------- /projects/cityscapes_label2photo/jobscript.sh: -------------------------------------------------------------------------------- 1 | #!/usr/local_rwth/bin/zsh 2 | 3 | 4 | # Job configuration --- 5 | 6 | #SBATCH --job-name=label2photo_pix2pix 7 | #SBATCH --output=/home/zk315372/Chinmay/Git/ganslate/projects/cityscapes_label2photo/slurm_logs/%j.log 8 | 9 | ## OpenMP settings 10 | #SBATCH --cpus-per-task=8 11 | #SBATCH --mem-per-cpu=4G 12 | 13 | ## Request for a node with 2 Tesla P100 GPUs 14 | #SBATCH --gres=gpu:pascal:2 15 | 16 | #SBATCH --time=5:00:00 17 | 18 | ## TO use the UM DKE project account 19 | # #SBATCH --account=um_dke 20 | 21 | 22 | # Load CUDA 23 | module load cuda 24 | 25 | # Debug info 26 | echo; echo 27 | nvidia-smi 28 | echo; echo 29 | 30 | # Execute training 31 | python_interpreter="/home/zk315372/miniconda3/envs/gan_env/bin/python3" 32 | training_file="/home/zk315372/Chinmay/Git/ganslate/tools/train.py" 33 | config_file="/home/zk315372/Chinmay/Git/ganslate/projects/cityscapes_label2photo/experiments/pix2pix.yaml" 34 | 35 | CUDA_VISIBLE_DEVICES=0 $python_interpreter $training_file config=$config_file 36 | 37 | 38 | # ---------------------- 39 | # Run single GPU example: 40 | # CUDA_VISIBLE_DEVICES=0 python tools/train.py config="./projects/cityscapes_label2photo/experiments/cyclegan.yaml" 41 | 42 | # Run distributed example: 43 | # python -m torch.distributed.launch --use_env --nproc_per_node 2 tools/train.py config="./projects/cityscapes_label2photo/experiments/cyclegan.yaml" -------------------------------------------------------------------------------- /projects/cleargrasp_depth_estimation/jobscript.sh: -------------------------------------------------------------------------------- 1 | #!/usr/local_rwth/bin/zsh 2 | 3 | 4 | # Job configuration --- 5 | 6 | #SBATCH --job-name=pix2pix_lambda100 7 | #SBATCH --output=/home/zk315372/Chinmay/Cleargrasp-Depth-Estimation/pix2pix_lambda100/slurm-%j.log 8 | 9 | ## OpenMP settings 10 | #SBATCH --cpus-per-task=8 11 | #SBATCH --mem-per-cpu=4G 12 | 13 | ## Request for a node with 2 Tesla P100 GPUs 14 | #SBATCH --gres=gpu:pascal:2 15 | 16 | #SBATCH --time=4:00:00 17 | 18 | ## TO use the UM DKE project account 19 | # #SBATCH --account=um_dke 20 | 21 | 22 | # Load CUDA 23 | module load cuda 24 | 25 | # Debug info 26 | echo; echo 27 | nvidia-smi 28 | echo; echo 29 | 30 | # Execute training 31 | python_interpreter="/home/zk315372/miniconda3/envs/gan_env/bin/python3" 32 | training_file="/home/zk315372/Chinmay/Git/ganslate/tools/train.py" 33 | config_file="/home/zk315372/Chinmay/Git/ganslate/projects/cleargrasp_depth_estimation/experiments/pix2pix_new.yaml" 34 | 35 | CUDA_VISIBLE_DEVICES=0 $python_interpreter $training_file config=$config_file 36 | 37 | 38 | # ---------------------- 39 | # Run single GPU example: 40 | # CUDA_VISIBLE_DEVICES=0 python tools/train.py config="./projects/cleargrasp_depth_estimation/experiments/pix2pix.yaml" 41 | 42 | # Run distributed example: 43 | # python -m torch.distributed.launch --use_env --nproc_per_node 2 tools/train.py config="./projects/cleargrasp_depth_estimation/experiments/pix2pix.yaml" 44 | -------------------------------------------------------------------------------- /projects/maastro_hx4_pet_translation/jobscript.sh: -------------------------------------------------------------------------------- 1 | #!/usr/local_rwth/bin/zsh 2 | 3 | 4 | # Job configuration --- 5 | 6 | #SBATCH --job-name=hx4_pet_pix2pix 7 | #SBATCH --output=/home/zk315372/Chinmay/Git/ganslate/projects/maastro_hx4_pet_translation/slurm_logs/%j.log 8 | 9 | ## OpenMP settings 10 | #SBATCH --cpus-per-task=8 11 | #SBATCH --mem-per-cpu=4G 12 | 13 | ## Request for a node with 2 Tesla P100 GPUs 14 | #SBATCH --gres=gpu:pascal:2 15 | 16 | #SBATCH --time=5:00:00 17 | 18 | ## TO use the UM DKE project account 19 | # #SBATCH --account=um_dke 20 | 21 | 22 | # Load CUDA 23 | module load cuda 24 | 25 | # Debug info 26 | echo; echo 27 | nvidia-smi 28 | echo; echo 29 | 30 | # Execute training 31 | python_interpreter="/home/zk315372/miniconda3/envs/gan_env/bin/python3" 32 | training_file="/home/zk315372/Chinmay/Git/ganslate/tools/train.py" 33 | config_file="/home/zk315372/Chinmay/Git/ganslate/projects/maastro_hx4_pet_translation/experiments/pix2pix.yaml" 34 | 35 | CUDA_VISIBLE_DEVICES=0 $python_interpreter $training_file config=$config_file 36 | 37 | 38 | # ---------------------- 39 | # Run single GPU example: 40 | # CUDA_VISIBLE_DEVICES=0 python tools/train.py config="./projects/maastro_hx4_pet_translation/experiments/pix2pix.yaml" 41 | 42 | # Run distributed example: 43 | # python -m torch.distributed.launch --use_env --nproc_per_node 2 tools/train.py config="./projects/maastro_hx4_pet_translation/experiments/pix2pix.yaml" -------------------------------------------------------------------------------- /projects/cleargrasp_depth_estimation/modules/old/cyclegan_multimodal_v2.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | 5 | from ganslate.nn.gans.unpaired import cyclegan 6 | from ganslate.nn.losses.adversarial_loss import AdversarialLoss 7 | 8 | from projects.cleargrasp_depth_estimation.modules.cyclegan_losses_with_structure import CycleGANLossesWithStructure 9 | 10 | 11 | @dataclass 12 | class OptimizerV2Config(cyclegan.OptimizerConfig): 13 | """ Optimizer Config CycleGAN multimodal v2 """ 14 | lambda_structure: float = 0 15 | 16 | 17 | @dataclass 18 | class CycleGANMultiModalV2Config(cyclegan.CycleGANConfig): 19 | """ CycleGANMultiModalV2 Config """ 20 | optimizer: OptimizerV2Config = OptimizerV2Config() 21 | 22 | 23 | class CycleGANMultiModalV2(cyclegan.CycleGAN): 24 | """ CycleGAN for multimodal images -- Version 2 """ 25 | 26 | def __init__(self, conf): 27 | super().__init__(conf) 28 | 29 | # Additional losses used by the model 30 | structure_loss_names = ['structure_AB', 'structure_BA'] 31 | self.losses.update({name: None for name in structure_loss_names}) 32 | 33 | 34 | def init_criterions(self): 35 | # Standard GAN loss 36 | self.criterion_adv = AdversarialLoss( 37 | self.conf.train.gan.optimizer.adversarial_loss_type).to(self.device) 38 | # G losses - Includes Structure-consistency loss 39 | self.criterion_G = CycleGANLossesWithStructure(self.conf, cyclegan_design_version='v2') 40 | -------------------------------------------------------------------------------- /projects/cleargrasp_depth_estimation/modules/cyclegan_losses_for_v3.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | from ganslate.nn.losses import cyclegan_losses 5 | 6 | 7 | class CycleGANLossesForV3(cyclegan_losses.CycleGANLosses): 8 | """ Modified to make Cycle-consitency account for only 9 | Normalmap images (in domain A) and depthmap images (in domain B), 10 | and ignore RGB """ 11 | 12 | def __init__(self, conf): 13 | self.lambda_AB = conf.train.gan.optimizer.lambda_AB 14 | self.lambda_BA = conf.train.gan.optimizer.lambda_BA 15 | 16 | lambda_identity = conf.train.gan.optimizer.lambda_identity 17 | proportion_ssim = conf.train.gan.optimizer.proportion_ssim 18 | 19 | # Cycle-consistency - L1, with optional weighted combination with SSIM 20 | self.criterion_cycle = cyclegan_losses.CycleLoss(proportion_ssim) 21 | 22 | 23 | def __call__(self, visuals): 24 | # Separate out the normalmap and depthmap parts from the visuals tensors 25 | real_A2, real_B2 = visuals['real_A'][:, 3:], visuals['real_B'][:, 3:] 26 | fake_A2, fake_B2 = visuals['fake_A'][:, 3:], visuals['fake_B'][:, 3:] 27 | rec_A2, rec_B2 = visuals['rec_A'][:, 3:], visuals['rec_B'][:, 3:] 28 | 29 | losses = {} 30 | 31 | # cycle-consistency loss 32 | losses['cycle_A'] = self.lambda_AB * self.criterion_cycle(real_A2, rec_A2) 33 | losses['cycle_B'] = self.lambda_BA * self.criterion_cycle(real_B2, rec_B2) 34 | 35 | return losses 36 | -------------------------------------------------------------------------------- /projects/maastro_hx4_pet_translation/modules/hx4_cyclegan_balanced_losses.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | from ganslate.nn.losses import cyclegan_losses 5 | 6 | 7 | class HX4CycleGANBalancedLosses(cyclegan_losses.CycleGANLosses): 8 | """ Modified to make Cycle-consitency account for only 9 | FDG-PET images (in domain A) and HX4-PET images (in domain B), 10 | and ignore CT components """ 11 | 12 | def __init__(self, conf): 13 | self.lambda_AB = conf.train.gan.optimizer.lambda_AB 14 | self.lambda_BA = conf.train.gan.optimizer.lambda_BA 15 | 16 | lambda_identity = conf.train.gan.optimizer.lambda_identity 17 | proportion_ssim = conf.train.gan.optimizer.proportion_ssim 18 | 19 | # Cycle-consistency - L1, with optional weighted combination with SSIM 20 | self.criterion_cycle = cyclegan_losses.CycleLoss(proportion_ssim) 21 | 22 | 23 | def __call__(self, visuals): 24 | # Separate out the FDG-PET and HX4-PET parts from the visuals tensors 25 | real_A1, real_B1 = visuals['real_A'][:, :1], visuals['real_B'][:, :1] 26 | fake_A1, fake_B1 = visuals['fake_A'][:, :1], visuals['fake_B'][:, :1] 27 | rec_A1, rec_B1 = visuals['rec_A'][:, :1], visuals['rec_B'][:, :1] 28 | 29 | losses = {} 30 | 31 | # cycle-consistency loss 32 | losses['cycle_A'] = self.lambda_AB * self.criterion_cycle(real_A1, rec_A1) 33 | losses['cycle_B'] = self.lambda_BA * self.criterion_cycle(real_B1, rec_B1) 34 | 35 | return losses -------------------------------------------------------------------------------- /projects/cleargrasp_depth_estimation/modules/old/cyclegan_multimodal_v1_structure.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | 5 | from ganslate.nn.gans.unpaired import cyclegan 6 | from ganslate.nn.losses.adversarial_loss import AdversarialLoss 7 | 8 | from projects.cleargrasp_depth_estimation.modules.cyclegan_losses_with_structure import CycleGANLossesWithStructure 9 | 10 | 11 | @dataclass 12 | class OptimizerV1StructureConfig(cyclegan.OptimizerConfig): 13 | """ Structure consistency config for CycleGAN multimodal v1 """ 14 | lambda_structure: float = 0 15 | 16 | 17 | @dataclass 18 | class CycleGANMultiModalV1StructureConfig(cyclegan.CycleGANConfig): 19 | """ CycleGANMultiModalV1Structure Config """ 20 | optimizer: OptimizerV1StructureConfig = OptimizerV1StructureConfig() 21 | 22 | 23 | class CycleGANMultiModalV1Structure(cyclegan.CycleGAN): 24 | """ """ 25 | 26 | def __init__(self, conf): 27 | super().__init__(conf) 28 | 29 | # Additional losses used by the model 30 | structure_loss_names = ['structure_AB', 'structure_BA'] 31 | self.losses.update({name: None for name in structure_loss_names}) 32 | 33 | 34 | def init_criterions(self): 35 | # Standard GAN loss 36 | self.criterion_adv = AdversarialLoss( 37 | self.conf.train.gan.optimizer.adversarial_loss_type).to(self.device) 38 | # G losses - Includes Structure-consistency loss 39 | self.criterion_G = CycleGANLossesWithStructure(self.conf, cyclegan_design_version='v1') 40 | -------------------------------------------------------------------------------- /ganslate/nn/losses/cut_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class PatchNCELoss(nn.Module): 6 | 7 | def __init__(self, conf): 8 | super().__init__() 9 | self.batch_size = conf.train.batch_size 10 | self.nce_T = conf.train.gan.optimizer.nce_T 11 | 12 | self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction='none') 13 | 14 | def forward(self, feat_q, feat_k): 15 | bs, dim = feat_q.shape[:2] 16 | feat_k = feat_k.detach() 17 | 18 | # pos logit 19 | l_pos = torch.bmm(feat_q.view(bs, 1, -1), feat_k.view(bs, -1, 1)) 20 | l_pos = l_pos.view(bs, 1) 21 | 22 | # neg logit 23 | batch_dim_for_bmm = self.batch_size 24 | 25 | # reshape features to batch size 26 | feat_q = feat_q.view(batch_dim_for_bmm, -1, dim) 27 | feat_k = feat_k.view(batch_dim_for_bmm, -1, dim) 28 | 29 | num_patches = feat_q.size(1) 30 | l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1)) 31 | 32 | # diagonal entries are similarity between same features, and hence meaningless. 33 | # just fill the diagonal with very small number, which is exp(-10) and almost zero 34 | diagonal = torch.eye(num_patches, device=feat_q.device, dtype=torch.bool)[None, :, :] 35 | l_neg_curbatch.masked_fill_(diagonal, -10.0) 36 | l_neg = l_neg_curbatch.view(-1, num_patches) 37 | 38 | out = torch.cat((l_pos, l_neg), dim=1) / self.nce_T 39 | 40 | loss = self.cross_entropy_loss( 41 | out, torch.zeros(out.size(0), dtype=torch.long, device=feat_q.device)) 42 | 43 | return loss 44 | -------------------------------------------------------------------------------- /projects/brats_mri_sequence_translation/experiments/cut.yaml: -------------------------------------------------------------------------------- 1 | project: "./projects/brats_flair_to_t1w/" 2 | 3 | train: 4 | output_dir: "./checkpoints/dke_brats_cut/" 5 | cuda: True 6 | n_iters: 20000 7 | n_iters_decay: 20000 8 | batch_size: 1 9 | mixed_precision: False 10 | 11 | logging: 12 | freq: 50 13 | wandb: 14 | project: "CUT-BraTS" 15 | 16 | checkpointing: 17 | freq: 5000 18 | 19 | dataset: 20 | _target_: project.datasets.train_dataset.BratsDataset 21 | root: "/workspace/Task01_BrainTumour/imagesTr" 22 | num_workers: 8 23 | patch_size: [32, 176, 176] 24 | 25 | gan: 26 | _target_: ganslate.nn.gans.unpaired.CUT 27 | nce_layers: [0, 1, 2, 3, 4] 28 | generator: 29 | _target_: ganslate.nn.generators.Vnet3D 30 | use_memory_saving: False 31 | use_inverse: False 32 | in_out_channels: 33 | AB: [1, 1] 34 | down_blocks: [2, 2, 3] 35 | up_blocks: [3, 3, 3] 36 | 37 | discriminator: 38 | _target_: ganslate.nn.discriminators.PatchGAN3D 39 | n_layers: 2 40 | in_channels: 41 | B: 1 42 | 43 | optimizer: 44 | lr_D: 0.0002 45 | lr_G: 0.0004 46 | 47 | val: 48 | freq: 40000 49 | dataset: 50 | _target_: project.datasets.val_test_dataset.BratsValTestDataset 51 | root: "/workspace/Task01_BrainTumour/imagesTs" 52 | sliding_window: 53 | window_size: ${train.dataset.patch_size} 54 | metrics: 55 | cycle_metrics: False 56 | -------------------------------------------------------------------------------- /projects/brats_mri_sequence_translation/experiments/revgan.yaml: -------------------------------------------------------------------------------- 1 | project: "./projects/brats_flair_to_t1w/" 2 | 3 | train: 4 | output_dir: "./checkpoints/dke_brats_cyclegan/" 5 | cuda: True 6 | n_iters: 20000 7 | n_iters_decay: 20000 8 | batch_size: 1 9 | mixed_precision: False 10 | 11 | logging: 12 | freq: 50 13 | wandb: 14 | project: "CUT-BraTS" 15 | 16 | checkpointing: 17 | freq: 5000 18 | 19 | dataset: 20 | _target_: project.datasets.train_dataset.BratsDataset 21 | root: "/workspace/Task01_BrainTumour/imagesTr" 22 | num_workers: 8 23 | patch_size: [32, 176, 176] 24 | 25 | gan: 26 | _target_: ganslate.nn.gans.unpaired.RevGAN 27 | generator: 28 | _target_: ganslate.nn.generators.Piresnet3D 29 | use_memory_saving: True 30 | use_inverse: True 31 | in_out_channels: 32 | AB: [1, 1] 33 | depth: 5 34 | 35 | discriminator: 36 | _target_: ganslate.nn.discriminators.PatchGAN3D 37 | n_layers: 2 38 | in_channels: 39 | B: 1 40 | 41 | optimizer: 42 | lr_D: 0.0002 43 | lr_G: 0.0004 44 | lambda_AB: 5.0 45 | lambda_BA: 5.0 46 | lambda_identity: 0 47 | proportion_ssim: 0 48 | 49 | val: 50 | freq: 40000 51 | dataset: 52 | _target_: project.datasets.val_test_dataset.BratsValTestDataset 53 | root: "/workspace/Task01_BrainTumour/imagesTs" 54 | sliding_window: 55 | window_size: ${train.dataset.patch_size} 56 | metrics: 57 | cycle_metrics: True 58 | -------------------------------------------------------------------------------- /projects/brats_mri_sequence_translation/experiments/cyclegan.yaml: -------------------------------------------------------------------------------- 1 | project: "./projects/brats_flair_to_t1w/" 2 | 3 | train: 4 | output_dir: "./checkpoints/dke_brats_cyclegan/" 5 | cuda: True 6 | n_iters: 20000 7 | n_iters_decay: 20000 8 | batch_size: 1 9 | mixed_precision: False 10 | 11 | logging: 12 | freq: 50 13 | wandb: 14 | project: "CUT-BraTS" 15 | 16 | checkpointing: 17 | freq: 5000 18 | 19 | dataset: 20 | _target_: project.datasets.train_dataset.BratsDataset 21 | root: "/workspace/Task01_BrainTumour/imagesTr" 22 | num_workers: 8 23 | patch_size: [32, 176, 176] 24 | 25 | gan: 26 | _target_: ganslate.nn.gans.unpaired.CycleGAN 27 | generator: 28 | _target_: ganslate.nn.generators.Vnet3D 29 | use_memory_saving: False 30 | use_inverse: False 31 | in_out_channels: 32 | AB: [1, 1] 33 | down_blocks: [2, 2, 3] 34 | up_blocks: [3, 3, 3] 35 | 36 | discriminator: 37 | _target_: ganslate.nn.discriminators.PatchGAN3D 38 | n_layers: 2 39 | in_channels: 40 | B: 1 41 | 42 | optimizer: 43 | lr_D: 0.0002 44 | lr_G: 0.0004 45 | lambda_AB: 5.0 46 | lambda_BA: 5.0 47 | lambda_identity: 0 48 | proportion_ssim: 0 49 | 50 | val: 51 | freq: 40000 52 | dataset: 53 | _target_: project.datasets.val_test_dataset.BratsValTestDataset 54 | root: "/workspace/Task01_BrainTumour/imagesTs" 55 | sliding_window: 56 | window_size: ${train.dataset.patch_size} 57 | metrics: 58 | cycle_metrics: True 59 | -------------------------------------------------------------------------------- /ganslate/configs/training.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from dataclasses import dataclass 3 | from omegaconf import MISSING 4 | from ganslate.configs import base 5 | 6 | 7 | @dataclass 8 | class TrainMetricsConfig: 9 | discriminator_evolution: bool = False 10 | ssim: bool = False 11 | 12 | 13 | @dataclass 14 | class TrainCheckpointingConfig(base.CheckpointingConfig): 15 | # How often (in iters) to save checkpoints during training 16 | freq: int = 2000 17 | # After which iteration should checkpointing begin 18 | start_after: int = 0 19 | # If False, the saved optimizers won't be loaded when continuing training 20 | load_optimizers: bool = True 21 | # Iteration number of the checkpoint to load for continuing training 22 | load_iter: Optional[int] = None 23 | 24 | 25 | @dataclass 26 | class TrainConfig(base.BaseEngineConfig): 27 | # TODO: add git hash? will help when re-running or inferencing old runs 28 | 29 | ################## Overriding defaults of BaseEngineConfig ###################### 30 | output_dir: str = MISSING 31 | batch_size: int = MISSING 32 | cuda: bool = True 33 | mixed_precision: bool = False 34 | opt_level: str = "O1" 35 | checkpointing: TrainCheckpointingConfig = TrainCheckpointingConfig() 36 | logging: base.LoggingConfig = base.LoggingConfig() 37 | ########################################################################### 38 | 39 | # Number of iters without linear decay of learning rates. 40 | n_iters: int = MISSING 41 | # Number of last iters in which the learning rates are linearly decayed. 42 | n_iters_decay: int = MISSING 43 | 44 | gan: base.BaseGANConfig = MISSING 45 | 46 | seed: Optional[int] = None 47 | metrics: TrainMetricsConfig = TrainMetricsConfig() 48 | -------------------------------------------------------------------------------- /ganslate/data/utils/fov_truncate.py: -------------------------------------------------------------------------------- 1 | import SimpleITK as sitk 2 | import numpy as np 3 | 4 | from ganslate.utils import sitk_utils 5 | 6 | 7 | def truncate_CBCT_based_on_fov(image: sitk.Image): 8 | """ 9 | Truncates the CBCT to consider full FOV in the scans. First few and last few slices 10 | generally have small FOV that is around 25-50% of the axial slice. Ignore this 11 | using simple value based filtering. 12 | 13 | Parameters 14 | --------------- 15 | image: Input CBCT image to truncate. 16 | 17 | Returns 18 | ---------------- 19 | filtered_image: Truncated CBCT image 20 | """ 21 | array = sitk.GetArrayFromImage(image) 22 | start_idx, end_idx = 0, array.shape[0] 23 | 24 | begin_truncate = False 25 | 26 | for idx, slice in enumerate(array): 27 | 28 | # Calculate the percentage FOV. 29 | # This should give an estimate of difference between 30 | # area of the z-axis rectangular slice and circle formed by 31 | # the FOV. Eg. 400x400 will have 160k area and if the FOV is 32 | # an end to end circle then it will have an area of 3.14*200*200 33 | percentage_fov = 1 - np.mean(slice == -1024) 34 | # As soon as the percentage of fov in the image 35 | # is above 75% of the image set the start index. 36 | if percentage_fov > 0.75 and start_idx == 0: 37 | start_idx = idx 38 | begin_truncate = True 39 | 40 | # Once the start index is set and the fov percentage 41 | # goes below 75% set the end index 42 | if begin_truncate and percentage_fov < 0.75: 43 | end_idx = idx - 1 44 | break 45 | 46 | image = sitk_utils.slice_image(image, start=(0, 0, start_idx), end=(-1, -1, end_idx)) 47 | 48 | return image 49 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.5494572.svg)](https://doi.org/10.5281/zenodo.5494572) 2 | 3 | # ganslate 4 | 5 | A [PyTorch](https://pytorch.org/) framework which aims to make GAN image-to-image translation more accessible to both beginner and advanced project with: 6 | 7 | - Simple configuration system 8 | - Extensibility for other datasets or architectures 9 | - Documentation and [video walk-throughs](INSERT_YOUTUBE_PLAYLIST) 10 | 11 | ## Features 12 | 13 | - 2D and 3D support 14 | - Mixed precision 15 | - Distributed training 16 | - Tensorboard and [Weights&Biases](https://wandb.ai/site) logging 17 | - Natural and medical image support 18 | - A range of generator and discriminator architectures 19 | 20 | ## Available GANs 21 | 22 | - Pix2Pix ([paper](https://www.google.com/search?q=pix2pix+paper&oq=pix2pix+paper&aqs=chrome.0.0l2j0i22i30l2j0i10i22i30.3304j0j7&sourceid=chrome&ie=UTF-8)) 23 | - CycleGAN ([paper](https://arxiv.org/abs/1703.10593)) 24 | - RevGAN ([paper](https://arxiv.org/abs/1902.02729)) 25 | - CUT (Contrastive Unpaired Translation) ([paper](https://arxiv.org/abs/2007.15651)) 26 | 27 | ## Projects 28 | `ganslate` was used in: 29 | 30 | - Project 1 31 | - Project 2 32 | 33 | ## Citation 34 | 35 | If you used `ganslate` in your project, please cite: 36 | 37 | ```text 38 | @software{ibrahim_hadzic_2021_5494572, 39 | author = {Ibrahim Hadzic and 40 | Suraj Pai and 41 | Chinmay Rao and 42 | Jonas Teuwen}, 43 | title = {ganslate-team/ganslate: Initial public release}, 44 | month = sep, 45 | year = 2021, 46 | publisher = {Zenodo}, 47 | version = {v0.1.0}, 48 | doi = {10.5281/zenodo.5494572}, 49 | url = {https://doi.org/10.5281/zenodo.5494572} 50 | } 51 | ``` -------------------------------------------------------------------------------- /ganslate/nn/invertible.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | from torch import nn 4 | 5 | import memcnn 6 | 7 | 8 | class InvertibleBlock(nn.Module): 9 | 10 | def __init__(self, block, keep_input, disable=False): 11 | """The input block should already be split across channels # TODO: explain better 12 | """ 13 | super().__init__() 14 | 15 | block = memcnn.AdditiveCoupling(deepcopy(block)) 16 | self.invertible_block = memcnn.InvertibleModuleWrapper(fn=block, 17 | keep_input=keep_input, 18 | keep_input_inverse=keep_input, 19 | disable=disable) 20 | 21 | def forward(self, x, inverse=False): 22 | if inverse: 23 | return self.invertible_block.inverse(x) 24 | return self.invertible_block(x) 25 | 26 | 27 | class InvertibleSequence(nn.Module): 28 | 29 | def __init__(self, block, n_blocks, keep_input, disable=False): 30 | super().__init__() 31 | 32 | sequence = [InvertibleBlock(block, keep_input, disable) for _ in range(n_blocks)] 33 | self.sequence = nn.Sequential(*sequence) 34 | 35 | def forward(self, x, inverse=False): 36 | if inverse: 37 | sequence = reversed(self.sequence) 38 | else: 39 | sequence = self.sequence 40 | 41 | for i, block in enumerate(sequence): 42 | if i == 0: #https://github.com/silvandeleemput/memcnn/issues/39#issuecomment-599199122 43 | if inverse: 44 | block.invertible_block.keep_input_inverse = True 45 | else: 46 | block.invertible_block.keep_input = True 47 | x = block(x, inverse=inverse) 48 | return x 49 | -------------------------------------------------------------------------------- /.github/workflows/yapf_autoformat.yml: -------------------------------------------------------------------------------- 1 | name: YAPF autoformat 2 | 3 | on: 4 | push: 5 | paths: 6 | - "**.py" 7 | 8 | jobs: 9 | build: 10 | name: YAPF Formatter 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - name: checkout repo 15 | uses: actions/checkout@v2.3.4 16 | with: 17 | repository: ${{ github.repository }} 18 | token: ${{ github.token }} 19 | 20 | # https://github.com/diegovalenzuelaiturra/yapf-action 21 | - name: YAPF Formatter 22 | uses: diegovalenzuelaiturra/yapf-action@master 23 | with: 24 | args: . --verbose --recursive --in-place --parallel 25 | 26 | - name: action metadata 27 | run: | 28 | echo -e "action : ${{ github.action }}" 29 | echo -e "actor : ${{ github.actor }}" 30 | echo -e "event_name : ${{ github.event_name }}" 31 | echo -e "job : ${{ github.job }}" 32 | echo -e "ref : ${{ github.ref }}" 33 | echo -e "repository : ${{ github.repository }}" 34 | echo -e "run_id : ${{ github.run_id }}" 35 | echo -e "sha : ${{ github.sha }}" 36 | echo -e "workflow : ${{ github.workflow }}" 37 | echo -e "workspace : ${{ github.workspace }}" 38 | 39 | - name: config github 40 | run: | 41 | git config user.name github-actions 42 | git config user.email github-actions@github.com 43 | 44 | - name: add changes 45 | run: | 46 | git add . 47 | 48 | - name: commit changes 49 | run: | 50 | git commit -m "Autoformatted to YAPF" --all | exit 0 51 | 52 | # NOTE : only push changes when the source branch is being pushed to master (target branch) 53 | - name: push changes 54 | if: github.ref == 'refs/heads/master' 55 | run: | 56 | git push 57 | -------------------------------------------------------------------------------- /docs/installation.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | You can install `ganslate` either through a docker setup or directly on your system. 4 | ## Docker 5 | *Supported operating systems: Linux, [Windows with WSL](https://docs.nvidia.com/cuda/wsl-user-guide/index.html)* 6 | 7 | Dockerized setup is the easiest way to get started with the framework. If you do not have docker installed, you can follow instructions [here](https://docs.docker.com/engine/install/ubuntu/#install-using-the-repository) 8 | 9 | 10 | You can run the docker image, which will give you access to a container with all dependencies installed, using, 11 | 12 | ```console 13 | docker run --gpus all -it surajpaib/ganslate:latest 14 | ``` 15 | 16 | This will drop down to a shell and [you can now check out the Getting Started page](getting_started/first_run.md) 17 | 18 | 19 | 20 | !!! note 21 | To get your data into the docker container, you can use volume mounts. The docker container [mounts volumes from the host system](https://docs.docker.com/storage/volumes/) to allow easier persistence of data. This can be done as `docker run --gpus all --volume=:/data -it ganslate:latest`. `` must be replaced with the full path of a directory where your data is located, this will then be mounted on the `/data` path within the docker 22 | 23 | ## Local 24 | 25 | !!! note 26 | It is recommended to use to setup a conda environment to install pytorch dependencies. You can do this by 27 | [installing conda](https://conda.io/projects/conda/en/latest/user-guide/install/index.html#regular-installation) first, then followed by `conda create env -n ganslate_env python pytorch -c pytorch`. 28 | 29 | You can install the ganslate package along with its dependencies using 30 | ```console 31 | pip install ganslate 32 | ``` 33 | 34 | The `ganslate` package is now installed. [You can now check out the Getting Started page](getting_started/first_run.md) 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | env.bak/ 92 | venv.bak/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | 107 | \.idea/inspectionProfiles/ 108 | 109 | \.idea/ 110 | 111 | \.vscode/ 112 | checkpoints/ 113 | \.wandb 114 | \wandb 115 | 116 | apex 117 | *.nrrd 118 | core.* 119 | -------------------------------------------------------------------------------- /ganslate/utils/cli/cookiecutter_templates/new_project/{{ cookiecutter.project_name }}/datasets/train_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | from pathlib import Path 3 | import torch 4 | 5 | from typing import Tuple 6 | from dataclasses import dataclass 7 | 8 | from torch.utils.data import Dataset 9 | from omegaconf import MISSING 10 | 11 | from ganslate import configs 12 | 13 | 14 | @dataclass 15 | class {{cookiecutter.dataset_name}}TrainConfig(configs.base.BaseDatasetConfig): 16 | # Define other attributes, e.g.: 17 | patch_size: Tuple[int, int] = [128, 128] 18 | ... 19 | 20 | 21 | class {{cookiecutter.dataset_name}}TrainDataset(Dataset): 22 | 23 | def __init__(self, conf): 24 | root_path = Path(conf.train.dataset.root).resolve() 25 | 26 | # Assumes `A` and `B` dirs only for demonstration 27 | self.paths_A = root_path / "A" 28 | self.paths_B = root_path / "B" 29 | 30 | self.num_datapoints_A = len(self.paths_A) 31 | self.num_datapoints_B = len(self.paths_B) 32 | ... 33 | 34 | def __getitem__(self, index): 35 | # Get the pair A and B. 36 | # In unpaired training, select a random index for 37 | # image B so that A and B pairs are not always the same. 38 | # For paired training, it depends on how the data is structured. 39 | index_A = index % self.num_datapoints_A 40 | index_B = random.randint(0, self.num_datapoints_B - 1) 41 | 42 | path_A = self.paths_A[index_A] 43 | path_B = self.paths_B[index_B] 44 | 45 | # Read the images, `read` is a placeholder 46 | A = read(path_A) 47 | B = read(path_B) 48 | 49 | # Preprocess and normalize to [-1,1], `preprocess` is a placeholder 50 | A = preprocess(A) 51 | B = preprocess(B) 52 | 53 | # You need to return a dict with `A` and `B` entries 54 | return {'A': A, 'B': B} 55 | 56 | def __len__(self): 57 | return max(self.num_datapoints_A, self.num_datapoints_B) 58 | -------------------------------------------------------------------------------- /projects/horse2zebra/experiments/default.yaml: -------------------------------------------------------------------------------- 1 | project: "/home/chinmay/git/ganslate/projects/horse2zebra" 2 | 3 | train: 4 | output_dir: "/home/chinmay/git/ganslate/checkpoints/horse2zebra_default" 5 | cuda: True 6 | n_iters: 117700 # (1177 [dataset_size] * 100) 7 | n_iters_decay: 117700 8 | batch_size: 1 9 | mixed_precision: False 10 | 11 | logging: 12 | freq: 500 13 | wandb: 14 | project: "horse2zebra" 15 | 16 | checkpointing: 17 | freq: 20000 18 | 19 | dataset: 20 | _target_: ganslate.data.UnpairedImageDataset 21 | root: "/home/chinmay/Datasets/horse2zebra/train/" 22 | num_workers: 16 23 | image_channels: 3 24 | preprocess: ["resize", "random_flip"] 25 | load_size: [128, 128] # (H, W) 26 | final_size: [128, 128] # (H, W) 27 | 28 | gan: 29 | _target_: ganslate.nn.gans.unpaired.CycleGAN 30 | 31 | generator: 32 | _target_: ganslate.nn.generators.Resnet2D 33 | n_residual_blocks: 9 34 | in_out_channels: 35 | AB: [3, 3] 36 | 37 | discriminator: 38 | _target_: ganslate.nn.discriminators.PatchGAN2D 39 | n_layers: 3 40 | in_channels: 41 | B: 3 42 | 43 | optimizer: 44 | lambda_AB: 10.0 45 | lambda_BA: 10.0 46 | lambda_identity: 0 47 | proportion_ssim: 0 48 | lr_D: 0.0002 49 | lr_G: 0.0002 50 | 51 | metrics: 52 | discriminator_evolution: True 53 | ssim: True 54 | 55 | infer: 56 | checkpointing: 57 | load_iter: 1 58 | dataset: 59 | _target_: ganslate.data.UnpairedImageDataset 60 | root: "/home/chinmay/Datasets/horse2zebra/test/" 61 | num_workers: 16 62 | image_channels: 3 63 | preprocess: ["resize"] 64 | load_size: [128, 128] # (H, W) 65 | final_size: [128, 128] # (H, W) -------------------------------------------------------------------------------- /projects/cityscapes_label2photo/experiments/pix2pix.yaml: -------------------------------------------------------------------------------- 1 | project: "/home/chinmay/git/ganslate/projects/horse2zebra" 2 | 3 | train: 4 | output_dir: "/home/chinmay/git/ganslate/checkpoints/label2photo_pix2pix/" 5 | cuda: True 6 | n_iters: 297500 # 2975 (images) x 100 ("epochs") 7 | n_iters_decay: 5000 8 | batch_size: 1 9 | mixed_precision: False 10 | 11 | logging: 12 | freq: 100 13 | wandb: 14 | project: "cityscapes_label2photo" 15 | run: "pix2pix_trial" 16 | 17 | checkpointing: 18 | freq: 5000 19 | 20 | dataset: 21 | _target_: ganslate.data.PairedImageDataset 22 | root: "/home/chinmay/Datasets/Cityscapes_label2photo/train" 23 | num_workers: 8 24 | image_channels: 3 25 | preprocess: ['resize', 'random_crop', 'random_flip'] 26 | load_size: [286, 572] 27 | final_size: [256, 512] 28 | 29 | gan: 30 | _target_: ganslate.nn.gans.paired.Pix2PixConditionalGAN 31 | generator: 32 | _target_: ganslate.nn.generators.Unet2D 33 | in_out_channels: 34 | AB: [3, 3] 35 | num_downs: 7 36 | ngf: 128 37 | use_dropout: True 38 | 39 | discriminator: 40 | _target_: ganslate.nn.discriminators.PatchGAN2D 41 | n_layers: 4 42 | in_channels: 43 | B: 6 44 | 45 | optimizer: 46 | lr_D: 0.0001 47 | lr_G: 0.0002 48 | lambda_pix2pix: 30.0 49 | 50 | metrics: 51 | discriminator_evolution: True 52 | 53 | 54 | val: 55 | freq: 5000 56 | dataset: 57 | _target_: ganslate.data.PairedImageDataset 58 | root: "/home/chinmay/Datasets/Cityscapes_label2photo/val" 59 | num_workers: 8 60 | image_channels: 3 61 | preprocess: ['resize'] 62 | load_size: [286, 572] 63 | final_size: [256, 512] 64 | metrics: 65 | cycle_metrics: False 66 | -------------------------------------------------------------------------------- /ganslate/utils/cli/cookiecutter_templates/new_project/{{ cookiecutter.project_name }}/architectures/template_architecture.py: -------------------------------------------------------------------------------- 1 | """ 2 | If you want to write a custom architecture (GAN, generator, discriminator) or a loss function 3 | it is best to check how they are implemented in ganslate: 4 | https://github.com/Maastro-CDS-Imaging-Group/ganslate/tree/master/ganslate/nn 5 | and to follow the documentation: 6 | https://ganslate.readthedocs.io/en/latest/ 7 | """ 8 | 9 | 10 | # ------------------- Custom GAN from scratch ----------------------- 11 | """Implementing a custom GAN from scratch is not trivial, and we advise you to go 12 | through ganslate's GAN source code for an example. 13 | (https://github.com/Maastro-CDS-Imaging-Group/ganslate/tree/master/ganslate/nn/gans) 14 | """ 15 | 16 | # ------------------ Extending an existing GAN ---------------------- 17 | """Extending an existing GAN is much easier. This is an example of how you would start 18 | extending CycleGAN. 19 | """ 20 | from ganslate.nn.gans.unpaired import cyclegan 21 | 22 | 23 | @dataclass 24 | class {{cookiecutter.cyclegan_name}}CycleGANConfig(cyclegan.CycleGANConfig): 25 | pass 26 | 27 | 28 | class {{cookiecutter.cyclegan_name}}CycleGAN(cyclegan.CycleGAN): 29 | 30 | def __init__(self, conf): 31 | super().__init__(conf) 32 | 33 | # Now, extend or redefine method(s). 34 | # In this example, we extend only the `init_criterions()` method. 35 | def init_criterions(self): 36 | # Standard GAN loss [Same as in the original CycleGAN] 37 | self.criterion_adv = AdversarialLoss( 38 | self.conf.train.gan.optimizer.adversarial_loss_type).to(self.device) 39 | 40 | # Fancy loss for generator [Different from the original CycleGAN] 41 | self.criterion_G = YourFancyLoss(self.conf) 42 | 43 | 44 | # ------------------ Custom generator or loss ----------------------- 45 | """No limitations, just basic PyTorch code. They do need to have their corresponding configs 46 | as can be seen in the documentation and the source code.""" -------------------------------------------------------------------------------- /ganslate/utils/sliding_window_inferer.py: -------------------------------------------------------------------------------- 1 | from monai.inferers import SlidingWindowInferer 2 | import torch 3 | from typing import Callable, Any 4 | 5 | from loguru import logger 6 | 7 | 8 | class SlidingWindowInferer(SlidingWindowInferer): 9 | 10 | def __init__(self, *args, **kwargs): 11 | self.logger = logger 12 | super().__init__(*args, **kwargs) 13 | 14 | def __call__( 15 | self, 16 | inputs: torch.Tensor, 17 | network: Callable[..., torch.Tensor], 18 | *args: Any, 19 | **kwargs: Any, 20 | ) -> torch.Tensor: 21 | 22 | # Check if roi size and full volume size are not matching 23 | if len(self.roi_size) != len(inputs.shape[2:]): 24 | self.logger.debug( 25 | f"ROI size: {self.roi_size} and input volume: {inputs.shape[2:]} do not match \n" 26 | "Brodcasting ROI size to match input volume size.") 27 | 28 | # If they do not match and roi_size is 2D add another dimension to roi size 29 | if len(self.roi_size) == 2: 30 | self.roi_size = [1, *self.roi_size] 31 | else: 32 | raise RuntimeError("Unsupported roi size, cannot broadcast to volume. ") 33 | 34 | return super().__call__(inputs, lambda x: self.network_wrapper(network, x)) 35 | 36 | def network_wrapper(self, network, x, *args, **kwargs): 37 | """ 38 | Wrapper handles cases where inference needs to be done using 39 | 2D models over 3D volume inputs. 40 | 41 | """ 42 | # If depth dim is 1 in [D, H, W] roi size, then the input is 2D and needs 43 | # be handled accordingly 44 | if self.roi_size[0] == 1: 45 | # Pass [N, C, H, W] to the model as it is 2D. 46 | x = x.squeeze(dim=2) 47 | out = network(x, *args, **kwargs) 48 | # Unsqueeze the network output so it is [N, C, D, H, W] 49 | return out.unsqueeze(dim=2) 50 | 51 | else: 52 | return network(x, *args, **kwargs) 53 | -------------------------------------------------------------------------------- /projects/maastro_hx4_pet_translation/datasets/utils/basic.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import SimpleITK as sitk 4 | 5 | from ganslate.utils import sitk_utils 6 | from ganslate.data.utils.body_mask import get_body_mask 7 | from ganslate.data.utils.normalization import min_max_normalize 8 | 9 | 10 | # Body mask settings 11 | OUT_OF_BODY_HU = -1024 12 | OUT_OF_BODY_SUV = 0 13 | HU_THRESHOLD = -300 14 | 15 | 16 | 17 | def apply_body_mask(image_dict, generate_body_mask=False): 18 | 19 | # If body mask doesn't exist, then create one from the available CT using morph. ops 20 | if generate_body_mask: 21 | assert image_dict['body-mask'] is None 22 | assert any(['CT' in k for k in image_dict.keys()]) # There should be a CT in the dict to be able to generate a mask 23 | ct_image_name = [k for k in image_dict.keys() if 'CT' in k][0] 24 | image_dict['body-mask'] = get_body_mask(image_dict[ct_image_name], HU_THRESHOLD) 25 | 26 | # Apply masking to any CT or PET image present in image_dict 27 | assert image_dict['body-mask'] is not None 28 | body_mask = image_dict['body-mask'] 29 | for k in image_dict.keys(): 30 | if 'PET' in k: 31 | image_dict[k] = np.where(body_mask, image_dict[k], OUT_OF_BODY_SUV) 32 | elif 'CT' in k: 33 | image_dict[k] = np.where(body_mask, image_dict[k], OUT_OF_BODY_HU) 34 | 35 | return image_dict 36 | 37 | 38 | def clip_and_min_max_normalize(tensor, min_value, max_value): 39 | tensor = torch.clamp(tensor, min_value, max_value) 40 | tensor = min_max_normalize(tensor, min_value, max_value) 41 | return tensor 42 | 43 | 44 | def sitk2np(image_dict): 45 | # WHD to DHW 46 | for k in image_dict.keys(): 47 | if isinstance(image_dict[k], sitk.SimpleITK.Image): 48 | image_dict[k] = sitk_utils.get_npy(image_dict[k]) 49 | return image_dict 50 | 51 | def np2tensor(image_dict): 52 | for k in image_dict.keys(): 53 | image_dict[k] = torch.tensor(image_dict[k]) 54 | return image_dict 55 | -------------------------------------------------------------------------------- /tools/analyzers/wandb/utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import wandb 3 | 4 | 5 | def get_wandb_history(run, conf): 6 | """ 7 | Get wandb history from a particular run as a dataframe 8 | """ 9 | df = pd.DataFrame() 10 | 11 | # Get total number of samples from the total number of iteration 12 | samples = run.summary._json_dict['_step'] 13 | 14 | for _, row in run.history(samples=samples).iterrows(): 15 | row_dict = {'iteration': row['_step']} 16 | 17 | if row_dict['iteration'] % conf.iters_sampling_freq != 0: 18 | continue 19 | 20 | if conf.last_ckpt and row_dict['iteration'] > conf.last_ckpt: 21 | logger.info(f"Stopped collecting samples @{row_dict['iteration']}") 22 | break 23 | 24 | for metric_label in row.keys(): 25 | if list_of_strings_has_substring(conf.rank_descending_keys + conf.rank_ascending_keys, 26 | metric_label): 27 | if not list_of_strings_has_substring(conf.ignore_tags + ['Train'], metric_label): 28 | row_dict[metric_label] = row[metric_label] 29 | 30 | df = df.append(row_dict, ignore_index=True) 31 | 32 | # Drop NaN values 33 | df = df.dropna() 34 | df = df.set_index('iteration') 35 | return df 36 | 37 | 38 | def get_aggregate_ranks(df, metric): 39 | if metric == "mean": 40 | return df.mean(axis=1) / len(df) 41 | elif metric == "mode": 42 | return df.mode(axis=1)[0] / len(df) 43 | 44 | 45 | def flatten(list_of_lists): 46 | return [item for sublist in list_of_lists for item in sublist] 47 | 48 | 49 | def list_of_strings_has_substring(list_of_strings, string): 50 | """Check if any of the strings in `list_of_strings` is a substrings of `string`""" 51 | return any([elem.lower() in string.lower() for elem in list_of_strings]) 52 | 53 | 54 | def filter_columns(df, columns): 55 | """Filters a list of columns from the dataframe""" 56 | return df[df.columns[~df.columns.isin(columns)]] 57 | -------------------------------------------------------------------------------- /projects/cleargrasp_depth_estimation/experiments/old/pix2pix.yaml: -------------------------------------------------------------------------------- 1 | project: "./projects/cleargrasp_depth_estimation/" 2 | 3 | train: 4 | output_dir: "/home/zk315372/Chinmay/Cleargrasp-Depth-Estimation/pix2pix_lambda100/" 5 | cuda: True 6 | n_iters: 62500 # (2500 (images) / 1 (batch_size)) x 25 ("epochs") 7 | n_iters_decay: 62500 # Extra 62500 iters with lr decay 8 | batch_size: 1 9 | mixed_precision: False 10 | 11 | logging: 12 | freq: 50 13 | multi_modality_split: 14 | A: [3, 3] 15 | B: [1] 16 | wandb: 17 | project: "cleargrasp_depth_estimation" 18 | run: "pix2pix_lambda_100" 19 | 20 | checkpointing: 21 | freq: 5000 22 | 23 | dataset: 24 | _target_: project.datasets.old.train_val_pix2pix_dataset.ClearGraspPix2PixDataset 25 | root: "/home/zk315372/Chinmay/Datasets/Cleargrasp_rgbnormal2depth_resized/train" 26 | load_size: [512, 256] 27 | num_workers: 8 28 | 29 | gan: 30 | _target_: ganslate.nn.gans.paired.Pix2PixConditionalGAN 31 | generator: 32 | _target_: ganslate.nn.generators.Unet2D 33 | in_out_channels: [6, 1] 34 | num_downs: 4 35 | ngf: 64 36 | use_dropout: True 37 | 38 | discriminator: 39 | _target_: ganslate.nn.discriminators.PatchGAN2D 40 | in_channels: 7 41 | n_layers: 3 42 | kernel_size: [4, 4] 43 | ndf: 64 44 | 45 | optimizer: 46 | lr_D: 0.0001 47 | lr_G: 0.0002 48 | lambda_pix2pix: 100 49 | 50 | metrics: 51 | discriminator_evolution: True 52 | ssim: False 53 | 54 | 55 | val: 56 | freq: 2500 57 | dataset: 58 | _target_: project.datasets.old.train_val_pix2pix_dataset.ClearGraspPix2PixDataset 59 | root: "/home/zk315372/Chinmay/Datasets/Cleargrasp_rgbnormal2depth_resized/val" 60 | load_size: [512, 256] 61 | num_workers: 8 62 | metrics: 63 | cycle_metrics: False 64 | -------------------------------------------------------------------------------- /projects/cityscapes_label2photo/experiments/cyclegan.yaml: -------------------------------------------------------------------------------- 1 | project: "/home/chinmay/git/ganslate/projects/horse2zebra" 2 | 3 | train: 4 | output_dir: "/home/chinmay/git/ganslate/checkpoints/label2photo_cyclegan/" 5 | cuda: True 6 | n_iters: 297500 # 2975 x 100 7 | n_iters_decay: 297500 8 | batch_size: 1 9 | mixed_precision: False 10 | 11 | logging: 12 | freq: 100 13 | wandb: 14 | project: "cityscapes_label2photo" 15 | run: "cyclegan_trial" 16 | 17 | checkpointing: 18 | freq: 10000 19 | 20 | dataset: 21 | _target_: ganslate.data.UnpairedImageDataset 22 | root: "/home/chinmay/Datasets/Cityscapes_label2photo/train" 23 | num_workers: 8 24 | image_channels: 3 25 | preprocess: ['resize', 'random_crop', 'random_flip'] 26 | load_size: [286, 572] 27 | final_size: [256, 512] 28 | 29 | gan: 30 | _target_: ganslate.nn.gans.unpaired.CycleGAN 31 | generator: 32 | _target_: ganslate.nn.generators.Resnet2D 33 | n_residual_blocks: 9 34 | in_out_channels: 35 | AB: [3, 3] 36 | 37 | discriminator: 38 | _target_: ganslate.nn.discriminators.PatchGAN2D 39 | n_layers: 3 40 | in_channels: 41 | B: 3 42 | 43 | optimizer: 44 | lr_D: 0.0002 45 | lr_G: 0.0002 46 | lambda_AB: 10.0 47 | lambda_BA: 10.0 48 | lambda_identity: 0 49 | proportion_ssim: 0 50 | 51 | metrics: 52 | discriminator_evolution: True 53 | ssim: True 54 | 55 | 56 | val: 57 | freq: 200 58 | dataset: 59 | _target_: ganslate.data.PairedImageDataset # Paired dataset for validation 60 | root: "/home/chinmay/Datasets/Cityscapes_label2photo/val" 61 | num_workers: 8 62 | image_channels: 3 63 | preprocess: ['resize'] 64 | load_size: [286, 572] 65 | final_size: [256, 512] 66 | metrics: 67 | cycle_metrics: False 68 | -------------------------------------------------------------------------------- /docs/tutorials_advanced/2_custom_G_and_D_architecture.md: -------------------------------------------------------------------------------- 1 | # Custom Generator or Discriminator Architectures 2 | 3 | In image translation GANs, the "generator" can be any network with an architecture that enables accepting as input an image and producing an output image of the same size as the input. Whereas, the discriminator is any network that can take as input these images and produce a real/fake validity score which may either be a scalar or a 2D/3D map with each unit casting a fixed receptive field on the input. In `ganslate`, the generator and discriminator networks are defined as standard _PyTorch_ modules, [constructed by inheriting](https://pytorch.org/tutorials/beginner/basics/buildmodel_tutorial.html) from the type `torch.nn.Module`. In addition to defining your custom generator or discriminator network, you must also define a configuration dataclass for your network in the same file as follows 4 | ```python 5 | from torch import nn 6 | from dataclasses import dataclass 7 | from ganslate import configs 8 | 9 | @dataclass 10 | class CustomGeneratorConfig(configs.base.BaseGeneratorConfig): 11 | name: str = 'CustomGenerator' 12 | n_residual_blocks: int = 9 13 | use_dropout: bool = False 14 | 15 | class CustomGenerator(nn.Module): 16 | """Create a custom generator module""" 17 | def __init__(self, in_channels, out_channels, norm_type, n_residual_blocks, use_dropout): 18 | # Define the class attributes 19 | ... 20 | 21 | def forward(self, input_tensor): 22 | # Define the forward pass operation 23 | ... 24 | ``` 25 | 26 | Ensure that your YAML configuration file includes the pointer to your `CustomGenerator` as well as the appropriate settings 27 | ```yaml 28 | project_dir: projects/your_project 29 | ... 30 | 31 | train: 32 | ... 33 | 34 | gan: 35 | ... 36 | 37 | generator: 38 | name: "CustomGenerator" # Name of your custom generator class 39 | n_residual_blocks: 9 # Configuration 40 | in_out_channels: 41 | AB: [3, 3] 42 | ... 43 | ... 44 | ``` 45 | -------------------------------------------------------------------------------- /ganslate/nn/attention.py: -------------------------------------------------------------------------------- 1 | # Source: https://github.com/heykeetae/Self-Attention-GAN/blob/master/sagan_models.py 2 | # Changes made: 3 | # Change 2D Conv to 3D Conv 4 | # N (dimension of query, key and, value vectors) is DxWxH instead of WxH 5 | # Removed attention matrix return in the forward pass 6 | 7 | import torch 8 | from torch import nn 9 | 10 | 11 | # Both discriminator and generator can use attention blocks 12 | class SelfAttentionBlock(nn.Module): 13 | """ Self attention Layer""" 14 | 15 | def __init__(self, in_dim, activation): 16 | super().__init__() 17 | self.chanel_in = in_dim 18 | self.activation = activation 19 | 20 | self.query_conv = nn.Conv3d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) 21 | self.key_conv = nn.Conv3d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) 22 | self.value_conv = nn.Conv3d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) 23 | self.gamma = nn.Parameter(torch.zeros(1)) 24 | 25 | self.softmax = nn.Softmax(dim=-1) # 26 | 27 | def forward(self, x): 28 | """ 29 | inputs : 30 | x : input feature maps( B X C X D X W X H) 31 | returns : 32 | out : self attention value + input feature 33 | attention: B X N X N (N is Depth*Width*Height) 34 | """ 35 | m_batchsize, C, depth, width, height = x.size() 36 | proj_query = self.query_conv(x).view(m_batchsize, -1, 37 | depth * width * height).permute(0, 2, 1) # B X CX(N) 38 | proj_key = self.key_conv(x).view(m_batchsize, -1, depth * width * height) # B X C x (D*W*H) 39 | energy = torch.bmm(proj_query, proj_key) # transpose check 40 | attention = self.softmax(energy) # BX (N) X (N) 41 | proj_value = self.value_conv(x).view(m_batchsize, -1, depth * width * height) # B X C X N 42 | 43 | out = torch.bmm(proj_value, attention.permute(0, 2, 1)) 44 | out = out.view(m_batchsize, C, depth, width, height) 45 | 46 | out = self.gamma * out + x 47 | return out 48 | -------------------------------------------------------------------------------- /ganslate/utils/trackers/tensorboard.py: -------------------------------------------------------------------------------- 1 | from torch.utils.tensorboard import SummaryWriter 2 | from ganslate.utils.trackers.utils import process_visuals_wandb_tensorboard 3 | 4 | 5 | class TensorboardTracker: 6 | 7 | def __init__(self, conf): 8 | self.writer = SummaryWriter(conf[conf.mode].output_dir) 9 | self.image_window = conf[conf.mode].logging.image_window 10 | 11 | def close(self): 12 | self.writer.close() 13 | 14 | def log_iter(self, 15 | iter_idx, 16 | visuals, 17 | mode, 18 | learning_rates=None, 19 | losses=None, 20 | metrics=None): 21 | # Learning rates 22 | if learning_rates is not None: 23 | for name, learning_rate in learning_rates.items(): 24 | self.writer.add_scalar(f"Learning Rates/{name}", learning_rate, iter_idx) 25 | 26 | # Losses 27 | if losses is not None: 28 | for name, loss in losses.items(): 29 | self.writer.add_scalar(f"Losses/{name}", loss, iter_idx) 30 | 31 | # Metrics 32 | if metrics is not None: 33 | for name, metric in metrics.items(): 34 | self.writer.add_scalar(f"Metrics ({mode})/{name}", metric, iter_idx) 35 | 36 | # Normal images 37 | normal_visuals = process_visuals_wandb_tensorboard(visuals, image_window=None) 38 | self._log_images(iter_idx, normal_visuals, tag=mode) 39 | 40 | # Windowed images 41 | if self.image_window: 42 | windowed_visuals = process_visuals_wandb_tensorboard(visuals, self.image_window) 43 | self._log_images(iter_idx, windowed_visuals, tag=f"{mode}_windowed") 44 | 45 | def _log_images(self, iter_idx, visuals, tag): 46 | visuals = visuals if isinstance(visuals, list) else [visuals] 47 | for idx, visual in enumerate(visuals): 48 | name, image = visual['name'], visual['image'] 49 | name = f"{idx}_{name}" if len(visuals) > 1 else name 50 | self.writer.add_image(f"{tag}/{name}", image, iter_idx, dataformats='HWC') 51 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = ganslate 3 | version = 0.1.1 4 | author = "ganslate team" 5 | # author-email = 6 | url = https://github.com/Maastro-CDS-Imaging-Group/ganslate 7 | description = GAN image-to-image translation framework made simple and extensible. 8 | long_description = file: README.md 9 | long_description_content_type = text/markdown 10 | license = mit 11 | platforms = any 12 | keywords = one, two 13 | classifiers = 14 | Development Status :: 3 - Alpha 15 | Programming Language :: Python :: 3 :: Only 16 | Intended Audience :: Science/Research 17 | project_urls = 18 | Documentation = https://ganslate.readthedocs.io/en/latest/ 19 | 20 | [options] 21 | zip_safe = False 22 | packages = find: 23 | include_package_data = True 24 | #setup_requires = 25 | # Add here dependencies of your project (semicolon/line-separated), e.g. 26 | install_requires = 27 | torch 28 | opencv-python 29 | simpleitk 30 | opencv-python 31 | memcnn 32 | loguru 33 | wandb 34 | tensorboard 35 | monai 36 | scipy 37 | scikit-image 38 | omegaconf 39 | pandas 40 | click 41 | cookiecutter 42 | wget 43 | bumpver 44 | 45 | # The usage of test_requires is discouraged, see `Dependency Management` docs 46 | # tests_require = pytest; pytest-cov 47 | # Require a specific Python version, e.g. Python 2.7 or >= 3.4 48 | # python_requires = >=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.* 49 | 50 | [options.packages.find] 51 | exclude = 52 | docs 53 | docker 54 | tools 55 | tests 56 | notebooks 57 | projects 58 | 59 | 60 | [options.extras_require] 61 | # Add here additional requirements for extra features, to install with: 62 | # `pip install clinical-evaluation[PDF]` like: 63 | # PDF = ReportLab; RXP 64 | # Add here test requirements (semicolon/line-separated) 65 | testing = 66 | pytest 67 | pytest-cov 68 | 69 | [options.entry_points] 70 | console_scripts = 71 | ganslate = ganslate.utils.cli.interface:interface 72 | 73 | [yapf] 74 | based_on_style = google 75 | column_limit = 100 76 | 77 | [pylint] 78 | fail-under = 8 79 | 80 | [pylint.typecheck] 81 | generated-members=torch.*, cv2.* 82 | -------------------------------------------------------------------------------- /ganslate/utils/cli/cookiecutter_templates/new_project/{{ cookiecutter.project_name }}/datasets/infer_dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import torch 3 | 4 | from typing import Tuple 5 | from dataclasses import dataclass 6 | 7 | from torch.utils.data import Dataset 8 | from omegaconf import MISSING 9 | 10 | from ganslate import configs 11 | 12 | 13 | @dataclass 14 | class {{cookiecutter.dataset_name}}InferDatasetConfig(configs.base.BaseDatasetConfig): 15 | # Define other attributes, e.g.: 16 | patch_size: Tuple[int, int] = [128, 128] 17 | ... 18 | 19 | 20 | class {{cookiecutter.dataset_name}}InferDataset(Dataset): 21 | 22 | def __init__(self, conf): 23 | self.root_path = Path(conf.infer.dataset.root).resolve() 24 | 25 | 26 | def __getitem__(self, index): 27 | # Depends on your dataset dir structure 28 | path_A = self.root_path[index] / "A.png" 29 | 30 | # Read the images, `read` is a placeholder 31 | A = read(path_A) 32 | 33 | # Preprocess and normalize to [-1,1], `preprocess` is a placeholder 34 | A = preprocess(A) 35 | 36 | # Metadata is optionally returned by this method, explained at the end of the method. 37 | # Delete if not necessary. 38 | metadata = { 39 | 'path': str(path_A), 40 | ... 41 | } 42 | 43 | return { 44 | # Notice that the key for inference input is not "A" 45 | "input": A, 46 | # [Optional] metadata - if `save()` is defined *and* if it requires metadata. 47 | "metadata": metadata, 48 | } 49 | 50 | def __len__(self): 51 | # Depending on the dataset dir structure, you might want to change it. 52 | return len(self.root_path) 53 | 54 | def save(self, tensor, save_dir, metadata=None): 55 | """ By default, ganslate logs images in png format. However, if you wish 56 | to save images in a different way, then implement this `save()` method. 57 | For example, you could save medical images in their native format for easier 58 | inspection or usage. 59 | If you do not need this method, remove it. 60 | """ 61 | pass 62 | -------------------------------------------------------------------------------- /ganslate/nn/discriminators/patchgan/patchgan3d.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | from torch import nn 3 | from ganslate.nn.utils import get_norm_layer_3d, is_bias_before_norm 4 | 5 | # Config imports 6 | from dataclasses import dataclass 7 | from ganslate import configs 8 | 9 | 10 | @dataclass 11 | class PatchGAN3DConfig(configs.base.BaseDiscriminatorConfig): 12 | ndf: int = 64 13 | n_layers: int = 3 14 | kernel_size: Tuple[int] = (4, 4, 4) 15 | 16 | 17 | class PatchGAN3D(nn.Module): 18 | 19 | def __init__(self, in_channels, ndf, n_layers, kernel_size, norm_type): 20 | super().__init__() 21 | 22 | norm_layer = get_norm_layer_3d(norm_type) 23 | use_bias = is_bias_before_norm(norm_type) 24 | 25 | kw = kernel_size 26 | padw = 1 27 | sequence = [ 28 | nn.Conv3d(in_channels, ndf, kernel_size=kw, stride=2, padding=padw), 29 | nn.LeakyReLU(0.2, True) 30 | ] 31 | 32 | nf_mult = 1 33 | nf_mult_prev = 1 34 | for n in range(1, n_layers): 35 | nf_mult_prev = nf_mult 36 | nf_mult = min(2**n, 8) 37 | sequence += [ 38 | nn.Conv3d(ndf * nf_mult_prev, 39 | ndf * nf_mult, 40 | kernel_size=kw, 41 | stride=2, 42 | padding=padw, 43 | bias=use_bias), 44 | norm_layer(ndf * nf_mult), 45 | nn.LeakyReLU(0.2, True) 46 | ] 47 | 48 | nf_mult_prev = nf_mult 49 | nf_mult = min(2**n_layers, 8) 50 | sequence += [ 51 | nn.Conv3d(ndf * nf_mult_prev, 52 | ndf * nf_mult, 53 | kernel_size=kw, 54 | stride=1, 55 | padding=padw, 56 | bias=use_bias), 57 | norm_layer(ndf * nf_mult), 58 | nn.LeakyReLU(0.2, True) 59 | ] 60 | 61 | sequence += [nn.Conv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] 62 | self.model = nn.Sequential(*sequence) 63 | 64 | def forward(self, input): 65 | return self.model(input) 66 | -------------------------------------------------------------------------------- /ganslate/data/paired_image_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | from pathlib import Path 3 | from typing import Tuple 4 | 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | 8 | from ganslate.data.utils.transforms import get_paired_image_transform 9 | from ganslate.utils.io import make_dataset_of_files 10 | 11 | # Config imports 12 | from dataclasses import dataclass, field 13 | from ganslate import configs 14 | 15 | 16 | EXTENSIONS = ['.jpg', '.jpeg', '.png'] 17 | 18 | 19 | @dataclass 20 | class PairedImageDatasetConfig(configs.base.BaseDatasetConfig): 21 | image_channels: int = 3 22 | # Preprocessing instructions for images at load time: 23 | # Initial resizing: 'resize', 'scale_width' 24 | # Random transforms: 'random_zoom', 'random_crop', 'random_flip' 25 | # Note: During val/test, make sure to not include random transforms in the YAML config 26 | preprocess: Tuple[str] = ('resize', 'random_crop', 'random_flip') 27 | # Sizes in (H, W) format 28 | load_size: Tuple[int, int] = field(default_factory=lambda: [286, 572]) 29 | final_size: Tuple[int, int] = field(default_factory=lambda: [256, 512]) 30 | 31 | 32 | class PairedImageDataset(Dataset): 33 | 34 | def __init__(self, conf): 35 | 36 | self.dir_A = Path(conf[conf.mode].dataset.root) / 'A' 37 | self.dir_B = Path(conf[conf.mode].dataset.root) / 'B' 38 | 39 | self.A_paths = make_dataset_of_files(self.dir_A, EXTENSIONS) 40 | self.B_paths = make_dataset_of_files(self.dir_B, EXTENSIONS) 41 | self.n_samples = len(self.A_paths) 42 | 43 | self.transform = get_paired_image_transform(conf) 44 | self.rgb_or_grayscale = 'RGB' if conf[conf.mode].dataset.image_channels == 3 else 'L' 45 | 46 | def __getitem__(self, index): 47 | index = index % self.n_samples 48 | 49 | A_path = self.A_paths[index] 50 | B_path = self.B_paths[index] 51 | 52 | A_img = Image.open(A_path).convert(self.rgb_or_grayscale) 53 | B_img = Image.open(B_path).convert(self.rgb_or_grayscale) 54 | 55 | A, B = self.transform(A_img, B_img) 56 | 57 | return {'A': A, 'B': B} 58 | 59 | def __len__(self): 60 | return self.n_samples 61 | -------------------------------------------------------------------------------- /ganslate/data/samplers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """ 3 | Copyright (c) DIRECT Contributors 4 | 5 | This source code is licensed under the MIT license found in the 6 | LICENSE file in the root directory of this source tree. 7 | """ 8 | # Taken from Detectron 2, licensed under Apache 2.0. 9 | # Changes: 10 | # - Docstring to match the rest of the library 11 | # - Calls to other subroutines which do not exist in DIRECT. 12 | 13 | import itertools 14 | import torch 15 | 16 | from torch.utils.data.sampler import Sampler 17 | from ganslate.utils import communication 18 | 19 | 20 | class InfiniteSampler(Sampler): 21 | """ 22 | In training, we only care about the "infinite stream" of training data. 23 | So this sampler produces an infinite stream of indices and 24 | all workers cooperate to correctly shuffle the indices and sample different indices. 25 | The samplers in each worker effectively produces `indices[worker_id::num_workers]` 26 | where `indices` is an infinite stream of indices consisting of 27 | `shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True) 28 | or `range(size) + range(size) + ...` (if shuffle is False) 29 | """ 30 | 31 | def __init__(self, size: int, shuffle: bool = True): 32 | """ 33 | Parameters 34 | ---------- 35 | size : int 36 | Length of the underlying dataset. 37 | shuffle : bool 38 | If true, the indices will be shuffled 39 | """ 40 | self._size = size 41 | assert size > 0 42 | self._shuffle = shuffle 43 | self._seed = communication.shared_random_seed() 44 | self._rank = communication.get_rank() 45 | self._world_size = communication.get_world_size() 46 | 47 | def __iter__(self): 48 | start = self._rank 49 | yield from itertools.islice(self._infinite_indices(), start, None, self._world_size) 50 | 51 | def _infinite_indices(self): 52 | g = torch.Generator() 53 | g.manual_seed(self._seed) 54 | while True: 55 | if self._shuffle: 56 | yield from torch.randperm(self._size, generator=g) 57 | else: 58 | yield from torch.arange(self._size) 59 | -------------------------------------------------------------------------------- /projects/cleargrasp_depth_estimation/experiments/pix2pix_new.yaml: -------------------------------------------------------------------------------- 1 | project: "./projects/cleargrasp_depth_estimation/" 2 | 3 | train: 4 | output_dir: "/home/zk315372/Chinmay/Cleargrasp-Depth-Estimation/pix2pix_lambda100/" 5 | cuda: True 6 | n_iters: 62500 # (2500 (images) / 1 (batch_size)) x 25 ("epochs") 7 | n_iters_decay: 62500 # Extra 62500 iters with lr decay 8 | batch_size: 1 9 | mixed_precision: False 10 | seed: 1 11 | 12 | logging: 13 | freq: 50 14 | multi_modality_split: 15 | A: [3, 3] 16 | B: [1] 17 | wandb: 18 | project: "cleargrasp_depth_estimation" 19 | run: "pix2pix_lambda100" 20 | 21 | checkpointing: 22 | freq: 25000 23 | 24 | dataset: 25 | _target_: project.datasets.train_dataset.ClearGraspTrainDataset 26 | root: "/home/zk315372/Chinmay/Datasets/Cleargrasp_rgbnormal2depth_resized/train" 27 | load_size: [512, 256] 28 | paired: True 29 | require_domain_B_rgb: False 30 | num_workers: 8 31 | 32 | gan: 33 | _target_: ganslate.nn.gans.paired.Pix2PixConditionalGAN 34 | generator: 35 | _target_: ganslate.nn.generators.Unet2D 36 | in_out_channels: 37 | AB: [6, 1] 38 | num_downs: 4 39 | ngf: 64 40 | use_dropout: True 41 | 42 | discriminator: 43 | _target_: ganslate.nn.discriminators.PatchGAN2D 44 | in_channels: 45 | B: 7 46 | n_layers: 3 47 | kernel_size: [4, 4] 48 | ndf: 64 49 | 50 | optimizer: 51 | lr_D: 0.0001 52 | lr_G: 0.0002 53 | lambda_pix2pix: 100 54 | 55 | metrics: 56 | discriminator_evolution: True 57 | ssim: False 58 | 59 | 60 | val: 61 | freq: 2500 62 | dataset: 63 | _target_: project.datasets.val_test_dataset.ClearGraspValTestDataset 64 | root: "/home/zk315372/Chinmay/Datasets/Cleargrasp_rgbnormal2depth_resized/val" 65 | load_size: [512, 256] 66 | model_is_cyclegan_balanced: False 67 | num_workers: 8 68 | metrics: 69 | cycle_metrics: False 70 | -------------------------------------------------------------------------------- /ganslate/nn/discriminators/patchgan/patchgan2d.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | from torch import nn 3 | from ganslate.nn.utils import get_norm_layer_2d, is_bias_before_norm 4 | 5 | # Config imports 6 | from dataclasses import dataclass 7 | from ganslate import configs 8 | 9 | 10 | @dataclass 11 | class PatchGAN2DConfig(configs.base.BaseDiscriminatorConfig): 12 | ndf: int = 64 13 | n_layers: int = 3 14 | kernel_size: Tuple[int] = (4, 4) 15 | 16 | 17 | class PatchGAN2D(nn.Module): 18 | 19 | def __init__(self, in_channels, ndf, n_layers, kernel_size, norm_type): 20 | super().__init__() 21 | 22 | norm_layer = get_norm_layer_2d(norm_type) 23 | use_bias = is_bias_before_norm(norm_type) 24 | 25 | kw = kernel_size 26 | padw = 1 27 | sequence = [ 28 | # TODO: instead of 1, give image_channel 29 | nn.Conv2d(in_channels, ndf, kernel_size=kw, stride=2, padding=padw), 30 | nn.LeakyReLU(0.2, True) 31 | ] 32 | 33 | nf_mult = 1 34 | nf_mult_prev = 1 35 | for n in range(1, n_layers): 36 | nf_mult_prev = nf_mult 37 | nf_mult = min(2**n, 8) 38 | sequence += [ 39 | nn.Conv2d(ndf * nf_mult_prev, 40 | ndf * nf_mult, 41 | kernel_size=kw, 42 | stride=2, 43 | padding=padw, 44 | bias=use_bias), 45 | norm_layer(ndf * nf_mult), 46 | nn.LeakyReLU(0.2, True) 47 | ] 48 | 49 | nf_mult_prev = nf_mult 50 | nf_mult = min(2**n_layers, 8) 51 | sequence += [ 52 | nn.Conv2d(ndf * nf_mult_prev, 53 | ndf * nf_mult, 54 | kernel_size=kw, 55 | stride=1, 56 | padding=padw, 57 | bias=use_bias), 58 | norm_layer(ndf * nf_mult), 59 | nn.LeakyReLU(0.2, True) 60 | ] 61 | 62 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] 63 | self.model = nn.Sequential(*sequence) 64 | 65 | def forward(self, input): 66 | return self.model(input) 67 | -------------------------------------------------------------------------------- /ganslate/data/utils/normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def min_max_normalize(image, min_value, max_value): 5 | image = image.float() 6 | image = (image - min_value) / (max_value - min_value) 7 | return 2 * image - 1 8 | 9 | 10 | def min_max_denormalize(image, min_value, max_value): 11 | image += 1 12 | image /= 2 13 | image *= (max_value - min_value) 14 | image += min_value 15 | return image 16 | 17 | 18 | def z_score_normalize(tensor, scale_to_range=None): 19 | """Performs z-score normalization on a tensor and scales to a range if specified.""" 20 | mean = tensor.mean() 21 | std = tensor.std() 22 | 23 | tensor = (tensor - mean) / std 24 | 25 | if scale_to_range: 26 | delta1 = tensor.max() - tensor.min() 27 | delta2 = scale_to_range[1] - scale_to_range[0] 28 | tensor = (delta2 * (tensor - tensor.min()) / delta1) + scale_to_range[0] 29 | 30 | return tensor 31 | 32 | 33 | def z_score_normalize_with_precomputed_stats(tensor, 34 | mean_std, 35 | original_scale=None, 36 | scale_to_range=None): 37 | """Performs z-score normalization on a tensor using precomputed mean, standard deviation, 38 | and, optionally, min-max scale. Optionally scales the normalized values to the specified range. 39 | This function is useful, e.g., when normalizing a slice using its volume's stats. 40 | """ 41 | mean = mean_std[0] 42 | std = mean_std[1] 43 | 44 | tensor = (tensor - mean) / std 45 | 46 | if scale_to_range: 47 | # Volume's min and max values, normalized 48 | original_scale = (torch.Tensor(original_scale) - mean) / std 49 | 50 | delta1 = original_scale[1] - original_scale[0] 51 | delta2 = scale_to_range[1] - scale_to_range[0] 52 | tensor = (delta2 * (tensor - original_scale[0]) / delta1) + scale_to_range[0] 53 | 54 | return tensor 55 | 56 | 57 | def get_stats_for_z_score_denormalization(tensor): 58 | # TODO: to be used in inference 59 | # take into account that the tensor might have been scaled to range 60 | pass 61 | 62 | 63 | def z_score_denormalize(): 64 | pass # TODO 65 | -------------------------------------------------------------------------------- /ganslate/data/unpaired_image_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | from pathlib import Path 3 | from typing import Tuple 4 | 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | 8 | from ganslate.data.utils.transforms import get_single_image_transform 9 | from ganslate.utils.io import make_dataset_of_files 10 | 11 | # Config imports 12 | from dataclasses import dataclass, field 13 | from ganslate import configs 14 | 15 | 16 | EXTENSIONS = ['.jpg', '.jpeg', '.png'] 17 | 18 | 19 | @dataclass 20 | class UnpairedImageDatasetConfig(configs.base.BaseDatasetConfig): 21 | image_channels: int = 3 22 | # Preprocessing instructions for images at load time: 23 | # Initial resizing: 'resize', 'scale_width' 24 | # Random transforms: 'random_zoom', 'random_crop', 'random_flip' 25 | preprocess: Tuple[str] = ('resize', 'random_crop', 'random_flip') 26 | # Sizes in (H, W) format 27 | load_size: Tuple[int, int] = field(default_factory=lambda: [286, 286]) 28 | final_size: Tuple[int, int] = field(default_factory=lambda: [256, 256]) 29 | 30 | 31 | class UnpairedImageDataset(Dataset): 32 | 33 | def __init__(self, conf): 34 | 35 | self.dir_A = Path(conf[conf.mode].dataset.root) / 'A' 36 | self.dir_B = Path(conf[conf.mode].dataset.root) / 'B' 37 | 38 | self.A_paths = make_dataset_of_files(self.dir_A, EXTENSIONS) 39 | self.B_paths = make_dataset_of_files(self.dir_B, EXTENSIONS) 40 | self.A_size = len(self.A_paths) 41 | self.B_size = len(self.B_paths) 42 | 43 | self.transform = get_single_image_transform(conf) 44 | self.rgb_or_grayscale = 'RGB' if conf[conf.mode].dataset.image_channels == 3 else 'L' 45 | 46 | def __getitem__(self, index): 47 | index_A = index % self.A_size 48 | index_B = random.randint(0, self.B_size - 1) 49 | 50 | A_path = self.A_paths[index_A] 51 | B_path = self.B_paths[index_B] 52 | 53 | A_img = Image.open(A_path).convert(self.rgb_or_grayscale) 54 | B_img = Image.open(B_path).convert(self.rgb_or_grayscale) 55 | 56 | A = self.transform(A_img) 57 | B = self.transform(B_img) 58 | 59 | return {'A': A, 'B': B} 60 | 61 | def __len__(self): 62 | return max(self.A_size, self.B_size) 63 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.5494572.svg)](https://doi.org/10.5281/zenodo.5494572) 2 | 3 | [![Python package](https://github.com/ganslate-team/ganslate/actions/workflows/ganslate_testing_suite.yml/badge.svg)](https://github.com/ganslate-team/ganslate/actions/workflows/ganslate_testing_suite.yml) 4 | 5 | # `ganslate` - a simple and extensible GAN image-to-image translation framework 6 | 7 | For comprehensive documentation, visit: https://ganslate.readthedocs.io/en/latest/ 8 | 9 | ***Note**: The documentation is still in progress! Suggestions and contributions are welcome!* 10 | 11 | `ganslate` is a [PyTorch](https://pytorch.org/) framework which aims to make GAN image-to-image translation more accessible to both beginner and advanced project with: 12 | 13 | - Simple configuration system 14 | - Extensibility for other datasets or architectures 15 | - Documentation and [video walk-throughs (soon)](INSERT_YOUTUBE_PLAYLIST) 16 | 17 | ## Features 18 | 19 | - 2D and 3D support 20 | - Mixed precision 21 | - Distributed training 22 | - Tensorboard and [Weights&Biases](https://wandb.ai/site) logging 23 | - Natural and medical image support 24 | - A range of generator and discriminator architectures 25 | 26 | ## Available GANs 27 | 28 | - Pix2Pix ([paper](https://www.google.com/search?q=pix2pix+paper&oq=pix2pix+paper&aqs=chrome.0.0l2j0i22i30l2j0i10i22i30.3304j0j7&sourceid=chrome&ie=UTF-8)) 29 | - CycleGAN ([paper](https://arxiv.org/abs/1703.10593)) 30 | - RevGAN ([paper](https://arxiv.org/abs/1902.02729)) 31 | - CUT (Contrastive Unpaired Translation) ([paper](https://arxiv.org/abs/2007.15651)) 32 | 33 | ## Projects 34 | `ganslate` was used in: 35 | 36 | - Project 1 37 | - Project 2 38 | 39 | ## Citation 40 | 41 | If you used `ganslate` in your project, please cite: 42 | 43 | ```text 44 | @software{ibrahim_hadzic_2021_5494572, 45 | author = {Ibrahim Hadzic and 46 | Suraj Pai and 47 | Chinmay Rao and 48 | Jonas Teuwen}, 49 | title = {ganslate-team/ganslate: Initial public release}, 50 | month = sep, 51 | year = 2021, 52 | publisher = {Zenodo}, 53 | version = {v0.1.0}, 54 | doi = {10.5281/zenodo.5494572}, 55 | url = {https://doi.org/10.5281/zenodo.5494572} 56 | } 57 | ``` 58 | -------------------------------------------------------------------------------- /.github/workflows/ganslate_testing_suite.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Latest version on pypi 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | python-version: [3.7] 20 | 21 | steps: 22 | - uses: actions/checkout@v2 23 | - uses: fregante/setup-git-user@v1 24 | - name: Set up Python ${{ matrix.python-version }} 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: ${{ matrix.python-version }} 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | python -m pip install pytest 32 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 33 | 34 | - name: Install package dependencies 35 | run: | 36 | python setup.py install 37 | - name: Test with pytest 38 | run: | 39 | pytest 40 | - name: Bump version to patch with bumpver 41 | if: contains(github.event.head_commit.message, '[PATCH]') 42 | run: | 43 | bumpver update --patch 44 | - name: Bump version to minor with bumpver 45 | if: contains(github.event.head_commit.message, '[MINOR]') 46 | run: | 47 | bumpver update --minor 48 | - name: Bump version to major with bumpver 49 | if: contains(github.event.head_commit.message, '[MAJOR]') 50 | run: | 51 | bumpver update --major 52 | - name: Install pypa/build 53 | run: >- 54 | python -m 55 | pip install 56 | build 57 | --user 58 | - name: Build a binary wheel and a source tarball 59 | run: >- 60 | python -m 61 | build 62 | --sdist 63 | --wheel 64 | --outdir dist/ 65 | - name: Publish a Python distribution to PyPI 66 | uses: pypa/gh-action-pypi-publish@release/v1 67 | with: 68 | user: __token__ 69 | password: ${{ secrets.PYPI_API_TOKEN }} 70 | skip_existing: true 71 | -------------------------------------------------------------------------------- /projects/cleargrasp_depth_estimation/experiments/old/cyclegan_multimodal_v2.yaml: -------------------------------------------------------------------------------- 1 | project: "./projects/cleargrasp_depth_estimation/" 2 | 3 | train: 4 | output_dir: "./checkpoints/cleargrasp_cycleganv2_5_struct0_5/" 5 | cuda: True 6 | n_iters: 250000 # (2500 (images) / 2 (batch_size)) x 200 ("epochs") 7 | n_iters_decay: 5000 # Extra 5000 iters with lr decay 8 | batch_size: 1 9 | mixed_precision: False 10 | 11 | logging: 12 | freq: 50 13 | multi_modality_split: 14 | A: [3, 3] 15 | B: [3, 1] 16 | wandb: 17 | project: "cleargrasp_depth_estimation" 18 | run: "cyclegan_v2.5_structure0.5" 19 | 20 | checkpointing: 21 | freq: 10000 22 | 23 | dataset: 24 | _target_: project.datasets.old.train_val_cyclegan_dataset.ClearGraspCycleGANDataset 25 | root: "/workspace/Chinmay-Datasets-Ephemeral/Cleargrasp_rgbnormal2depth_resized/train" 26 | load_size: [512, 256] 27 | paired: False 28 | fetch_rgb_b: True # `True` for v2 29 | num_workers: 8 30 | 31 | gan: 32 | _target_: project.modules.CycleGANMultiModalV2 33 | generator: 34 | _target_: ganslate.nn.generators.Unet2D 35 | in_out_channels_AB: [6, 4] 36 | in_out_channels_BA: [4, 6] 37 | num_downs: 5 38 | ngf: 64 39 | 40 | discriminator: 41 | _target_: ganslate.nn.discriminators.PatchGAN2D 42 | in_channels_B: 4 43 | in_channels_A: 6 44 | n_layers: 4 45 | ndf: 64 46 | 47 | optimizer: 48 | lr_D: 0.0001 49 | lr_G: 0.0002 50 | lambda_AB: 10.0 51 | lambda_BA: 10.0 52 | lambda_identity: 0 53 | proportion_ssim: 0 54 | lambda_structure: 0.5 55 | 56 | metrics: 57 | discriminator_evolution: True 58 | ssim: True 59 | 60 | 61 | val: 62 | freq: 5000 63 | dataset: 64 | _target_: project.datasets.old.train_val_cyclegan_dataset.ClearGraspCycleGANDataset 65 | root: "/workspace/Chinmay-Datasets-Ephemeral/Cleargrasp_rgbnormal2depth_resized/val" 66 | load_size: [512, 256] 67 | fetch_rgb_b: True # `True` for v2 68 | paired: True 69 | num_workers: 8 70 | metrics: 71 | cycle_metrics: True 72 | -------------------------------------------------------------------------------- /projects/cleargrasp_depth_estimation/experiments/old/cyclegan_multimodal_v1_structure.yaml: -------------------------------------------------------------------------------- 1 | project: "./projects/cleargrasp_depth_estimation/" 2 | 3 | train: 4 | output_dir: "./checkpoints/cleargrasp_cycleganv1_struct0_5/" 5 | cuda: True 6 | n_iters: 250000 # (2500 (images) / 1 (batch_size)) x 100 ("epochs") 7 | n_iters_decay: 5000 # Extra 5000 iters with lr decay 8 | batch_size: 1 9 | mixed_precision: False 10 | 11 | logging: 12 | freq: 50 13 | multi_modality_split: 14 | A: [3, 3] 15 | B: [1] 16 | wandb: 17 | project: "cleargrasp_depth_estimation" 18 | run: "cyclegan_v1_structure0.5" 19 | 20 | checkpointing: 21 | freq: 10000 22 | 23 | dataset: 24 | _target_: project.datasets.old.train_val_cyclegan_dataset.ClearGraspCycleGANDataset 25 | root: "/workspace/Chinmay-Datasets-Ephemeral/Cleargrasp_rgbnormal2depth_resized/train" 26 | load_size: [512, 256] 27 | paired: False 28 | fetch_rgb_b: False # `False` for v1 29 | num_workers: 8 30 | 31 | gan: 32 | _target_: project.modules.CycleGANMultiModalV1Structure 33 | generator: 34 | _target_: ganslate.nn.generators.Unet2D 35 | in_out_channels_AB: [6, 1] 36 | in_out_channels_BA: [1, 6] 37 | num_downs: 5 38 | ngf: 64 39 | 40 | discriminator: 41 | _target_: ganslate.nn.discriminators.PatchGAN2D 42 | in_channels_B: 1 43 | in_channels_A: 6 44 | n_layers: 4 45 | ndf: 64 46 | 47 | optimizer: 48 | lr_D: 0.0001 49 | lr_G: 0.0002 50 | lambda_AB: 10.0 51 | lambda_BA: 10.0 52 | lambda_identity: 0 53 | proportion_ssim: 0 54 | lambda_structure: 0.5 55 | 56 | metrics: 57 | discriminator_evolution: True 58 | ssim: True 59 | 60 | 61 | val: 62 | freq: 5000 63 | dataset: 64 | _target_: project.datasets.old.train_val_cyclegan_dataset.ClearGraspCycleGANDataset 65 | root: "/workspace/Chinmay-Datasets-Ephemeral/Cleargrasp_rgbnormal2depth_resized/val" 66 | load_size: [512, 256] 67 | fetch_rgb_b: False # `False` for v1 68 | paired: True 69 | num_workers: 8 70 | metrics: 71 | cycle_metrics: True 72 | -------------------------------------------------------------------------------- /projects/cleargrasp_depth_estimation/experiments/old/cyclegan_multimodal_v1.yaml: -------------------------------------------------------------------------------- 1 | project: "./projects/cleargrasp_depth_estimation/" 2 | 3 | train: 4 | output_dir: "/home/zk315372/Chinmay/Cleargrasp-Depth-Estimation/cyclegan_naive/" 5 | cuda: True 6 | n_iters: 62500 # (2500 (images) / 1 (batch_size)) x 25 ("epochs") 7 | n_iters_decay: 62500 # Extra 125000 iters with lr decay 8 | batch_size: 1 9 | mixed_precision: False 10 | 11 | logging: 12 | freq: 50 13 | multi_modality_split: 14 | A: [3, 3] 15 | B: [1] 16 | wandb: 17 | project: "cleargrasp_depth_estimation" 18 | run: "cyclegan_naive" 19 | 20 | checkpointing: 21 | freq: 5000 22 | 23 | dataset: 24 | _target_: project.datasets.old.train_val_cyclegan_dataset.ClearGraspCycleGANDataset 25 | root: "/home/zk315372/Chinmay/Datasets/Cleargrasp_rgbnormal2depth_resized/train" 26 | load_size: [512, 256] 27 | paired: False 28 | fetch_rgb_b: False # `False` for v1 29 | num_workers: 8 30 | 31 | gan: 32 | _target_: ganslate.nn.gans.unpaired.CycleGAN 33 | generator: 34 | _target_: ganslate.nn.generators.Unet2D 35 | in_out_channels_AB: [6, 1] # RGB + Normal -> Depth 36 | in_out_channels_BA: [1, 6] # Depth -> Normal + RGB 37 | num_downs: 4 38 | ngf: 64 39 | use_dropout: True 40 | 41 | discriminator: 42 | _target_: ganslate.nn.discriminators.PatchGAN2D 43 | in_channels_B: 1 44 | in_channels_A: 6 45 | n_layers: 3 46 | kernel_size: [4, 4] 47 | ndf: 64 48 | 49 | optimizer: 50 | lr_D: 0.0001 51 | lr_G: 0.0002 52 | lambda_AB: 10.0 53 | lambda_BA: 10.0 54 | lambda_identity: 0 55 | proportion_ssim: 0 56 | 57 | metrics: 58 | discriminator_evolution: True 59 | ssim: False 60 | 61 | 62 | val: 63 | freq: 2500 64 | dataset: 65 | _target_: project.datasets.old.train_val_cyclegan_dataset.ClearGraspCycleGANDataset 66 | root: "/home/zk315372/Chinmay/Datasets/Cleargrasp_rgbnormal2depth_resized/val" 67 | load_size: [512, 256] 68 | paired: True 69 | fetch_rgb_b: False # `False` for v1 70 | num_workers: 8 71 | metrics: 72 | cycle_metrics: False 73 | -------------------------------------------------------------------------------- /projects/cleargrasp_depth_estimation/experiments/old/cyclegan_multimodal_v3.yaml: -------------------------------------------------------------------------------- 1 | project: "./projects/cleargrasp_depth_estimation/" 2 | 3 | train: 4 | output_dir: "/home/zk315372/Chinmay/Cleargrasp-Depth-Estimation/cyclegan_balanced/" 5 | cuda: True 6 | n_iters: 62500 # (2500 (images) / 1 (batch_size)) x 25 ("epochs") 7 | n_iters_decay: 62500 # Extra 62500 iters with lr decay 8 | batch_size: 1 9 | mixed_precision: False 10 | 11 | logging: 12 | freq: 50 13 | multi_modality_split: 14 | A: [3, 3] 15 | B: [3, 1] 16 | wandb: 17 | project: "cleargrasp_depth_estimation" 18 | run: "cyclegan_balanced" 19 | 20 | checkpointing: 21 | freq: 5000 22 | 23 | dataset: 24 | _target_: project.datasets.old.train_val_cyclegan_dataset.ClearGraspCycleGANDataset 25 | root: "/home/zk315372/Chinmay/Datasets/Cleargrasp_rgbnormal2depth_resized/train" 26 | load_size: [512, 256] 27 | paired: False 28 | fetch_rgb_b: True # `True` for v3 29 | num_workers: 8 30 | 31 | gan: 32 | _target_: project.modules.CycleGANMultiModalV3 33 | generator: 34 | _target_: ganslate.nn.generators.Unet2D 35 | in_out_channels_AB: [6, 1] # RGB + Normal -> Depth 36 | in_out_channels_BA: [4, 3] # RGB + Depth -> Normal 37 | num_downs: 4 38 | ngf: 64 39 | use_dropout: True 40 | 41 | discriminator: 42 | _target_: ganslate.nn.discriminators.PatchGAN2D 43 | in_channels_B: 1 44 | in_channels_A: 3 45 | n_layers: 3 46 | kernel_size: [4, 4] 47 | ndf: 64 48 | 49 | optimizer: 50 | lr_D: 0.0001 51 | lr_G: 0.0002 52 | lambda_AB: 10.0 53 | lambda_BA: 10.0 54 | lambda_identity: 0 55 | proportion_ssim: 0 56 | 57 | metrics: 58 | discriminator_evolution: True 59 | ssim: False 60 | 61 | 62 | val: 63 | freq: 2500 64 | dataset: 65 | _target_: project.datasets.old.train_val_cyclegan_dataset.ClearGraspCycleGANDataset 66 | root: "/home/zk315372/Chinmay/Datasets/Cleargrasp_rgbnormal2depth_resized/val" 67 | load_size: [512, 256] 68 | fetch_rgb_b: True # `True` for v3 69 | paired: True 70 | num_workers: 8 71 | metrics: 72 | cycle_metrics: False 73 | -------------------------------------------------------------------------------- /ganslate/nn/discriminators/patchgan/selfattention_patchgan3d.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | from torch import nn 3 | from ganslate.nn.utils import get_norm_layer_3d, is_bias_before_norm 4 | from ganslate.nn import attention 5 | 6 | # Config imports 7 | from dataclasses import dataclass 8 | from ganslate import configs 9 | 10 | 11 | @dataclass 12 | class SelfAttentionPatchGAN3DConfig(configs.base.BaseDiscriminatorConfig): 13 | ndf: int = 64 14 | n_layers: int = 3 15 | kernel_size: Tuple[int] = (4, 4, 4) 16 | 17 | 18 | class SelfAttentionPatchGAN3D(nn.Module): 19 | 20 | def __init__(self, in_channels, ndf, n_layers, kernel_size, norm_type): 21 | super().__init__() 22 | 23 | norm_layer = get_norm_layer_3d(norm_type) 24 | use_bias = is_bias_before_norm(norm_type) 25 | 26 | kw = kernel_size 27 | padw = 1 28 | 29 | # Stride changed to 3 to allow memory to fit! 30 | sequence = [ 31 | nn.Conv3d(in_channels, ndf, kernel_size=kw, stride=3, padding=padw), 32 | nn.LeakyReLU(0.2, True) 33 | ] 34 | 35 | nf_mult = 1 36 | nf_mult_prev = 1 37 | for n in range(1, n_layers): 38 | nf_mult_prev = nf_mult 39 | nf_mult = min(2**n, 8) 40 | sequence += [ 41 | nn.Conv3d(ndf * nf_mult_prev, 42 | ndf * nf_mult, 43 | kernel_size=kw, 44 | stride=2, 45 | padding=padw, 46 | bias=use_bias), 47 | norm_layer(ndf * nf_mult), 48 | nn.LeakyReLU(0.2, True) 49 | ] 50 | 51 | sequence += [attention.SelfAttentionBlock(ndf * nf_mult, 'relu')] 52 | 53 | nf_mult_prev = nf_mult 54 | nf_mult = min(2**n_layers, 8) 55 | sequence += [ 56 | nn.Conv3d(ndf * nf_mult_prev, 57 | ndf * nf_mult, 58 | kernel_size=kw, 59 | stride=1, 60 | padding=padw, 61 | bias=use_bias), 62 | norm_layer(ndf * nf_mult), 63 | nn.LeakyReLU(0.2, True) 64 | ] 65 | 66 | sequence += [attention.SelfAttentionBlock(ndf * nf_mult, 'relu')] 67 | sequence += [nn.Conv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] 68 | self.model = nn.Sequential(*sequence) 69 | 70 | def forward(self, input): 71 | return self.model(input) 72 | -------------------------------------------------------------------------------- /ganslate/configs/validation_testing.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Any, Dict 2 | from dataclasses import dataclass 3 | from omegaconf import MISSING, II 4 | from ganslate.configs import base 5 | 6 | 7 | @dataclass 8 | class SlidingWindowConfig: 9 | # https://docs.monai.io/en/latest/inferers.html#monai.inferers.SlidingWindowInferer 10 | window_size: Tuple[int] = MISSING 11 | batch_size: int = 1 12 | overlap: float = 0.25 13 | mode: str = 'gaussian' 14 | 15 | 16 | ######################## Val and Test Metrics Configs ######################### 17 | 18 | 19 | @dataclass 20 | class BaseValTestMetricsConfig: 21 | # SSIM metric between the images 22 | ssim: bool = True 23 | # PSNR metric between the images 24 | psnr: bool = True 25 | # Normalized MSE 26 | nmse: bool = True 27 | # MSE 28 | mse: bool = True 29 | # Abs diff between the two images 30 | mae: bool = True 31 | # Normalized Mutual Information 32 | nmi: bool = False 33 | # Chi-squared Histogram Distance 34 | histogram_chi2: bool = False 35 | 36 | 37 | @dataclass 38 | class ValMetricsConfig(BaseValTestMetricsConfig): 39 | # Set to true if cycle metrics need to be logged (i.e between original and reconstructed image) 40 | cycle_metrics: bool = True 41 | 42 | 43 | @dataclass 44 | class TestMetricsConfig(BaseValTestMetricsConfig): 45 | # True if the metrics comparing input and ground truth are to be computed be as well 46 | compute_over_input: bool = False 47 | # Save per image metrics to a CSV for further analysis 48 | save_to_csv: bool = True 49 | 50 | 51 | ######################## Val and Test General Configs ######################### 52 | 53 | 54 | @dataclass 55 | class BaseValTestConfig(base.BaseEngineConfig): 56 | sliding_window: Optional[SlidingWindowConfig] = None 57 | dataset: Optional[base.BaseDatasetConfig] = None 58 | # Val/test can have multiple datasets provided to it 59 | multi_dataset: Optional[Dict[str, base.BaseDatasetConfig]] = None 60 | 61 | 62 | @dataclass 63 | class ValidationConfig(BaseValTestConfig): 64 | # How frequently to validate (each `freq` iters) 65 | freq: int = MISSING 66 | # After which iteration should validation begin 67 | start_after: int = 0 68 | metrics: ValMetricsConfig = ValMetricsConfig() 69 | 70 | 71 | @dataclass 72 | class TestConfig(BaseValTestConfig): 73 | checkpointing: base.CheckpointingConfig = base.CheckpointingConfig() 74 | metrics: TestMetricsConfig = TestMetricsConfig() 75 | -------------------------------------------------------------------------------- /ganslate/data/utils/image_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | 5 | class ImagePool(): 6 | """This class implements an image buffer that stores previously generated images. 7 | 8 | This buffer enables us to update discriminators using a history of generated images 9 | rather than the ones produced by the latest generators. 10 | """ 11 | 12 | def __init__(self, pool_size): 13 | """Initialize the ImagePool class 14 | 15 | Parameters: 16 | pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created 17 | """ 18 | self.pool_size = pool_size 19 | # Create an empty pool 20 | if self.pool_size > 0: 21 | self.num_imgs = 0 22 | self.images = [] 23 | 24 | def query(self, images): 25 | """Return an image from the pool. 26 | 27 | Parameters: 28 | images: the latest generated images from the generator 29 | 30 | Returns images from the buffer. 31 | 32 | By 50/100, the buffer will return input images. 33 | By 50/100, the buffer will return images previously stored in the buffer, 34 | and insert the current images to the buffer. 35 | """ 36 | if self.pool_size == 0: # if the buffer size is 0, do nothing 37 | return images 38 | return_images = [] 39 | for image in images: 40 | image = torch.unsqueeze(image.data, 0) 41 | # If the buffer is not full, keep inserting current images to the buffer 42 | if self.num_imgs < self.pool_size: 43 | self.num_imgs = self.num_imgs + 1 44 | self.images.append(image) 45 | return_images.append(image) 46 | else: 47 | p = random.uniform(0, 1) 48 | # By 50% chance, the buffer will return a previously stored image, 49 | # and insert the current image into the buffer 50 | if p > 0.5: 51 | random_id = random.randint(0, self.pool_size - 1) 52 | tmp = self.images[random_id].clone() 53 | self.images[random_id] = image 54 | return_images.append(tmp) 55 | # By another 50% chance, the buffer will return the current image 56 | else: 57 | return_images.append(image) 58 | # Collect all the images and return 59 | return_images = torch.cat(return_images, 0) 60 | return return_images 61 | -------------------------------------------------------------------------------- /ganslate/nn/discriminators/patchgan/multiscale_patchgan3d.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | from torch import nn 3 | import torch 4 | from ganslate.nn.utils import get_norm_layer_3d, is_bias_before_norm 5 | import monai 6 | # Config imports 7 | from dataclasses import dataclass 8 | from ganslate import configs 9 | 10 | # Network imports 11 | from ganslate.nn.discriminators.patchgan import patchgan3d 12 | 13 | 14 | def get_cropped_patch(input: torch.Tensor, scale: int = 1) -> torch.Tensor: 15 | """ 16 | Get a downscaled patch from the input tensor. 17 | The scale determines how much to reduce the size by. A scale of 2 would mean a patch half the size. 18 | 19 | The patch is extracted randomly from the input tensor 20 | """ 21 | # Monai transforms expect shape in CDHW format 22 | crop_to_shape = (input.shape[1], input.shape[2] // scale, input.shape[3] // scale, 23 | input.shape[4] // scale) 24 | # Random center enabled but with fixed size. 25 | crop_transform = monai.transforms.RandSpatialCrop(crop_to_shape, 26 | random_center=True, 27 | random_size=False) 28 | cropped_input = crop_transform(input) 29 | return cropped_input 30 | 31 | 32 | @dataclass 33 | class MultiScalePatchGAN3DConfig(configs.base.BaseDiscriminatorConfig): 34 | ndf: int = 64 35 | n_layers: int = 3 36 | kernel_size: Tuple[int] = (4, 4, 4) 37 | 38 | # Each scale will reduce the input size to the discriminator by 1/x a factor. 39 | # So if scales=3 the discriminator will discriminate on original, 40 | # a patch 1/2 size and a patch 1/3 sized sampled randomly 41 | scales: int = 2 42 | 43 | 44 | class MultiScalePatchGAN3D(nn.Module): 45 | 46 | def __init__(self, in_channels, ndf, n_layers, kernel_size, scales, norm_type): 47 | super().__init__() 48 | # Multiscale PatchGAN consists of multiple PatchGANs. 49 | self.model = nn.ModuleDict() 50 | for scale in range(1, scales + 1): 51 | self.model[str(scale)] = patchgan3d.PatchGAN3D(in_channels, ndf, n_layers, kernel_size, 52 | norm_type) 53 | 54 | def forward(self, input): 55 | model_outputs = {} 56 | for scale, model in self.model.items(): 57 | patch = get_cropped_patch(input, scale=int(scale)) 58 | model_outputs[str(scale)] = model.forward(patch) 59 | 60 | return model_outputs 61 | -------------------------------------------------------------------------------- /projects/cleargrasp_depth_estimation/experiments/cyclegan_naive.yaml: -------------------------------------------------------------------------------- 1 | project: "./projects/cleargrasp_depth_estimation/" 2 | 3 | # output_dir: 4 | train: 5 | output_dir: "/home/zk315372/Chinmay/Cleargrasp-Depth-Estimation/cyclegan_naive/" 6 | cuda: True 7 | n_iters: 62500 # (2500 (images) / 1 (batch_size)) x 25 ("epochs") 8 | n_iters_decay: 62500 # Extra 62500 iters with lr decay 9 | batch_size: 1 10 | mixed_precision: False 11 | seed: 1 12 | 13 | logging: 14 | freq: 50 15 | multi_modality_split: 16 | A: [3, 3] 17 | B: [1] 18 | wandb: 19 | project: "cleargrasp_depth_estimation" 20 | run: "cyclegan_naive" 21 | 22 | checkpointing: 23 | freq: 25000 24 | # load_iter: 125000 ## 25 | 26 | dataset: 27 | _target_: project.datasets.train_dataset.ClearGraspTrainDataset 28 | root: "/home/zk315372/Chinmay/Datasets/Cleargrasp_rgbnormal2depth_resized/train" 29 | load_size: [512, 256] 30 | paired: False # Unpaired 31 | require_domain_B_rgb: False # Not required for cyclegan-naive 32 | num_workers: 8 33 | 34 | gan: 35 | _target_: ganslate.nn.gans.unpaired.CycleGAN 36 | generator: 37 | _target_: ganslate.nn.generators.Unet2D 38 | in_out_channels: 39 | AB: [6, 1] # RGB + Normal -> Depth 40 | BA: [1, 6] # Depth -> Normal + RGB 41 | num_downs: 4 42 | ngf: 64 43 | use_dropout: True 44 | 45 | discriminator: 46 | _target_: ganslate.nn.discriminators.PatchGAN2D 47 | in_channels: 48 | B: 1 49 | A: 6 50 | n_layers: 3 51 | kernel_size: [4, 4] 52 | ndf: 64 53 | 54 | optimizer: 55 | lr_D: 0.0001 56 | lr_G: 0.0002 57 | lambda_AB: 10.0 58 | lambda_BA: 10.0 59 | lambda_identity: 0 60 | proportion_ssim: 0 61 | 62 | metrics: 63 | discriminator_evolution: True 64 | ssim: False 65 | 66 | 67 | val: 68 | freq: 2500 69 | dataset: 70 | _target_: project.datasets.val_test_dataset.ClearGraspValTestDataset 71 | root: "/home/zk315372/Chinmay/Datasets/Cleargrasp_rgbnormal2depth_resized/val" 72 | load_size: [512, 256] 73 | model_is_cyclegan_balanced: False # False 74 | num_workers: 8 75 | metrics: 76 | cycle_metrics: False 77 | -------------------------------------------------------------------------------- /ganslate/utils/trackers/inference.py: -------------------------------------------------------------------------------- 1 | #from pathlib import Path 2 | from loguru import logger 3 | import time 4 | 5 | import torch 6 | 7 | from ganslate.utils import communication 8 | from ganslate.utils.trackers.base import BaseTracker 9 | from ganslate.utils.trackers.utils import (process_visuals_for_logging, 10 | concat_batch_of_visuals_after_gather) 11 | 12 | 13 | class InferenceTracker(BaseTracker): 14 | 15 | def __init__(self, conf): 16 | super().__init__(conf) 17 | self.logger = logger 18 | 19 | def log_iter(self, visuals, len_dataset): 20 | 21 | def parse_visuals(visuals): 22 | # Gather visuals from different processes to the rank 0 process 23 | visuals = communication.gather(visuals) 24 | visuals = concat_batch_of_visuals_after_gather(visuals) 25 | visuals = process_visuals_for_logging(self.conf, visuals, single_example=False) 26 | return visuals 27 | 28 | def log_message(): 29 | # In case of DDP, if (len_dataset % number of processes != 0), 30 | # it will show more iters than there actually are 31 | if self.iter_idx > len_dataset: 32 | self.iter_idx = len_dataset 33 | 34 | message = (f"{self.iter_idx}/{len_dataset} - loading: {self.t_data:.2f}s", 35 | f" | inference: {self.t_comp:.2f}s | saving: {self.t_save:.2f}s") 36 | self.logger.info(message) 37 | 38 | visuals = parse_visuals(visuals) 39 | log_message() 40 | 41 | for i, visuals_grid in enumerate(visuals): 42 | # In DDP, each process is for a different iter, so incrementing it accordingly 43 | self._save_image(visuals_grid, self.iter_idx + i) 44 | 45 | if self.wandb: 46 | self.wandb.log_iter(iter_idx=self.iter_idx + i, 47 | visuals=visuals_grid, 48 | mode="infer") 49 | 50 | if self.tensorboard: 51 | self.tensorboard.log_iter(iter_idx=self.iter_idx + i, 52 | visuals=visuals_grid, 53 | mode="infer") 54 | 55 | 56 | def start_saving_timer(self): 57 | self.saving_start_time = time.time() 58 | 59 | def end_saving_timer(self): 60 | self.t_save = (time.time() - self.saving_start_time) / self.batch_size 61 | # Reduce computational time data point (avg) and send to the process of rank 0 62 | self.t_save = communication.reduce(self.t_save, average=True, all_reduce=False) 63 | -------------------------------------------------------------------------------- /projects/maastro_hx4_pet_translation/experiments/pix2pix.yaml: -------------------------------------------------------------------------------- 1 | project: "./projects/maastro_hx4_pet_translation/" 2 | 3 | train: 4 | output_dir: "/workspace/Chinmay-Checkpoints-Ephemeral/HX4-PET-Translation/hx4_pet_pix2pix_lambda10/" 5 | cuda: True 6 | n_iters: 30000 7 | n_iters_decay: 30000 8 | batch_size: 1 9 | mixed_precision: False 10 | seed: 1 11 | 12 | logging: 13 | freq: 50 14 | multi_modality_split: 15 | A: [1, 1] 16 | B: [1] 17 | wandb: 18 | project: "maastro_hx4_pet_translation" 19 | run: "pix2pix_lambda10" 20 | 21 | checkpointing: 22 | freq: 1000 23 | 24 | dataset: 25 | _target_: project.datasets.train_dataset.HX4PETTranslationTrainDataset 26 | root: "/workspace/Chinmay-Datasets-Ephemeral/HX4-PET-Translation-Processed/train" 27 | paired: True # Paired training 28 | patch_size: [32, 128, 128] # (D,H,W) 29 | patch_sampling: uniform-random-within-body 30 | num_workers: 8 31 | 32 | gan: 33 | _target_: ganslate.nn.gans.paired.Pix2PixConditionalGAN 34 | generator: 35 | _target_: ganslate.nn.generators.Unet3D 36 | in_out_channels: 37 | AB: [2, 1] 38 | num_downs: 4 39 | ngf: 64 40 | use_dropout: True 41 | 42 | discriminator: 43 | _target_: ganslate.nn.discriminators.PatchGAN3D 44 | in_channels: 45 | B: 3 46 | n_layers: 3 47 | kernel_size: [4, 4, 4] 48 | ndf: 64 49 | 50 | optimizer: 51 | lr_D: 0.0001 52 | lr_G: 0.0002 53 | lambda_pix2pix: 10.0 54 | 55 | metrics: 56 | discriminator_evolution: True 57 | 58 | 59 | val: 60 | freq: 1000 61 | dataset: 62 | _target_: project.datasets.val_test_dataset.HX4PETTranslationValTestDataset 63 | root: "/workspace/Chinmay-Datasets-Ephemeral/HX4-PET-Translation-Processed/val" 64 | num_workers: 8 65 | supply_masks: False # Do not supply masks -> less number of val metrics to compute, faster training 66 | use_patch_based_inference: True # 67 | sliding_window: # Enable sliding window inferer 68 | window_size: ${train.dataset.patch_size} 69 | metrics: # Enabled metrics: MSE, MAE, PSNR, SSIM, NMI, Chi-squared histogram distance 70 | mse: True 71 | mae: True 72 | nmse: False 73 | psnr: True 74 | ssim: True 75 | nmi: True 76 | histogram_chi2: True 77 | cycle_metrics: False -------------------------------------------------------------------------------- /ganslate/utils/metrics/train_metrics.py: -------------------------------------------------------------------------------- 1 | import ganslate.nn.losses.utils.ssim as ssim 2 | import torch 3 | 4 | 5 | class TrainingMetrics: 6 | 7 | def __init__(self, conf): 8 | self.output_distributions = True if conf.train.metrics.discriminator_evolution else False 9 | 10 | if conf.train.metrics.ssim: 11 | self.ssim = ssim.SSIMLoss() 12 | else: 13 | self.ssim = None 14 | 15 | def get_output_metric_D(self, out): 16 | """ 17 | Store fake and real discriminator outputs to analyze training convergence: 18 | Based on ADA-StyleGAN observations: 19 | https://medium.com/swlh/training-gans-with-limited-data-22a7c8ffce78 20 | """ 21 | if self.output_distributions: 22 | # Reduce the output to a tensor if it is dict 23 | if isinstance(out, dict): 24 | out = torch.tensor([elem.detach().mean() for elem in out.values()]) 25 | 26 | else: 27 | out = out.detach() 28 | if len(out.size()) > 1: 29 | return out.mean() 30 | else: 31 | return out 32 | else: 33 | return None 34 | 35 | def get_SSIM_metric(self, input, target): 36 | # Gradient computation not needed for metric computation 37 | input = input.detach() 38 | target = target.detach() 39 | # Data range needs to be positive and normalized 40 | # https://github.com/VainF/pytorch-msssim#2-normalized-input 41 | input = (input + 1) / 2 42 | target = (target + 1) / 2 43 | 44 | if self.ssim: 45 | return 1 - self.ssim(input, target, data_range=1) 46 | else: 47 | return None 48 | 49 | def compute_metrics_D(self, discriminator, pred_real, pred_fake): 50 | # Update metrics with output distributions if enabled 51 | return { 52 | f"{discriminator}_real": self.get_output_metric_D(pred_real), 53 | f"{discriminator}_fake": self.get_output_metric_D(pred_fake) 54 | } 55 | 56 | def compute_metrics_G(self, visuals): 57 | # Update metrics with SSIM if enabled in config 58 | metrics_G = {} 59 | if all([key in visuals for key in ["rec_A", "real_A"]]): 60 | # Update SSIM for forward A->B->A reconstruction 61 | metrics_G['ssim_A'] = self.get_SSIM_metric(visuals["real_A"], visuals["rec_A"]) 62 | 63 | if all([key in visuals for key in ["rec_B", "real_B"]]): 64 | # Update SSIM for forward B->A->B reconstruction 65 | metrics_G['ssim_B'] = self.get_SSIM_metric(visuals["real_B"], visuals["rec_B"]) 66 | 67 | return metrics_G 68 | -------------------------------------------------------------------------------- /projects/brats_mri_sequence_translation/datasets/val_test_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Dataset 6 | import SimpleITK as sitk 7 | from ganslate.utils.io import make_dataset_of_files 8 | from ganslate.utils import sitk_utils 9 | from ganslate.data.utils.normalization import z_score_normalize 10 | from ganslate.data.utils.stochastic_focal_patching import StochasticFocalPatchSampler 11 | 12 | # Config imports 13 | from typing import Tuple 14 | from dataclasses import dataclass, field 15 | from omegaconf import MISSING 16 | from ganslate import configs 17 | 18 | 19 | @dataclass 20 | class BratsValTestDatasetConfig(configs.base.BaseDatasetConfig): 21 | source_sequence: str = "flair" 22 | target_sequence: str = "t1w" 23 | 24 | 25 | EXTENSIONS = ['.nii.gz'] 26 | 27 | # MRI sequences z-axis indices in Brats 28 | SEQUENCE_MAP = {"flair": 0, "t1w": 1, "t1gd": 2, "t2w": 3} 29 | 30 | 31 | def get_mri_sequence(sitk_image, sequence_name): 32 | z_index = SEQUENCE_MAP[sequence_name.lower()] 33 | 34 | size = list(sitk_image.GetSize()) 35 | size[3] = 0 36 | index = [0, 0, 0, z_index] 37 | 38 | extractor = sitk.ExtractImageFilter() 39 | extractor.SetSize(size) 40 | extractor.SetIndex(index) 41 | return extractor.Execute(sitk_image) 42 | 43 | 44 | class BratsValTestDataset(Dataset): 45 | 46 | def __init__(self, conf): 47 | dir_brats = conf[conf.mode].dataset.root 48 | self.paths_brats = make_dataset_of_files(dir_brats, EXTENSIONS) 49 | self.num_datapoints = len(self.paths_brats) 50 | 51 | self.source_sequence = conf[conf.mode].dataset.source_sequence 52 | self.target_sequence = conf[conf.mode].dataset.target_sequence 53 | 54 | def __getitem__(self, index): 55 | mri = sitk_utils.load(self.paths_brats[index]) 56 | 57 | A = get_mri_sequence(mri, self.source_sequence) 58 | B = get_mri_sequence(mri, self.target_sequence) 59 | 60 | A = sitk_utils.get_tensor(A) 61 | B = sitk_utils.get_tensor(B) 62 | 63 | # Z-score normalization per volume 64 | A = z_score_normalize(A, scale_to_range=(-1, 1)) 65 | B = z_score_normalize(B, scale_to_range=(-1, 1)) 66 | 67 | # Add channel dimension (1 = grayscale) 68 | A = A.unsqueeze(0) 69 | B = B.unsqueeze(0) 70 | 71 | return {'A': A, 'B': B} 72 | 73 | def __len__(self): 74 | return self.num_datapoints 75 | 76 | def denormalize(self, tensor): 77 | """Allows the Tester and Validator to calculate the metrics in 78 | the original range of values. 79 | """ 80 | # TODO: TEMPORARYYYYY 81 | return (tensor + 1) / 2 82 | -------------------------------------------------------------------------------- /ganslate/utils/cli/cookiecutter_templates/your_first_run/{{ cookiecutter.project_name }}/default.yaml: -------------------------------------------------------------------------------- 1 | project: null 2 | 3 | train: 4 | output_dir: "./checkpoints/facades_default" 5 | cuda: {{ cookiecutter.enable_cuda }} 6 | n_iters: {{ cookiecutter.number_of_iterations }} 7 | n_iters_decay: {{ cookiecutter.number_of_iterations }} 8 | batch_size: {{ cookiecutter.batch_size }} 9 | mixed_precision: False 10 | 11 | logging: 12 | freq: {{ cookiecutter.logging_frequency }} 13 | 14 | checkpointing: 15 | freq: {{ cookiecutter.checkpointing_frequency }} 16 | 17 | dataset: 18 | _target_: ganslate.data.UnpairedImageDataset 19 | root: "{{ cookiecutter.path }}/{{ cookiecutter.project_name }}/facades/train/" 20 | num_workers: 4 21 | image_channels: 3 22 | preprocess: ["resize", "random_flip"] 23 | load_size: [256, 256] 24 | 25 | gan: 26 | _target_: ganslate.nn.gans.unpaired.CycleGAN 27 | 28 | generator: 29 | _target_: ganslate.nn.generators.{{ cookiecutter.generator_model }} 30 | n_residual_blocks: 9 31 | in_out_channels: 32 | AB: [3, 3] 33 | 34 | discriminator: 35 | _target_: ganslate.nn.discriminators.PatchGAN2D 36 | n_layers: 3 37 | in_channels: 38 | B: 3 39 | 40 | optimizer: 41 | lambda_AB: 10.0 42 | lambda_BA: 10.0 43 | lambda_identity: 0 44 | proportion_ssim: {{ cookiecutter.cycle_consistency_ssim_percentage }} 45 | lr_D: 0.0002 46 | lr_G: 0.0004 47 | 48 | metrics: 49 | discriminator_evolution: True 50 | ssim: True 51 | 52 | # Uncomment to enable validation, a folder called val needs to be 53 | # created with some A->B paired samples within it. 54 | # val: 55 | # freq: {{ cookiecutter.logging_frequency }} * 10 56 | # dataset: 57 | # _target_: ganslate.data.PairedImageDataset # Paired dataset for validation 58 | # root: "{{ cookiecutter.path }}/{{ cookiecutter.project_name }}/facades/val" 59 | # num_workers: 4 60 | # image_channels: 3 61 | # preprocess: ["resize"] 62 | # load_size: [256, 256] 63 | # flip: False 64 | # metrics: 65 | # cycle_metrics: False 66 | 67 | infer: 68 | checkpointing: 69 | load_iter: 1 70 | dataset: 71 | _target_: ganslate.data.UnpairedImageDataset 72 | root: "{{ cookiecutter.path }}/{{ cookiecutter.project_name }}/facades/test/" 73 | num_workers: 4 74 | image_channels: 3 75 | preprocess: ["resize"] 76 | load_size: [256, 256] 77 | -------------------------------------------------------------------------------- /ganslate/utils/cli/cookiecutter_templates/new_project/{{ cookiecutter.project_name }}/experiments/template_experiment.yaml: -------------------------------------------------------------------------------- 1 | project: {{cookiecutter.path}} 2 | 3 | train: 4 | output_dir: "./checkpoints/{{ cookiecutter.project_name }}" 5 | cuda: True 6 | n_iters: {{ cookiecutter.number_of_iterations }} 7 | n_iters_decay: {{ cookiecutter.number_of_iterations }} 8 | batch_size: {{ cookiecutter.batch_size }} 9 | mixed_precision: False 10 | 11 | logging: 12 | freq: {{ cookiecutter.logging_frequency }} 13 | 14 | checkpointing: 15 | freq: {{ cookiecutter.checkpointing_frequency }} 16 | 17 | dataset: 18 | _target_: project.datasets.{{cookiecutter.dataset_name}}TrainDataset 19 | root: "{{ cookiecutter.path }}/{{ cookiecutter.project_name }}/maps/train/" 20 | num_workers: 4 21 | image_channels: 3 22 | preprocess: ["resize", "random_flip"] 23 | load_size: [256, 256] 24 | 25 | gan: 26 | _target_: ganslate.nn.gans.unpaired.CycleGAN 27 | 28 | generator: 29 | _target_: ganslate.nn.generators.{{ cookiecutter.generator_model }} 30 | in_out_channels: 31 | AB: [3, 3] 32 | 33 | discriminator: 34 | _target_: ganslate.nn.discriminators.PatchGAN2D 35 | n_layers: 3 36 | in_channels: 37 | B: 3 38 | 39 | optimizer: 40 | lambda_AB: 10.0 41 | lambda_BA: 10.0 42 | lambda_identity: 0 43 | proportion_ssim: {{ cookiecutter.cycle_consistency_ssim_percentage }} 44 | lr_D: 0.0002 45 | lr_G: 0.0004 46 | 47 | metrics: 48 | discriminator_evolution: True 49 | ssim: True 50 | 51 | # Uncomment to enable validation, a folder called val needs to be 52 | # created with some A->B paired samples within it. 53 | # val: 54 | # freq: {{ cookiecutter.logging_frequency }} * 10 55 | # dataset: 56 | # _target_: project.datasets.{{cookiecutter.dataset_name}}ValTestDataset # Paired dataset for validation 57 | # root: "{{ cookiecutter.path }}/{{ cookiecutter.project_name }}/maps/val" 58 | # num_workers: 4 59 | # image_channels: 3 60 | # preprocess: ["resize"] 61 | # load_size: [256, 256] 62 | # flip: False 63 | # metrics: 64 | # cycle_metrics: False 65 | 66 | infer: 67 | checkpointing: 68 | load_iter: 1 69 | dataset: 70 | _target_: project.datasets.{{cookiecutter.dataset_name}}InferenceDataset 71 | root: "{{ cookiecutter.path }}/{{ cookiecutter.project_name }}/maps/test/" 72 | num_workers: 4 73 | image_channels: 3 74 | preprocess: ["resize"] 75 | load_size: [256, 256] 76 | -------------------------------------------------------------------------------- /ganslate/utils/cli/cookiecutter_templates/new_project/{{ cookiecutter.project_name }}/datasets/val_test_dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import torch 3 | 4 | from typing import Tuple 5 | from dataclasses import dataclass 6 | 7 | from torch.utils.data import Dataset 8 | from omegaconf import MISSING 9 | 10 | from ganslate import configs 11 | 12 | 13 | @dataclass 14 | class {{cookiecutter.dataset_name}}ValTestDatasetConfig(configs.base.BaseDatasetConfig): 15 | # Define other attributes, e.g.: 16 | patch_size: Tuple[int, int] = [128, 128] 17 | ... 18 | 19 | 20 | class {{cookiecutter.dataset_name}}ValTestDataset(Dataset): 21 | 22 | def __init__(self, conf): 23 | self.root_path = Path(conf[conf.mode].dataset.root).resolve() 24 | 25 | def __getitem__(self, index): 26 | # Assigning paths for A and B depends on your dataset dir structure 27 | path_A = self.root_path[index] / "A.png" 28 | path_B = self.root_path[index] / "B.png" 29 | 30 | # Read the images, `read` is a placeholder 31 | A = read(path_A) 32 | B = read(path_B) 33 | 34 | # Preprocess and normalize to [-1,1], `preprocess` is a placeholder 35 | A = preprocess(A) 36 | B = preprocess(B) 37 | 38 | # Metadata is optionally returned by this method, explained at the end of the method. 39 | # Delete if not necessary. 40 | metadata = { 41 | 'path': str(path_A), 42 | ... 43 | } 44 | 45 | # Masks are optionally returned by this method, explained at the end of the method. 46 | # Delete if not necessary. 47 | masks = {} 48 | path_foreground_mask = self.root_path[index] / "foreground.png" 49 | foreground_mask = read(path_mask) 50 | masks["foreground"] = foreground_mask 51 | 52 | return {'A': A, 53 | 'B': B, 54 | # [Optional] metadata - if `save()` is defined *and* if it requires metadata. 55 | "metadata": metadata, 56 | # [Optional] masks - a dict of masks, used during the validation or 57 | # testing to also calculate metrics over specific regions of the image. 58 | "masks": masks 59 | } 60 | 61 | def save(self, tensor, save_dir, metadata=None): 62 | """ By default, ganslate logs images in png format. However, if you wish 63 | to save images in a different way, then implement this `save()` method. 64 | For example, you could save medical images in their native format for easier 65 | inspection or usage. 66 | If you do not need this method, remove it. 67 | """ 68 | pass 69 | 70 | def __len__(self): 71 | # Depending on the dataset dir structure, you might want to change it. 72 | return len(self.root_path) 73 | 74 | -------------------------------------------------------------------------------- /projects/cleargrasp_depth_estimation/experiments/cyclegan_balanced.yaml: -------------------------------------------------------------------------------- 1 | # Note: The design of CycleGAN-balanced is a bit weird in the sense that the conventional notion of a domain A image and a domain B image does not apply. 2 | # Therefore, it is not supported by the Ganslate framework unless tweaking some code as a hack (specifically, in the `get_metrics()` function file `ganslate/utils/metrics/val_test_metrics.py`) 3 | 4 | project: "./projects/cleargrasp_depth_estimation/" 5 | 6 | 7 | train: 8 | output_dir: "/home/zk315372/Chinmay/Cleargrasp-Depth-Estimation/cyclegan_balanced/" 9 | cuda: True 10 | n_iters: 62500 # (2500 (images) / 1 (batch_size)) x 25 ("epochs") 11 | n_iters_decay: 62500 # Extra 62500 iters with lr decay 12 | batch_size: 1 13 | mixed_precision: False 14 | seed: 1 15 | 16 | logging: 17 | freq: 50 18 | multi_modality_split: 19 | A: [3, 3] 20 | B: [3, 1] 21 | wandb: 22 | project: "cleargrasp_depth_estimation" 23 | run: "cyclegan_balanced" 24 | 25 | checkpointing: 26 | freq: 25000 27 | # load_iter: 125000 ## 28 | 29 | dataset: 30 | _target_: project.datasets.train_dataset.ClearGraspTrainDataset 31 | root: "/home/zk315372/Chinmay/Datasets/Cleargrasp_rgbnormal2depth_resized/train" 32 | load_size: [512, 256] 33 | paired: False # Unpaired 34 | require_domain_B_rgb: True # Required here for cyclegan-balanced 35 | num_workers: 8 36 | 37 | gan: 38 | _target_: project.modules.CycleGANMultiModalV3 39 | generator: 40 | _target_: ganslate.nn.generators.Unet2D 41 | in_out_channels: 42 | AB: [6, 1] # RGB + Normal -> Depth 43 | BA: [4, 3] # RGB + Depth -> Normal 44 | num_downs: 4 45 | ngf: 64 46 | use_dropout: True 47 | 48 | discriminator: 49 | _target_: ganslate.nn.discriminators.PatchGAN2D 50 | in_channels: 51 | B: 1 52 | A: 3 53 | n_layers: 3 54 | kernel_size: [4, 4] 55 | ndf: 64 56 | 57 | optimizer: 58 | lr_D: 0.0001 59 | lr_G: 0.0002 60 | lambda_AB: 10.0 61 | lambda_BA: 10.0 62 | lambda_identity: 0 63 | proportion_ssim: 0 64 | 65 | metrics: 66 | discriminator_evolution: True 67 | ssim: False 68 | 69 | 70 | val: 71 | freq: 2500 72 | dataset: 73 | _target_: project.datasets.val_test_dataset.ClearGraspValTestDataset 74 | root: "/home/zk315372/Chinmay/Datasets/Cleargrasp_rgbnormal2depth_resized/val" 75 | load_size: [512, 256] 76 | model_is_cyclegan_balanced: True # True 77 | num_workers: 8 78 | metrics: 79 | cycle_metrics: False 80 | -------------------------------------------------------------------------------- /ganslate/utils/trackers/wandb.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | import wandb 3 | import os 4 | import torch 5 | import numpy as np 6 | from ganslate.utils.trackers.utils import process_visuals_wandb_tensorboard 7 | 8 | 9 | def torch_npy_to_python(x): 10 | if isinstance(x, torch.Tensor) or np.isscalar(x): 11 | return x.item() 12 | return x 13 | 14 | 15 | class WandbTracker: 16 | 17 | def __init__(self, conf): 18 | project = conf[conf.mode].logging.wandb.project 19 | entity = conf[conf.mode].logging.wandb.entity 20 | conf_dict = OmegaConf.to_container(conf, resolve=True) 21 | run_dir = conf[conf.mode].output_dir 22 | 23 | if wandb.run is None: 24 | if conf[conf.mode].checkpointing.load_iter and conf[conf.mode].logging.wandb.id: 25 | # Source: https://docs.wandb.ai/library/resuming 26 | os.environ["WANDB_RESUME"] = "allow" 27 | os.environ["WANDB_RUN_ID"] = conf[conf.mode].logging.wandb.id 28 | 29 | wandb.init(project=project, entity=entity, config=conf_dict, dir=run_dir) 30 | 31 | if conf[conf.mode].logging.wandb.run: 32 | wandb.run.name = conf[conf.mode].logging.wandb.run 33 | 34 | self.image_window = None 35 | if conf[conf.mode].logging.image_window: 36 | self.image_window = conf[conf.mode].logging.image_window 37 | 38 | def log_iter(self, 39 | iter_idx, 40 | visuals, 41 | mode, 42 | learning_rates=None, 43 | losses=None, 44 | metrics=None): 45 | """""" 46 | log_dict = {} 47 | 48 | # Learning rates 49 | if learning_rates: 50 | for name, learning_rate in learning_rates.items(): 51 | log_dict[f"Learning rate: {name}"] = learning_rate 52 | 53 | # Losses 54 | if losses: 55 | for name, loss in losses.items(): 56 | log_dict[f"Loss: {name}"] = torch_npy_to_python(loss) 57 | 58 | # Metrics 59 | if metrics: 60 | for name, metric in metrics.items(): 61 | log_dict[f"Metric: {name} ({mode})"] = torch_npy_to_python(metric) 62 | 63 | normal_visuals = process_visuals_wandb_tensorboard(visuals, 64 | image_window=None, 65 | is_wandb=True) 66 | log_dict[f"Images ({mode})"] = normal_visuals 67 | 68 | if self.image_window: 69 | windowed_visuals = process_visuals_wandb_tensorboard(visuals, 70 | self.image_window, 71 | is_wandb=True) 72 | log_dict[f"Windowed images ({mode})"] = windowed_visuals 73 | wandb.log(log_dict, step=iter_idx) 74 | -------------------------------------------------------------------------------- /ganslate/utils/trackers/base.py: -------------------------------------------------------------------------------- 1 | import time 2 | from pathlib import Path 3 | 4 | import torchvision 5 | from omegaconf import OmegaConf 6 | from ganslate.utils import communication, io 7 | 8 | from ganslate.utils.trackers.tensorboard import TensorboardTracker 9 | from ganslate.utils.trackers.wandb import WandbTracker 10 | 11 | 12 | class BaseTracker: 13 | """"Base for training and inference trackers.""" 14 | 15 | def __init__(self, conf): 16 | self.conf = conf 17 | self.batch_size = self.conf[conf.mode].batch_size 18 | self.output_dir = Path(self.conf[self.conf.mode].output_dir) / self.conf.mode 19 | self.iter_idx = None 20 | self.iter_end_time = None 21 | self.iter_start_time = None 22 | self.t_data = None 23 | self.t_comp = None 24 | 25 | self.wandb, self.tensorboard = self._setup_wandb_tensorboard(conf) 26 | self._save_config(conf) 27 | 28 | def _save_config(self, conf): 29 | if communication.get_rank() == 0: 30 | config_path = self.output_dir / f"{self.conf.mode}_config.yaml" 31 | io.mkdirs(config_path.parent) 32 | with open(config_path, "w") as file: 33 | file.write(OmegaConf.to_yaml(conf)) 34 | 35 | def _setup_wandb_tensorboard(self, conf): 36 | wandb, tensorboard = None, None 37 | if communication.get_rank() == 0: 38 | if conf[conf.mode].logging.wandb: 39 | wandb = WandbTracker(conf) 40 | if conf[conf.mode].logging.tensorboard: 41 | tensorboard = TensorboardTracker(conf) 42 | return wandb, tensorboard 43 | 44 | def set_iter_idx(self, iter_idx): 45 | self.iter_idx = iter_idx 46 | 47 | def start_computation_timer(self): 48 | self.iter_start_time = time.time() 49 | 50 | def start_dataloading_timer(self): 51 | self.iter_end_time = time.time() 52 | 53 | def end_computation_timer(self): 54 | self.t_comp = (time.time() - self.iter_start_time) / self.batch_size 55 | # reduce computational time data point (avg) and send to the process of rank 0 56 | self.t_comp = communication.reduce(self.t_comp, average=True, all_reduce=False) 57 | 58 | def end_dataloading_timer(self): 59 | self.t_data = self.iter_start_time - self.iter_end_time 60 | # reduce data loading per data point (avg) and send to the process of rank 0 61 | self.t_data = communication.reduce(self.t_data, average=True, all_reduce=False) 62 | 63 | def close(self): 64 | if communication.get_rank() == 0 and self.tensorboard: 65 | self.tensorboard.close() 66 | 67 | def _save_image(self, visuals, name): 68 | if communication.get_rank() == 0: 69 | image_name, image = visuals['name'], visuals['image'] 70 | file_path = Path(self.output_dir) / f"images/{name}_{image_name}.png" 71 | io.mkdirs(file_path.parent) 72 | torchvision.utils.save_image(image, file_path) 73 | -------------------------------------------------------------------------------- /projects/maastro_hx4_pet_translation/experiments/cyclegan_naive.yaml: -------------------------------------------------------------------------------- 1 | project: "./projects/maastro_hx4_pet_translation/" 2 | 3 | train: 4 | output_dir: "/workspace/Chinmay-Checkpoints-Ephemeral/HX4-PET-Translation/hx4_pet_cyclegan_naive/" 5 | cuda: True 6 | n_iters: 30000 7 | n_iters_decay: 30000 8 | batch_size: 1 9 | mixed_precision: False 10 | seed: 1 11 | 12 | logging: 13 | freq: 50 14 | multi_modality_split: 15 | A: [1, 1] 16 | B: [1] 17 | wandb: 18 | project: "maastro_hx4_pet_translation" 19 | run: "cyclegan_naive_lambdas10" 20 | 21 | checkpointing: 22 | freq: 1000 23 | 24 | dataset: 25 | _target_: project.datasets.train_dataset.HX4PETTranslationTrainDataset 26 | root: "/workspace/Chinmay-Datasets-Ephemeral/HX4-PET-Translation-Processed/train" 27 | paired: False # Unpaired training 28 | require_ldct_for_training: False # ldCT not required for training 29 | patch_size: [32, 128, 128] # (D,H,W) 30 | patch_sampling: uniform-random-within-body-sf 31 | focal_region_proportion: [0.6, 0.35, 0.35] # (D,H,W) 32 | num_workers: 8 33 | 34 | gan: 35 | _target_: ganslate.nn.gans.unpaired.CycleGAN 36 | generator: 37 | _target_: ganslate.nn.generators.Unet3D 38 | in_out_channels: 39 | AB: [2, 1] 40 | BA: [1, 2] 41 | num_downs: 4 42 | ngf: 64 43 | use_dropout: True 44 | 45 | discriminator: 46 | _target_: ganslate.nn.discriminators.PatchGAN3D 47 | in_channels: 48 | B: 1 49 | A: 2 50 | n_layers: 3 51 | kernel_size: [4, 4, 4] 52 | ndf: 64 53 | 54 | optimizer: 55 | lr_D: 0.0001 56 | lr_G: 0.0002 57 | lambda_AB: 10.0 58 | lambda_BA: 10.0 59 | lambda_identity: 0 60 | proportion_ssim: 0 61 | 62 | metrics: 63 | discriminator_evolution: True 64 | ssim: False # `False`, to match with HX4-CycleGAN-balanced 65 | 66 | 67 | val: 68 | freq: 1000 69 | dataset: 70 | _target_: project.datasets.val_test_dataset.HX4PETTranslationValTestDataset 71 | root: "/workspace/Chinmay-Datasets-Ephemeral/HX4-PET-Translation-Processed/val" 72 | num_workers: 8 73 | supply_masks: False # Do not supply masks -> less number of val metrics to compute, faster training 74 | use_patch_based_inference: True # 75 | sliding_window: # Enable sliding window inferer 76 | window_size: ${train.dataset.patch_size} 77 | metrics: # Enabled metrics: MSE, MAE, PSNR, SSIM, NMI, Chi-squared histogram distance 78 | mse: True 79 | mae: True 80 | nmse: False 81 | psnr: True 82 | ssim: True 83 | nmi: True 84 | histogram_chi2: True 85 | cycle_metrics: False # `False`, to match with HX4-CycleGAN-balanced 86 | -------------------------------------------------------------------------------- /ganslate/nn/generators/resnet/resnet2d.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from ganslate.nn.utils import get_norm_layer_2d, is_bias_before_norm 3 | 4 | # Config imports 5 | from dataclasses import dataclass 6 | from ganslate import configs 7 | 8 | 9 | @dataclass 10 | class Resnet2DConfig(configs.base.BaseGeneratorConfig): 11 | n_residual_blocks: int = 9 12 | 13 | 14 | class Resnet2D(nn.Module): 15 | 16 | def __init__(self, in_channels, out_channels, norm_type, n_residual_blocks=9): 17 | super().__init__() 18 | 19 | norm_layer = get_norm_layer_2d(norm_type) 20 | use_bias = is_bias_before_norm(norm_type) 21 | 22 | # Initial convolution block 23 | model = [ 24 | nn.ReflectionPad2d(3), 25 | nn.Conv2d(in_channels, 64, 7, bias=use_bias), 26 | norm_layer(64), 27 | nn.ReLU(inplace=True) 28 | ] 29 | 30 | # Downsampling 31 | in_features = 64 32 | out_features = in_features * 2 33 | for _ in range(2): 34 | model += [ 35 | nn.Conv2d(in_features, out_features, 3, stride=2, padding=1, bias=use_bias), 36 | norm_layer(out_features), 37 | nn.ReLU(inplace=True) 38 | ] 39 | in_features = out_features 40 | out_features = in_features * 2 41 | 42 | # Residual blocks 43 | for _ in range(n_residual_blocks): 44 | model += [ResidualBlock(in_features, norm_type)] 45 | 46 | self.encoder = nn.ModuleList(model) 47 | 48 | # Upsampling 49 | out_features = in_features // 2 50 | for _ in range(2): 51 | model += [ 52 | nn.ConvTranspose2d(in_features, 53 | out_features, 54 | 3, 55 | stride=2, 56 | padding=1, 57 | output_padding=1), 58 | norm_layer(out_features), 59 | nn.ReLU(inplace=True) 60 | ] 61 | in_features = out_features 62 | out_features = in_features // 2 63 | 64 | # Output layer 65 | model += [nn.ReflectionPad2d(3), nn.Conv2d(64, out_channels, 7, bias=use_bias), nn.Tanh()] 66 | 67 | self.model = nn.Sequential(*model) 68 | 69 | def forward(self, x): 70 | return self.model(x) 71 | 72 | 73 | class ResidualBlock(nn.Module): 74 | 75 | def __init__(self, in_features, norm_type): 76 | super().__init__() 77 | norm_layer = get_norm_layer_2d(norm_type) 78 | use_bias = is_bias_before_norm(norm_type) 79 | 80 | conv_block = [ 81 | nn.ReflectionPad2d(1), 82 | nn.Conv2d(in_features, in_features, 3, bias=use_bias), 83 | norm_layer(in_features), 84 | nn.ReLU(inplace=True), 85 | nn.ReflectionPad2d(1), 86 | nn.Conv2d(in_features, in_features, 3, bias=use_bias), 87 | norm_layer(in_features) 88 | ] 89 | 90 | self.conv_block = nn.Sequential(*conv_block) 91 | 92 | def forward(self, x): 93 | return x + self.conv_block(x) 94 | -------------------------------------------------------------------------------- /ganslate/nn/generators/resnet/resnet3d.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from ganslate.nn.utils import get_norm_layer_3d, is_bias_before_norm 3 | 4 | # Config imports 5 | from dataclasses import dataclass 6 | from ganslate import configs 7 | 8 | 9 | @dataclass 10 | class Resnet3DConfig(configs.base.BaseGeneratorConfig): 11 | n_residual_blocks: int = 9 12 | 13 | 14 | class Resnet3D(nn.Module): 15 | """Note: Unlike 2D version, this one uses ReplicationPad instead of RefectionPad""" 16 | 17 | def __init__(self, in_channels, out_channels, norm_type, n_residual_blocks=9): 18 | super().__init__() 19 | 20 | norm_layer = get_norm_layer_3d(norm_type) 21 | use_bias = is_bias_before_norm(norm_type) 22 | 23 | # Initial convolution block 24 | model = [ 25 | nn.ReplicationPad3d(3), 26 | nn.Conv3d(in_channels, 64, 7, bias=use_bias), 27 | norm_layer(64), 28 | nn.ReLU(inplace=True) 29 | ] 30 | 31 | # Downsampling 32 | in_features = 64 33 | out_features = in_features * 2 34 | for _ in range(2): 35 | model += [ 36 | nn.Conv3d(in_features, out_features, 3, stride=2, padding=1, bias=use_bias), 37 | norm_layer(out_features), 38 | nn.ReLU(inplace=True) 39 | ] 40 | in_features = out_features 41 | out_features = in_features * 2 42 | 43 | # Residual blocks 44 | for _ in range(n_residual_blocks): 45 | model += [ResidualBlock(in_features, norm_type)] 46 | 47 | # Upsampling 48 | out_features = in_features // 2 49 | for _ in range(2): 50 | model += [ 51 | nn.ConvTranspose3d(in_features, 52 | out_features, 53 | 3, 54 | stride=2, 55 | padding=1, 56 | output_padding=1), 57 | norm_layer(out_features), 58 | nn.ReLU(inplace=True) 59 | ] 60 | in_features = out_features 61 | out_features = in_features // 2 62 | 63 | # Output layer 64 | model += [nn.ReplicationPad3d(3), nn.Conv3d(64, out_channels, 7, bias=use_bias), nn.Tanh()] 65 | 66 | self.model = nn.Sequential(*model) 67 | 68 | def forward(self, x): 69 | return self.model(x) 70 | 71 | 72 | class ResidualBlock(nn.Module): 73 | 74 | def __init__(self, in_features, norm_type): 75 | super().__init__() 76 | norm_layer = get_norm_layer_3d(norm_type) 77 | use_bias = is_bias_before_norm(norm_type) 78 | 79 | conv_block = [ 80 | nn.ReplicationPad3d(1), 81 | nn.Conv3d(in_features, in_features, 3, bias=use_bias), 82 | norm_layer(in_features), 83 | nn.ReLU(inplace=True), 84 | nn.ReplicationPad3d(1), 85 | nn.Conv3d(in_features, in_features, 3, bias=use_bias), 86 | norm_layer(in_features) 87 | ] 88 | 89 | self.conv_block = nn.Sequential(*conv_block) 90 | 91 | def forward(self, x): 92 | return x + self.conv_block(x) 93 | -------------------------------------------------------------------------------- /ganslate/nn/separable.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.modules.utils import _triple 3 | 4 | 5 | class SeparableConv3d(nn.Module): 6 | 7 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True): 8 | super().__init__() 9 | kernel_size = _triple(kernel_size) 10 | stride = _triple(stride) 11 | padding = _triple(padding) 12 | 13 | depthwise_kernel = (1, kernel_size[1], kernel_size[2]) 14 | pointwise_kernel = (kernel_size[0], 1, 1) 15 | 16 | depthwise_stride = (1, stride[1], stride[2]) 17 | pointwise_stride = (stride[0], 1, 1) 18 | 19 | depthwise_padding = (0, padding[1], padding[2]) 20 | pointwise_padding = (padding[0], 0, 0) 21 | 22 | # construct the layers 23 | self.conv_depthwise = nn.Conv3d(in_channels, 24 | out_channels, 25 | kernel_size=depthwise_kernel, 26 | stride=depthwise_stride, 27 | padding=depthwise_padding, 28 | bias=bias) 29 | 30 | self.conv_pointwise = nn.Conv3d(out_channels, 31 | out_channels, 32 | kernel_size=pointwise_kernel, 33 | stride=pointwise_stride, 34 | padding=pointwise_padding, 35 | bias=bias) 36 | 37 | def forward(self, input): 38 | out = self.conv_depthwise(input) 39 | out = self.conv_pointwise(out) 40 | return out 41 | 42 | 43 | class SeparableConvTranspose3d(nn.Module): 44 | 45 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True): 46 | super().__init__() 47 | kernel_size = _triple(kernel_size) 48 | stride = _triple(stride) 49 | padding = _triple(padding) 50 | 51 | depthwise_kernel = (1, kernel_size[1], kernel_size[2]) 52 | pointwise_kernel = (kernel_size[0], 1, 1) 53 | 54 | depthwise_stride = (1, stride[1], stride[2]) 55 | pointwise_stride = (stride[0], 1, 1) 56 | 57 | depthwise_padding = (0, padding[1], padding[2]) 58 | pointwise_padding = (padding[0], 0, 0) 59 | 60 | # construct the layers 61 | self.conv_transp_depthwise = nn.ConvTranspose3d(in_channels, 62 | out_channels, 63 | kernel_size=depthwise_kernel, 64 | stride=depthwise_stride, 65 | padding=depthwise_padding, 66 | bias=bias) 67 | 68 | self.conv_transp_pointwise = nn.ConvTranspose3d(out_channels, 69 | out_channels, 70 | kernel_size=pointwise_kernel, 71 | stride=pointwise_stride, 72 | padding=pointwise_padding, 73 | bias=bias) 74 | 75 | def forward(self, input): 76 | out = self.conv_transp_depthwise(input) 77 | out = self.conv_transp_pointwise(out) 78 | return out 79 | -------------------------------------------------------------------------------- /ganslate/utils/environment.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import sys 4 | from os import PathLike 5 | from pathlib import Path 6 | from typing import Optional 7 | 8 | import cv2 9 | import numpy as np 10 | import SimpleITK as sitk 11 | import torch 12 | from omegaconf import OmegaConf 13 | from ganslate.utils import communication, io 14 | 15 | from loguru import logger 16 | 17 | 18 | def setup_logging_with_config(conf, debug=False): 19 | 20 | output_dir = Path(conf[conf.mode].output_dir).resolve() 21 | io.mkdirs(output_dir) 22 | 23 | filename = None 24 | # Log file only for the global main process 25 | if communication.get_rank() == 0: 26 | filename = Path(output_dir) / f"{conf.mode}_log.txt" 27 | # Stdout for *local* main process only 28 | use_stdout = communication.get_local_rank() == 0 or debug 29 | log_level = 'INFO' if not debug else 'DEBUG' 30 | 31 | setup_logging(use_stdout, filename, log_level=log_level) 32 | 33 | logger.info(f'Configuration:\n{OmegaConf.to_yaml(conf)}') 34 | logger.info(f'Saving checkpoints, logs and config to: {output_dir}') 35 | logger.info(f'Python version: {sys.version.strip()}') 36 | logger.info(f'PyTorch version: {torch.__version__}') # noqa 37 | logger.info(f'CUDA {torch.version.cuda} - cuDNN {torch.backends.cudnn.version()}') 38 | logger.info(f'Global rank: {communication.get_rank()}') 39 | logger.info(f'Local rank: {communication.get_local_rank()}') 40 | 41 | 42 | def setup_logging(use_stdout: Optional[bool] = True, 43 | filename: Optional[PathLike] = None, 44 | log_level: Optional[str] = 'INFO') -> None: 45 | """ 46 | Parameters 47 | ---------- 48 | use_stdout : bool 49 | Write output to standard out. 50 | filename : PathLike 51 | Filename to write log to. 52 | log_level : str 53 | Logging level as in the `python.logging` library. 54 | 55 | Returns 56 | ------- 57 | None 58 | """ 59 | if log_level not in ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'EXCEPTION']: 60 | raise ValueError(f'Unexpected log level got {log_level}.') 61 | 62 | formatter = ("[{time:YYYY-MM-DD HH:mm:ss}]" 63 | "[{name}][{level}]" 64 | " - {message}") 65 | 66 | # Clear the default handlers 67 | logger.remove() 68 | 69 | if use_stdout: 70 | logger.add(sys.stdout, level=log_level, format=formatter, colorize=True) 71 | if filename is not None: 72 | logger.add(filename, level=log_level, format=formatter) 73 | 74 | 75 | def set_seed(seed=0): 76 | # Inspired also from: https://stackoverflow.com/a/57417097 77 | logger.info(f"Reproducible mode ON with seed : {seed}") 78 | torch.manual_seed(seed) 79 | np.random.seed(seed) 80 | random.seed(seed) 81 | os.environ['PYTHONHASHSEED'] = str(seed) 82 | 83 | 84 | def setup_threading(): 85 | """ 86 | Sets max threads for SimpleITK and Opencv. 87 | For numpy etc. set OMP_NUM_THREADS=1 as an env var while running the training script, 88 | e.g., OMP_NUM_THREADS=1 python tools/train.py ... 89 | """ 90 | logger.warning(""" 91 | Max threads for SimpleITK and Opencv set to 1 92 | For numpy etc. set OMP_NUM_THREADS=1 as an env var while running the training script, 93 | e.g., OMP_NUM_THREADS=1 python tools/train.py ... 94 | """) 95 | MAX_THREADS = 1 96 | sitk.ProcessObject_SetGlobalDefaultNumberOfThreads(MAX_THREADS) 97 | cv2.setNumThreads(MAX_THREADS) 98 | -------------------------------------------------------------------------------- /projects/maastro_hx4_pet_translation/experiments/cyclegan_balanced.yaml: -------------------------------------------------------------------------------- 1 | # Note: The design of CycleGAN-balanced is a bit weird in the sense that the conventional notion of a domain A image and a domain B image does not apply. 2 | # Therefore, it is not supported by the Ganslate framework unless tweaking some code as a hack (specifically, in the `get_metrics()` function file `ganslate/utils/metrics/val_test_metrics.py`) 3 | 4 | 5 | project: "./projects/maastro_hx4_pet_translation/" 6 | 7 | train: 8 | output_dir: "/workspace/Chinmay-Checkpoints-Ephemeral/HX4-PET-Translation/hx4_pet_cyclegan_balanced/" 9 | cuda: True 10 | n_iters: 30000 11 | n_iters_decay: 30000 12 | batch_size: 1 13 | mixed_precision: False 14 | seed: 1 15 | 16 | logging: 17 | freq: 50 18 | multi_modality_split: 19 | A: [1, 1] 20 | B: [1, 1] # ldCT is the 2nd component 21 | wandb: 22 | project: "maastro_hx4_pet_translation" 23 | run: "cyclegan_balanced_lambdas10" 24 | 25 | checkpointing: 26 | freq: 1000 27 | 28 | dataset: 29 | _target_: project.datasets.train_dataset.HX4PETTranslationTrainDataset 30 | root: "/workspace/Chinmay-Datasets-Ephemeral/HX4-PET-Translation-Processed/train" 31 | paired: False # Unpaired training 32 | require_ldct_for_training: True # ldCT required for training 33 | patch_size: [32, 128, 128] # (D,H,W) 34 | patch_sampling: uniform-random-within-body-sf 35 | focal_region_proportion: [0.6, 0.35, 0.35] # (D,H,W) 36 | num_workers: 8 37 | 38 | gan: 39 | _target_: project.modules.HX4CycleGANBalanced 40 | generator: 41 | _target_: ganslate.nn.generators.Unet3D 42 | in_out_channels: 43 | AB: [2, 1] # Both G's take 2 inputs and predict 1 output 44 | BA: [2, 1] # 45 | num_downs: 4 46 | ngf: 64 47 | use_dropout: True 48 | 49 | discriminator: 50 | _target_: ganslate.nn.discriminators.PatchGAN3D 51 | in_channels: 52 | B: 1 # Both D's evaluate a single modality (the PETs) 53 | A: 1 # 54 | n_layers: 3 55 | kernel_size: [4, 4, 4] 56 | ndf: 64 57 | 58 | optimizer: 59 | lr_D: 0.0001 60 | lr_G: 0.0002 61 | lambda_AB: 10.0 62 | lambda_BA: 10.0 63 | lambda_identity: 0 64 | proportion_ssim: 0 65 | 66 | metrics: 67 | discriminator_evolution: True 68 | ssim: False # `False` because it's computed by with the dummy array included, and is hence wrong 69 | 70 | 71 | val: 72 | freq: 1000 73 | dataset: 74 | _target_: project.datasets.val_test_dataset.HX4PETTranslationValTestDataset 75 | root: "/workspace/Chinmay-Datasets-Ephemeral/HX4-PET-Translation-Processed/val" 76 | num_workers: 8 77 | supply_masks: False # Do not supply masks -> less number of val metrics to compute, faster training 78 | model_is_hx4_cyclegan_balanced: True # Using HX4-CycleGAN-balanced 79 | use_patch_based_inference: True # 80 | sliding_window: # Enable sliding window inferer 81 | window_size: ${train.dataset.patch_size} 82 | metrics: # Enabled metrics: MSE, MAE, PSNR, SSIM, NMI, Chi-squared histogram distance 83 | mse: True 84 | mae: True 85 | nmse: False 86 | psnr: True 87 | ssim: True 88 | nmi: True 89 | histogram_chi2: True 90 | cycle_metrics: False # `False` because cycle in validation is hardcoded to be the default (naive) way 91 | -------------------------------------------------------------------------------- /ganslate/engines/inferer.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | 3 | from ganslate.engines.base import BaseEngineWithInference 4 | from ganslate.utils import environment 5 | from ganslate.utils.builders import build_gan, build_loader 6 | from ganslate.utils.trackers.inference import InferenceTracker 7 | from ganslate.utils import communication 8 | 9 | 10 | class Inferer(BaseEngineWithInference): 11 | 12 | def __init__(self, conf): 13 | super().__init__(conf) 14 | self.logger = logger 15 | 16 | # Logging, dataloader and tracker only when not in deployment mode 17 | if not self.conf.infer.is_deployment: 18 | assert self.conf.infer.dataset, "Please specify the dataset for inference." 19 | environment.setup_logging_with_config(self.conf) 20 | self.tracker = InferenceTracker(self.conf) 21 | self.data_loader = build_loader(self.conf) 22 | 23 | self.model = build_gan(self.conf) 24 | 25 | def _set_mode(self): 26 | self.conf.mode = 'infer' 27 | 28 | def run(self): 29 | assert not self.conf.infer.is_deployment, \ 30 | "`Inferer.run()` cannot be used in deployment, please use `Inferer.infer()`." 31 | 32 | self.logger.info("Inference started.") 33 | 34 | self.tracker.start_dataloading_timer() 35 | for i, data in enumerate(self.data_loader): 36 | # Iteration index 37 | # (1) When using DDP, multiply with world size since each process does an iteration 38 | # (2) Multiply with batch size to get accurate info on how many examples are done 39 | # (3) Add 1 to start from iter 1 instead of 0 40 | iter_idx = i * communication.get_world_size() * self.conf.infer.batch_size + 1 41 | self.tracker.set_iter_idx(iter_idx) 42 | if i == 0: 43 | input_key = self._get_input_key(data) 44 | if not hasattr(self.data_loader.dataset, "save"): 45 | self.logger.warning( 46 | "The dataset class used does not have a 'save' method." 47 | " It is not necessary, however, it may be useful in cases" 48 | " where the outputs should be stored individually" 49 | " ('images/' folder saves input and output in a single image), " 50 | " or in a specific format.") 51 | 52 | self.tracker.start_computation_timer() 53 | self.tracker.end_dataloading_timer() 54 | out = self.infer(data[input_key]) 55 | self.tracker.end_computation_timer() 56 | 57 | self.tracker.start_saving_timer() 58 | # Save the output as specified in dataset`s `save` method, if implemented 59 | metadata = data["metadata"] if "metadata" in data else None 60 | self.save_generated_tensor(generated_tensor=out, 61 | metadata=metadata, 62 | data_loader=self.data_loader) 63 | self.tracker.end_saving_timer() 64 | 65 | visuals = {"input": data[input_key], "output": out.cpu()} 66 | self.tracker.log_iter(visuals, len(self.data_loader.dataset)) 67 | self.tracker.start_dataloading_timer() 68 | self.tracker.close() 69 | 70 | def _get_input_key(self, data): 71 | """The dataset (dataloader) needs to return a dict with input data 72 | either under the key 'input' or 'A'.""" 73 | if "input" in data: 74 | return "input" 75 | elif "A" in data: 76 | return "A" 77 | else: 78 | raise ValueError("An inference dataset needs to provide" 79 | "the input data under the dict key 'input' or 'A'.") 80 | -------------------------------------------------------------------------------- /ganslate/utils/cli/interface.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import inspect 3 | import git 4 | import shutil 5 | from pathlib import Path 6 | import click 7 | from cookiecutter.main import cookiecutter 8 | from ganslate.utils.cli import cookiecutter_templates 9 | from ganslate.utils.cli.scripts import download_datasets 10 | from ganslate.engines.utils import init_engine 11 | 12 | 13 | COOKIECUTTER_TEMPLATES_DIR = Path(inspect.getfile(cookiecutter_templates)).parent 14 | 15 | # Interface 16 | @click.group() 17 | def interface(): 18 | """ganslate - GAN image-to-image translation framework made simple and extensible.""" 19 | pass 20 | 21 | # Train 22 | @interface.command(help="Train a model.") 23 | @click.argument("omegaconf_args", nargs=-1) 24 | def train(omegaconf_args): 25 | init_engine('train', omegaconf_args).run() 26 | 27 | # Test 28 | @interface.command(help="Test a trained model. Requires paired data.") 29 | @click.argument("omegaconf_args", nargs=-1) 30 | def test(omegaconf_args): 31 | init_engine('test', omegaconf_args).run() 32 | 33 | # Infer 34 | @interface.command(help="Do inference with a trained model.") 35 | @click.argument("omegaconf_args", nargs=-1) 36 | def infer(omegaconf_args): 37 | init_engine('infer', omegaconf_args).run() 38 | 39 | # New project 40 | @interface.command(help="Initialize a new project.") 41 | @click.argument("path", default="./") 42 | def new_project(path): 43 | template = str(COOKIECUTTER_TEMPLATES_DIR / "new_project") 44 | cookiecutter(template, output_dir=path) 45 | 46 | # First run 47 | def setup_first_run(path, no_input=False, extra_context={}): 48 | template = str(COOKIECUTTER_TEMPLATES_DIR / "your_first_run") 49 | project_path = cookiecutter(template, output_dir=path, no_input=no_input,\ 50 | overwrite_if_exists=True, extra_context=extra_context) 51 | download_datasets.download("facades", project_path) 52 | 53 | 54 | @interface.command(help="Fetch resources for the maps first run") 55 | @click.argument("path", default="./") 56 | def your_first_run(path): 57 | click.echo(setup_first_run(path)) 58 | 59 | # Download project 60 | @interface.command(help="Download a project.") 61 | @click.argument("name") 62 | @click.argument("path") 63 | def download_project(name, path): 64 | print(name, path) 65 | 66 | # Download dataset 67 | @interface.command(help="Download a dataset.") 68 | @click.argument("name") 69 | @click.argument("path") 70 | def download_dataset(name, path): 71 | download_datasets.download(name, path) 72 | 73 | # Install Nvidia Apex 74 | @interface.command(help="Install Nvidia Apex for mixed precision support.") 75 | @click.option( 76 | "--cpp/--python", 77 | default=True, 78 | help=("C++ support is faster and preferred, use Python fallback " 79 | "only when CUDA is not installed natively.") 80 | ) 81 | 82 | 83 | def install_nvidia_apex(cpp): 84 | # TODO: Installing with C++ support is a pain due to CUDA installations, 85 | # waiting for https://github.com/pytorch/pytorch/issues/40497#issuecomment-908685435 86 | # to switch to PyTorch AMP and get rid of Nvidia Apex 87 | 88 | # Removes the folder if it already exists from a previous, cancelled, try. 89 | shutil.rmtree("./nvidia-apex-tmp", ignore_errors=True) 90 | git.Repo.clone_from("https://github.com/NVIDIA/apex", './nvidia-apex-tmp') 91 | 92 | cmd = 'pip install -v --disable-pip-version-check --no-cache-dir' 93 | if cpp: 94 | cmd += ' --global-option="--cpp_ext" --global-option="--cuda_ext"' 95 | cmd += ' ./nvidia-apex-tmp' 96 | 97 | subprocess.run(cmd.split(' ')) 98 | shutil.rmtree("./nvidia-apex-tmp") 99 | 100 | if __name__ == "__main__": 101 | interface() 102 | -------------------------------------------------------------------------------- /ganslate/engines/base.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from abc import ABC, abstractmethod 3 | from pathlib import Path 4 | from loguru import logger 5 | 6 | from ganslate.utils import sliding_window_inferer 7 | from ganslate.utils.io import decollate 8 | 9 | 10 | class BaseEngine(ABC): 11 | 12 | def __init__(self, conf): 13 | # deep copy to isolate the conf.mode of an engine from other engines (e.g train from val) 14 | self.conf = copy.deepcopy(conf) 15 | self._set_mode() 16 | 17 | self.output_dir = Path(conf[conf.mode].output_dir) / self.conf.mode 18 | self.model = None 19 | self.logger = logger 20 | 21 | @abstractmethod 22 | def _set_mode(self): 23 | """Sets the mode for the particular engine. 24 | E.g., 'train' for Trainer, 'val' for 'Validator' etc.""" 25 | self.conf.mode = ... 26 | 27 | 28 | class BaseEngineWithInference(BaseEngine): 29 | 30 | def __init__(self, conf): 31 | super().__init__(conf) 32 | self.sliding_window_inferer = self._init_sliding_window_inferer() 33 | 34 | def infer(self, data, *args, **kwargs): 35 | data = data.to(self.model.device) 36 | # Sliding window (i.e. patch-wise) inference 37 | if self.sliding_window_inferer: 38 | return self.sliding_window_inferer(data, self.model.infer, *args, **kwargs) 39 | return self.model.infer(data, *args, **kwargs) 40 | 41 | def _init_sliding_window_inferer(self): 42 | sw = self.conf[self.conf.mode].sliding_window 43 | if not sw: 44 | return None 45 | 46 | return sliding_window_inferer.SlidingWindowInferer(roi_size=sw.window_size, 47 | sw_batch_size=sw.batch_size, 48 | overlap=sw.overlap, 49 | mode=sw.mode, 50 | cval=-1) 51 | 52 | def save_generated_tensor(self, 53 | generated_tensor, 54 | metadata, 55 | data_loader, 56 | idx=None, 57 | dataset_name=None): 58 | # A dataset object has to have a `save()` method if it 59 | # wishes to save the outputs in a particular way or format 60 | save_fn = getattr(data_loader.dataset, "save", False) 61 | if save_fn: 62 | # Tolerates `save` methods with and without `metadata` argument 63 | def save(tensor, save_dir, metadata=None): 64 | if metadata is None: 65 | save_fn(tensor=tensor, save_dir=save_dir) 66 | else: 67 | save_fn(tensor=tensor, save_dir=save_dir, metadata=metadata) 68 | 69 | # Output dir 70 | save_dir = "saved/" 71 | if dataset_name is not None: 72 | save_dir += f"{dataset_name}/" 73 | if idx is not None: 74 | save_dir += f"{idx}/" 75 | save_dir = self.output_dir / save_dir 76 | 77 | # Metadata 78 | if metadata: 79 | # After decollate, it is a list of length equal to batch_size, 80 | # containing separate metadata for each tensor in the mini-batch 81 | metadata = decollate(metadata, batch_size=len(generated_tensor)) 82 | 83 | # Loop over the batch and save each tensor 84 | for batch_idx in range(len(generated_tensor)): 85 | tensor = generated_tensor[batch_idx] 86 | current_metadata = metadata[batch_idx] if metadata is not None else metadata 87 | save(tensor=tensor, save_dir=save_dir, metadata=current_metadata) 88 | -------------------------------------------------------------------------------- /projects/brats_mri_sequence_translation/datasets/train_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Dataset 6 | import SimpleITK as sitk 7 | from ganslate.utils.io import make_dataset_of_files 8 | from ganslate.utils import sitk_utils 9 | from ganslate.data.utils.normalization import z_score_normalize 10 | from ganslate.data.utils.stochastic_focal_patching import StochasticFocalPatchSampler 11 | 12 | # Config imports 13 | from typing import Tuple 14 | from dataclasses import dataclass, field 15 | from omegaconf import MISSING 16 | from ganslate import configs 17 | 18 | 19 | @dataclass 20 | class BratsDatasetConfig(configs.base.BaseDatasetConfig): 21 | patch_size: Tuple[int, int, int] = (32, 32, 32) 22 | # Proportion of focal region size compared to original volume size 23 | focal_region_proportion: float = 0 24 | source_sequence: str = "flair" 25 | target_sequence: str = "t1w" 26 | 27 | 28 | EXTENSIONS = ['.nii.gz'] 29 | 30 | # MRI sequences z-axis indices in Brats 31 | SEQUENCE_MAP = {"flair": 0, "t1w": 1, "t1gd": 2, "t2w": 3} 32 | 33 | 34 | def get_mri_sequence(sitk_image, sequence_name): 35 | z_index = SEQUENCE_MAP[sequence_name.lower()] 36 | 37 | size = list(sitk_image.GetSize()) 38 | size[3] = 0 39 | index = [0, 0, 0, z_index] 40 | 41 | extractor = sitk.ExtractImageFilter() 42 | extractor.SetSize(size) 43 | extractor.SetIndex(index) 44 | return extractor.Execute(sitk_image) 45 | 46 | 47 | class BratsDataset(Dataset): 48 | 49 | def __init__(self, conf): 50 | dir_brats = conf.train.dataset.root 51 | self.paths_brats = make_dataset_of_files(dir_brats, EXTENSIONS) 52 | self.num_datapoints = len(self.paths_brats) 53 | 54 | focal_region_proportion = conf.train.dataset.focal_region_proportion 55 | self.patch_size = np.array(conf.train.dataset.patch_size) 56 | self.patch_sampler = StochasticFocalPatchSampler(self.patch_size, focal_region_proportion) 57 | 58 | self.source_sequence = conf.train.dataset.source_sequence 59 | self.target_sequence = conf.train.dataset.target_sequence 60 | 61 | def __getitem__(self, index): 62 | index_A = index % self.num_datapoints 63 | index_B = random.randint(0, self.num_datapoints - 1) 64 | 65 | path_A = self.paths_brats[index_A] 66 | path_B = self.paths_brats[index_B] 67 | 68 | # load nrrd as SimpleITK objects 69 | A = sitk_utils.load(path_A) 70 | B = sitk_utils.load(path_B) 71 | 72 | A = get_mri_sequence(A, self.source_sequence) 73 | B = get_mri_sequence(B, self.target_sequence) 74 | 75 | if (sitk_utils.is_image_smaller_than(A, self.patch_size) or 76 | sitk_utils.is_image_smaller_than(B, self.patch_size)): 77 | raise ValueError("Volume size not smaller than the defined patch size.\ 78 | \nA: {} \nB: {} \npatch_size: {}."\ 79 | .format(sitk_utils.get_torch_like_size(A), 80 | sitk_utils.get_torch_like_size(B), 81 | self.patch_size)) 82 | 83 | A = sitk_utils.get_tensor(A) 84 | B = sitk_utils.get_tensor(B) 85 | 86 | # Extract patches 87 | A, B = self.patch_sampler.get_patch_pair(A, B) 88 | # Z-score normalization per volume 89 | A = z_score_normalize(A, scale_to_range=(-1, 1)) 90 | B = z_score_normalize(B, scale_to_range=(-1, 1)) 91 | 92 | # Add channel dimension (1 = grayscale) 93 | A = A.unsqueeze(0) 94 | B = B.unsqueeze(0) 95 | 96 | return {'A': A, 'B': B} 97 | 98 | def __len__(self): 99 | return self.num_datapoints 100 | -------------------------------------------------------------------------------- /ganslate/configs/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import importlib 3 | from pathlib import Path 4 | 5 | from loguru import logger 6 | from omegaconf import DictConfig, OmegaConf 7 | 8 | from ganslate.utils.io import import_attr 9 | 10 | def init_config(conf, config_class): 11 | # Run-specific config 12 | conf = conf if isinstance(conf, DictConfig) else OmegaConf.load(str(conf)) 13 | 14 | # Allows the framework to find user-defined, project-specific, classes and their configs 15 | if conf.project: 16 | 17 | assert isinstance(conf.project, str), "project needs to be a str path" 18 | 19 | # Import project as module with name "project" 20 | # https://stackoverflow.com/a/41595552 21 | project_path = Path(conf.project).resolve() / "__init__.py" 22 | assert project_path.is_file(), f"No `__init__.py` in project `{project_path}`." 23 | 24 | spec = importlib.util.spec_from_file_location("project", str(project_path)) 25 | project_module = importlib.util.module_from_spec(spec) 26 | spec.loader.exec_module(project_module) 27 | sys.modules["project"] = project_module 28 | 29 | logger.info(f"Project directory {conf.project} added to the" 30 | " path as `project` to allow imports of modules from it.") 31 | 32 | 33 | # Make yaml mergeable by instantiating the dataclasses 34 | conf = instantiate_dataclasses_from_yaml(conf) 35 | # Merge default and run-specifig config 36 | return OmegaConf.merge(OmegaConf.structured(config_class), conf) 37 | 38 | 39 | def instantiate_dataclasses_from_yaml(conf): 40 | """Goes through a config and instantiates the fields that are dataclasses. 41 | Each such dataclass should have an entry "_target_" which is used to import its dataclass 42 | class using that "_target_" + "Config" as class name. 43 | Instantiates the deepest dataclasses first as otherwise OmegaConf would throw an error. 44 | """ 45 | for key in get_all_conf_keys(conf): 46 | # Get the field for that key 47 | field = OmegaConf.select(conf, key) 48 | if is_dataclass(field): 49 | dataclass = init_dataclass(field) 50 | # Update the field for that key with the newly instantiated dataclass 51 | OmegaConf.update(conf, key, OmegaConf.merge(dataclass, field), merge=False) 52 | return conf 53 | 54 | 55 | def init_dataclass(field): 56 | """Initialize a dataclass. Requires the field to have a "_target_" entry. 57 | Assumes that the class name is of format "_target_" + "Config", e.g. "MRIDatasetConfig". 58 | """ 59 | dataclass = f'{field["_target_"]}Config' 60 | dataclass = import_attr(dataclass) 61 | return OmegaConf.structured(dataclass) 62 | 63 | 64 | def is_dataclass(field): 65 | """If a field contains `_target_` key, it is a dataclass.""" 66 | return bool(isinstance(field, DictConfig) and "_target_" in field) 67 | 68 | 69 | def get_all_conf_keys(conf): 70 | """Get all keys from a conf and order from them the deepest to the shallowest.""" 71 | conf = OmegaConf.to_container(conf) 72 | keys = list(iterate_nested_dict_keys(conf)) 73 | # Order deeper to shallower 74 | return keys[::-1] 75 | 76 | 77 | def iterate_nested_dict_keys(dictionary): 78 | """Returns an iterator that returns all keys of a nested dictionary ordered 79 | from the shallowest to the deepest key. The nested keys are in the dot-list format, 80 | e.g. "gan.discriminator.in_channels". 81 | """ 82 | if isinstance(dictionary, dict): 83 | current_level_keys = [] 84 | for key in dictionary.keys(): 85 | current_level_keys.append(key) 86 | yield key 87 | for key in current_level_keys: 88 | value = dictionary[key] 89 | for ret in iterate_nested_dict_keys(value): 90 | yield f"{key}.{ret}" 91 | -------------------------------------------------------------------------------- /ganslate/utils/trackers/training.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | from pathlib import Path 3 | 4 | from ganslate.utils import communication 5 | from ganslate.utils.trackers.base import BaseTracker 6 | from ganslate.utils.trackers.utils import process_visuals_for_logging 7 | 8 | 9 | class TrainingTracker(BaseTracker): 10 | 11 | def __init__(self, conf): 12 | super().__init__(conf) 13 | self.logger = logger 14 | self.log_freq = conf.train.logging.freq 15 | 16 | def log_iter(self, learning_rates, losses, visuals, metrics): 17 | """Parameters: # TODO: update this 18 | iters (int) -- current training iteration 19 | losses (tuple/list) -- training losses 20 | t_comp (float) -- computational time per data point (normalized by batch_size) 21 | t_data (float) -- data loading time per data point (normalized by batch_size) 22 | """ 23 | if self.iter_idx % self.log_freq != 0: 24 | return 25 | 26 | def parse_visuals(visuals): 27 | # Note: Gather not necessary as in val/test, enough to log one example when training. 28 | visuals = {k: v for k, v in visuals.items() if v is not None} 29 | visuals = process_visuals_for_logging(self.conf, visuals, single_example=True) 30 | # `single_example=True` returns a single example from the batch, selecting it 31 | return visuals[0] 32 | 33 | def parse_losses(losses): 34 | losses = {k: v for k, v in losses.items() if v is not None} 35 | # Reduce losses (avg) and send to the process of rank 0 36 | losses = communication.reduce(losses, average=True, all_reduce=False) 37 | return losses 38 | 39 | def parse_metrics(metrics): 40 | metrics = {k: v for k, v in metrics.items() if v is not None} 41 | # Training metrics are optional 42 | if metrics: 43 | # Reduce metrics (avg) and send to the process of rank 0 44 | metrics = communication.reduce(metrics, average=True, all_reduce=False) 45 | return metrics 46 | 47 | def log_message(): 48 | message = '\n' + 20 * '-' + ' ' 49 | # Iteration, computing time, dataloading time 50 | message += f"(iter: {self.iter_idx} | comp: {self.t_comp:.3f}, data: {self.t_data:.3f}" 51 | message += " | " 52 | # Learning rates 53 | for i, (name, learning_rate) in enumerate(learning_rates.items()): 54 | message += "" if i == 0 else ", " 55 | message += f"{name}: {learning_rate:.7f}" 56 | message += ') ' + 20 * '-' + '\n' 57 | # Losses 58 | for name, loss in losses.items(): 59 | message += f"{name}: {loss:.3f} " 60 | self.logger.info(message) 61 | 62 | def log_visuals(): 63 | self._save_image(visuals, self.iter_idx) 64 | 65 | visuals = parse_visuals(visuals) 66 | losses = parse_losses(losses) 67 | metrics = parse_metrics(metrics) 68 | 69 | log_message() 70 | log_visuals() 71 | 72 | if self.wandb: 73 | self.wandb.log_iter(iter_idx=self.iter_idx, 74 | visuals=visuals, 75 | mode="train", 76 | learning_rates=learning_rates, 77 | losses=losses, 78 | metrics=metrics) 79 | 80 | if self.tensorboard: 81 | self.tensorboard.log_iter(iter_idx=self.iter_idx, 82 | visuals=visuals, 83 | mode="train", 84 | learning_rates=learning_rates, 85 | losses=losses, 86 | metrics=metrics) 87 | -------------------------------------------------------------------------------- /docs/package_overview/5_engines.md: -------------------------------------------------------------------------------- 1 | # Engines 2 | 3 | `ganslate` defines four _engines_ that implement processes crucial to deep learning workflow. These are `Trainer`, `Validator`, `Tester`, and `Inferer`. The following UML diagram shows the design of the `ganslate`'s `engines` module and the relationship between the different engine classes defined in it. 4 | 5 | ![alt text](../imgs/uml-ganslate_engines.png "Relationship between ganslate's engine classes") 6 | 7 | 8 | 9 | ------------ 10 | ## `Trainer` 11 | 12 | The `Trainer` class ([source](https://github.com/ganslate-team/ganslate/ganslate/engines/trainer.py)) implements the training procedure and is instantiated at the start of the training process. Upon initialization, the trainer object executes the following tasks: 13 | 1. Preparing the environment. 14 | 2. Initializing the GAN model, training data loader, traning tracker, and validator. 15 | 16 | The `Trainer` class provides the `run()` method which defines the training logic. This includes: 17 | 1. Fetching data from the training dataloader 18 | 2. Invoking the GAN model's methods that set the inputs and perform forward pass, backpropagation, and parameter update. 19 | 3. Obtaining the results of the iteration which includes the computed images, loss values, metrics, and I/O and computation times, and pushing them into the experiment tracker for logging. 20 | 4. Running model validation. 21 | 5. Saving checkpoints locally. 22 | 6. Updating the learning rates. 23 | 24 | All configuration pertaining to the `Trainer` is grouped under the `'train'` mode in `ganslate`. 25 | 26 | 27 | 28 | -------------- 29 | ## `Validator` 30 | The `Validator`class ([source](https://github.com/ganslate-team/ganslate/ganslate/engines/validator_tester.py)) inherits almost all of its properties and functionalities from the `BaseValTestEngine`, and is responsible for performing validation given a model during the training process. It is instantiated and utilized within the `Trainer` where it is supplied with its configuration and the model. Upon initialization, a `Validator` object executes the following: 31 | 1. Initializes the sliding window inferer, validation data loader, validation tracker, and the validation-test metricizer 32 | 33 | The `run()` method of the `Validator` iterates over the validation dataset and executes the following steps: 34 | 1. Fetching data from the validation data loader. 35 | 2. Running inference on the given model and holding the computed images. 36 | 3. Saving the computed image and its relevant metadata (useful in case of medical images). 37 | 4. Calculate image quality/similarity metrics by comparing the generated image with the geound truth. 38 | 5. Pushing the images and metrics into the validation tracker for logging. 39 | 40 | All configuration pertaining to the `Validator` is grouped under the `'val'` mode in `ganslate`. 41 | 42 | 43 | 44 | ----------- 45 | ## `Tester` 46 | The `Tester` class ([source](https://github.com/ganslate-team/ganslate/ganslate/engines/validator_tester.py)), like the `Validator`, inherits from the `BaseValTestEngine` and has the same properties and functionalities as the `Validator`. The only difference is that a `Tester` instance sets up the environment and builds its own GAN model, and is therefore used independently of the `Trainer`. 47 | 48 | All configuration pertaining to the `Tester` is grouped under the `'test'` mode in `ganslate`. 49 | 50 | 51 | 52 | ------------ 53 | ## `Inferer` 54 | The `Inferer` class ([source](https://github.com/ganslate-team/ganslate/ganslate/engines/validator_tester.py)) represents a simplified inference engine without any mechanism for metric calculation. Therefore, it expects data without a ground truth to compare against. It does execute utility tasks like fetching data from a data loader, tracking I/O and computation time, and logging and saving images under normal circumstances. However, when used in the _deployment_ mode, the `Inferer` essentially acts as a minimal inference engine that can be easily integrated into other applications. -------------------------------------------------------------------------------- /ganslate/nn/utils.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn import init 3 | from torch.optim import lr_scheduler 4 | 5 | from ganslate.nn import separable 6 | 7 | 8 | def init_net(network, conf, device): 9 | init_weights(network, conf.train.gan.weight_init_type, conf.train.gan.weight_init_gain) 10 | return network.to(device) 11 | 12 | 13 | def init_weights(net, weight_init_type='normal', gain=0.02): 14 | 15 | def init_func(m): 16 | classname = m.__class__.__name__ 17 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or 18 | classname.find('Linear') != -1): 19 | if weight_init_type == 'normal': 20 | init.normal_(m.weight.data, 0.0, gain) 21 | elif weight_init_type == 'xavier': 22 | init.xavier_normal_(m.weight.data, gain=gain) 23 | elif weight_init_type == 'kaiming': 24 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 25 | elif weight_init_type == 'orthogonal': 26 | init.orthogonal_(m.weight.data, gain=gain) 27 | else: 28 | raise NotImplementedError( 29 | f"initialization method `{weight_init_type}` is not implemented") 30 | if hasattr(m, 'bias') and m.bias is not None: 31 | init.constant_(m.bias.data, 0.0) 32 | elif classname.find('BatchNorm3d') != -1: 33 | init.normal_(m.weight.data, 1.0, gain) 34 | init.constant_(m.bias.data, 0.0) 35 | 36 | net.apply(init_func) 37 | 38 | 39 | def get_conv_layer_3d(is_separable=False): 40 | if is_separable: 41 | return separable.SeparableConv3d 42 | else: 43 | return nn.Conv3d 44 | 45 | 46 | def get_conv_transpose_layer_3d(is_separable=False): 47 | if is_separable: 48 | return separable.SeparableConvTranspose3d 49 | else: 50 | return nn.ConvTranspose3d 51 | 52 | 53 | def get_norm_layer_2d(norm_type='instance'): 54 | if norm_type == 'batch': 55 | return nn.BatchNorm2d 56 | elif norm_type == 'instance': 57 | return nn.InstanceNorm2d 58 | else: 59 | raise NotImplementedError(f"Normalization layer `{norm_type}` not supported") 60 | 61 | 62 | def get_norm_layer_3d(norm_type='instance'): 63 | if norm_type == 'batch': 64 | return nn.BatchNorm3d 65 | elif norm_type == 'instance': 66 | return nn.InstanceNorm3d 67 | else: 68 | raise NotImplementedError(f"Normalization layer `{norm_type}` not supported") 69 | 70 | 71 | def is_bias_before_norm(norm_type='instance'): 72 | """When using BatchNorm, the preceding Conv layer does not use bias, 73 | but it does if using InstanceNorm. 74 | """ 75 | if norm_type == 'instance': 76 | return True 77 | elif norm_type == 'batch': 78 | return False 79 | else: 80 | raise NotImplementedError(f"Normalization layer `{norm_type}` not supported") 81 | 82 | 83 | def get_scheduler(optimizer, conf): 84 | """Return a scheduler that keeps the same learning rate for the first epochs 85 | and linearly decays the rate to zero over the next epochs. 86 | Parameters: 87 | optimizer -- the optimizer of the network 88 | TODO 89 | """ 90 | 91 | def lambda_rule(iter_idx): 92 | start_iter = 1 93 | if conf.train.checkpointing.load_iter: 94 | start_iter += conf.train.checkpointing.load_iter 95 | lr_l = 1.0 - max( 96 | 0, iter_idx + start_iter - conf.train.n_iters) / float(conf.train.n_iters_decay + 1) 97 | return lr_l 98 | 99 | return lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 100 | 101 | 102 | def get_network_device(network): 103 | """Returns the device of the network. Assumes that the whole network is on a single device.""" 104 | return next(network.parameters()).device 105 | -------------------------------------------------------------------------------- /ganslate/utils/cli/scripts/download_datasets.py: -------------------------------------------------------------------------------- 1 | import wget 2 | import zipfile 3 | import os 4 | from pathlib import Path 5 | import shutil 6 | 7 | AVAILABLE_DATASETS = ["ae_photos", "apple2orange", "summer2winter_yosemite", "horse2zebra", \ 8 | "monet2photo", "cezanne2photo","ukiyoe2photo", "vangogh2photo", "maps", \ 9 | "cityscapes", "facades", "iphone2dslr_flower", "mini", "mini_pix2pix", "mini_colorization"] 10 | 11 | def download(name, path): 12 | if name not in AVAILABLE_DATASETS: 13 | print(""".Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos 14 | 15 | facades: 400 images from the CMP Facades dataset. [Citation] 16 | cityscapes: 2975 images from the Cityscapes training set. [Citation]. Note: Due to license issue, we cannot directly provide the Cityscapes dataset. Please download the Cityscapes dataset from https://cityscapes-dataset.com 17 | maps: 1096 training images scraped from Google Maps. 18 | horse2zebra: 939 horse images and 1177 zebra images downloaded from ImageNet using keywords wild horse and zebra 19 | apple2orange: 996 apple images and 1020 orange images downloaded from ImageNet using keywords apple and navel orange. 20 | summer2winter_yosemite: 1273 summer Yosemite images and 854 winter Yosemite images were downloaded using Flickr API. See more details in our paper. 21 | monet2photo, vangogh2photo, ukiyoe2photo, cezanne2photo: The art images were downloaded from Wikiart. The real photos are downloaded from Flickr using the combination of the tags landscape and landscapephotography. The training set size of each class is Monet:1074, Cezanne:584, Van Gogh:401, Ukiyo-e:1433, Photographs:6853. 22 | iphone2dslr_flower: both classes of images were downlaoded from Flickr. The training set size of each class is iPhone:1813, DSLR:3316. See more details in our paper. 23 | 24 | Refer link: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/datasets.md 25 | """) 26 | 27 | else: 28 | 29 | assert Path(path).is_dir(), f"{path} provided is not a directory" 30 | url = f"https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/{name}.zip" 31 | 32 | 33 | path_to_zip_file = f"{path}/{name}.zip" 34 | 35 | # Remove if file already exists to handle corrupt files. 36 | if os.path.isfile(path_to_zip_file): 37 | os.remove(path_to_zip_file) 38 | 39 | print(f"Fetching {name} datasets from {url}:") 40 | wget.download(url, out=path_to_zip_file) 41 | 42 | if Path(f"{path}/{name}").is_dir(): 43 | shutil.rmtree(Path(f"{path}/{name}")) 44 | 45 | print(f"Extracting zip file to {path}") 46 | with zipfile.ZipFile(path_to_zip_file, 'r') as zip_ref: 47 | zip_ref.extractall(path) 48 | 49 | os.remove(path_to_zip_file) 50 | 51 | print("Reorganizing folder structure for ganslate") 52 | # Make folders for train and test 53 | train_path = Path(f"{path}/{name}/train") 54 | test_path = Path(f"{path}/{name}/test") 55 | 56 | train_path.mkdir(parents=True, exist_ok=True) 57 | test_path.mkdir(parents=True, exist_ok=True) 58 | 59 | # Copy contents of download to path structure required by ganslate 60 | shutil.move(f"{path}/{name}/trainA", str(train_path / "A")) 61 | shutil.move(f"{path}/{name}/trainB", str(train_path / "B")) 62 | shutil.move(f"{path}/{name}/testA", str(test_path / "A")) 63 | shutil.move(f"{path}/{name}/testB", str(test_path / "B")) 64 | 65 | 66 | if __name__ == "__main__": 67 | import argparse 68 | 69 | parser = argparse.ArgumentParser() 70 | 71 | parser.add_argument("name") 72 | parser.add_argument("path") 73 | 74 | args = parser.parse_args() 75 | 76 | download(args.name, args.path) 77 | -------------------------------------------------------------------------------- /ganslate/nn/losses/cyclegan_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import ganslate.nn.losses.utils.ssim as ssim 3 | 4 | from loguru import logger 5 | 6 | 7 | class CycleGANLosses: 8 | """Defines losses used for optiming the generators in CycleGAN setup. 9 | Consists of: 10 | (1) Cycle-consistency loss (weighted combination of L1 and, optionally, SSIM) 11 | (2) Identity loss 12 | """ 13 | 14 | def __init__(self, conf): 15 | self.lambda_AB = conf.train.gan.optimizer.lambda_AB 16 | self.lambda_BA = conf.train.gan.optimizer.lambda_BA 17 | 18 | lambda_identity = conf.train.gan.optimizer.lambda_identity 19 | proportion_ssim = conf.train.gan.optimizer.proportion_ssim 20 | 21 | # Cycle-consistency - L1, with optional weighted combination with SSIM 22 | self.criterion_cycle = CycleLoss(proportion_ssim) 23 | if lambda_identity > 0: 24 | self.criterion_idt = IdentityLoss(lambda_identity) 25 | else: 26 | self.criterion_idt = None 27 | 28 | def is_using_identity(self): 29 | """Check if idt_A and idt_B should be computed.""" 30 | return True if self.criterion_idt else False 31 | 32 | def __call__(self, visuals): 33 | real_A, real_B = visuals['real_A'], visuals['real_B'] 34 | fake_A, fake_B = visuals['fake_A'], visuals['fake_B'] 35 | rec_A, rec_B = visuals['rec_A'], visuals['rec_B'] 36 | idt_A, idt_B = visuals['idt_A'], visuals['idt_B'] 37 | 38 | losses = {} 39 | 40 | # cycle-consistency loss 41 | # || G_BA(G_AB(real_A)) - real_A|| 42 | losses['cycle_A'] = self.lambda_AB * self.criterion_cycle(real_A, rec_A) 43 | # || G_AB(G_BA(real_B)) - real_B|| 44 | losses['cycle_B'] = self.lambda_BA * self.criterion_cycle(real_B, rec_B) 45 | 46 | # identity loss 47 | if self.criterion_idt: 48 | if idt_A is not None and idt_B is not None: 49 | # || G_AB(real_B) - real_B || 50 | losses['idt_B'] = self.lambda_AB * self.criterion_idt(idt_B, real_B) 51 | # || G_BA(real_A) - real_A || 52 | losses['idt_A'] = self.lambda_BA * self.criterion_idt(idt_A, real_A) 53 | 54 | else: 55 | raise ValueError( 56 | "idt_A and/or idt_B is not computed but the identity loss is defined.") 57 | 58 | return losses 59 | 60 | 61 | class CycleLoss: 62 | 63 | def __init__(self, proportion_ssim): 64 | self.criterion = torch.nn.L1Loss() 65 | if proportion_ssim > 0: 66 | self.ssim_criterion = ssim.SSIMLoss() 67 | # weights for addition of SSIM and L1 losses 68 | self.alpha = proportion_ssim 69 | self.beta = 1 - proportion_ssim 70 | else: 71 | self.ssim_criterion = None 72 | 73 | def __call__(self, real, reconstructed): 74 | # regular L1 cycle-consistency 75 | cycle_loss_L1 = self.criterion(reconstructed, real) 76 | 77 | # cycle-consistency using a weighted combination of SSIM and L1 78 | if self.ssim_criterion: 79 | # Data range needs to be positive and normalized 80 | # https://github.com/VainF/pytorch-msssim#2-normalized-input 81 | ssim_real = (real + 1) / 2 82 | ssim_reconstructed = (reconstructed + 1) / 2 83 | 84 | # SSIM criterion returns distance metric 85 | cycle_loss_ssim = self.ssim_criterion(ssim_reconstructed, ssim_real, data_range=1) 86 | 87 | # weighted sum of SSIM and L1 losses for both forward and backward cycle losses 88 | return self.alpha * cycle_loss_ssim + self.beta * cycle_loss_L1 89 | else: 90 | return cycle_loss_L1 91 | 92 | 93 | class IdentityLoss: 94 | 95 | def __init__(self, lambda_identity): 96 | self.lambda_identity = lambda_identity 97 | self.criterion = torch.nn.L1Loss() 98 | 99 | def __call__(self, idt, real): 100 | loss_idt = self.criterion(idt, real) 101 | return loss_idt * self.lambda_identity 102 | -------------------------------------------------------------------------------- /ganslate/nn/losses/utils/ssim.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) ganslate Contributors 3 | # Changes added by Maastro-CDS-Imaging-Group : https://github.com/Maastro-CDS-Imaging-Group/ganslate 4 | # Clean and simplify SSIM computation similar to fastMRI SSIM. 5 | 6 | # Taken from: https://github.com/VainF/pytorch-msssim/blob/master/pytorch_msssim/ssim.py 7 | # Licensed under MIT. 8 | # Copyright 2020 by Gongfan Fang, Zhejiang University. 9 | # All rights reserved. 10 | # Some changes are made to work together with DIRECT. 11 | 12 | # ---------------------------------------------------- 13 | # Taken from DIRECT https://github.com/directgroup/direct 14 | # Copyright (c) DIRECT Contributors 15 | # Added support for mixed precision by allowing one image to be of type `half` and the other `float`. 16 | # ---------------------------------------------------- 17 | 18 | import torch 19 | import torch.nn.functional as F 20 | 21 | 22 | def _fspecial_gauss_1d(size, sigma, device=None, dtype=None): 23 | """ 24 | Create a 1D gaussian kernel 25 | Parameters 26 | ---------- 27 | size : int 28 | The size of the gaussian kernel 29 | sigma : float 30 | The standard deviation of the normal distribution 31 | Returns 32 | ------- 33 | torch.Tensor: 1D kernel (1 x 1 x size) 34 | """ 35 | coords = torch.arange(size, dtype=dtype, device=device).float() 36 | coords -= size // 2 37 | g = torch.exp(-(coords**2) / (2 * sigma**2)) 38 | g /= g.sum() 39 | # Return window as 1x1xsize 40 | return g.view(1, 1, *g.shape) 41 | 42 | 43 | def gaussian_filter(input, win): 44 | """ 45 | Blur input with 1D kernel 46 | """ 47 | out = F.conv2d(input, win, stride=1, padding=0, groups=input.shape[1]) 48 | return F.conv2d(out, win.transpose(2, 3), stride=1, padding=0, groups=input.shape[1]) 49 | 50 | 51 | class SSIMLoss(torch.nn.Module): 52 | 53 | def __init__(self, win_size=11, win_sigma=1.5, K=(0.01, 0.03)): 54 | r""" class for ssim 55 | Args: 56 | data_range (float or int, optional): value range of input images. (usually 1.0 or 255) 57 | win_size: (int, optional): the size of gauss kernel 58 | win_sigma: (float, optional): sigma of normal distribution 59 | K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. 60 | """ 61 | super().__init__() 62 | self.win_size, self.win_sigma = win_size, win_sigma 63 | self.K = K 64 | 65 | def forward(self, X, Y, data_range=1): 66 | assert X.shape == Y.shape, "X and Y need to be the same shape" 67 | assert X.ndim in [4, 5], "Dimensions of input must be NxCxHxW or NxCxDxHxW" 68 | 69 | # if NxCxDxHxW, convert NxC to N only giving NxDxHxW 70 | if X.ndim == 5: 71 | X = X.view(-1, *X.shape[2:]) 72 | Y = Y.view(-1, *Y.shape[2:]) 73 | channels = X.shape[1] 74 | 75 | # Create 1D gaussian window and repeat it over channel dims 76 | win = _fspecial_gauss_1d(self.win_size, self.win_sigma, dtype=X.dtype, device=X.device) 77 | win = win.repeat(channels, 1, 1, 1) 78 | 79 | K1, K2 = self.K 80 | compensation = 1.0 81 | 82 | C1 = (K1 * data_range)**2 83 | C2 = (K2 * data_range)**2 84 | 85 | mu1 = gaussian_filter(X, win) 86 | mu2 = gaussian_filter(Y, win) 87 | 88 | sigma1_sq = compensation * (gaussian_filter(X * X, win) - mu1.pow(2)) 89 | sigma2_sq = compensation * (gaussian_filter(Y * Y, win) - mu2.pow(2)) 90 | sigma12 = compensation * (gaussian_filter(X * Y, win) - (mu1 * mu2)) 91 | 92 | S1 = (2 * mu1 * mu2 + C1) / (mu1.pow(2) + mu2.pow(2) + C1) 93 | S2 = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2) # set alpha=beta=gamma=1 94 | 95 | # SSIM Distance metric approximation from: https://ece.uwaterloo.ca/~z70wang/publications/TIP_SSIM_MathProperties.pdf 96 | # Add relu here since floating point rounding errors can lead this value to be slightly negative! 97 | S = torch.relu(2 - (S1 + S2)) 98 | D_map = torch.sqrt(S) 99 | return D_map.mean() 100 | -------------------------------------------------------------------------------- /tools/analyzers/wandb/wandb_analyzer.py: -------------------------------------------------------------------------------- 1 | # Use this script to analyze data from wandb logs. 2 | import wandb 3 | import pandas as pd 4 | from omegaconf import OmegaConf 5 | from typing import Optional, Tuple, List 6 | from omegaconf import MISSING, II 7 | 8 | import utils 9 | from dataclasses import dataclass, field 10 | 11 | from loguru import logger 12 | 13 | 14 | ############################### Analyzer Configuration ######################################## 15 | @dataclass 16 | class AnalyzerConfig: 17 | # Wandb entity and project 18 | entity: str = MISSING 19 | project: str = MISSING 20 | # Select a particular run ID to run the analyzer on 21 | run_id: str = MISSING 22 | # Run validation analyzer only up to the last checkpoint specified 23 | last_ckpt: Optional[int] = None 24 | # Metric tags to ignore in the analysis 25 | ignore_tags: List = field(default_factory=lambda: []) 26 | # Additionally group by metric tags during the ranking process 27 | group_by: List = field(default_factory=lambda: []) 28 | # Once the metrics are ranked, this determines how the ranks are aggregated 29 | aggregate_ranks_by: str = "mean" 30 | # Sampling frequency applied to total number of iterations 31 | # If iters_sampling_freq = 10, every 10 iterations are considered for ranking 32 | iters_sampling_freq: int = 1 33 | # Metric tags to include in the analysis 34 | # in descending or ascending format for ranking 35 | rank_descending_keys: List = field(default_factory=lambda: ["psnr", "ssim"]) 36 | rank_ascending_keys: List = field(default_factory=lambda: ["mae", "mse", "nmse"]) 37 | 38 | 39 | # Example: python tools/analyzers/wandb/wandb_analyzer.py entity=maastro-clinic project="Media_Experiments" run_id="348tusn" 40 | # group_by=[phantom,BODY] ignore_tags=['cycle'] 41 | ################################################################################################# 42 | 43 | 44 | def main(conf): 45 | api = wandb.Api() 46 | api.entity = conf.entity 47 | 48 | for run in api.runs(f"{conf.project}"): 49 | if run.id == conf.run_id: 50 | logger.info(f"Loading {run.name} ...") 51 | df = utils.get_wandb_history(run, conf) 52 | 53 | # Overall all the checkpoints in the run history, get ranks for the metrics 54 | # by going over each metric and ranking the series in ascending or descending order 55 | for label, series in df.items(): 56 | if utils.list_of_strings_has_substring(conf.rank_descending_keys, label): 57 | df[label] = series.rank(ascending=False) 58 | elif utils.list_of_strings_has_substring(conf.rank_ascending_keys, label): 59 | df[label] = series.rank(ascending=True) 60 | else: 61 | logger.warning(f'{label} not in ascending or descending set of keys') 62 | 63 | # Aggregate 'all' metrics based on rank and selected method of ordering 64 | df[f'{conf.aggregate_ranks_by}_rank_all_metrics'] = utils.get_aggregate_ranks( 65 | df, conf.aggregate_ranks_by) 66 | sort_by = [f'{conf.aggregate_ranks_by}_rank_all_metrics'] 67 | 68 | # Check for any provided groups to be inspected among the metrics 69 | for group_key in conf.group_by: 70 | # If group key is present in df columns add it to 71 | group_metric_cols = [col for col in df.columns if group_key.lower() in col.lower()] 72 | group_df = df[group_metric_cols] 73 | # Aggregate 'group_key' metrics based on rank and selected method 74 | df[f'{conf.aggregate_ranks_by}_rank_{group_key}_metrics'] = utils.get_aggregate_ranks( 75 | group_df, conf.aggregate_ranks_by) 76 | sort_by += [f'{conf.aggregate_ranks_by}_rank_{group_key}_metrics'] 77 | 78 | # For 'all' metrics and grouped metrics, sort by the aggregate rank 79 | for val in sort_by: 80 | df = df.sort_values(by=val) 81 | df[val].to_csv(f"{run.name}_{val}.csv") 82 | logger.info(f"Top 5 iterations for {val}: \n {df[val].head()}\n") 83 | 84 | 85 | if __name__ == "__main__": 86 | cli = OmegaConf.from_cli() 87 | conf = AnalyzerConfig() 88 | conf = OmegaConf.merge(conf, cli) 89 | main(conf) 90 | --------------------------------------------------------------------------------