├── .gitignore ├── Makefile ├── README.md ├── config ├── backbone │ ├── bsp_swinir.yaml │ ├── swinia.yaml │ ├── swinir.yaml │ └── unet.yaml ├── config.yaml └── experiment │ ├── bsd68.yaml │ ├── fmd.yaml │ ├── fmd_deconvolution.yaml │ ├── hanzi.yaml │ ├── imagenet.yaml │ ├── microtubules.yaml │ ├── microtubules_generated.yaml │ ├── microtubules_original.yaml │ ├── planaria.yaml │ ├── sidd.yaml │ ├── ssi.yaml │ ├── synthetic.yaml │ └── synthetic_grayscale.yaml ├── data └── README.md ├── evaluate.py ├── noise2same ├── __init__.py ├── backbone │ ├── __init__.py │ ├── bsp_swinir.py │ ├── decoder │ │ └── __init__.py │ ├── encoder │ │ └── __init__.py │ ├── swinia.py │ ├── swinir.py │ └── unet.py ├── contrast.py ├── dataset │ ├── __init__.py │ ├── abc.py │ ├── bsd68.py │ ├── dummy.py │ ├── fmd.py │ ├── getter.py │ ├── hanzi.py │ ├── imagenet.py │ ├── microtubules.py │ ├── planaria.py │ ├── sidd.py │ ├── ssi.py │ ├── synthetic.py │ ├── synthetic_grayscale.py │ ├── transforms.py │ └── util.py ├── evaluator.py ├── fft_conv.py ├── model.py ├── ops │ ├── __init__.py │ └── wrappers.py ├── optimizers │ ├── __init__.py │ └── esadam.py ├── psf │ ├── __init__.py │ ├── microscope_psf.py │ └── psf_convolution.py ├── trainer.py └── util.py ├── requirements.txt ├── scripts ├── bsd68_swin.sh ├── bsd68_unet.sh ├── deconvolution │ ├── fmd_unet.sh │ ├── microtubules_inv_mse_before_psf_boundary.sh │ ├── ssi.sh │ ├── ssi_boundary.sh │ ├── ssi_inv_mse_before_and_after_psf_boundary.sh │ ├── ssi_inv_mse_before_psf.sh │ ├── ssi_inv_mse_before_psf_boundary.sh │ ├── ssi_only_masked.sh │ └── ssi_only_masked_boundary.sh ├── fmd_swinia.sh ├── fmd_swinia_dp.sh ├── fmd_unet.sh ├── hanzi_swin.sh ├── hanzi_unet.sh ├── imagenet_swin.sh ├── imagenet_unet.sh ├── sidd_swinia.sh ├── sidd_unet.sh ├── synthetic_grayscale_swinia.sh ├── synthetic_grayscale_swinia_dp.sh ├── synthetic_grayscale_unet.sh ├── synthetic_swinia.sh ├── synthetic_swinia_dp.sh └── synthetic_unet.sh ├── tests ├── test_contrast.py ├── test_dataset.py ├── test_metrics.py ├── test_model.py └── test_psf.py ├── train.py ├── utils.py └── weights └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | weights/*.pth 3 | weights/*.zip 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | .dmypy.json 133 | dmypy.json 134 | 135 | # Pyre type checker 136 | .pyre/ 137 | 138 | # pytype static type analyzer 139 | .pytype/ 140 | 141 | # Cython debug symbols 142 | cython_debug/ 143 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | TEST_PATH=./tests 2 | MODULES_PATH=./noise2same 3 | TRAIN_PATH=train.py evaluate.py 4 | 5 | .PHONY: format lint test commit 6 | 7 | format: 8 | isort $(MODULES_PATH) $(TRAIN_PATH) $(TEST_PATH) 9 | black $(MODULES_PATH) $(TRAIN_PATH) $(TEST_PATH) 10 | 11 | lint: 12 | isort -c $(MODULES_PATH) $(TRAIN_PATH) $(TEST_PATH) 13 | black --check $(MODULES_PATH) $(TRAIN_PATH) $(TEST_PATH) 14 | mypy $(MODULES_PATH) $(TRAIN_PATH) $(TEST_PATH) 15 | 16 | test: 17 | python3 -m unittest discover -s $(TEST_PATH) -t $(TEST_PATH) 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Noise2Same (PyTorch) 2 | 3 | PyTorch reimplementation of [Noise2Same](https://github.com/divelab/Noise2Same). 4 | Work in progress. 5 | 6 | ## Usage 7 | 8 | Default configuration is located in `config/config.yaml`. 9 | Experiment configs `config/experiments` may override defaults. 10 | 11 | ### Training 12 | 13 | To run an experiment for BSD68, execute 14 | ```bash 15 | python train.py +experiment=bsd68 16 | ``` 17 | Four experiments from Noise2Same are supported: `bsd68`, `hanzi`, `imagenet`, `planaria`. 18 | 19 | Training logs and model weights will be saved to `resuts/train/datetime`. 20 | 21 | ### Evaluation 22 | 23 | To run evaluation for BSD68, execute 24 | ```bash 25 | python evaluate.py +experiment=bsd68 26 | ``` 27 | By default, we assume the weights for the model to be in `weights/experiment.pth` 28 | but you can specify the path by adding `+checkpoint=/path/to/checkpoint`. 29 | 30 | Model's outputs and scores (RMSE, PSNR, SSIM for each image) will be saved to `resuts/evaluate/datetime`. 31 | 32 | 33 | ## Results replication 34 | 35 | We replicate the main results of Noise2Same (Table 3) 36 | []() 37 | 38 | | Dataset | Ours (Noise2Self) | Noise2Same paper | Ours (Noise2Same) | Weights | 39 | |---------------------|-------------------|-----------------------|-------|----------| 40 | | BSD68 | 26.73 | 27.95 | 28.11 | [Drive](https://drive.google.com/file/d/1YTlHpL-C4JaRtfp8YUiXfzppX0Tgs4lC/view?usp=sharing)| 41 | | HanZi | | 14.38 | 14.83 |[Drive](https://drive.google.com/file/d/1WHd_BUqlibrDERWwzs4ReSZu2s8Ya1Y2/view?usp=sharing)| 42 | | ImageNet | | 22.26 | 22.81 |[Drive](https://drive.google.com/file/d/12Rxp30DmwmYBq6ZtgnPD-SMd9u1FeRER/view?usp=sharing)| 43 | | Planaria (C1/C2/C3) | | 29.48 / 26.93 / 22.41 | 29.14 / 27.11 / 22.80 |[Drive](https://drive.google.com/file/d/17Yz_f8RNOu7nEztSOug_1PLQFKYYj_Mf/view?usp=sharing)| 44 | 45 | -------------------------------------------------------------------------------- /config/backbone/bsp_swinir.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | backbone_name: bsp_swinir 3 | 4 | backbone: 5 | embed_dim: 96 6 | upscale: 1 7 | window_size: 8 8 | depths: [ 6, 6, 6, 6, 6, 6 ] 9 | num_heads: [ 6, 6, 6, 6, 6, 6 ] 10 | -------------------------------------------------------------------------------- /config/backbone/swinia.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | backbone_name: swinia 3 | 4 | backbone: 5 | embed_dim: 144 6 | window_size: 8 7 | depths: [ 4, 4, 4, 4, 4 ] 8 | num_heads: [ 16, 16, 16, 16, 16 ] 9 | dilations: [ 1, 1, 1, 1, 1 ] 10 | shuffles: [ 1, 2, 4, 2, 1 ] 11 | 12 | optim: 13 | lr: 1e-3 14 | scheduler: cosine 15 | eta_min: 1e-6 -------------------------------------------------------------------------------- /config/backbone/swinir.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | backbone_name: swinir 3 | 4 | backbone: 5 | embed_dim: 96 6 | upscale: 1 7 | window_size: 8 8 | depths: [ 6, 6, 6, 6, 6, 6 ] 9 | num_heads: [ 6, 6, 6, 6, 6, 6 ] 10 | -------------------------------------------------------------------------------- /config/backbone/unet.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | backbone_name: unet 3 | 4 | backbone: 5 | base_channels: 96 6 | kernel_size: 3 7 | depth: 3 8 | encoding_block_sizes: [ 1, 1, 0 ] 9 | decoding_block_sizes: [ 1, 1 ] 10 | downsampling: [ conv, conv ] 11 | skip_method: concat 12 | -------------------------------------------------------------------------------- /config/config.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: results/${experiment}/${backbone_name}/${hydra.job.name}/${now:%Y-%m-%d_%H-%M-%S} 4 | 5 | project: noise2same 6 | device: 0 7 | seed: 56 8 | check: False 9 | evaluate: True 10 | 11 | model: 12 | lambda_rec: 1 13 | lambda_bsp: 0 14 | lambda_inv: 2 15 | lambda_inv_deconv: 0 16 | lambda_proj: 0 17 | lambda_bound: 0 18 | lambda_sharp: 0 19 | masked_inv_deconv: True 20 | mask_percentage: 0.5 21 | masking: gaussian 22 | noise_mean: 0 23 | noise_std: 0.2 24 | residual: False 25 | regularization_key: image 26 | mode: "noise2same" 27 | 28 | training: 29 | steps_per_epoch: 1000 30 | steps: 50000 31 | batch_size: 64 32 | num_workers: 8 33 | crop: 64 34 | validate: True 35 | val_partition: 1.0 36 | val_batch_size: 4 37 | monitor: bsp_mse 38 | amp: True 39 | info_padding: False 40 | 41 | data: 42 | n_dim: 2 43 | n_channels: 1 44 | standardize: True # subtract mean and divide by std for each image separately 45 | add_blur_and_noise: False 46 | 47 | optim: 48 | optimizer: adam 49 | lr: 0.0004 50 | weight_decay: 0 51 | scheduler: lambda 52 | decay_rate: 0.5 53 | decay_steps: 5e3 # how many steps to decrease by decay rate 54 | staircase: True # integer division 55 | -------------------------------------------------------------------------------- /config/experiment/bsd68.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | experiment: bsd68 3 | data: 4 | n_dim: 2 5 | n_channels: 1 6 | 7 | training: 8 | crop: 128 # 64 in the paper 9 | steps: 80000 10 | batch_size: 32 # 64 in the paper 11 | 12 | model: 13 | lambda_inv: 0.95 14 | -------------------------------------------------------------------------------- /config/experiment/fmd.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | experiment: fmd 3 | data: 4 | n_dim: 2 5 | n_channels: 1 6 | part: cf_fish 7 | 8 | training: 9 | crop: 64 10 | steps: 20000 11 | batch_size: 64 12 | monitor: val_mse 13 | 14 | optim: 15 | weight_decay: 1e-8 -------------------------------------------------------------------------------- /config/experiment/fmd_deconvolution.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | experiment: fmd 3 | data: 4 | n_dim: 2 5 | n_channels: 1 6 | part: cf_fish 7 | add_blur_and_noise: True 8 | 9 | training: 10 | crop: 128 11 | steps: 10000 12 | steps_per_epoch: 200 13 | batch_size: 16 14 | val_batch_size: 1 15 | monitor: val_mse 16 | 17 | optim: 18 | optimizer: adam 19 | lr: 0.0004 20 | weight_decay: 1e-8 21 | decay_rate: 0.5 22 | decay_steps: 1e3 # how many steps to decrease by decay rate 23 | staircase: True # integer division 24 | 25 | psf: 26 | path: null 27 | psf_size: null 28 | psf_pad_mode: replicate # check if reflect affected anything 29 | psf_fft: auto -------------------------------------------------------------------------------- /config/experiment/hanzi.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | experiment: hanzi 3 | data: 4 | version: 0 5 | noise_level: 3 6 | 7 | training: 8 | steps: 50000 9 | crop: 64 10 | batch_size: 64 11 | validate: True 12 | val_partition: 0.05 13 | monitor: bsp_mse -------------------------------------------------------------------------------- /config/experiment/imagenet.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | experiment: imagenet 3 | data: 4 | n_channels: 3 5 | version: 0 6 | 7 | training: 8 | steps: 50000 9 | crop: 64 10 | batch_size: 64 11 | validate: True 12 | val_partition: 0.1 13 | monitor: bsp_mse -------------------------------------------------------------------------------- /config/experiment/microtubules.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | experiment: microtubules 3 | evaluate: True 4 | 5 | data: 6 | n_dim: 3 7 | path: data/microtubules-simulation 8 | input_name: input-generated-poisson-gaussian-2e-4.tif 9 | tile_size: 64 10 | tile_step: 48 11 | add_blur_and_noise: False 12 | 13 | network: 14 | base_channels: 48 15 | skip_method: add 16 | 17 | training: 18 | steps_per_epoch: 100 19 | steps: 15000 20 | crop: 64 21 | batch_size: 4 22 | validate: False 23 | monitor: rec_mse 24 | 25 | optim: 26 | decay_steps: 2000 27 | 28 | psf: 29 | path: data/microtubules-simulation/psf-bw-31.tif # near the data by default; parametrize? 30 | psf_size: null 31 | psf_pad_mode: replicate # check if reflect affected anything 32 | psf_fft: auto 33 | -------------------------------------------------------------------------------- /config/experiment/microtubules_generated.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | experiment: microtubules 3 | evaluate: True 4 | 5 | data: 6 | n_dim: 3 7 | path: data/microtubules-simulation 8 | input_name: ground-truth.tif 9 | tile_size: 64 10 | tile_step: 48 11 | add_blur_and_noise: True 12 | 13 | network: 14 | base_channels: 48 15 | skip_method: add 16 | 17 | training: 18 | steps_per_epoch: 100 19 | steps: 15000 20 | crop: 64 21 | batch_size: 4 22 | validate: False 23 | monitor: rec_mse 24 | 25 | optim: 26 | decay_steps: 2000 27 | 28 | psf: 29 | path: null 30 | psf_size: null 31 | psf_pad_mode: replicate # check if reflect affected anything 32 | psf_fft: auto 33 | -------------------------------------------------------------------------------- /config/experiment/microtubules_original.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | experiment: microtubules 3 | data: 4 | n_dim: 3 5 | path: data/microtubules-simulation 6 | input_name: input.tif 7 | tile_size: 64 8 | tile_step: 64 # to prevent overlap 9 | 10 | network: 11 | base_channels: 48 12 | skip_method: add 13 | 14 | training: 15 | steps_per_epoch: 100 16 | steps: 15000 17 | crop: 64 18 | batch_size: 4 19 | validate: False 20 | monitor: rec_mse 21 | amp: True # No inf checks were recorded for this optimizer 22 | info_padding: True 23 | 24 | optim: 25 | decay_steps: 2000 26 | 27 | psf: 28 | path: data/microtubules-simulation/psf.tif # near the data by default; parametrize? 29 | psf_size: null 30 | psf_pad_mode: constant # check if reflect affected anything 31 | psf_fft: auto 32 | -------------------------------------------------------------------------------- /config/experiment/planaria.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | experiment: planaria 3 | data: 4 | n_dim: 3 5 | 6 | training: 7 | steps: 50000 8 | crop: 64 9 | batch_size: 4 -------------------------------------------------------------------------------- /config/experiment/sidd.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | experiment: sidd 3 | data: 4 | n_channels: 3 5 | 6 | training: 7 | steps: 50000 8 | crop: 64 9 | batch_size: 64 10 | validate: True 11 | val_partition: 0.1 12 | monitor: bsp_mse -------------------------------------------------------------------------------- /config/experiment/ssi.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | experiment: ssi 3 | 4 | evaluate: True 5 | 6 | data: 7 | path: data/ssi 8 | n_dim: 2 9 | n_channels: 1 10 | input_name: drosophila 11 | 12 | training: 13 | num_workers: 4 14 | crop: 128 15 | steps_per_epoch: 30 16 | steps: 3000 17 | batch_size: 16 18 | validate: False 19 | monitor: rec_mse 20 | 21 | optim: 22 | optimizer: adam 23 | lr: 0.0004 24 | decay_rate: 0.5 25 | decay_steps: 5e2 # how many steps to decrease by decay rate 26 | staircase: True # integer division 27 | 28 | 29 | psf: 30 | path: null 31 | psf_size: null 32 | psf_pad_mode: replicate # check if reflect affected anything -------------------------------------------------------------------------------- /config/experiment/synthetic.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | experiment: synthetic 3 | data: 4 | n_dim: 2 5 | n_channels: 3 6 | noise_type: gaussian 7 | noise_param: 25 8 | 9 | training: 10 | crop: 64 11 | steps: 50000 12 | batch_size: 64 13 | validate: True 14 | -------------------------------------------------------------------------------- /config/experiment/synthetic_grayscale.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | experiment: synthetic_grayscale 3 | data: 4 | n_dim: 2 5 | n_channels: 1 6 | noise_type: gaussian 7 | noise_param: 25 8 | 9 | training: 10 | crop: 64 11 | steps: 50000 12 | batch_size: 64 13 | validate: True 14 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | See https://github.com/divelab/Noise2Same/tree/main/Denoising_data -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import datetime 4 | from pathlib import Path 5 | from pprint import pprint 6 | 7 | from argparse import ArgumentParser 8 | import numpy as np 9 | import pandas as pd 10 | from omegaconf import OmegaConf 11 | from torch.utils.data import DataLoader, Dataset, ConcatDataset 12 | from tqdm.auto import tqdm 13 | 14 | from noise2same import model, util 15 | from noise2same.dataset.getter import ( 16 | get_planaria_dataset_and_gt, 17 | get_test_dataset_and_gt, 18 | ) 19 | from noise2same.evaluator import Evaluator 20 | from utils import parametrize_backbone_and_head 21 | 22 | 23 | def get_loader( 24 | dataset: Dataset, 25 | experiment: str, 26 | num_workers: int, 27 | ): 28 | loader = None 29 | if experiment.lower() in ("bsd68", "fmd", "imagenet", "sidd", "hanzi", "synthetic", "synthetic_grayscale"): 30 | loader = DataLoader( 31 | dataset, 32 | batch_size=1, # todo customize 33 | num_workers=num_workers, 34 | shuffle=False, 35 | pin_memory=True, 36 | drop_last=False, 37 | ) 38 | elif experiment == "ssi": 39 | loader = DataLoader( 40 | dataset, 41 | batch_size=1, # todo customize 42 | num_workers=num_workers, 43 | shuffle=False, 44 | pin_memory=True, 45 | drop_last=False, 46 | ) 47 | return loader 48 | 49 | 50 | def get_ground_truth_and_predictions( 51 | evaluator: Evaluator, 52 | experiment: str, 53 | ground_truth: np.ndarray, 54 | cwd: Path, 55 | loader: DataLoader = None, 56 | dataset: Dataset = None, 57 | half: bool = False 58 | ): 59 | if experiment in ("bsd68", "fmd", "hanzi", "sidd", "synthetic", "synthetic_grayscale"): 60 | add_blur_and_noise = getattr(dataset, "add_blur_and_noise", False) 61 | if add_blur_and_noise: 62 | print("Validate for deconvolution") 63 | predictions, _ = evaluator.inference(loader, half=half) 64 | elif experiment in ("imagenet",): 65 | predictions, indices = evaluator.inference(loader, half=half, empty_cache=True) 66 | ground_truth = [ground_truth[i] for i in indices] 67 | elif experiment in ("microtubules",): 68 | predictions, _ = evaluator.inference_single_image_dataset( 69 | dataset, half=half, batch_size=1 70 | ) 71 | elif experiment in ("planaria",): 72 | files = sorted( 73 | glob.glob(str(cwd / "data/Denoising_Planaria/test_data/GT/*.tif")) 74 | ) 75 | predictions = {"c1": [], "c2": [], "c3": [], "y": []} 76 | for f in tqdm(files): 77 | datasets, gt = get_planaria_dataset_and_gt(f) 78 | predictions["y"].append(gt) 79 | for c in range(1, 4): 80 | predictions[f"c{c}"].append( 81 | evaluator.inference_single_image_dataset( 82 | datasets[f"c{c}"], half=half, batch_size=1 83 | ) 84 | ) 85 | else: 86 | raise ValueError 87 | 88 | # Rearrange predictions List[Dict[str, array]] -> Dict[str, List[array]] 89 | if experiment not in ("planaria", "microtubules"): 90 | predictions = {k: [d[k].squeeze() for d in predictions] for k in predictions[0]} 91 | 92 | return ground_truth, predictions 93 | 94 | 95 | def get_scores( 96 | ground_truth: np.ndarray, 97 | predictions: np.ndarray, 98 | experiment: str 99 | ): 100 | # Calculate scores 101 | if experiment in ("bsd68",): 102 | scores = [ 103 | util.calculate_scores(gtx, pred, data_range=255) 104 | for gtx, pred in zip(ground_truth, predictions["image"]) 105 | ] 106 | elif experiment in ("synthetic", "synthetic_grayscale", "sidd", "fmd",): 107 | scale = 255 if experiment.startswith("synthetic") else 1 108 | multichannel = experiment in ("synthetic", "sidd") 109 | scores = [ 110 | # https://github.com/TaoHuang2018/Neighbor2Neighbor/blob/2fff2978/train.py#L446 111 | # SSIM is not exactly the same as the original Neighbor2Neighbor implementation, 112 | # because skimage uses padding (which is more fair), while the original implementation crops the borders. 113 | # However, the difference is negligible (<0.001 in their favor). 114 | util.calculate_scores(gtx.astype(np.float32), 115 | np.clip(pred * scale + 0.5, 0, 255).astype(np.uint8).astype(np.float32), 116 | data_range=255, 117 | multichannel=multichannel, 118 | gaussian_weights=True, 119 | ) 120 | for gtx, pred in zip(ground_truth, predictions["image"]) 121 | ] 122 | elif experiment in ("hanzi",): 123 | scores = [ 124 | util.calculate_scores(gtx * 255, pred, data_range=255, scale=True) 125 | for gtx, pred in zip(ground_truth, predictions["image"]) 126 | ] 127 | elif experiment in ("imagenet",): 128 | scores = [ 129 | util.calculate_scores( 130 | gtx, 131 | pred, 132 | data_range=255, 133 | scale=True, 134 | multichannel=True, 135 | ) 136 | for gtx, pred in zip(ground_truth, predictions["image"]) 137 | ] 138 | elif experiment in ("microtubules",): 139 | scores = util.calculate_scores(ground_truth, predictions, normalize_pairs=True) 140 | elif experiment in ("planaria",): 141 | scores = [] 142 | for c in range(1, 4): 143 | scores_c = [ 144 | util.calculate_scores(gt, x, normalize_pairs=True) 145 | for gt, x in tqdm( 146 | zip(predictions["y"], predictions[f"c{c}"]), 147 | total=len(predictions["y"]), 148 | ) 149 | ] 150 | scores.append(pd.DataFrame(scores_c).assign(c=c)) 151 | scores = pd.concat(scores) 152 | else: 153 | raise ValueError 154 | return scores 155 | 156 | 157 | def evaluate( 158 | evaluator: Evaluator, 159 | ground_truth: np.ndarray, 160 | experiment: str, 161 | cwd: Path, 162 | train_dir: Path, 163 | loader: DataLoader = None, 164 | dataset: Dataset = None, 165 | num_workers: int = None, 166 | half: bool = False, 167 | save_results: bool = True, 168 | verbose: bool = True, 169 | ): 170 | assert loader is not None or dataset is not None 171 | 172 | if loader is None: 173 | loader = get_loader(dataset, experiment, num_workers) 174 | 175 | ground_truth, predictions = get_ground_truth_and_predictions( 176 | evaluator, experiment, ground_truth, cwd, loader, dataset, half 177 | ) 178 | 179 | scores = get_scores(ground_truth, predictions, experiment) 180 | scores = pd.DataFrame(scores) 181 | 182 | if experiment in ("synthetic", "synthetic_grayscale",): 183 | # Label each score with its dataset name and repeat id 184 | # by default {"kodak": 10, "bsd300": 3, "set14": 20} but the code below generalizes to any number of repeats 185 | dataset_name = [] 186 | repeat_id = [] 187 | repeat = 0 188 | assert isinstance(dataset, ConcatDataset) 189 | for ds in dataset.datasets: 190 | assert isinstance(ds, ConcatDataset) 191 | for repeat_ds in ds.datasets: 192 | repeat_id += [repeat] * len(repeat_ds) 193 | dataset_name += [repeat_ds.name] * len(repeat_ds) 194 | repeat += 1 195 | repeat = 0 196 | scores = scores.assign(dataset_name=dataset_name, repeat_id=repeat_id) 197 | evaluation_dir = train_dir / f'evaluate' / datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 198 | evaluation_dir.mkdir(parents=True, exist_ok=True) 199 | 200 | if save_results: 201 | print("Saving results to", evaluation_dir) 202 | scores.to_csv(evaluation_dir / "scores.csv") 203 | np.savez(evaluation_dir / "predictions.npz", **predictions) 204 | 205 | if experiment in ("planaria",): 206 | scores = scores.groupby("c").mean() 207 | elif experiment in ("synthetic", "synthetic_grayscale",): 208 | if verbose: 209 | print("\nBefore averaging over repeats:") 210 | pprint(scores.groupby(["dataset_name", "repeat_id"]).mean()) 211 | scores = scores.groupby("dataset_name").mean().drop(columns="repeat_id") 212 | else: 213 | scores = scores.mean() 214 | 215 | if verbose: 216 | print("\nEvaluation results:") 217 | pprint(scores) 218 | 219 | scores = scores.to_dict() 220 | if experiment in ("synthetic", "synthetic_grayscale", "planaria", ): 221 | # Flatten scores dict as "metric.dataset" to make it compatible with wandb 222 | scores = {f"{metric}.{dataset_name}": score for metric, dataset_dict in scores.items() 223 | for dataset_name, score in dataset_dict.items()} 224 | return scores 225 | 226 | 227 | def main(train_dir: Path, checkpoint: str = 'last', other_args: list = None) -> None: 228 | 229 | cfg = OmegaConf.load(f'{train_dir}/.hydra/config.yaml') 230 | if other_args is not None: 231 | cfg.merge_with_dotlist(other_args) 232 | 233 | os.environ["CUDA_VISIBLE_DEVICES"] = f"{cfg.device}" 234 | 235 | print(f"Evaluate backbone {cfg.backbone_name} on experiment {cfg.experiment}, work in {train_dir}") 236 | 237 | cwd = Path(os.getcwd()) 238 | 239 | dataset, ground_truth = None, None 240 | if cfg.experiment not in ("planaria",): 241 | # For some datasets we need custom loading 242 | dataset, ground_truth = get_test_dataset_and_gt(cfg, cwd) 243 | 244 | backbone, head = parametrize_backbone_and_head(cfg) 245 | 246 | mdl = model.Noise2Same( 247 | n_dim=cfg.data.n_dim, 248 | in_channels=cfg.data.n_channels, 249 | psf=cfg.psf.path if "psf" in cfg else None, 250 | psf_size=cfg.psf.psf_size if "psf" in cfg else None, 251 | psf_pad_mode=cfg.psf.psf_pad_mode if "psf" in cfg else None, 252 | backbone=backbone, 253 | head=head, 254 | **cfg.model, 255 | ) 256 | 257 | checkpoint_path = train_dir / Path(f"checkpoints/model{'_last' if checkpoint == 'last' else ''}.pth") 258 | 259 | # Run evaluation 260 | half = getattr(cfg, "amp", False) 261 | masked = getattr(cfg, "masked", False) 262 | evaluator = Evaluator(mdl, checkpoint_path=checkpoint_path, masked=masked) 263 | evaluate( 264 | evaluator, ground_truth, cfg.experiment, cwd, train_dir, dataset=dataset, half=half, 265 | num_workers=cfg.training.num_workers 266 | ) 267 | 268 | 269 | if __name__ == "__main__": 270 | parser = ArgumentParser() 271 | parser.add_argument("--train_dir", required=True, 272 | help="Path to hydra train directory") 273 | parser.add_argument("--checkpoint", choices=["last", "best"], 274 | default="last", help="The checkpoint to evaluate, 'last' or 'best'") 275 | args, unknown_args = parser.parse_known_args() 276 | main(Path(args.train_dir), args.checkpoint, unknown_args) 277 | -------------------------------------------------------------------------------- /noise2same/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/papkov/noise2same.pytorch/dc3697be1dfd27e78b0b129d4aecb5bcf98ccfcb/noise2same/__init__.py -------------------------------------------------------------------------------- /noise2same/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .swinir import SwinIR 2 | from .unet import UNet, RegressionHead 3 | 4 | from . import decoder, encoder 5 | 6 | __all__ = ["SwinIR", "UNet", "RegressionHead", "decoder", "encoder"] 7 | -------------------------------------------------------------------------------- /noise2same/backbone/decoder/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /noise2same/backbone/encoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/papkov/noise2same.pytorch/dc3697be1dfd27e78b0b129d4aecb5bcf98ccfcb/noise2same/backbone/encoder/__init__.py -------------------------------------------------------------------------------- /noise2same/backbone/swinia.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from typing import Tuple, Optional, Iterable, Set 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch import Tensor as T 7 | from timm.models.layers import to_2tuple, trunc_normal_ 8 | import numpy as np 9 | import einops 10 | 11 | 12 | def connect_shortcut(layer: nn.Module, x: T, y: T) -> T: 13 | x = torch.cat([x, y], -1) 14 | return layer(x) 15 | 16 | 17 | class MLP(nn.Module): 18 | 19 | def __init__( 20 | self, 21 | in_features: int = 96, 22 | out_features: int = 96, 23 | n_layers: int = 1, 24 | hidden_features: Optional[int] = None, 25 | act_layer: nn.Module = nn.GELU, 26 | drop=0., 27 | ): 28 | super().__init__() 29 | hidden_features = hidden_features or out_features 30 | features = [hidden_features] * (n_layers + 1) 31 | features[0], features[-1] = in_features, out_features 32 | self.layers = nn.ModuleList([ 33 | nn.Sequential( 34 | nn.Linear(features[i], features[i + 1]), 35 | nn.LayerNorm(features[i + 1]) 36 | ) for i in range(n_layers) 37 | ]) 38 | self.act = act_layer() 39 | self.drop = nn.Dropout(drop) 40 | 41 | def forward(self, x: T) -> T: 42 | for layer in self.layers: 43 | x = layer(x) 44 | x = self.act(x) 45 | x = self.drop(x) 46 | return x 47 | 48 | 49 | class DiagonalWindowAttention(nn.Module): 50 | 51 | def __init__( 52 | self, 53 | embed_dim: int = 96, 54 | window_size: Tuple[int] = (8, 8), 55 | dilation: int = 1, 56 | shuffle: int = 1, 57 | num_heads: int = 6, 58 | attn_drop: float = 0.05, 59 | proj_drop: float = 0.05, 60 | ): 61 | super().__init__() 62 | self.softmax = nn.Softmax(dim=-1) 63 | self.attn_drop = nn.Dropout(attn_drop) 64 | self.norm = nn.LayerNorm(embed_dim) 65 | self.window_size = window_size 66 | self.num_patches = np.prod(window_size).item() 67 | self.num_heads = num_heads 68 | self.embed_dim = embed_dim 69 | self.dilation = dilation 70 | self.shuffle = shuffle 71 | 72 | head_dim = embed_dim // num_heads 73 | self.scale = head_dim ** -0.5 / shuffle 74 | self.proj = nn.Linear(embed_dim, embed_dim) 75 | self.proj_drop = nn.Dropout(proj_drop) 76 | self.norm_q = nn.LayerNorm([shuffle ** 2, head_dim]) 77 | self.norm_k = nn.LayerNorm([shuffle ** 2, head_dim]) 78 | self.norm_v = nn.LayerNorm([shuffle ** 2, head_dim]) 79 | self.norm2 = nn.LayerNorm(embed_dim) 80 | 81 | window_bias_shape = [2 * s - 1 for s in window_size] 82 | self.relative_position_bias_table = nn.Parameter(torch.zeros(np.prod(window_bias_shape).item(), num_heads)) 83 | 84 | coords = torch.stack(torch.meshgrid([torch.arange(s) for s in window_size], indexing='ij')) 85 | coords_flatten = torch.flatten(coords, 1) 86 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] 87 | relative_coords = torch.einsum("n...->...n", relative_coords) 88 | 89 | coefficients = [np.prod(window_bias_shape[i:]) for i in range(1, len(window_size))] 90 | for dim, size in enumerate(window_size): 91 | relative_coords[..., dim] += size - 1 92 | if dim < len(window_size) - 1: 93 | relative_coords[..., dim] *= coefficients[dim] 94 | 95 | relative_position_index = relative_coords.sum(-1).flatten() 96 | self.register_buffer("relative_position_index", relative_position_index) 97 | trunc_normal_(self.relative_position_bias_table, std=.02) 98 | 99 | def head_partition(self, x: T) -> T: 100 | return einops.rearrange(x, 'nw (wh sh ww sw) (nh ch) -> nw nh (wh ww) (sh sw) ch', 101 | wh=self.window_size[0], sh=self.shuffle, sw=self.shuffle, nh=self.num_heads) 102 | 103 | def head_partition_reversed(self, x: T) -> T: 104 | return einops.rearrange(x, "nw nh (wh ww) (sh sw) ch -> nw (wh sh ww sw) (nh ch)", 105 | wh=self.window_size[0], sh=self.shuffle) 106 | 107 | def forward( 108 | self, 109 | query: T, 110 | key: T, 111 | value: T, 112 | mask: T, 113 | ) -> T: 114 | shortcut = query 115 | query, key, value = map(self.head_partition, (query, key, value)) 116 | query = self.norm_q(query) 117 | key = self.norm_k(key) 118 | value = self.norm_v(value) 119 | query = query * self.scale 120 | attn = torch.einsum("...qsc,...ksc->...qk", query, key) 121 | relative_position_bias = einops.rearrange( 122 | self.relative_position_bias_table[self.relative_position_index], 123 | "(np1 np2) nh -> 1 nh np1 np2", np1=self.num_patches 124 | ) 125 | attn = attn + relative_position_bias 126 | attn += einops.repeat(mask, f"nw np1 np2 -> (b nw) 1 np1 np2", b=attn.shape[0] // mask.shape[0]).to(attn.device) 127 | 128 | attn = self.softmax(attn) 129 | attn = self.attn_drop(attn) 130 | query = torch.einsum("...qk,...ksc->...qsc", attn, value) 131 | query, key, value = map(self.head_partition_reversed, (query, key, value)) 132 | query = self.norm2(query + shortcut) 133 | query = self.proj(query) 134 | query = self.proj_drop(query) 135 | return query 136 | 137 | 138 | class TransformerBlock(nn.Module): 139 | 140 | def __init__( 141 | self, 142 | embed_dim: int = 96, 143 | window_size: int = 8, 144 | shift_size: Tuple[int] = (0, 0), 145 | num_heads: int = 6, 146 | dilation: int = 1, 147 | shuffle: int = 1, 148 | input_size: Tuple[int] = (128, 128), 149 | attn_drop: float = 0.05, 150 | proj_drop: float = 0.05, 151 | ): 152 | super().__init__() 153 | self.softmax = nn.Softmax(dim=-1) 154 | self.mlp = MLP(embed_dim, embed_dim, n_layers=2) 155 | self.attn = DiagonalWindowAttention( 156 | embed_dim, to_2tuple(window_size), dilation, shuffle, num_heads, 157 | attn_drop, proj_drop 158 | ) 159 | self.norm1 = nn.LayerNorm(embed_dim) 160 | self.norm2 = nn.LayerNorm(embed_dim) 161 | self.window_size = window_size 162 | self.shift_size = np.array(shift_size) 163 | self.dilation = dilation 164 | self.shuffle = shuffle 165 | self.embed_dim = embed_dim 166 | self.input_size = input_size 167 | self.attn_mask = self.calculate_mask(input_size) 168 | self.shortcut = MLP(embed_dim * 2, embed_dim) 169 | 170 | def shift_image(self, x: T) -> T: 171 | if np.all(self.shift_size == 0): 172 | return x 173 | return torch.roll(x, shifts=tuple(-self.shift_size * self.dilation * self.shuffle), dims=(1, 2)) 174 | 175 | def shift_image_reversed(self, x: T) -> T: 176 | if np.all(self.shift_size == 0): 177 | return x 178 | return torch.roll(x, shifts=tuple(self.shift_size * self.dilation * self.shuffle), dims=(1, 2)) 179 | 180 | def window_partition(self, x: T) -> T: 181 | return einops.rearrange(x, 'b (h wh sh dh) (w ww sw dw) c -> (b h w dh dw) (wh sh ww sw) c', 182 | wh=self.window_size, ww=self.window_size, sh=self.shuffle, sw=self.shuffle, 183 | dh=self.dilation, dw=self.dilation) 184 | 185 | def window_partition_reversed(self, x: T, x_size: Iterable[int]) -> T: 186 | height, width = x_size 187 | h = height // (self.window_size * self.shuffle * self.dilation) 188 | w = width // (self.window_size * self.shuffle * self.dilation) 189 | return einops.rearrange(x, '(b h w dh dw) (wh sh ww sw) c -> b (h wh sh dh) (w ww sw dw) c', 190 | h=h, w=w, sh=self.shuffle, sw=self.shuffle, wh=self.window_size, 191 | dh=self.dilation, dw=self.dilation) 192 | 193 | def mask_window_partition(self, mask: T) -> T: 194 | return einops.rearrange(mask, 'b (h wh dh) (w ww dw) c -> (b h w dh dw) (wh ww) c', 195 | wh=self.window_size, ww=self.window_size, dh=self.dilation, dw=self.dilation) 196 | 197 | def calculate_mask(self, x_size: Iterable[int]) -> T: 198 | x_size = [s // self.shuffle for s in x_size] 199 | attn_mask = torch.zeros((1, *x_size, 1)) 200 | if np.any(self.shift_size != 0): 201 | slices = [(slice(0, s), slice(s, None)) for s in -self.shift_size * self.dilation] 202 | cnt = 0 203 | for h, w in itertools.product(*slices): 204 | attn_mask[:, h, w, :] = cnt 205 | cnt += 1 206 | attn_mask = self.mask_window_partition(attn_mask) 207 | attn_mask = einops.rearrange(attn_mask, "nw np 1 -> nw 1 np") - attn_mask 208 | torch.diagonal(attn_mask, dim1=-2, dim2=-1).fill_(1) 209 | attn_mask = attn_mask.masked_fill(attn_mask != 0, -10 ** 9) 210 | return attn_mask 211 | 212 | def forward( 213 | self, 214 | query: T, 215 | key: T, 216 | value: T, 217 | ) -> T: 218 | image_size = key.shape[1:-1] 219 | shortcut = query 220 | query, key, value = map(self.shift_image, (query, key, value)) 221 | query, key, value = map(self.window_partition, (query, key, value)) 222 | mask = self.attn_mask if image_size == self.input_size else self.calculate_mask(image_size).to(key.device) 223 | query = self.attn(query, key, value, mask=mask) 224 | query, key, value = map(self.window_partition_reversed, (query, key, value), [image_size] * 3) 225 | query, key, value = map(self.shift_image_reversed, (query, key, value)) 226 | query = query + self.mlp(query) 227 | query = connect_shortcut(self.shortcut, query, shortcut) 228 | return query 229 | 230 | 231 | class ResidualGroup(nn.Module): 232 | 233 | def __init__( 234 | self, 235 | embed_dim: int = 96, 236 | window_size: int = 8, 237 | depth: int = 6, 238 | num_heads: int = 6, 239 | dilation: int = 1, 240 | shuffle: int = 1, 241 | input_size: Tuple[int] = (128, 128), 242 | ): 243 | super().__init__() 244 | shift_size = window_size // 2 245 | shifts = ((0, 0), (0, shift_size), (shift_size, shift_size), (shift_size, 0)) 246 | self.blocks = nn.ModuleList([ 247 | TransformerBlock( 248 | embed_dim=embed_dim, 249 | window_size=window_size, 250 | shift_size=shifts[i % 4], 251 | num_heads=num_heads, 252 | dilation=dilation, 253 | shuffle=shuffle, 254 | input_size=input_size, 255 | ) for i in range(depth) 256 | ]) 257 | self.shortcut = MLP(embed_dim * 2, embed_dim) 258 | 259 | def forward( 260 | self, 261 | query: T, 262 | key: T, 263 | value: T, 264 | ) -> T: 265 | shortcut = query 266 | for block in self.blocks: 267 | query = block(query, key, value) 268 | query = connect_shortcut(self.shortcut, query, shortcut) 269 | return query 270 | 271 | 272 | class SwinIA(nn.Module): 273 | 274 | def __init__( 275 | self, 276 | in_chans: int = 1, 277 | embed_dim: int = 96, 278 | window_size: int = 8, 279 | input_size: int = 128, 280 | depths: Tuple[int] = (8, 4, 4, 4, 4, 4, 8), 281 | num_heads: Tuple[int] = (6, 6, 6, 6, 6, 6, 6), 282 | dilations: Tuple[int] = (1, 1, 1, 1, 1, 1, 1), 283 | shuffles: Tuple[int] = (1, 1, 1, 1, 1, 1, 1) 284 | ): 285 | super().__init__() 286 | self.window_size = window_size 287 | self.num_heads = num_heads 288 | self.embed_k = MLP(in_chans, embed_dim) 289 | self.embed_v = MLP(in_chans, embed_dim) 290 | self.proj_last = nn.Linear(embed_dim, in_chans) 291 | self.shortcut = MLP(embed_dim * 2, embed_dim) 292 | self.absolute_pos_embed = nn.Parameter(torch.zeros(window_size ** 2, embed_dim // num_heads[0])) 293 | trunc_normal_(self.absolute_pos_embed, std=.02) 294 | self.groups = nn.ModuleList([ 295 | ResidualGroup( 296 | embed_dim=embed_dim, 297 | window_size=window_size, 298 | depth=d, 299 | num_heads=n, 300 | dilation=dl, 301 | shuffle=sh, 302 | input_size=to_2tuple(input_size) 303 | ) for i, (d, n, dl, sh) in enumerate(zip(depths, num_heads, dilations, shuffles)) 304 | ]) 305 | self.apply(self._init_weights) 306 | 307 | def _init_weights(self, m: nn.Module): 308 | if isinstance(m, nn.Linear): 309 | trunc_normal_(m.weight, std=.02) 310 | if isinstance(m, nn.Linear) and m.bias is not None: 311 | nn.init.constant_(m.bias, 0) 312 | elif isinstance(m, nn.LayerNorm): 313 | nn.init.constant_(m.bias, 0) 314 | nn.init.constant_(m.weight, 1.0) 315 | 316 | @torch.jit.ignore 317 | def no_weight_decay(self) -> Set[str]: 318 | return {'absolute_pos_embed'} 319 | 320 | @torch.jit.ignore 321 | def no_weight_decay_keywords(self) -> Set[str]: 322 | return {'relative_position_bias_table'} 323 | 324 | def forward(self, x: T) -> T: 325 | x = einops.rearrange(x, 'b c ... -> b ... c') 326 | k = self.embed_k(x) 327 | v = self.embed_v(x) 328 | wh, ww = x.shape[1] // self.window_size, x.shape[2] // self.window_size 329 | full_pos_embed = einops.repeat(self.absolute_pos_embed, "(ws1 ws2) ch -> b (wh ws1) (ww ws2) (nh ch)", 330 | b=x.shape[0], ws1=self.window_size, wh=wh, ww=ww, nh=self.num_heads[0]) 331 | q, k, v = full_pos_embed, k + full_pos_embed, v + full_pos_embed 332 | shortcuts = [] 333 | mid = len(self.groups) // 2 334 | for i, group in enumerate(self.groups): 335 | if i < mid: 336 | q_ = group(q, k, v) 337 | shortcuts.append(q_) 338 | elif shortcuts: 339 | q = group(q, k, v) 340 | q = connect_shortcut(self.shortcut, q, shortcuts.pop()) 341 | else: 342 | q = group(q, k, v) 343 | q = self.proj_last(q) 344 | q = einops.rearrange(q, 'b ... c -> b c ...') 345 | return q 346 | -------------------------------------------------------------------------------- /noise2same/backbone/unet.py: -------------------------------------------------------------------------------- 1 | # translated from 2 | # https://github.com/divelab/Noise2Same/blob/main/network.py 3 | # https://github.com/divelab/Noise2Same/blob/main/resnet_module.py 4 | from typing import Tuple 5 | 6 | import torch 7 | from torch import Tensor as T 8 | from torch import nn 9 | 10 | 11 | class ProjectHead(nn.Sequential): 12 | """ 13 | Implements projection head for contrastive learning as per 14 | "Exploring Cross-Image Pixel Contrast for Semantic Segmentation" 15 | https://arxiv.org/abs/2101.11939 16 | https://github.com/tfzhou/ContrastiveSeg 17 | 18 | Provides high-dimensional L2-normalized pixel embeddings (256-d from 1x1 conv by default) 19 | """ 20 | 21 | def __init__( 22 | self, 23 | in_channels: int, 24 | out_channels: int = 256, 25 | n_dim: int = 2, 26 | kernel_size: int = 1, 27 | ): 28 | assert n_dim in (2, 3) 29 | conv = nn.Conv2d if n_dim == 2 else nn.Conv3d 30 | conv_1 = conv( 31 | in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2 32 | ) 33 | relu = nn.ReLU(inplace=True) 34 | conv_2 = conv( 35 | out_channels, 36 | out_channels, 37 | kernel_size=kernel_size, 38 | padding=kernel_size // 2, 39 | ) 40 | super().__init__(conv_1, relu, conv_2, relu) 41 | 42 | def forward(self, x): 43 | x = super().forward(x) 44 | x = nn.functional.normalize(x, p=2, dim=1) 45 | return x 46 | 47 | 48 | class RegressionHead(nn.Sequential): 49 | def __init__( 50 | self, in_channels: int, out_channels: int, n_dim: int = 2, kernel_size: int = 1 51 | ): 52 | """ 53 | Denoising regression head BN-ReLU-Conv 54 | 55 | https://github.com/divelab/Noise2Same/blob/main/models.py 56 | :param in_channels: 57 | :param out_channels: 58 | :param n_dim: 59 | :param kernel_size: 60 | """ 61 | assert n_dim in (2, 3) 62 | conv = nn.Conv2d if n_dim == 2 else nn.Conv3d 63 | bn = nn.BatchNorm2d if n_dim == 2 else nn.BatchNorm3d 64 | 65 | bn = bn(num_features=in_channels) 66 | relu = nn.ReLU(inplace=True) 67 | conv = conv( 68 | in_channels=in_channels, 69 | out_channels=out_channels, 70 | kernel_size=kernel_size, 71 | padding=kernel_size // 2, 72 | bias=False, 73 | ) 74 | super().__init__(bn, relu, conv) 75 | 76 | 77 | class ResidualUnit(nn.Module): 78 | def __init__( 79 | self, 80 | in_channels: int, 81 | out_channels: int, 82 | n_dim: int = 2, 83 | kernel_size: int = 3, 84 | downsample: bool = False, 85 | ): 86 | super().__init__() 87 | self.in_channels = in_channels 88 | self.out_channels = out_channels 89 | self.n_dim = n_dim 90 | self.kernel_size = kernel_size 91 | self.downsample = downsample 92 | 93 | bn = nn.BatchNorm2d if n_dim == 2 else nn.BatchNorm3d 94 | conv = nn.Conv2d if n_dim == 2 else nn.Conv3d 95 | stride = 2 if downsample else 1 96 | 97 | self.act = nn.ReLU(inplace=True) 98 | # todo parametrize as in the original repo (bn momentum is inverse) 99 | self.bn = bn(in_channels, momentum=1 - 0.997, eps=1e-5) 100 | self.conv_shortcut = conv( 101 | in_channels=in_channels, 102 | out_channels=out_channels, 103 | kernel_size=1, 104 | padding=0, 105 | stride=stride, 106 | bias=False, 107 | ) 108 | 109 | self.layers = nn.Sequential( 110 | conv( 111 | in_channels=in_channels, 112 | out_channels=out_channels, 113 | kernel_size=2 if downsample else kernel_size, 114 | padding=0 if downsample else kernel_size // 2, 115 | stride=stride, 116 | bias=False, 117 | ), 118 | bn(out_channels), 119 | self.act, 120 | conv( 121 | in_channels=out_channels, 122 | out_channels=out_channels, 123 | kernel_size=kernel_size, 124 | padding=kernel_size // 2, 125 | stride=1, 126 | bias=False, 127 | ), 128 | ) 129 | 130 | def forward(self, x: T) -> T: 131 | shortcut = x 132 | x = self.bn(x) 133 | x = self.act(x) 134 | if self.in_channels != self.out_channels or self.downsample: 135 | shortcut = self.conv_shortcut(x) 136 | x = self.layers(x) 137 | return x + shortcut 138 | 139 | 140 | class ResidualBlock(nn.Module): 141 | def __init__( 142 | self, 143 | in_channels: int, 144 | out_channels: int, 145 | block_size: int = 1, 146 | n_dim: int = 2, 147 | kernel_size: int = 3, 148 | downsample: bool = False, 149 | ): 150 | super().__init__() 151 | self.in_channels = in_channels 152 | self.out_channels = out_channels 153 | self.n_dim = n_dim 154 | self.kernel_size = kernel_size 155 | self.downsample = downsample 156 | self.block_size = block_size 157 | 158 | self.block = nn.Sequential( 159 | *[ 160 | ResidualUnit( 161 | in_channels=in_channels if i == 0 else out_channels, 162 | out_channels=out_channels, 163 | n_dim=n_dim, 164 | kernel_size=kernel_size, 165 | downsample=downsample if i == 0 else False, 166 | ) 167 | for i in range(0, block_size) 168 | ] 169 | ) 170 | 171 | def forward(self, x: T) -> T: 172 | return self.block(x) 173 | 174 | 175 | class EncoderBlock(nn.Module): 176 | def __init__( 177 | self, 178 | in_channels: int, 179 | out_channels: int, 180 | block_size: int = 1, 181 | n_dim: int = 2, 182 | kernel_size: int = 3, 183 | downsampling: str = "conv", 184 | ): 185 | super().__init__() 186 | self.in_channels = in_channels 187 | self.out_channels = out_channels 188 | self.n_dim = n_dim 189 | self.kernel_size = kernel_size 190 | self.block_size = block_size 191 | 192 | conv = nn.Conv2d if n_dim == 2 else nn.Conv3d 193 | 194 | if downsampling == "res": 195 | downsampling_block = ResidualBlock( 196 | in_channels=in_channels, 197 | out_channels=out_channels, 198 | n_dim=n_dim, 199 | kernel_size=kernel_size, 200 | block_size=1, 201 | downsample=True, 202 | ) 203 | elif downsampling == "conv": 204 | downsampling_block = conv( 205 | in_channels=in_channels, 206 | out_channels=out_channels, 207 | kernel_size=2, 208 | stride=2, 209 | bias=True, 210 | ) 211 | else: 212 | raise ValueError("downsampling should be `res`. `conv`, `pool`") 213 | 214 | self.block = nn.Sequential( 215 | downsampling_block, 216 | ResidualBlock( 217 | in_channels=out_channels, 218 | out_channels=out_channels, 219 | n_dim=n_dim, 220 | block_size=block_size, 221 | downsample=False, 222 | kernel_size=kernel_size, 223 | ), 224 | ) 225 | 226 | def forward(self, x: T) -> T: 227 | return self.block(x) 228 | 229 | 230 | class UNet(nn.Module): 231 | def __init__( 232 | self, 233 | in_channels: int, 234 | base_channels: int = 96, 235 | kernel_size: int = 3, 236 | n_dim: int = 2, 237 | depth: int = 3, 238 | encoding_block_sizes: Tuple[int, ...] = (1, 1, 0), 239 | decoding_block_sizes: Tuple[int, ...] = (1, 1), 240 | downsampling: Tuple[str, ...] = ("conv", "conv"), 241 | skip_method: str = "concat", 242 | ): 243 | """ 244 | 245 | configuration: https://github.com/divelab/Noise2Same/blob/main/network_configure.py 246 | architecture: https://github.com/divelab/Noise2Same/blob/main/network.py 247 | 248 | :param n_dim: 249 | :param depth: 250 | :param base_channels: 251 | :param encoding_block_sizes: 252 | :param decoding_block_sizes: 253 | :param downsampling: 254 | :param skip_method: 255 | """ 256 | super().__init__() 257 | 258 | assert depth == len(encoding_block_sizes) 259 | assert encoding_block_sizes[0] > 0 260 | assert encoding_block_sizes[-1] == 0 261 | assert depth == len(decoding_block_sizes) + 1 262 | assert depth == len(downsampling) + 1 263 | assert skip_method in ["add", "concat", "cat"] 264 | 265 | self.in_channels = in_channels 266 | self.n_dim = n_dim 267 | self.depth = depth 268 | self.base_channels = base_channels 269 | self.encoding_block_sizes = encoding_block_sizes 270 | self.decoding_block_sizes = decoding_block_sizes 271 | self.downsampling = downsampling 272 | self.skip_method = skip_method 273 | print(f"Use {self.skip_method} skip method") 274 | 275 | conv = nn.Conv2d if n_dim == 2 else nn.Conv3d 276 | conv_transpose = nn.ConvTranspose2d if n_dim == 2 else nn.ConvTranspose3d 277 | 278 | self.conv_first = conv( 279 | in_channels=in_channels, 280 | out_channels=base_channels, 281 | kernel_size=kernel_size, 282 | padding=kernel_size // 2, 283 | stride=1, 284 | bias=False, 285 | ) 286 | 287 | # Encoder 288 | self.encoder_blocks = nn.ModuleList( 289 | [ 290 | ResidualBlock( 291 | in_channels=base_channels, 292 | out_channels=base_channels, 293 | n_dim=n_dim, 294 | kernel_size=kernel_size, 295 | block_size=encoding_block_sizes[0], 296 | ) 297 | ] 298 | ) 299 | 300 | out_channels = base_channels 301 | for i in range(2, self.depth + 1): 302 | in_channels = base_channels * (2 ** (i - 2)) 303 | out_channels = base_channels * (2 ** (i - 1)) 304 | 305 | # Here 306 | 307 | # todo downsampling 308 | 309 | self.encoder_blocks.append( 310 | EncoderBlock( 311 | in_channels=in_channels, 312 | out_channels=out_channels, 313 | n_dim=n_dim, 314 | kernel_size=kernel_size, 315 | block_size=encoding_block_sizes[i - 1], 316 | downsampling=downsampling[i - 2], 317 | ) 318 | ) 319 | 320 | # Bottom block 321 | self.bottom_block = ResidualBlock( 322 | in_channels=out_channels, 323 | out_channels=base_channels * (2 ** (depth - 1)), 324 | n_dim=n_dim, 325 | kernel_size=kernel_size, 326 | block_size=1, 327 | ) 328 | 329 | # Decoder 330 | self.decoder_blocks = nn.ModuleList() 331 | self.upsampling_blocks = nn.ModuleList() 332 | for i in range(self.depth - 1, 0, -1): 333 | in_channels = int(base_channels * (2 ** i)) 334 | out_channels = int(base_channels * (2 ** (i - 1))) 335 | 336 | # todo parametrize to use linear upsampling optionally 337 | self.upsampling_blocks.append( 338 | conv_transpose( 339 | in_channels=in_channels, 340 | out_channels=out_channels, 341 | kernel_size=2, 342 | stride=2, 343 | bias=True, 344 | ) 345 | ) 346 | 347 | # Here goes skip connection, then decoder block 348 | self.decoder_blocks.append( 349 | ResidualBlock( 350 | in_channels=out_channels * (2 if self.skip_method != "add" else 1), 351 | out_channels=out_channels, 352 | n_dim=n_dim, 353 | kernel_size=kernel_size, 354 | block_size=decoding_block_sizes[depth - 1 - i], 355 | ) 356 | ) 357 | 358 | def forward(self, x: T) -> T: 359 | encoder_outputs = [] 360 | x = self.conv_first(x) 361 | # print("First conv", x.shape) 362 | x = self.encoder_blocks[0](x) 363 | # print("Encoder 0", x.shape) 364 | 365 | for i, encoder_block in enumerate(self.encoder_blocks[1:]): 366 | encoder_outputs.append(x) 367 | x = encoder_block(x) 368 | # print(f"Encoder {i+1}", x.shape) 369 | 370 | x = self.bottom_block(x) 371 | # print("Bottom", x.shape) 372 | 373 | for i, (upsampling_block, decoder_block, skip) in enumerate( 374 | zip(self.upsampling_blocks, self.decoder_blocks, encoder_outputs[::-1]) 375 | ): 376 | x = upsampling_block(x) 377 | # print(f"Upsampling {i}", x.shape) 378 | if self.skip_method == "add": 379 | x.add_(skip) 380 | elif self.skip_method in ("cat", "concat"): 381 | x = torch.cat([x, skip], dim=1) 382 | else: 383 | raise ValueError 384 | x = decoder_block(x) 385 | # print(f"Decoder {i}", x.shape) 386 | 387 | # x = self.conv_last(x) 388 | # print("Last conv", x.shape) 389 | return x 390 | -------------------------------------------------------------------------------- /noise2same/contrast.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | from einops import rearrange, reduce 5 | from torch import Tensor as T 6 | from torch.nn.modules.loss import _Loss 7 | 8 | 9 | class PixelContrastLoss(_Loss): 10 | def __init__(self, temperature: float = 0.1): 11 | super(PixelContrastLoss, self).__init__() 12 | 13 | self.temperature = temperature 14 | 15 | def forward(self, out_raw: T, out_mask: T, mask: T) -> T: 16 | """ 17 | 18 | :param out_raw: tensor (B, E, H, W) 19 | :param out_mask: tensor (B, E, H, W) 20 | :param mask: tensor (B, 1, H, W) 21 | :return: 22 | """ 23 | mask = rearrange(mask, "b e h w -> (b e h w)") # e == 1 24 | 25 | emb_raw = rearrange(out_raw, "b e h w -> (b h w) e")[mask.bool()] 26 | emb_mask = rearrange(out_mask, "b e h w -> (b h w) e")[mask.bool()] 27 | rand_idx = torch.randperm(emb_raw.size(0)) 28 | 29 | pos_dot = torch.einsum("be,be->b", emb_raw, emb_mask) / self.temperature 30 | neg_dot_raw = ( 31 | torch.einsum("be,be->b", emb_raw[rand_idx], emb_mask) / self.temperature 32 | ) 33 | neg_dot_mask = ( 34 | torch.einsum("be,be->b", emb_raw, emb_mask[rand_idx]) / self.temperature 35 | ) 36 | neg_dot = torch.stack([neg_dot_raw, neg_dot_mask], dim=-1) 37 | 38 | pos_max_val = torch.max(pos_dot) 39 | neg_max_val = torch.max(neg_dot) 40 | max_val = torch.max(torch.stack([pos_max_val, neg_max_val])) 41 | 42 | numerator = torch.exp(pos_dot - max_val) 43 | 44 | denominator = ( 45 | reduce(torch.exp(neg_dot - max_val), "b k -> b", "sum") + numerator 46 | ) 47 | loss = -torch.log((numerator / denominator) + 1e-8) 48 | return loss 49 | -------------------------------------------------------------------------------- /noise2same/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from . import bsd68, hanzi, imagenet, planaria, transforms, util 2 | 3 | __all__ = ["bsd68", "hanzi", "imagenet", "planaria", "transforms", "util"] 4 | -------------------------------------------------------------------------------- /noise2same/dataset/abc.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | from typing import Any, Dict, List, Optional, Tuple, Union 5 | 6 | import albumentations as albu 7 | import h5py 8 | import numpy as np 9 | import torch 10 | from albumentations import BasicTransform, Compose 11 | from albumentations.pytorch import ToTensorV2 12 | from pytorch_toolbelt.inference.tiles import ImageSlicer 13 | from skimage import io 14 | from torch import tensor as T 15 | from torch.utils.data import Dataset 16 | 17 | from noise2same.dataset import transforms as t3d 18 | from noise2same.dataset.util import mask_like_image 19 | from noise2same.util import normalize_percentile 20 | 21 | 22 | @dataclass 23 | class AbstractNoiseDataset(Dataset, ABC): 24 | """ 25 | Abstract noise dataset 26 | """ 27 | 28 | path: Union[Path, str] 29 | mask_percentage: float = 0.5 30 | pad_divisor: int = 8 31 | channel_last: bool = True 32 | standardize: bool = True 33 | standardize_by_channel: bool = False 34 | n_dim: int = 2 35 | mean: Optional[Union[float, np.ndarray]] = None 36 | std: Optional[Union[float, np.ndarray]] = None 37 | transforms: Optional[ 38 | Union[ 39 | List[BasicTransform], 40 | Compose, 41 | List[Compose], 42 | List[t3d.BaseTransform3D], 43 | t3d.Compose, 44 | List[t3d.Compose], 45 | ] 46 | ] = None 47 | 48 | def _validate(self) -> bool: 49 | """ 50 | Check init arguments types and values 51 | :return: bool 52 | """ 53 | return True 54 | 55 | def __post_init__(self) -> None: 56 | """ 57 | Get a list of images, compose provided transforms with a list of necessary post-transforms 58 | :return: 59 | """ 60 | if not self._validate(): 61 | raise ValueError("Validation failed") 62 | 63 | self.path = Path(self.path) 64 | if not self.path.is_dir() and self.path.suffix not in ( 65 | ".tif", 66 | ".tiff", 67 | ): 68 | raise ValueError( 69 | f"Incorrect path, {self.path} not a dir and {self.path.suffix} is not TIF " 70 | ) 71 | 72 | images = self._get_images() 73 | self.images = images['noisy_input'] 74 | self.ground_truth = images.get('ground_truth', None) 75 | if not isinstance(self.transforms, list): 76 | self.transforms = [self.transforms] 77 | self.transforms = self._compose_transforms( 78 | self.transforms + self._get_post_transforms(), 79 | additional_targets={"ground_truth": "image"} if self.ground_truth is not None else None 80 | ) 81 | 82 | def __len__(self) -> int: 83 | return len(self.images) 84 | 85 | @abstractmethod 86 | def _compose_transforms(self, *args, **kwargs) -> Union[Compose, t3d.Compose]: 87 | """ 88 | Compose a list of transforms with a specific function 89 | :param args: 90 | :param kwargs: 91 | :return: 92 | """ 93 | raise NotImplementedError 94 | 95 | @abstractmethod 96 | def _apply_transforms(self, image: np.ndarray, mask: np.ndarray, ground_truth: np.ndarray = None) -> Dict[str, T]: 97 | """ 98 | Apply transforms to both image and mask 99 | :param image: 100 | :param mask: 101 | :return: 102 | """ 103 | raise NotImplementedError 104 | 105 | @abstractmethod 106 | def _get_post_transforms( 107 | self, 108 | ) -> Union[List[BasicTransform], List[t3d.BaseTransform3D]]: 109 | """ 110 | Necessary post-transforms (e.g. ToTensor) 111 | :return: 112 | """ 113 | raise NotImplementedError 114 | 115 | @abstractmethod 116 | def _get_images(self) -> Dict[str, Union[List[str], np.ndarray]]: 117 | """ 118 | Obtain images or their paths from file system 119 | :return: list of images of paths to them 120 | """ 121 | raise NotImplementedError 122 | 123 | @abstractmethod 124 | def _read_image(self, image_or_path: Union[str, np.ndarray]) -> np.ndarray: 125 | """ 126 | Read a single image from file system or preloaded array 127 | :param image_or_path: 128 | :return: 129 | """ 130 | raise NotImplementedError 131 | 132 | def __getitem__(self, i: int) -> Dict[str, Any]: 133 | """ 134 | :param i: int, index 135 | :return: dict(image, mask, mean, std) 136 | """ 137 | image = self._read_image(self.images[i]).astype(np.float32) 138 | if image.ndim == self.n_dim: 139 | image = np.expand_dims(image, axis=-1 if self.channel_last else 0) 140 | 141 | ground_truth = None 142 | if self.ground_truth is not None: 143 | ground_truth = self._read_image(self.ground_truth[i]).astype(np.float32) 144 | if ground_truth.ndim == self.n_dim: 145 | ground_truth = np.expand_dims(ground_truth, axis=-1 if self.channel_last else 0) 146 | 147 | mask = self._mask_like_image(image) 148 | # this was noise_patch in the original code, concatenation does not make any sense 149 | # https://github.com/divelab/Noise2Same/blob/main/models.py#L154 150 | # noise_mask = np.concatenate([noise, mask], axis=-1) 151 | ret = self._apply_transforms(image, mask, ground_truth=ground_truth) 152 | if self.standardize: 153 | # by default, self.mean and self.std are None, and normalization is done by patch 154 | ret["image"], ret["mean"], ret["std"] = self._standardize(ret["image"], 155 | self.mean or torch.from_numpy(self.mean), 156 | self.std or torch.from_numpy(self.std)) 157 | if self.ground_truth is not None: 158 | ret["ground_truth"], _, _ = self._standardize(ret["ground_truth"], ret["mean"], ret["std"]) 159 | else: 160 | # in case the data was normalized or standardized before 161 | ret["mean"] = torch.tensor(0).view((1,) * ret["image"].ndim) 162 | ret["std"] = torch.tensor(1).view((1,) * ret["image"].ndim) 163 | 164 | return ret 165 | 166 | def _mask_like_image(self, image: np.ndarray) -> np.ndarray: 167 | return mask_like_image( 168 | image, mask_percentage=self.mask_percentage, channels_last=self.channel_last 169 | ) 170 | 171 | def _standardize(self, image: T, mean: T = None, std: T = None) -> Tuple[T, T, T]: 172 | """ 173 | Normalize an image by mean and std 174 | :param image: tensor 175 | :return: normalized image, mean, std 176 | """ 177 | # Image is already a tensor, hence channel-first 178 | dim = tuple(range(1, image.ndim)) 179 | if not self.standardize_by_channel: 180 | dim = (0,) + dim 181 | # normalize as per the paper 182 | # TODO in the paper channels are not specified. do they matter? try with dim=(1, 2) 183 | mean = torch.mean(image, dim=dim, keepdim=True) if mean is None else mean 184 | std = torch.std(image, dim=dim, keepdim=True) if std is None else std 185 | image = (image - mean) / std 186 | return image, mean, std 187 | 188 | 189 | @dataclass 190 | class AbstractNoiseDataset2D(AbstractNoiseDataset, ABC): 191 | def _compose_transforms(self, *args, **kwargs) -> Compose: 192 | return Compose(*args, **kwargs) 193 | 194 | def _get_post_transforms(self) -> List[BasicTransform]: 195 | return [ 196 | albu.PadIfNeeded( 197 | min_height=None, 198 | min_width=None, 199 | pad_height_divisor=self.pad_divisor, 200 | pad_width_divisor=self.pad_divisor, 201 | ), 202 | ToTensorV2(transpose_mask=True) 203 | ] 204 | 205 | def _apply_transforms(self, image, mask, ground_truth=None) -> Dict[str, T]: 206 | if ground_truth is None: 207 | return self.transforms(image=image, mask=mask) 208 | return self.transforms(image=image, mask=mask, ground_truth=ground_truth) 209 | 210 | 211 | @dataclass 212 | class AbstractNoiseDataset3D(AbstractNoiseDataset, ABC): 213 | channel_last: bool = False 214 | n_dim: int = 3 215 | transforms: Optional[ 216 | Union[List[t3d.BaseTransform3D], t3d.Compose, List[t3d.Compose]] 217 | ] = None 218 | 219 | def _compose_transforms(self, *args, **kwargs) -> t3d.Compose: 220 | return t3d.Compose(*args, **kwargs) 221 | 222 | def _get_post_transforms(self) -> List[t3d.BaseTransform3D]: 223 | return [t3d.ToTensor(transpose=False)] 224 | 225 | def _apply_transforms(self, image: np.ndarray, mask: np.ndarray) -> Dict[str, T]: 226 | ret = { 227 | "image": self.transforms(image, resample=True), 228 | "mask": self.transforms(mask, resample=False), 229 | } 230 | return ret 231 | 232 | 233 | @dataclass 234 | class AbstractNoiseDataset3DLarge(AbstractNoiseDataset3D, ABC): 235 | """ 236 | For large images where we standardize a full-size image 237 | """ 238 | 239 | input_name: str = None 240 | tile_size: int = 64 241 | tile_step: int = 48 242 | mean: float = 0 243 | std: float = 1 244 | weight: str = "pyramid" 245 | 246 | def __getitem__(self, i: int) -> Dict[str, Any]: 247 | """ 248 | :param i: int, index 249 | :return: dict(image, mask, mean, std, crop) 250 | """ 251 | image, crop = self._read_image(self.images[i]) 252 | mask = self._mask_like_image(image) 253 | ret = self._apply_transforms(image.astype(np.float32), mask) 254 | # standardization/normalization step removed since we process the full-sized image 255 | ret["mean"], ret["std"] = ( 256 | # TODO can rewrite just for self.mean and std? 257 | torch.tensor(self.mean if self.standardize else 0).view(1, 1, 1, 1), 258 | torch.tensor(self.std if self.standardize else 1).view(1, 1, 1, 1), 259 | ) 260 | ret["crop"] = crop 261 | return ret 262 | 263 | def _read_image(self, image_or_path: List[int]) -> Tuple[np.ndarray, List[int]]: 264 | image, crop = self.tiler.crop_tile(image=self.image, crop=image_or_path) 265 | return np.moveaxis(image, -1, 0), crop 266 | 267 | def _read_large_image(self): 268 | self.image = io.imread(str(self.path / self.input_name)).astype(np.float32) 269 | 270 | def _get_images(self) -> Union[List[str], np.ndarray]: 271 | self._read_large_image() 272 | 273 | if len(self.image.shape) < 4: 274 | self.image = self.image[..., np.newaxis] 275 | 276 | if self.standardize: 277 | self.mean = self.image.mean() 278 | self.std = self.image.std() 279 | self.image = (self.image - self.mean) / self.std 280 | else: 281 | self.image = normalize_percentile(self.image) 282 | 283 | self.tiler = ImageSlicer( 284 | self.image.shape, 285 | tile_size=self.tile_size, 286 | tile_step=self.tile_step, 287 | weight=self.weight, 288 | is_channels=True, 289 | ) 290 | 291 | return self.tiler.crops 292 | 293 | 294 | @dataclass 295 | class AbstractNoiseDataset3DLargeH5(AbstractNoiseDataset3DLarge): 296 | def _read_large_image(self): 297 | with h5py.File(str(self.path / self.input_name), "r") as f: 298 | self.image = np.array(f["image"], dtype=np.float32) 299 | -------------------------------------------------------------------------------- /noise2same/dataset/bsd68.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pathlib import Path 3 | from typing import List, Union, Dict 4 | 5 | import numpy as np 6 | 7 | from noise2same.dataset.abc import AbstractNoiseDataset2D 8 | 9 | 10 | @dataclass 11 | class BSD68DatasetPrepared(AbstractNoiseDataset2D): 12 | path: Union[Path, str] = "data/BSD68" 13 | mode: str = "train" 14 | 15 | def _validate(self) -> bool: 16 | assert self.mode in ("train", "val", "test") 17 | return True 18 | 19 | def _get_images(self) -> Dict[str, Union[List[str], np.ndarray]]: 20 | path = self.path / self.mode 21 | files = list(path.glob("*.npy")) 22 | return { 23 | "noisy_input": np.load(files[0].as_posix(), allow_pickle=True) 24 | } 25 | 26 | def _read_image(self, image_or_path: Union[str, np.ndarray]) -> np.ndarray: 27 | return image_or_path 28 | -------------------------------------------------------------------------------- /noise2same/dataset/dummy.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | 4 | import numpy as np 5 | 6 | from noise2same.dataset.abc import AbstractNoiseDataset3DLarge 7 | 8 | 9 | @dataclass 10 | class DummyDataset3DLarge(AbstractNoiseDataset3DLarge): 11 | image: Optional[np.ndarray] = None 12 | image_size: int = 256 13 | path: str = "." 14 | 15 | def _read_large_image(self): 16 | if self.image is None: 17 | shape = (self.image_size,) * self.n_dim 18 | self.image = np.random.rand(*shape) 19 | self.image = self.image.astype(np.float32) 20 | -------------------------------------------------------------------------------- /noise2same/dataset/fmd.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | from typing import List, Union, Dict 5 | 6 | import cv2 7 | import numpy as np 8 | 9 | from noise2same.dataset.abc import AbstractNoiseDataset2D 10 | from noise2same.dataset.util import ( 11 | add_microscope_blur_2d, 12 | add_poisson_gaussian_noise, 13 | normalize, 14 | ) 15 | 16 | from tqdm import tqdm 17 | 18 | 19 | @dataclass 20 | class FMDDatasetPrepared(AbstractNoiseDataset2D): 21 | path: Union[Path, str] = "data/FMD" 22 | mode: str = "train" 23 | part: str = "cf_fish" 24 | add_blur_and_noise: bool = False 25 | 26 | def _validate(self) -> bool: 27 | assert self.mode in ("train", "val", "test") 28 | assert self.part in ("cf_fish", "cf_mice", "tp_mice") 29 | return True 30 | 31 | def _get_images(self) -> Dict[str, Union[List[str], np.ndarray]]: 32 | path = self.path / { 33 | 'cf_fish': 'Confocal_FISH', 34 | 'cf_mice': 'Confocal_MICE', 35 | 'tp_mice': 'TwoPhoton_MICE' 36 | }[self.part] 37 | folders = list(range(1, 19)) + [20] if self.mode == 'train' else [19] 38 | paths = {sub: [ 39 | np.concatenate([ 40 | np.expand_dims(cv2.imread(str(path / sub / str(i) / image), cv2.IMREAD_GRAYSCALE), 0) for image 41 | in sorted(os.listdir(path / sub / str(i))) if image.endswith('png') 42 | ]) for i in tqdm(folders, desc=sub) 43 | ] for sub in (('raw', 'gt') if not self.add_blur_and_noise else ('gt', )) 44 | } 45 | if self.add_blur_and_noise: 46 | paths['raw'] = [self._add_blur_and_noise(img[0])[None, ...] for img in tqdm(paths['gt'], desc='blur')] 47 | else: 48 | paths['gt'] = [np.concatenate([folder] * 50) for folder in paths['gt']] 49 | 50 | 51 | return { 52 | "noisy_input": np.concatenate(paths['raw']), 53 | "ground_truth": np.concatenate(paths['gt']) 54 | } 55 | 56 | def _read_image(self, image_or_path: Union[str, np.ndarray]) -> np.ndarray: 57 | self.mean = np.mean(image_or_path, keepdims=True, dtype=np.float32)[None, ...] 58 | self.std = np.std(image_or_path, keepdims=True, dtype=np.float32)[None, ...] 59 | return image_or_path 60 | 61 | def _add_blur_and_noise(self, image: np.ndarray) -> np.ndarray: 62 | image = normalize(image) 63 | # TODO parametrize 64 | try: 65 | image, self.psf = add_microscope_blur_2d(image, size=17) 66 | except ValueError as e: 67 | raise ValueError(f"Failed to convolve image {image.shape}") from e 68 | image = add_poisson_gaussian_noise( 69 | image, 70 | alpha=0.001, 71 | sigma=0.1, 72 | sap=0, # 0.01 by default but it is not common to have salt and pepper 73 | quant_bits=10) 74 | return image * 255 75 | -------------------------------------------------------------------------------- /noise2same/dataset/getter.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Tuple, Optional, Dict 3 | 4 | import numpy as np 5 | import tifffile 6 | from omegaconf import DictConfig 7 | from skimage import io 8 | from torch.utils.data import Dataset, ConcatDataset 9 | from tqdm.auto import tqdm 10 | 11 | from . import bsd68, fmd, hanzi, imagenet, sidd, microtubules, planaria, ssi, synthetic, synthetic_grayscale 12 | from .util import training_augmentations_2d, training_augmentations_3d, validation_transforms_2d 13 | from noise2same.util import normalize_percentile 14 | 15 | 16 | def compute_pad_divisor(cfg: DictConfig) -> Optional[int]: 17 | """ 18 | Compute the number by which the padded image size should 19 | be divisible, so that it is suitable for the chosen backbone 20 | :param cfg: DictConfig, training/evaluation configuration object 21 | :return: Optional[int] 22 | """ 23 | if cfg.backbone_name == "unet": 24 | return 2 ** cfg.backbone.depth 25 | elif cfg.backbone_name in ("swinir", "bsp_swinir"): 26 | return cfg.backbone.window_size 27 | elif cfg.backbone_name == "swinia": 28 | return cfg.backbone.window_size * max(cfg.backbone.dilations) * max(cfg.backbone.shuffles) 29 | else: 30 | raise ValueError("Incorrect backbone name") 31 | 32 | 33 | def get_dataset(cfg: DictConfig, cwd: Path) -> Tuple[Dataset, Dataset]: 34 | """ 35 | Collect training and validation datasets specified in the configuration 36 | :param cfg: DictConfig, training/evaluation configuration object 37 | :param cwd: Path, project working directory 38 | :return: Tuple[Dataset, Dataset] 39 | """ 40 | 41 | dataset_valid = None 42 | 43 | pad_divisor = compute_pad_divisor(cfg) 44 | 45 | transforms = None 46 | transforms_valid = None 47 | if cfg.experiment.lower() in ("bsd68", "fmd", "synthetic", "synthetic_grayscale", "hanzi", "imagenet", "sidd", "ssi"): 48 | transforms = training_augmentations_2d(crop=cfg.training.crop) 49 | transforms_valid = validation_transforms_2d(crop=cfg.training.crop) 50 | 51 | if cfg.experiment.lower() == "bsd68": 52 | dataset_train = bsd68.BSD68DatasetPrepared( 53 | path=cwd / "data/BSD68/", 54 | mode="train", 55 | transforms=transforms, 56 | pad_divisor=pad_divisor, 57 | ) 58 | if cfg.training.validate: 59 | dataset_valid = bsd68.BSD68DatasetPrepared( 60 | path=cwd / "data/BSD68/", mode="val", 61 | pad_divisor=pad_divisor, 62 | ) 63 | 64 | elif cfg.experiment.lower() == "synthetic": 65 | dataset_train = synthetic.ImagenetSyntheticDataset( 66 | path=cwd / "data/Imagenet_val", 67 | noise_type=cfg.data.noise_type, 68 | noise_param=cfg.data.noise_param, 69 | transforms=transforms, 70 | pad_divisor=pad_divisor, 71 | standardize=cfg.data.standardize, 72 | ) 73 | if cfg.training.validate: 74 | dataset_valid = synthetic.Set14SyntheticDataset( 75 | path=cwd / "data/Set14", 76 | noise_type=cfg.data.noise_type, 77 | noise_param=cfg.data.noise_param, 78 | transforms=transforms_valid, 79 | pad_divisor=pad_divisor, 80 | standardize=cfg.data.standardize, 81 | ) 82 | 83 | elif cfg.experiment.lower() == "synthetic_grayscale": 84 | dataset_train = synthetic_grayscale.BSD400SyntheticDataset( 85 | path=cwd / "data/BSD400", 86 | noise_type=cfg.data.noise_type, 87 | noise_param=cfg.data.noise_param, 88 | transforms=transforms, 89 | pad_divisor=pad_divisor, 90 | standardize=cfg.data.standardize, 91 | ) 92 | if cfg.training.validate: 93 | dataset_valid = synthetic_grayscale.BSD68SyntheticDataset( 94 | path=cwd / "data/BSD68-test", 95 | noise_type=cfg.data.noise_type, 96 | noise_param=cfg.data.noise_param, 97 | transforms=transforms_valid, 98 | pad_divisor=pad_divisor, 99 | standardize=cfg.data.standardize, 100 | fixed=True, 101 | ) 102 | 103 | elif cfg.experiment.lower() == "fmd": 104 | dataset_train = fmd.FMDDatasetPrepared( 105 | path=cwd / "data/FMD", 106 | mode="train", 107 | transforms=transforms, 108 | pad_divisor=pad_divisor, 109 | part=cfg.data.part, 110 | add_blur_and_noise=cfg.data.add_blur_and_noise, 111 | ) 112 | if cfg.training.validate: 113 | dataset_valid = fmd.FMDDatasetPrepared( 114 | path=cwd / "data/FMD", 115 | mode="val", 116 | pad_divisor=pad_divisor, 117 | part=cfg.data.part, 118 | add_blur_and_noise=cfg.data.add_blur_and_noise, 119 | ) 120 | 121 | elif cfg.experiment.lower() == "hanzi": 122 | dataset_train = hanzi.HanziDatasetPrepared( 123 | path=cwd / "data/Hanzi/tiles", 124 | mode="training", 125 | transforms=transforms, 126 | version=cfg.data.version, 127 | noise_level=cfg.data.noise_level, 128 | pad_divisor=pad_divisor, 129 | ) 130 | if cfg.training.validate: 131 | dataset_valid = hanzi.HanziDatasetPrepared( 132 | path=cwd / "data/Hanzi/tiles", 133 | mode="validation", 134 | version=cfg.data.version, 135 | noise_level=cfg.data.noise_level, 136 | pad_divisor=pad_divisor, 137 | ) 138 | 139 | elif cfg.experiment.lower() == "imagenet": 140 | dataset_train = imagenet.ImagenetDatasetPrepared( 141 | path=cwd / "data/ImageNet", 142 | mode="train", 143 | transforms=transforms, 144 | version=cfg.data.version, 145 | pad_divisor=pad_divisor, 146 | ) 147 | if cfg.training.validate: 148 | dataset_valid = imagenet.ImagenetDatasetPrepared( 149 | path=cwd / "data/ImageNet", 150 | mode="val", 151 | version=cfg.data.version, 152 | pad_divisor=pad_divisor, 153 | ) 154 | 155 | elif cfg.experiment.lower() == "sidd": 156 | dataset_train = sidd.SIDDDatasetPrepared( 157 | path=cwd / "data/SIDD-NAFNet", 158 | mode="train", 159 | transforms=transforms, 160 | pad_divisor=pad_divisor, 161 | ) 162 | if cfg.training.validate: 163 | dataset_valid = sidd.SIDDDatasetPrepared( 164 | path=cwd / "data/SIDD-NAFNet", 165 | mode="val", 166 | pad_divisor=pad_divisor, 167 | ) 168 | 169 | elif cfg.experiment.lower() == "planaria": 170 | dataset_train = planaria.PlanariaDatasetPrepared( 171 | path=cwd / "data/Denoising_Planaria", 172 | mode="train", 173 | transforms=training_augmentations_3d(), 174 | pad_divisor=pad_divisor, 175 | ) 176 | if cfg.training.validate: 177 | dataset_valid = planaria.PlanariaDatasetPrepared( 178 | path=cwd / "data/Denoising_Planaria", 179 | mode="val", 180 | pad_divisor=pad_divisor, 181 | ) 182 | 183 | elif cfg.experiment.lower() == "microtubules": 184 | dataset_train = microtubules.MicrotubulesDataset( 185 | path=cwd / cfg.data.path, 186 | input_name=cfg.data.input_name, 187 | transforms=training_augmentations_3d(), 188 | tile_size=cfg.data.tile_size, 189 | tile_step=cfg.data.tile_step, 190 | add_blur_and_noise=cfg.data.add_blur_and_noise, 191 | pad_divisor=pad_divisor, 192 | ) 193 | 194 | elif cfg.experiment.lower() == "ssi": 195 | dataset_train = ssi.SSIDataset( 196 | path=cwd / cfg.data.path, 197 | input_name=cfg.data.input_name, 198 | transforms=transforms, 199 | pad_divisor=pad_divisor, 200 | ) 201 | else: 202 | # todo add other datasets 203 | raise ValueError(f"Unknown experiment: {cfg.experiment}") 204 | 205 | return dataset_train, dataset_valid 206 | 207 | 208 | def get_test_dataset_and_gt(cfg: DictConfig, cwd: Path) -> Tuple[Dataset, np.ndarray]: 209 | """ 210 | Collect test dataset and ground truth specified in the configuration 211 | :param cfg: DictConfig, training/evaluation configuration object 212 | :param cwd: Path, project working directory 213 | :return: Tuple[Dataset, np.ndarray] 214 | """ 215 | 216 | pad_divisor = compute_pad_divisor(cfg) 217 | 218 | if cfg.experiment.lower() == "bsd68": 219 | dataset = bsd68.BSD68DatasetPrepared( 220 | path=cwd / "data/BSD68/", 221 | mode="test", 222 | pad_divisor=pad_divisor, 223 | ) 224 | gt = np.load( 225 | str(cwd / "data/BSD68/test/bsd68_groundtruth.npy"), allow_pickle=True 226 | ) 227 | 228 | elif cfg.experiment.lower() == "fmd": 229 | dataset = fmd.FMDDatasetPrepared( 230 | path=cwd / "data/FMD", 231 | mode="test", 232 | pad_divisor=pad_divisor, 233 | part=cfg.data.part, 234 | add_blur_and_noise=cfg.data.add_blur_and_noise, 235 | ) 236 | gt = dataset.ground_truth 237 | 238 | elif cfg.experiment.lower() == "synthetic": 239 | params = { 240 | "noise_type": cfg.data.noise_type, 241 | "noise_param": cfg.data.noise_param, 242 | "pad_divisor": pad_divisor, 243 | "standardize": cfg.data.standardize, 244 | } 245 | dataset = { 246 | "kodak": synthetic.KodakSyntheticDataset(path=cwd / "data/Kodak", **params), 247 | "bsd300": synthetic.BSD300SyntheticDataset(path=cwd / "data/BSD300/test", **params), 248 | "set14": synthetic.Set14SyntheticDataset(path=cwd / "data/Set14", **params), 249 | } 250 | gt = {name: [synthetic.read_image(p) for p in tqdm(ds.images, desc=name)] for name, ds in dataset.items()} 251 | 252 | # Repeat datasets for stable validation 253 | # https://github.com/TaoHuang2018/Neighbor2Neighbor/blob/2fff2978/train.py#L412 254 | repeats = {"kodak": 10, "bsd300": 3, "set14": 20} 255 | dataset = {name: ConcatDataset([ds] * repeats[name]) for name, ds in dataset.items()} 256 | gt = {name: ds * repeats[name] for name, ds in gt.items()} 257 | 258 | # Concatenate datasets together 259 | dataset = ConcatDataset(list(dataset.values())) 260 | gt = np.concatenate(list(gt.values())) 261 | assert len(dataset) == len(gt) 262 | 263 | elif cfg.experiment.lower() == "synthetic_grayscale": 264 | params = { 265 | "noise_type": cfg.data.noise_type, 266 | "noise_param": cfg.data.noise_param, 267 | "pad_divisor": pad_divisor, 268 | "standardize": cfg.data.standardize, 269 | } 270 | dataset = { 271 | "set12": synthetic_grayscale.Set12SyntheticDataset(path=cwd / "data/Set12", **params), 272 | # Fixed noise is lower intensity, because it was quantized to 8-bit 273 | "bsd68": synthetic_grayscale.BSD68SyntheticDataset(path=cwd / "data/BSD68-test", fixed=False, **params), 274 | } 275 | gt = {name: [synthetic.read_image(p) for p in 276 | tqdm(ds.ground_truth if ds.ground_truth is not None else ds.images, desc=name)] 277 | for name, ds in dataset.items()} 278 | 279 | # Repeat datasets for stable validation 280 | # https://github.com/TaoHuang2018/Neighbor2Neighbor/blob/2fff2978/train.py#L412 281 | # repeats = {"set12": 10, "bsd68": 10} 282 | repeats = {"set12": 20, "bsd68": 4} # 240 and 272 images respectively 283 | dataset = {name: ConcatDataset([ds] * repeats[name]) for name, ds in dataset.items()} 284 | gt = {name: ds * repeats[name] for name, ds in gt.items()} 285 | 286 | # Concatenate datasets together 287 | dataset = ConcatDataset(list(dataset.values())) 288 | gt = np.concatenate(list(gt.values())) 289 | assert len(dataset) == len(gt) 290 | 291 | elif cfg.experiment.lower() == "hanzi": 292 | dataset = hanzi.HanziDatasetPrepared( 293 | path=cwd / "data/Hanzi/tiles", 294 | mode="testing", 295 | pad_divisor=pad_divisor, 296 | noise_level=cfg.data.noise_level 297 | ) 298 | gt = np.load(str(cwd / "data/Hanzi/tiles/testing.npy"))[:, 0] 299 | 300 | elif cfg.experiment.lower() == "imagenet": 301 | dataset = imagenet.ImagenetDatasetTest( 302 | path=cwd / "data/ImageNet/", 303 | pad_divisor=pad_divisor, 304 | ) 305 | gt = [ 306 | np.load(p)[0] for p in tqdm(sorted((dataset.path / "test").glob("*.npy"))) 307 | ] 308 | 309 | elif cfg.experiment.lower() == "sidd": 310 | dataset = sidd.SIDDDatasetPrepared( 311 | path=cwd / "data/SIDD-NAFNet/", 312 | mode='test', 313 | pad_divisor=pad_divisor, 314 | ) 315 | gt = dataset.ground_truth 316 | 317 | elif cfg.experiment.lower() == "planaria": 318 | # This returns just a single image! 319 | # Use get_planaria_dataset_and_gt() instead 320 | dataset = planaria.PlanariaDatasetTiff( 321 | cwd 322 | / "data/Denoising_Planaria/test_data/condition_1/EXP278_Smed_fixed_RedDot1_sub_5_N7_m0012.tif", 323 | standardize=True, 324 | pad_divisor=pad_divisor, 325 | ) 326 | dataset.mean, dataset.std = 0, 1 327 | 328 | gt = tifffile.imread( 329 | cwd 330 | / "data/Denoising_Planaria/test_data/GT/EXP278_Smed_fixed_RedDot1_sub_5_N7_m0012.tif" 331 | ) 332 | gt = normalize_percentile(gt, 0.1, 99.9) 333 | 334 | elif cfg.experiment.lower() == "microtubules": 335 | dataset = microtubules.MicrotubulesDataset( 336 | path=cwd / cfg.data.path, 337 | input_name=cfg.data.input_name, 338 | # we can double the size of the tiles for validation 339 | tile_size=cfg.data.tile_size * 2, # 64 * 2 = 128 340 | tile_step=cfg.data.tile_step * 2, # 48 * 2 = 96 341 | add_blur_and_noise=cfg.data.add_blur_and_noise, # TODO add different noise by random seed? 342 | pad_divisor=pad_divisor, 343 | ) 344 | # dataset.mean, dataset.std = 0, 1 345 | 346 | gt = io.imread(str(cwd / "data/microtubules-simulation/ground-truth.tif")) 347 | gt = normalize_percentile(gt, 0.1, 99.9) 348 | 349 | elif cfg.experiment.lower() == "ssi": 350 | dataset = ssi.SSIDataset( 351 | path=cwd / cfg.data.path, 352 | input_name=cfg.data.input_name, 353 | pad_divisor=pad_divisor, 354 | ) 355 | gt = dataset.gt 356 | else: 357 | raise ValueError(f"Dataset {cfg.experiment} not found") 358 | 359 | return dataset, gt 360 | 361 | 362 | def get_planaria_dataset_and_gt(filename_gt: str) -> Tuple[Dict[str, Dataset], np.ndarray]: 363 | """ 364 | Collect Planaria dataset and ground truth 365 | :param filename_gt: str, Planaria dataset ground truth filename 366 | :return: Tuple[Dict[str, Dataset], np.ndarray] 367 | """ 368 | gt = tifffile.imread(filename_gt) 369 | gt = normalize_percentile(gt, 0.1, 99.9) 370 | datasets = {} 371 | for c in range(1, 4): 372 | datasets[f"c{c}"] = planaria.PlanariaDatasetTiff( 373 | filename_gt.replace("GT", f"condition_{c}"), 374 | standardize=True, 375 | ) 376 | datasets[f"c{c}"].mean, datasets[f"c{c}"].std = 0, 1 377 | 378 | return datasets, gt 379 | -------------------------------------------------------------------------------- /noise2same/dataset/hanzi.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pathlib import Path 3 | from typing import List, Union, Dict 4 | 5 | import numpy as np 6 | 7 | from noise2same.dataset.abc import AbstractNoiseDataset2D 8 | 9 | 10 | @dataclass 11 | class HanziDatasetPrepared(AbstractNoiseDataset2D): 12 | path: Union[Path, str] = "data//Hanzi/tiles/" 13 | mode: str = "training" 14 | version: int = 0 # two noisy copies exist (0, 1) 15 | noise_level: int = 3 # four noise levels (1, 2, 3, 4) 16 | 17 | def _validate(self) -> bool: 18 | assert self.mode in ("training", "testing", "validation") 19 | assert self.noise_level in (1, 2, 3, 4) 20 | assert self.version in (0, 1) 21 | return True 22 | 23 | def _get_images(self) -> Dict[str, Union[List[str], np.ndarray]]: 24 | data = np.load(self.path / f"{self.mode}.npy") 25 | return { 26 | "noisy_input": data[:, self.version * 4 + self.noise_level], 27 | "ground_truth": data[:, 0] 28 | } 29 | 30 | def _read_image(self, image_or_path: Union[str, np.ndarray]) -> np.ndarray: 31 | return image_or_path 32 | -------------------------------------------------------------------------------- /noise2same/dataset/imagenet.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pathlib import Path 3 | from typing import List, Union, Dict 4 | 5 | import numpy as np 6 | 7 | from noise2same.dataset.abc import AbstractNoiseDataset2D 8 | 9 | 10 | @dataclass 11 | class ImagenetDatasetPrepared(AbstractNoiseDataset2D): 12 | path: Union[Path, str] = "data/ImageNet" 13 | mode: str = "train" 14 | version: int = 0 # two noisy copies exist (0, 1) 15 | standardize_by_channel: bool = True 16 | 17 | def _validate(self) -> bool: 18 | assert self.mode in ("train", "val") 19 | assert self.version in (0, 1) 20 | return True 21 | 22 | def _get_images(self) -> Dict[str, Union[List[str], np.ndarray]]: 23 | data = np.load(self.path / f"{self.mode}.npy") 24 | return { 25 | "noisy_input": data[:, self.version + 1], 26 | "ground_truth": data[:, 0] 27 | } 28 | 29 | def _read_image(self, image_or_path: Union[str, np.ndarray]) -> np.ndarray: 30 | return image_or_path 31 | 32 | 33 | @dataclass 34 | class ImagenetDatasetTest(AbstractNoiseDataset2D): 35 | path: Union[Path, str] = "data/ImageNet" 36 | standardize_by_channel: bool = True 37 | 38 | def _get_images(self) -> Dict[str, Union[List[str], np.ndarray]]: 39 | return { 40 | "noisy_input": sorted((self.path / "test").glob("*.npy")) 41 | } 42 | 43 | def _read_image(self, image_or_path: Union[str, np.ndarray]) -> np.ndarray: 44 | return np.load(image_or_path)[1] 45 | -------------------------------------------------------------------------------- /noise2same/dataset/microtubules.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pathlib import Path 3 | from typing import Any, Dict, List, Tuple, Union 4 | 5 | import numpy as np 6 | from skimage import io 7 | 8 | from noise2same.dataset.abc import AbstractNoiseDataset3DLarge 9 | from noise2same.dataset.util import ( 10 | add_microscope_blur_3d, 11 | add_poisson_gaussian_noise, 12 | normalize, 13 | ) 14 | 15 | 16 | @dataclass 17 | class MicrotubulesDataset(AbstractNoiseDataset3DLarge): 18 | path: Union[Path, str] = "data/microtubules-simulation" 19 | input_name: str = "input.tif" 20 | add_blur_and_noise: bool = False 21 | 22 | def _read_large_image(self): 23 | self.image = io.imread(str(self.path / self.input_name)).astype(np.float32) 24 | if self.add_blur_and_noise: 25 | print(f"Generating blur and noise for {self.input_name}") 26 | # self.image = normalize_percentile(self.image, 0.1, 99.9) 27 | self.image = normalize(self.image) 28 | # TODO parametrize 29 | self.image, self.psf = add_microscope_blur_3d(self.image, size=17) 30 | self.image = add_poisson_gaussian_noise( 31 | self.image, 32 | alpha=0.001, 33 | sigma=0.1, 34 | sap=0, # 0.01 by default but it is not common to have salt and pepper 35 | quant_bits=10, 36 | ) 37 | -------------------------------------------------------------------------------- /noise2same/dataset/planaria.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pathlib import Path 3 | from typing import Any, Dict, List, Tuple, Union 4 | 5 | import numpy as np 6 | import tifffile 7 | from pytorch_toolbelt.inference.tiles import ImageSlicer 8 | 9 | from noise2same.dataset.abc import AbstractNoiseDataset3D, AbstractNoiseDataset3DLarge 10 | from noise2same.util import normalize_percentile 11 | 12 | 13 | @dataclass 14 | class PlanariaDatasetPrepared(AbstractNoiseDataset3D): 15 | path: Union[Path, str] = "data/Denoising_Planaria" 16 | mode: str = "train" 17 | train_size: float = 0.9 18 | standardize: bool = False # data was prepared and percentile normalized 19 | 20 | def _get_images(self) -> Union[List[str], np.ndarray]: 21 | data = np.load(self.path / "train_data/data_label.npz")["X"].astype(np.float32) 22 | if self.mode == "train": 23 | data = data[: int(len(data) * self.train_size)] 24 | else: 25 | data = data[int(len(data) * self.train_size) :] 26 | return data 27 | 28 | def _read_image(self, image_or_path: Union[str, np.ndarray]) -> np.ndarray: 29 | return image_or_path 30 | 31 | 32 | @dataclass 33 | class PlanariaDatasetTiff(AbstractNoiseDataset3DLarge): 34 | tile_size: int = 256 35 | tile_step: int = 192 36 | crop_border: int = 32 37 | weight: str = "pyramid" 38 | 39 | def _get_images(self) -> Union[List[str], np.ndarray]: 40 | self.image = tifffile.imread(self.path)[..., None] 41 | 42 | if self.standardize: 43 | self.mean = self.image.mean() 44 | self.std = self.image.std() 45 | self.image = (self.image - self.mean) / self.std 46 | else: 47 | self.image = normalize_percentile(self.image) 48 | 49 | self.tiler = ImageSlicer( 50 | self.image.shape, 51 | tile_size=(96, self.tile_size, self.tile_size), 52 | tile_step=(96, self.tile_step, self.tile_step), 53 | weight=self.weight, 54 | is_channels=True, 55 | crop_border=(0, self.crop_border, self.crop_border), 56 | ) 57 | return self.tiler.crops 58 | 59 | def _read_image(self, image_or_path: List[int]) -> Tuple[np.ndarray, List[int]]: 60 | image, crop = self.tiler.crop_tile(image=self.image, crop=image_or_path) 61 | return np.moveaxis(image, -1, 0), crop 62 | -------------------------------------------------------------------------------- /noise2same/dataset/sidd.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pathlib import Path 3 | from typing import List, Union, Dict 4 | import os 5 | import cv2 6 | import lmdb 7 | 8 | import numpy as np 9 | from noise2same.dataset.abc import AbstractNoiseDataset2D 10 | 11 | 12 | def paired_paths_from_lmdb(folders, keys) -> List[Dict[str, str]]: 13 | """Generate paired paths from lmdb files. 14 | Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is: 15 | lq.lmdb 16 | ├── data.mdb 17 | ├── lock.mdb 18 | ├── meta_info.txt 19 | The data.mdb and lock.mdb are standard lmdb files and you can refer to 20 | https://lmdb.readthedocs.io/en/release/ for more details. 21 | The meta_info.txt is a specified txt file to record the meta information 22 | of our datasets. It will be automatically created when preparing 23 | datasets by our provided dataset tools. 24 | Each line in the txt file records 25 | 1)image name (with extension), 26 | 2)image shape, 27 | 3)compression level, separated by a white space. 28 | Example: `baboon.png (120,125,3) 1` 29 | We use the image name without extension as the lmdb key. 30 | Note that we use the same key for the corresponding lq and gt images. 31 | Args: 32 | folders (list[str]): A list of folder path. The order of list should 33 | be [input_folder, gt_folder]. 34 | keys (list[str]): A list of keys identifying folders. The order should 35 | be in consistent with folders, e.g., ['lq', 'gt']. 36 | Note that this key is different from lmdb keys. 37 | Returns: 38 | list[str]: Returned path list. 39 | """ 40 | assert len(folders) == 2, ( 41 | 'The len of folders should be 2 with [input_folder, gt_folder]. ' 42 | f'But got {len(folders)}') 43 | assert len(keys) == 2, ( 44 | 'The len of keys should be 2 with [input_key, gt_key]. ' 45 | f'But got {len(keys)}') 46 | input_folder, gt_folder = folders 47 | input_key, gt_key = keys 48 | 49 | if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')): 50 | raise ValueError( 51 | f'{input_key} folder and {gt_key} folder should both in lmdb ' 52 | f'formats. But received {input_key}: {input_folder}; ' 53 | f'{gt_key}: {gt_folder}') 54 | # ensure that the two meta_info files are the same 55 | with open(os.path.join(input_folder, 'meta_info.txt')) as fin: 56 | input_lmdb_keys = [line.split('.')[0] for line in fin] 57 | with open(os.path.join(gt_folder, 'meta_info.txt')) as fin: 58 | gt_lmdb_keys = [line.split('.')[0] for line in fin] 59 | if set(input_lmdb_keys) != set(gt_lmdb_keys): 60 | raise ValueError( 61 | f'Keys in {input_key}_folder and {gt_key}_folder are different.') 62 | else: 63 | paths = [] 64 | for lmdb_key in sorted(input_lmdb_keys): 65 | paths.append( 66 | dict([(f'{input_key}_path', lmdb_key), 67 | (f'{gt_key}_path', lmdb_key)])) 68 | return paths 69 | 70 | 71 | @dataclass 72 | class SIDDDatasetPrepared(AbstractNoiseDataset2D): 73 | path: Union[Path, str] = Path("data/SIDD-NAFNet") 74 | mode: str = "train" 75 | standardize_by_channel: bool = True 76 | 77 | def _validate(self) -> bool: 78 | assert self.mode in ("train", "val", "test") 79 | return True 80 | 81 | def _get_images(self) -> Dict[str, Union[List[str], np.ndarray]]: 82 | data_path = self.path / 'train' if self.mode == 'train' else self.path / 'val' 83 | input_path, gt_path = str(data_path / 'input_crops.lmdb'), str(data_path / 'gt_crops.lmdb') 84 | paired_paths = paired_paths_from_lmdb([input_path, gt_path], ['lq', 'gt']) 85 | input_db = lmdb.open( 86 | input_path, 87 | readonly=True, 88 | lock=False, 89 | readahead=False, 90 | map_size=8 * 1024 * 10485760, 91 | ) 92 | gt_db = lmdb.open( 93 | gt_path, 94 | readonly=True, 95 | lock=False, 96 | readahead=False, 97 | map_size=8 * 1024 * 10485760, 98 | ) 99 | input_data, gt_data = [], [] 100 | for paired_path in paired_paths[:100]: 101 | with input_db.begin(write=False) as txn: 102 | value_buf = txn.get(paired_path['lq_path'].encode('ascii')) 103 | input_data.append(np.expand_dims(cv2.imdecode(np.frombuffer(value_buf, np.uint8), cv2.IMREAD_COLOR), 0)) 104 | with gt_db.begin(write=False) as txn: 105 | value_buf = txn.get(paired_path['gt_path'].encode('ascii')) 106 | gt_data.append(np.expand_dims(cv2.imdecode(np.frombuffer(value_buf, np.uint8), cv2.IMREAD_COLOR), 0)) 107 | return { 108 | 'noisy_input': np.concatenate(input_data), 109 | 'ground_truth': np.concatenate(gt_data) 110 | } 111 | 112 | def _read_image(self, image_or_path: Union[str, np.ndarray]) -> np.ndarray: 113 | return image_or_path 114 | -------------------------------------------------------------------------------- /noise2same/dataset/ssi.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | from typing import List, Union, Dict 5 | 6 | import numpy as np 7 | from imageio import imread 8 | 9 | from noise2same.dataset.abc import AbstractNoiseDataset2D 10 | from noise2same.dataset.util import ( 11 | add_microscope_blur_2d, 12 | add_poisson_gaussian_noise, 13 | normalize, 14 | ) 15 | 16 | 17 | @dataclass 18 | class SSIDataset(AbstractNoiseDataset2D): 19 | path: Union[Path, str] = "data/ssi/" 20 | standardize_by_channel: bool = True 21 | input_name: str = "drosophila" 22 | 23 | def _get_images(self) -> Dict[str, Union[List[str], np.ndarray]]: 24 | try: 25 | files = [f for f in self.path.iterdir() if f.is_file()] 26 | except FileNotFoundError as e: 27 | print("File not found, cwd:", os.getcwd()) 28 | raise e 29 | 30 | filename = [f.name for f in files if self.input_name in f.name][0] 31 | filepath = self.path / filename 32 | 33 | image_clipped = imread(filepath) 34 | 35 | image_clipped = normalize(image_clipped.astype(np.float32)) 36 | blurred_image, psf_kernel = add_microscope_blur_2d(image_clipped) 37 | noisy_blurred_image = add_poisson_gaussian_noise( 38 | blurred_image, alpha=0.001, sigma=0.1, sap=0.01, quant_bits=10 39 | ) 40 | 41 | self.psf = psf_kernel 42 | self.gt = image_clipped 43 | 44 | return { 45 | "noisy_input": noisy_blurred_image[None, ...] 46 | } 47 | 48 | def _read_image(self, image_or_path: Union[str, np.ndarray]) -> np.ndarray: 49 | return image_or_path 50 | -------------------------------------------------------------------------------- /noise2same/dataset/synthetic.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pathlib import Path 3 | from typing import Dict, List, Union, Sequence, Tuple 4 | 5 | import numpy as np 6 | import torch 7 | from PIL import Image 8 | from torch import tensor as T 9 | 10 | from noise2same.dataset.abc import AbstractNoiseDataset2D 11 | 12 | 13 | def read_image(path: Union[str, Path]) -> np.ndarray: 14 | """ 15 | Read image from path with PIL and convert to np.uint8 16 | :param path: path to image 17 | :return: np.uint8 image [0, 255] 18 | """ 19 | im = Image.open(path) 20 | im = np.array(im, dtype=np.uint8) 21 | return im 22 | 23 | 24 | @dataclass 25 | class SyntheticDataset(AbstractNoiseDataset2D): 26 | noise_type: str = "gaussian" 27 | extension: str = "JPEG" 28 | noise_param: Union[int, Tuple[int, int]] = 25 29 | name: str = "" 30 | cached: str = "" 31 | 32 | def _validate(self) -> bool: 33 | assert self.noise_type in ("gaussian", "poisson", "none") 34 | assert isinstance(self.noise_param, int) or \ 35 | (isinstance(self.noise_param, Sequence) and len(self.noise_param) == 2) 36 | return True 37 | 38 | def _noise_param(self) -> float: 39 | if isinstance(self.noise_param, int): 40 | return self.noise_param 41 | else: 42 | return np.random.uniform(low=self.noise_param[0], high=self.noise_param[1]) 43 | 44 | def add_gaussian(self, x: T) -> T: 45 | """ 46 | Add gaussian noise to image 47 | :param x: image [0, 1] 48 | 49 | Adopted from Neighbor2Neighbor https://github.com/TaoHuang2018/Neighbor2Neighbor/blob/2fff2978/train.py#L115 50 | """ 51 | noise = torch.FloatTensor(x.shape).normal_(mean=0.0, std=self._noise_param() / 255.0) 52 | return x + noise 53 | 54 | def add_poisson(self, x: T) -> T: 55 | """ 56 | Add gaussian noise to image 57 | :param x: image [0, 1] 58 | 59 | Adopted from Neighbor2Neighbor https://github.com/TaoHuang2018/Neighbor2Neighbor/blob/2fff2978/train.py#L124 60 | """ 61 | lam = self._noise_param() 62 | return torch.poisson(lam * x) / lam 63 | 64 | def _get_images(self) -> Dict[str, Union[List[str], np.ndarray]]: 65 | if self.cached: 66 | cached_path = self.path / self.cached 67 | if cached_path.exists(): 68 | print(f"Cache found in {cached_path}, reading images from npy...\n") 69 | return {"noisy_input": np.load(self.path / self.cached, allow_pickle=True)} 70 | else: 71 | print(f"Cache not found in {cached_path}, read images from disk\n") 72 | return {"noisy_input": sorted(list(self.path.glob(f"*.{self.extension}")))} 73 | 74 | def add_noise(self, x: T): 75 | if self.noise_type == "gaussian": 76 | return self.add_gaussian(x) 77 | elif self.noise_type == "poisson": 78 | return self.add_poisson(x) 79 | else: 80 | return x 81 | 82 | def _read_image(self, image_or_path: Union[str, np.ndarray]) -> np.ndarray: 83 | im = image_or_path if isinstance(image_or_path, np.ndarray) else read_image(image_or_path) 84 | im = im.astype(np.float32) / 255.0 85 | return im 86 | 87 | def _apply_transforms(self, image: np.ndarray, mask: np.ndarray, ground_truth: np.ndarray = None) -> Dict[str, T]: 88 | ret = super()._apply_transforms(image, mask, ground_truth) 89 | # Add noise on a cropped image (much faster than on the full one) 90 | ret["image"] = self.add_noise(ret["image"]) 91 | return ret 92 | 93 | 94 | @dataclass 95 | class ImagenetSyntheticDataset(SyntheticDataset): 96 | path: Union[Path, str] = "data/Imagenet_val" 97 | extension: str = "JPEG" 98 | name: str = "imagenet" 99 | cached: str = "Imagenet_val.npy" 100 | 101 | 102 | @dataclass 103 | class KodakSyntheticDataset(SyntheticDataset): 104 | path: Union[Path, str] = "data/Kodak" 105 | extension: str = "png" 106 | name: str = "kodak" 107 | 108 | 109 | @dataclass 110 | class Set14SyntheticDataset(SyntheticDataset): 111 | path: Union[Path, str] = "data/Set14" 112 | extension: str = "png" 113 | name: str = "set14" 114 | 115 | 116 | @dataclass 117 | class BSD300SyntheticDataset(SyntheticDataset): 118 | path: Union[Path, str] = "data/BSD300/test" 119 | extension: str = "png" 120 | name: str = "bsd300" 121 | -------------------------------------------------------------------------------- /noise2same/dataset/synthetic_grayscale.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pathlib import Path 3 | from typing import Dict, List, Union 4 | 5 | import numpy as np 6 | from torch import Tensor as T 7 | 8 | from noise2same.dataset.synthetic import SyntheticDataset 9 | 10 | 11 | @dataclass 12 | class SyntheticDatasetPrepared(SyntheticDataset): 13 | fixed: bool = False # if True, read prepared noisy images from disk 14 | 15 | def _get_images(self) -> Dict[str, Union[List[str], np.ndarray]]: 16 | path_original = self.path / "original" 17 | path_noisy = self.path / f"noise{self.noise_param}" if self.fixed else path_original 18 | if not path_noisy.exists(): 19 | print(f"Path {path_noisy} does not exist, generate random images " 20 | f"with {self.noise_type} noise {self.noise_param}") 21 | path_noisy = path_original 22 | self.fixed = False 23 | return {"noisy_input": sorted(list(path_noisy.glob(f"*.{self.extension}"))), 24 | "ground_truth": sorted(list(path_original.glob(f"*.{self.extension}")))} 25 | 26 | def add_noise(self, x: T) -> T: 27 | if self.fixed: 28 | return x 29 | return super().add_noise(x) 30 | 31 | 32 | @dataclass 33 | class BSD400SyntheticDataset(SyntheticDataset): 34 | path: Union[Path, str] = "data/BSD400" 35 | extension: str = "png" 36 | name: str = "bsd400" 37 | 38 | 39 | @dataclass 40 | class BSD68SyntheticDataset(SyntheticDatasetPrepared): 41 | path: Union[Path, str] = "data/BSD68-test/" 42 | extension: str = "png" 43 | name: str = "bsd68" 44 | 45 | 46 | @dataclass 47 | class Set12SyntheticDataset(SyntheticDataset): 48 | path: Union[Path, str] = "data/Set12" 49 | extension: str = "png" 50 | name: str = "set12" 51 | 52 | 53 | # @dataclass 54 | # class Urban100SyntheticDataset(SyntheticDataset): 55 | # path: Union[Path, str] = "data/Urban100" 56 | # extension: str = "png" 57 | # name: str = "urban100" 58 | 59 | -------------------------------------------------------------------------------- /noise2same/dataset/transforms.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass 3 | from typing import List, Optional, Tuple, Union 4 | 5 | import numpy as np 6 | import torch 7 | from numpy import ndarray 8 | from torch import Tensor as T 9 | 10 | Ints = Union[int, Tuple[int, ...], List[int]] 11 | Array = Union[ndarray, T] 12 | 13 | 14 | @dataclass 15 | class BaseTransform3D(ABC): 16 | p: float = 0.5 17 | axis: Ints = 0 18 | seed: int = 43 19 | k: int = 0 20 | done: bool = False 21 | channel_axis: Optional[Ints] = None 22 | 23 | def __post_init__(self): 24 | np.random.seed(self.seed) 25 | 26 | @abstractmethod 27 | def apply(self, x: ndarray) -> ndarray: 28 | raise NotImplementedError 29 | 30 | @abstractmethod 31 | def resample(self, x: ndarray) -> None: 32 | raise NotImplementedError 33 | 34 | def __call__(self, x: Array, resample: bool = False) -> Array: 35 | # If we do not resample, check if transform was done 36 | if not resample: 37 | if self.done: 38 | # if transform was applied before and we do not resample, always transform 39 | return self.apply(x) 40 | else: 41 | return x 42 | 43 | # If we resample 44 | else: 45 | self.resample(x) 46 | if np.random.uniform() < self.p: 47 | # transform with probability p 48 | self.done = True 49 | return self.apply(x) 50 | else: 51 | # otherwise return identity 52 | self.done = False 53 | return x 54 | 55 | 56 | class RandomFlip(BaseTransform3D): 57 | channel_axis: Ints = (0, 1) 58 | 59 | def resample(self, x: ndarray) -> None: 60 | dims = np.arange(x.ndim) 61 | dims[-1] = -1 62 | if self.channel_axis is not None: 63 | dims = np.delete(dims, self.channel_axis) 64 | self.axis = np.random.choice(dims) 65 | 66 | def apply(self, x: ndarray) -> ndarray: 67 | # .copy() solves negative stride issue 68 | return np.flip(x, axis=self.axis).copy() 69 | 70 | 71 | class RandomRotate90(BaseTransform3D): 72 | channel_axis: Ints = (0, 1) 73 | 74 | def apply(self, x: ndarray) -> ndarray: 75 | return np.rot90(x, k=self.k, axes=self.axis).copy() 76 | 77 | def resample(self, x: ndarray) -> None: 78 | dims = np.arange(x.ndim) 79 | dims[-1] = -1 80 | if self.channel_axis is not None: 81 | dims = np.delete(dims, self.channel_axis) 82 | self.k = np.random.choice(4) 83 | a = int(np.random.choice(len(dims))) 84 | self.axis = (dims[a], dims[a - 1]) 85 | 86 | 87 | class RandomCrop(BaseTransform3D): 88 | p: float = 1 89 | patch_size: Union[None, int, Tuple[int, ...]] = 64 90 | start: Optional[Union[int, List[int]]] = None 91 | 92 | def patch_tuple(self, x: Array) -> Tuple[int, ...]: 93 | """ 94 | Forms a correct tuple of patch shape from provided init argument `patch_size` 95 | :param x: array to crop a patch from 96 | :return: tuple with patch size 97 | """ 98 | if self.patch_size is None: 99 | # crop patch half a size of the original if None 100 | return tuple( 101 | s // 2 if s != 1 and i != self.channel_axis else None 102 | for i, s in enumerate(x.shape) 103 | ) 104 | if isinstance(self.patch_size, int): 105 | return tuple( 106 | self.patch_size if s != 1 and i != self.channel_axis else None 107 | for i, s in enumerate(x.shape) 108 | ) 109 | else: 110 | assert len(self.patch_size) == x.squeeze().n_dim 111 | return self.patch_size 112 | 113 | def slice(self, x: Array) -> Tuple[slice, ...]: 114 | """ 115 | Returns tuple of slices to slice a given array 116 | :param x: array to slice 117 | :return: tuple of slices 118 | """ 119 | patch_size = self.patch_tuple(x) 120 | 121 | # Create slices from patch_size and start points 122 | slices = tuple( 123 | slice(s, s + p) if p is not None else slice(None) 124 | for s, p in zip(self.start, patch_size) 125 | ) 126 | 127 | return slices 128 | 129 | def resample(self, x: ndarray) -> None: 130 | patch_size = self.patch_tuple(x) 131 | self.start = [ 132 | np.random.choice(s - p) if p is not None else p 133 | for s, p in zip(x.shape, patch_size) 134 | ] 135 | 136 | def apply(self, x: Array) -> Array: 137 | s = self.slice(x) 138 | return x[s] 139 | 140 | 141 | class CenterCrop(RandomCrop): 142 | def slice(self, x: Array) -> Tuple[slice, ...]: 143 | patch_size = self.patch_tuple(x) 144 | center = [s // 2 if p is not None else p for s, p in zip(x.shape, patch_size)] 145 | return tuple( 146 | slice(c - int(np.floor(p / 2)), c + int(np.ceil(p / 2))) 147 | if p is not None 148 | else slice(None) 149 | for c, p in zip(center, patch_size) 150 | ) 151 | 152 | 153 | @dataclass 154 | class Compose: 155 | transforms: List[BaseTransform3D] 156 | debug: bool = False 157 | 158 | def __call__(self, x: ndarray, resample: bool = False): 159 | out = x.copy() 160 | for t in self.transforms: 161 | if t is not None: 162 | if self.debug: 163 | print(f"Apply {t}") 164 | out = t(out, resample=resample) 165 | return out 166 | 167 | 168 | @dataclass 169 | class ToTensor(BaseTransform3D): 170 | transpose: bool = False 171 | p: int = 1 172 | done: bool = True 173 | 174 | def resample(self, x: ndarray) -> None: 175 | self.done = True 176 | 177 | def apply(self, x: ndarray) -> T: 178 | out = x.copy() 179 | if self.transpose: 180 | out = np.moveaxis(out, -1, 0) 181 | out = torch.from_numpy(out) 182 | return out 183 | -------------------------------------------------------------------------------- /noise2same/dataset/util.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | from typing import Any, List, Optional, Tuple, Union 3 | 4 | import albumentations as albu 5 | import numpy as np 6 | from numpy.random.mtrand import normal, uniform 7 | from scipy.signal import convolve, convolve2d 8 | from skimage.exposure import rescale_intensity 9 | from skimage.util import random_noise 10 | 11 | from noise2same.dataset import transforms as t3d 12 | from noise2same.psf.microscope_psf import SimpleMicroscopePSF 13 | 14 | Ints = Optional[Union[int, List[int], Tuple[int, ...]]] 15 | 16 | 17 | def get_stratified_coords( 18 | box_size: int, 19 | shape: Tuple[int, ...], 20 | resample: bool = False, 21 | ) -> Tuple[List[int], ...]: 22 | """ 23 | Create stratified blind spot coordinates 24 | :param box_size: int, size of stratification box 25 | :param shape: tuple, image shape 26 | :param resample: bool, resample if out o box 27 | :return: 28 | """ 29 | box_count = [int(np.ceil(s / box_size)) for s in shape] 30 | coords = [] 31 | 32 | for ic in product(*[np.arange(bc) for bc in box_count]): 33 | sampled = False 34 | while not sampled: 35 | coord = tuple(np.random.rand() * box_size for _ in shape) 36 | coord = [int(i * box_size + c) for i, c in zip(ic, coord)] 37 | if all(c < s for c, s in zip(coord, shape)): 38 | coords.append(coord) 39 | sampled = True 40 | if not resample: 41 | break 42 | 43 | coords = tuple(zip(*coords)) # transpose (N, 3) -> (3, N) 44 | return coords 45 | 46 | 47 | def mask_like_image( 48 | image: np.ndarray, mask_percentage: float = 0.5, channels_last: bool = True 49 | ) -> np.ndarray: 50 | """ 51 | Generates a stratified mask of image.shape 52 | :param image: ndarray, reference image to mask 53 | :param mask_percentage: float, percentage of pixels to mask, default 0.5% 54 | :param channels_last: bool, true to process image as channel-last (256, 256, 3) 55 | :return: ndarray, mask 56 | """ 57 | # todo understand generator_val 58 | # https://github.com/divelab/Noise2Same/blob/8cdbfef5c475b9f999dcb1a942649af7026c887b/models.py#L130 59 | mask = np.zeros_like(image) 60 | n_channels = image.shape[-1 if channels_last else 0] 61 | channel_shape = image.shape[:-1] if channels_last else image.shape[1:] 62 | n_dim = len(channel_shape) 63 | # I think, here comes a mistake in original implementation (np.sqrt used both for 2D and 3D images) 64 | # If we use square root for 3D images, we do not reach the required masking percentage 65 | # See test_dataset.py for checks 66 | box_size = np.round(np.power(100 / mask_percentage, 1 / n_dim)).astype(np.int) 67 | for c in range(n_channels): 68 | mask_coords = get_stratified_coords(box_size=box_size, shape=channel_shape) 69 | mask_coords = (mask_coords + (c,)) if channels_last else ((c,) + mask_coords) 70 | mask[mask_coords] = 1.0 71 | return mask 72 | 73 | 74 | def training_augmentations_2d(crop: int = 64): 75 | return [ 76 | albu.RandomCrop(width=crop, height=crop, p=1), 77 | albu.RandomRotate90(p=0.5), 78 | albu.Flip(p=0.5), 79 | ] 80 | 81 | 82 | def validation_transforms_2d(crop: int = 64): 83 | return [ 84 | albu.CenterCrop(width=crop, height=crop, p=1) 85 | ] 86 | 87 | 88 | def training_augmentations_3d(): 89 | return t3d.Compose( 90 | [ 91 | t3d.RandomRotate90(p=0.5, axis=(2, 3), channel_axis=(0, 1)), 92 | t3d.RandomFlip(p=0.5, axis=(2, 3), channel_axis=(0, 1)), 93 | ] 94 | ) 95 | 96 | 97 | def _raise(e): 98 | raise e 99 | 100 | 101 | class PadAndCropResizer(object): 102 | """ 103 | https://github.com/divelab/Noise2Same/blob/8cdbfef5c475b9f999dcb1a942649af7026c887b/utils/predict_utils.py#L115 104 | """ 105 | 106 | def __init__( 107 | self, mode: str = "reflect", div_n: Optional[int] = None, **kwargs: Any 108 | ): 109 | self.mode = mode 110 | self.kwargs = kwargs 111 | self.pad = None 112 | self.div_n = div_n 113 | 114 | def _normalize_exclude(self, exclude: Ints, n_dim: int): 115 | """Return normalized list of excluded axes.""" 116 | if exclude is None: 117 | return [] 118 | exclude_list = [exclude] if np.isscalar(exclude) else list(exclude) 119 | exclude_list = [d % n_dim for d in exclude_list] 120 | len(exclude_list) == len(np.unique(exclude_list)) or _raise(ValueError()) 121 | all((isinstance(d, int) and 0 <= d < n_dim for d in exclude_list)) or _raise( 122 | ValueError() 123 | ) 124 | return exclude_list 125 | 126 | def before(self, x: np.ndarray, div_n: int = None, exclude: Ints = None): 127 | def _split(v): 128 | a = v // 2 129 | return a, v - a 130 | 131 | if div_n is None: 132 | div_n = self.div_n 133 | assert div_n is not None 134 | 135 | exclude = self._normalize_exclude(exclude, x.ndim) 136 | self.pad = [ 137 | _split((div_n - s % div_n) % div_n) if (i not in exclude) else (0, 0) 138 | for i, s in enumerate(x.shape) 139 | ] 140 | x_pad = np.pad(x, self.pad, mode=self.mode, **self.kwargs) 141 | for i in exclude: 142 | del self.pad[i] 143 | return x_pad 144 | 145 | def after(self, x: np.ndarray, exclude: Ints = None): 146 | 147 | pads = self.pad[: len(x.shape)] # ? 148 | crop = [slice(p[0], -p[1] if p[1] > 0 else None) for p in self.pad] 149 | for i in self._normalize_exclude(exclude, x.ndim): 150 | crop.insert(i, slice(None)) 151 | len(crop) == x.ndim or _raise(ValueError()) 152 | return x[tuple(crop)] 153 | 154 | 155 | # https://github.com/royerlab/ssi-code/blob/master/ssi/utils/io/datasets.py 156 | 157 | 158 | def normalize(image): 159 | return rescale_intensity( 160 | image.astype(np.float32), in_range="image", out_range=(0, 1) 161 | ) 162 | 163 | 164 | def add_poisson_gaussian_noise( 165 | image, 166 | alpha=5, 167 | sigma=0.01, 168 | sap=0.0, 169 | quant_bits=8, 170 | dtype=np.float32, 171 | clip=True, 172 | fix_seed=True, 173 | ): 174 | if fix_seed: 175 | np.random.seed(0) 176 | rnd = normal(size=image.shape) 177 | rnd_bool = uniform(size=image.shape) < sap 178 | 179 | noisy = image + np.sqrt(alpha * image + sigma ** 2) * rnd 180 | noisy = noisy * (1 - rnd_bool) + rnd_bool * uniform(size=image.shape) 181 | noisy = np.around((2 ** quant_bits) * noisy) / 2 ** quant_bits 182 | noisy = np.clip(noisy, 0, 1) if clip else noisy 183 | noisy = noisy.astype(dtype) 184 | return noisy 185 | 186 | 187 | def add_noise(image, intensity=5, variance=0.01, sap=0.0, dtype=np.float32, clip=True): 188 | np.random.seed(0) 189 | noisy = image 190 | if intensity is not None: 191 | noisy = np.random.poisson(image * intensity) / intensity 192 | noisy = random_noise(noisy, mode="gaussian", var=variance, seed=0, clip=clip) 193 | noisy = random_noise(noisy, mode="s&p", amount=sap, seed=0, clip=clip) 194 | noisy = noisy.astype(dtype) 195 | return noisy 196 | 197 | 198 | def add_blur_2d(image, k=17, sigma=5, multi_channel=False): 199 | from numpy import exp, pi, sqrt 200 | 201 | # generate a (2k+1)x(2k+1) gaussian kernel with mean=0 and sigma = s 202 | probs = [ 203 | exp(-z * z / (2 * sigma * sigma)) / sqrt(2 * pi * sigma * sigma) 204 | for z in range(-k, k + 1) 205 | ] 206 | psf_kernel = np.outer(probs, probs) 207 | 208 | def conv(_image): 209 | return convolve2d(_image, psf_kernel, mode="same").astype(np.float32) 210 | 211 | if multi_channel: 212 | image = np.moveaxis(image.copy(), -1, 0) 213 | return ( 214 | np.moveaxis(np.stack([conv(channel) for channel in image]), 0, -1), 215 | psf_kernel, 216 | ) 217 | else: 218 | return conv(image), psf_kernel 219 | 220 | 221 | def add_microscope_blur_2d( 222 | image: np.ndarray, dz: int = 0, multi_channel: bool = False, size: int = 17 223 | ): 224 | psf = SimpleMicroscopePSF() 225 | psf_xyz_array = psf.generate_xyz_psf(dxy=0.406, dz=0.406, xy_size=size, z_size=size) 226 | psf_kernel = psf_xyz_array[dz] 227 | psf_kernel /= psf_kernel.sum() 228 | 229 | def conv(_image): 230 | return convolve2d(_image, psf_kernel, mode="same").astype(np.float32) 231 | 232 | if multi_channel: 233 | image = np.moveaxis(image.copy(), -1, 0) 234 | return ( 235 | np.moveaxis(np.stack([conv(channel) for channel in image]), 0, -1), 236 | psf_kernel, 237 | ) 238 | else: 239 | return conv(image), psf_kernel 240 | 241 | 242 | def add_microscope_blur_3d(image, size: int = 17): 243 | psf = SimpleMicroscopePSF() 244 | psf_xyz_array = psf.generate_xyz_psf(dxy=0.406, dz=0.406, xy_size=size, z_size=size) 245 | psf_kernel = psf_xyz_array 246 | psf_kernel /= psf_kernel.sum() 247 | return convolve(image, psf_kernel, mode="same"), psf_kernel 248 | -------------------------------------------------------------------------------- /noise2same/evaluator.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Tuple 2 | 3 | import numpy as np 4 | import torch 5 | from pytorch_toolbelt.inference.tiles import TileMerger 6 | from torch.cuda.amp import autocast 7 | from torch.utils.data import DataLoader, Dataset 8 | from tqdm import tqdm 9 | import time 10 | 11 | from noise2same.dataset.util import PadAndCropResizer 12 | from noise2same.model import Noise2Same 13 | from noise2same.backbone.unet import UNet 14 | from noise2same.backbone.swinir import SwinIR 15 | 16 | 17 | class Evaluator(object): 18 | def __init__( 19 | self, 20 | model: Noise2Same, 21 | device: str = "cuda", 22 | checkpoint_path: Optional[str] = None, 23 | masked: bool = False, 24 | ): 25 | """ 26 | Model evaluator, describes inference for different data formats 27 | :param model: model architecture to evaluate 28 | :param device: str, device to run inference 29 | :param checkpoint_path: optional str, path to the model checkpoint 30 | :param masked: if perform forward pass masked 31 | """ 32 | self.model = model 33 | self.device = device 34 | self.checkpoint_path = checkpoint_path 35 | if checkpoint_path is not None: 36 | self.load_checkpoint(checkpoint_path) 37 | self.masked = masked 38 | 39 | self.model.to(device) 40 | 41 | if isinstance(self.model.net, UNet): 42 | self.resizer = PadAndCropResizer( 43 | mode="reflect", div_n=2 ** self.model.net.depth 44 | ) 45 | elif isinstance(self.model.net, SwinIR): 46 | self.resizer = PadAndCropResizer( 47 | mode="reflect", div_n=self.model.net.window_size 48 | ) 49 | else: 50 | self.resizer = PadAndCropResizer(div_n=1) 51 | 52 | @torch.no_grad() 53 | def inference( 54 | self, 55 | loader: DataLoader, 56 | half: bool = False, 57 | empty_cache: bool = False, 58 | convolve: bool = False, 59 | key: str = "image", 60 | ) -> Tuple[List[Dict[str, np.ndarray]], List[int]]: 61 | """ 62 | Run inference for a given dataloader 63 | :param loader: DataLoader 64 | :param half: bool, if use half precision 65 | :param empty_cache: bool, if empty CUDA cache after each iteration 66 | :param convolve: bool, if convolve the output with a PSF 67 | :param key: str, key to use for the output [image, deconv] 68 | :return: List[Dict[key, output]] 69 | """ 70 | self.model.eval() 71 | 72 | outputs = [] 73 | iterator = tqdm(loader, desc="inference", position=0, leave=True) 74 | times = [] 75 | errors_num = 0 76 | indices = [] 77 | for i, batch in enumerate(iterator): 78 | try: 79 | batch = {k: v.to(self.device) for k, v in batch.items()} 80 | start = time.time() 81 | with autocast(enabled=half): 82 | if self.masked: 83 | # TODO remove randomness 84 | # idea: use the same mask for all images? mask as tta? 85 | out, _ = self.model.forward( 86 | batch["image"], mask=batch["mask"], convolve=convolve 87 | ) 88 | else: 89 | _, out = self.model.forward(batch["image"], convolve=convolve) 90 | out_raw = out[key] * batch["std"] + batch["mean"] 91 | 92 | out_raw = {"image": np.moveaxis(out_raw.detach().cpu().numpy(), 1, -1)} 93 | if self.model.lambda_proj > 0: 94 | out_raw.update( 95 | {"proj": np.moveaxis(out["proj"].detach().cpu().numpy(), 1, -1)} 96 | ) 97 | 98 | end = time.time() 99 | times.append(end - start) 100 | 101 | outputs.append(out_raw) 102 | iterator.set_postfix( 103 | { 104 | "shape": out_raw["image"].shape, 105 | "reserved": torch.cuda.memory_reserved(0) / (1024 ** 2), 106 | "allocated": torch.cuda.memory_allocated(0) / (1024 ** 2), 107 | } 108 | ) 109 | 110 | if empty_cache: 111 | torch.cuda.empty_cache() 112 | except RuntimeError: 113 | errors_num += 1 114 | print('Skipping image ', i) 115 | pass 116 | else: 117 | indices.append(i) 118 | 119 | print(f"Average inference time: {np.mean(times) * 1000:.2f} ms") 120 | print(f'Dropped images rate: {errors_num / len(loader)}') 121 | return outputs, indices # СТЫД 122 | 123 | @torch.no_grad() 124 | def inference_single_image_dataset( 125 | self, 126 | dataset: Dataset, 127 | batch_size: int = 1, 128 | num_workers: int = 0, 129 | crop_border: int = 0, 130 | device: str = "cpu", 131 | half: bool = False, 132 | empty_cache: bool = False, 133 | key: str = "image", 134 | convolve: bool = False, 135 | ) -> np.ndarray: 136 | """ 137 | Run inference for a single image represented as Dataset 138 | Here, we assume that dataset was tiled and has a `tiler` attribute 139 | 140 | :param dataset: Dataset representing a single large tiled image 141 | :param batch_size: int, batch size for DataLoader 142 | :param num_workers: int, number of workers for DataLoader 143 | :param crop_border: int, border pixels to crop when merging tiles 144 | :param device: str, device where to accumulate merging tiles 145 | :param half: bool, if use half precision 146 | :param empty_cache: bool, if empty CUDA cache after 147 | :param key: str, which output key to accumulate 148 | :param convolve: bool, if convolve the output 149 | :return: numpy array, merged image 150 | """ 151 | assert hasattr(dataset, "tiler"), "Dataset should have a `tiler` attribute" 152 | 153 | self.model.eval() 154 | 155 | merger = TileMerger( 156 | image_shape=dataset.tiler.target_shape, 157 | channels=self.model.in_channels, 158 | weight=dataset.tiler.weight, 159 | device=device, 160 | crop_border=crop_border, 161 | default_value=0, 162 | ) 163 | # print(f'Created merger for image {merger.image.shape}') 164 | 165 | iterator = dataset 166 | if batch_size > 1: 167 | iterator = torch.utils.data.DataLoader( 168 | dataset, 169 | batch_size=batch_size, 170 | num_workers=num_workers, 171 | shuffle=False, 172 | pin_memory=False, 173 | drop_last=False, 174 | ) 175 | iterator = tqdm(iterator, desc="Predict") 176 | 177 | for i, batch in enumerate(iterator): 178 | # We can iterate over dataset, hence need to unsqueeze batch dim 179 | if batch_size == 1: 180 | batch["image"] = batch["image"][None, ...] 181 | batch["crop"] = batch["crop"][None, ...] 182 | 183 | # We don't need move to device for `crop` 184 | batch = { 185 | k: v.to(self.device) if k != "crop" else v for k, v in batch.items() 186 | } 187 | with autocast(enabled=half): 188 | pred_batch = ( 189 | self.model.forward(batch["image"], convolve=convolve)[1][key] 190 | * batch["std"] 191 | + batch["mean"] 192 | ) 193 | iterator.set_postfix( 194 | { 195 | "in_shape": tuple(batch["image"].shape), 196 | "out_shape": tuple(pred_batch.shape), 197 | "crop": batch["crop"], 198 | } 199 | ) 200 | 201 | merger.integrate_batch(batch=pred_batch, crop_coords=batch["crop"]) 202 | if empty_cache: 203 | torch.cuda.empty_cache() 204 | 205 | merger.merge_() 206 | return dataset.tiler.crop_to_original_size( 207 | merger.image["image"].cpu().numpy()[0] 208 | ) 209 | 210 | @torch.no_grad() 211 | def inference_single_image_tensor( 212 | self, 213 | image: torch.Tensor, 214 | standardize: bool = True, 215 | im_mean: Optional[float] = None, 216 | im_std: Optional[float] = None, 217 | ) -> torch.Tensor: 218 | """ 219 | Run inference for a single image represented as Dataset 220 | 221 | :param image: torch.Tensor 222 | :param standardize: bool, if subtract mean and divide by std 223 | :param im_mean: float, precalculated image mean 224 | :param im_std: float, precalculated image std 225 | :return: torch.Tensor 226 | """ 227 | 228 | if not standardize: 229 | im_mean, im_std = 0, 1 230 | 231 | if im_mean is None: 232 | im_mean = image.mean() 233 | 234 | if im_std is None: 235 | im_std = image.std() 236 | 237 | image = (image - im_mean) / im_std 238 | 239 | image = self.resizer.before(image, exclude=0)[None, ...] 240 | out = self.model(image.to(self.device)).detach().cpu() 241 | out = self.resizer.after(out[0], exclude=0) 242 | out = out * im_std + im_mean 243 | 244 | return out 245 | 246 | def load_checkpoint(self, checkpoint_path: str): 247 | checkpoint = torch.load(checkpoint_path, map_location="cpu") 248 | if "model" in checkpoint: 249 | checkpoint["model"].pop("mask_kernel.kernel", None) 250 | checkpoint = checkpoint["model"] 251 | self.model.load_state_dict(checkpoint, strict=False) 252 | -------------------------------------------------------------------------------- /noise2same/fft_conv.py: -------------------------------------------------------------------------------- 1 | # https://github.com/fkodom/fft-conv-pytorch 2 | from functools import partial 3 | from typing import Iterable, Tuple, Union 4 | 5 | import torch 6 | import torch.nn.functional as f 7 | from torch import Tensor, nn 8 | from torch.fft import irfftn, rfftn 9 | 10 | 11 | def complex_matmul(a: Tensor, b: Tensor, groups: int = 1) -> Tensor: 12 | """Multiplies two complex-valued tensors.""" 13 | # Scalar matrix multiplication of two tensors, over only the first channel 14 | # dimensions. Dimensions 3 and higher will have the same shape after multiplication. 15 | # We also allow for "grouped" multiplications, where multiple sections of channels 16 | # are multiplied independently of one another (required for group convolutions). 17 | scalar_matmul = partial(torch.einsum, "agc..., gbc... -> agb...") 18 | a = a.view(a.size(0), groups, -1, *a.shape[2:]) 19 | b = b.view(groups, -1, *b.shape[1:]) 20 | 21 | # Compute the real and imaginary parts independently, then manually insert them 22 | # into the output Tensor. This is fairly hacky but necessary for PyTorch 1.7.0, 23 | # because Autograd is not enabled for complex matrix operations yet. Not exactly 24 | # idiomatic PyTorch code, but it should work for all future versions (>= 1.7.0). 25 | real = scalar_matmul(a.real, b.real) - scalar_matmul(a.imag, b.imag) 26 | imag = scalar_matmul(a.imag, b.real) + scalar_matmul(a.real, b.imag) 27 | c = torch.zeros(real.shape, dtype=torch.complex64, device=a.device) 28 | c.real, c.imag = real, imag 29 | 30 | return c.view(c.size(0), -1, *c.shape[3:]) 31 | 32 | 33 | def to_ntuple(val: Union[int, Iterable[int]], n: int) -> Tuple[int, ...]: 34 | """Casts to a tuple with length 'n'. Useful for automatically computing the 35 | padding and stride for convolutions, where users may only provide an integer. 36 | Args: 37 | val: (Union[int, Iterable[int]]) Value to cast into a tuple. 38 | n: (int) Desired length of the tuple 39 | Returns: 40 | (Tuple[int, ...]) Tuple of length 'n' 41 | """ 42 | if isinstance(val, Iterable): 43 | out = tuple(val) 44 | if len(out) == n: 45 | return out 46 | else: 47 | raise ValueError(f"Cannot cast tuple of length {len(out)} to length {n}.") 48 | else: 49 | return n * (val,) 50 | 51 | 52 | def fft_conv( 53 | signal: Tensor, 54 | kernel: Tensor, 55 | bias: Tensor = None, 56 | padding: Union[int, Iterable[int], str] = 0, 57 | stride: Union[int, Iterable[int]] = 1, 58 | groups: int = 1, 59 | padding_mode: str = "constant", 60 | ) -> Tensor: 61 | """Performs N-d convolution of Tensors using a fast fourier transform, which 62 | is very fast for large kernel sizes. Also, optionally adds a bias Tensor after 63 | the convolution (in order ot mimic the PyTorch direct convolution). 64 | Args: 65 | signal: (Tensor) Input tensor to be convolved with the kernel. 66 | kernel: (Tensor) Convolution kernel. 67 | bias: (Tensor) Bias tensor to add to the output. 68 | padding: (Union[int, Iterable[int]) Number of zero samples to pad the 69 | input on the last dimension. 70 | stride: (Union[int, Iterable[int]) Stride size for computing output values. 71 | groups: (Union[int, Iterable[int]]) 72 | padding_mode: (str) Padding mode to use from {constant, reflection, replication}. 73 | reflection not available for 3d. 74 | Returns: 75 | (Tensor) Convolved tensor 76 | """ 77 | # Cast stride to tuple. 78 | stride_ = to_ntuple(stride, n=signal.ndim - 2) 79 | 80 | if padding != "same": 81 | padding_ = to_ntuple(padding, n=signal.ndim - 2) 82 | signal_padding = [p for p in padding_[::-1] for _ in range(2)] 83 | else: 84 | # signal_padding = [ 85 | # (0, 0) if k <= s else ((k - s) // 2, k - (k - s) // 2) 86 | # for s, k, in zip(signal.shape[2:], kernel.shape[2:]) 87 | # ] 88 | # signal_padding = [p for pd in signal_padding[::-1] for p in pd] 89 | padding_ = [k // 2 for k in kernel.shape[2:]] 90 | 91 | signal_padding = [p for p in padding_[::-1] for _ in range(2)] 92 | # Pad the input signal & kernel tensors 93 | signal = f.pad(signal, signal_padding, mode=padding_mode) 94 | 95 | # Because PyTorch computes a *one-sided* FFT, we need the final dimension to 96 | # have *even* length. Just pad with one more zero if the final dimension is odd. 97 | signal_size = signal.size() # original signal size without padding to even 98 | if signal.size(-1) % 2 != 0: 99 | signal = f.pad(signal, [0, 1]) 100 | 101 | kernel_padding = [ 102 | pad 103 | for i in reversed(range(2, signal.ndim)) 104 | for pad in [0, signal.size(i) - kernel.size(i)] 105 | ] 106 | 107 | padded_kernel = f.pad(kernel, kernel_padding) 108 | assert ( 109 | padded_kernel.shape[1:] == signal.shape[1:] 110 | ), f"padded kernel shape {padded_kernel.shape} not equal to signal shape {signal.shape}" 111 | 112 | # Perform fourier convolution -- FFT, matrix multiply, then IFFT 113 | # signal = signal.reshape(signal.size(0), groups, -1, *signal.shape[2:]) 114 | signal_fr = rfftn(signal.float(), dim=tuple(range(2, signal.ndim))) 115 | kernel_fr = rfftn(padded_kernel.float(), dim=tuple(range(2, signal.ndim))) 116 | 117 | kernel_fr.imag *= -1 118 | output_fr = complex_matmul(signal_fr, kernel_fr, groups=groups) 119 | output = irfftn(output_fr, dim=tuple(range(2, signal.ndim))) 120 | 121 | # Remove extra padded values 122 | crop_slices = [slice(None), slice(None)] + [ 123 | slice( 124 | 0, 125 | (signal_size[i] - kernel.size(i) + (kernel.size(i) % 2)), 126 | # if padding != "same" 127 | # else None, 128 | stride_[i - 2], 129 | ) 130 | for i in range(2, signal.ndim) 131 | ] 132 | output = output[crop_slices].contiguous() 133 | 134 | # Optionally, add a bias term before returning. 135 | if bias is not None: 136 | bias_shape = tuple([1, -1] + (signal.ndim - 2) * [1]) 137 | output += bias.view(bias_shape) 138 | 139 | return output 140 | 141 | 142 | class _FFTConv(nn.Module): 143 | """Base class for PyTorch FFT convolution layers.""" 144 | 145 | def __init__( 146 | self, 147 | in_channels: int, 148 | out_channels: int, 149 | kernel_size: Union[int, Iterable[int]], 150 | padding: Union[int, Iterable[int]] = 0, 151 | stride: Union[int, Iterable[int]] = 1, 152 | groups: int = 1, 153 | bias: bool = True, 154 | ndim: int = 1, 155 | ): 156 | """ 157 | Args: 158 | in_channels: (int) Number of channels in input tensors 159 | out_channels: (int) Number of channels in output tensors 160 | kernel_size: (Union[int, Iterable[int]) Square radius of the kernel 161 | padding: (Union[int, Iterable[int]) Number of zero samples to pad the 162 | input on the last dimension. 163 | stride: (Union[int, Iterable[int]) Stride size for computing output values. 164 | bias: (bool) If True, includes bias, which is added after convolution 165 | """ 166 | super().__init__() 167 | self.in_channels = in_channels 168 | self.out_channels = out_channels 169 | self.kernel_size = kernel_size 170 | self.padding = padding 171 | self.stride = stride 172 | self.groups = groups 173 | self.use_bias = bias 174 | 175 | if in_channels % groups != 0: 176 | raise ValueError( 177 | "'in_channels' must be divisible by 'groups'." 178 | f"Found: in_channels={in_channels}, groups={groups}." 179 | ) 180 | if out_channels % groups != 0: 181 | raise ValueError( 182 | "'out_channels' must be divisible by 'groups'." 183 | f"Found: out_channels={out_channels}, groups={groups}." 184 | ) 185 | 186 | kernel_size = to_ntuple(kernel_size, ndim) 187 | self.weight = nn.Parameter( 188 | torch.randn(out_channels, in_channels // groups, *kernel_size) 189 | ) 190 | self.bias = ( 191 | nn.Parameter( 192 | torch.randn( 193 | out_channels, 194 | ) 195 | ) 196 | if bias 197 | else None 198 | ) 199 | 200 | def forward(self, signal): 201 | return fft_conv( 202 | signal, 203 | self.weight, 204 | bias=self.bias, 205 | padding=self.padding, 206 | stride=self.stride, 207 | groups=self.groups, 208 | ) 209 | 210 | 211 | FFTConv1d = partial(_FFTConv, ndim=1) 212 | FFTConv2d = partial(_FFTConv, ndim=2) 213 | FFTConv3d = partial(_FFTConv, ndim=3) 214 | -------------------------------------------------------------------------------- /noise2same/ops/__init__.py: -------------------------------------------------------------------------------- 1 | from .wrappers import * 2 | -------------------------------------------------------------------------------- /noise2same/ops/wrappers.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | def resize(input, 8 | size=None, 9 | scale_factor=None, 10 | mode='nearest', 11 | align_corners=None, 12 | warning=True): 13 | if warning: 14 | if size is not None and align_corners: 15 | input_h, input_w = tuple(int(x) for x in input.shape[2:]) 16 | output_h, output_w = tuple(int(x) for x in size) 17 | if output_h > input_h or output_w > output_h: 18 | if ((output_h > 1 and output_w > 1 and input_h > 1 19 | and input_w > 1) and (output_h - 1) % (input_h - 1) 20 | and (output_w - 1) % (input_w - 1)): 21 | warnings.warn( 22 | f'When align_corners={align_corners}, ' 23 | 'the output would more aligned if ' 24 | f'input size {(input_h, input_w)} is `x+1` and ' 25 | f'out size {(output_h, output_w)} is `nx+1`') 26 | if isinstance(size, torch.Size): 27 | size = tuple(int(x) for x in size) 28 | return F.interpolate(input, size, scale_factor, mode, align_corners) 29 | -------------------------------------------------------------------------------- /noise2same/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .esadam import ESAdam 2 | 3 | __all__ = ["ESAdam"] 4 | -------------------------------------------------------------------------------- /noise2same/optimizers/esadam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import Adam 3 | 4 | 5 | class ESAdam(Adam): 6 | """ 7 | Implements a modified version of the Adam algorithm that adds noise to the 8 | the parameters. 9 | https://github.com/royerlab/ssi-code/blob/master/ssi/optimisers/esadam.py 10 | """ 11 | 12 | def __init__(self, params, start_noise_level=0.001, **kwargs): 13 | super().__init__(params, **kwargs) 14 | 15 | self.start_noise_level = start_noise_level 16 | self.step_counter = 0 17 | 18 | def step(self, closure=None): 19 | """Performs a single optimization step. 20 | 21 | Arguments: 22 | closure (callable, optional): A closure that reevaluates the model 23 | and returns the loss. 24 | """ 25 | loss = super().step(closure) 26 | 27 | for group in self.param_groups: 28 | for p in group["params"]: 29 | if p.grad is None: 30 | continue 31 | grad: torch.Tensor = p.grad.data 32 | if grad.is_sparse: 33 | continue 34 | 35 | step_size = group["lr"] 36 | 37 | p.data += ( 38 | step_size 39 | * (self.start_noise_level / (1 + self.step_counter)) 40 | * (torch.randn_like(p.data)) 41 | ) 42 | 43 | self.step_counter += 1 44 | 45 | return loss 46 | -------------------------------------------------------------------------------- /noise2same/psf/__init__.py: -------------------------------------------------------------------------------- 1 | from .microscope_psf import MicroscopePSF, SimpleMicroscopePSF 2 | from .psf_convolution import read_psf, PSF, PSFParameter 3 | 4 | __all__ = ["MicroscopePSF", "SimpleMicroscopePSF", "read_psf", 5 | "PSF", "PSFParameter"] 6 | -------------------------------------------------------------------------------- /noise2same/psf/psf_convolution.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from pathlib import Path 3 | from typing import Optional, Union 4 | 5 | import h5py 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as f 9 | from skimage import io 10 | from torch import nn 11 | 12 | from noise2same.fft_conv import FFTConv2d, FFTConv3d, fft_conv 13 | from noise2same.util import center_crop 14 | 15 | 16 | class PSF(nn.Module): 17 | def __init__( 18 | self, 19 | kernel_psf: np.ndarray, 20 | in_channels: int = 1, 21 | pad_mode="replicate", 22 | fft: Union[str, bool] = "auto", 23 | ): 24 | """ 25 | Point-spread function 26 | https://github.com/royerlab/ssi-code/blob/master/ssi/models/psf_convolution.py 27 | :param kernel_psf: 2D or 3D np.ndarray 28 | :param pad_mode: {"reflect", "replicate"} 29 | :param in_channels: int, number of channels to convolve 30 | """ 31 | super().__init__() 32 | self.kernel_size = kernel_psf.shape[0] 33 | self.n_dim = len(kernel_psf.shape) 34 | self.fft = fft 35 | if self.fft == "auto": 36 | # TODO run own benchmarks 37 | # Use FFT Conv if kernel has > 100 elements 38 | self.fft = self.kernel_size ** self.n_dim > 100 39 | # self.fft = (self.kernel_size > 21 and self.n_dim == 2) or ( 40 | # self.kernel_size > 7 and self.n_dim == 3 41 | # ) 42 | if isinstance(self.fft, str): 43 | raise ValueError(f"Invalid fft value {self.fft}") 44 | 45 | if self.n_dim == 3 and pad_mode == "reflect": 46 | # Not supported yet 47 | pad_mode = "replicate" 48 | 49 | self.pad_mode = pad_mode 50 | self.pad = (self.kernel_size - 1) // 2 51 | assert self.n_dim in (2, 3) 52 | 53 | if fft: 54 | conv = FFTConv2d if self.n_dim == 2 else FFTConv3d 55 | else: 56 | conv = nn.Conv2d if self.n_dim == 2 else nn.Conv3d 57 | 58 | self.psf = conv( 59 | in_channels=in_channels, 60 | out_channels=in_channels, 61 | kernel_size=self.kernel_size, 62 | stride=1, 63 | padding=0, 64 | bias=False, 65 | groups=in_channels, 66 | ) 67 | 68 | self.weights_init(kernel_psf) 69 | 70 | def weights_init(self, kernel_psf: np.ndarray): 71 | for name, f in self.named_parameters(): 72 | f.data.copy_(torch.from_numpy(kernel_psf)) 73 | 74 | def forward(self, x: torch.Tensor) -> torch.Tensor: 75 | x = f.pad(x, (self.pad,) * self.n_dim * 2, mode=self.pad_mode) 76 | return self.psf(x) 77 | 78 | 79 | class PSFParameter(nn.Module): 80 | def __init__( 81 | self, 82 | kernel_psf: np.ndarray, 83 | in_channels: int = 1, 84 | pad_mode="replicate", 85 | trainable=False, 86 | fft: Union[str, bool] = "auto", 87 | auto_padding: bool = False, 88 | ): 89 | """ 90 | Parametrized trainable version of PSF 91 | :param kernel_psf: 92 | :param in_channels: 93 | :param pad_mode: 94 | :param trainable: 95 | :param auto_padding: (bool) If True, automatically computes padding based on the 96 | signal size, kernel size and stride. 97 | """ 98 | super().__init__() 99 | self.kernel_size = kernel_psf.shape 100 | self.n_dim = len(kernel_psf.shape) 101 | 102 | if self.n_dim == 3 and pad_mode == "reflect": 103 | # Not supported yet 104 | pad_mode = "replicate" 105 | 106 | self.in_channels = in_channels 107 | self.pad_mode = pad_mode 108 | 109 | self.pad = [k // 2 for k in self.kernel_size] 110 | assert self.n_dim in (2, 3) 111 | 112 | self.fft = fft 113 | if self.fft == "auto": 114 | # Use FFT Conv if kernel has > 100 elements 115 | self.fft = np.product(self.kernel_size) > 100 116 | if isinstance(self.fft, str): 117 | raise ValueError(f"Invalid fft value {self.fft}") 118 | 119 | if not self.fft: 120 | auto_padding = False 121 | self.auto_padding = auto_padding 122 | 123 | self.psf = torch.from_numpy(kernel_psf.squeeze()[(None,) * 2]).float() 124 | self.psf = nn.Parameter(self.psf, requires_grad=trainable) 125 | 126 | def forward(self, x: torch.Tensor) -> torch.Tensor: 127 | 128 | if self.fft: 129 | conv = partial( 130 | fft_conv, 131 | padding_mode=self.pad_mode, 132 | padding="same" if self.auto_padding else self.pad, 133 | ) 134 | else: 135 | signal_padding = tuple(np.repeat(self.pad[::-1], 2)) 136 | x = f.pad(x, signal_padding, mode=self.pad_mode) 137 | conv = torch.conv2d if self.n_dim == 2 else torch.conv3d 138 | 139 | x = conv(x, self.psf, groups=self.in_channels, stride=1) 140 | return x 141 | 142 | 143 | def read_psf( 144 | path: Union[Path, str], psf_size: Optional[int] = None, normalize: bool = True 145 | ) -> np.ndarray: 146 | """ 147 | Reads PSF from .h5 or .tif file 148 | :param path: absolute path to file 149 | :param psf_size: int, optional, crop PSF to a cube of this size if provided 150 | :param normalize: bool, is divide PSF by its sum 151 | :return: PSF as numpy array 152 | """ 153 | path = str(path) 154 | if path.endswith(".h5"): 155 | with h5py.File(path, "r") as f: 156 | psf = f["psf"] 157 | else: 158 | psf = io.imread(path) 159 | 160 | if psf_size is not None: 161 | psf = center_crop(psf, psf_size) 162 | 163 | if normalize: 164 | psf /= psf.sum() 165 | 166 | return psf 167 | -------------------------------------------------------------------------------- /noise2same/trainer.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | from pathlib import Path 3 | from typing import Any, Dict, List, Optional, Tuple 4 | 5 | import numpy as np 6 | import torch 7 | import wandb 8 | from torch.cuda.amp import GradScaler, autocast 9 | from torch.utils.data import DataLoader 10 | from tqdm import tqdm, trange 11 | 12 | from noise2same.evaluator import Evaluator 13 | from noise2same.model import Noise2Same 14 | from noise2same.util import ( 15 | detach_to_np, 16 | load_checkpoint_to_module, 17 | normalize_zero_one_dict, 18 | ) 19 | from evaluate import get_scores 20 | 21 | 22 | class Trainer(object): 23 | def __init__( 24 | self, 25 | model: Noise2Same, 26 | optimizer, 27 | scheduler=None, 28 | device: str = "cuda", 29 | checkpoint_path: str = "checkpoints", 30 | monitor: str = "val_rec_mse", 31 | experiment: str = None, 32 | check: bool = False, 33 | wandb_log: bool = True, 34 | amp: bool = False, 35 | info_padding: bool = False, 36 | ): 37 | 38 | self.model = model 39 | self.inner_model = model if not isinstance(model, torch.nn.DataParallel) else model.module 40 | self.optimizer = optimizer 41 | self.scheduler = scheduler 42 | self.device = device 43 | self.checkpoint_path = Path(checkpoint_path) 44 | self.monitor = monitor 45 | self.experiment = experiment 46 | self.check = check 47 | if check: 48 | wandb_log = False 49 | self.wandb_log = wandb_log 50 | self.info_padding = info_padding 51 | 52 | self.model.to(device) 53 | self.checkpoint_path.mkdir(parents=True, exist_ok=False) 54 | self.evaluator = Evaluator(model=self.inner_model, device=device) 55 | 56 | self.amp = amp 57 | self.scaler = GradScaler() if amp else None 58 | 59 | def optimizer_scheduler_step(self, loss: torch.Tensor): 60 | """ 61 | Step the optimizer and scheduler given the loss 62 | :param loss: 63 | :return: 64 | """ 65 | if self.scaler is not None: 66 | self.scaler.scale(loss).backward() 67 | self.scaler.step(self.optimizer) 68 | self.scaler.update() 69 | else: 70 | loss.backward() 71 | self.optimizer.step() 72 | 73 | if self.scheduler is not None: 74 | self.scheduler.step() 75 | 76 | def one_epoch( 77 | self, loader: DataLoader 78 | ) -> Tuple[Dict[str, float], Dict[str, np.ndarray]]: 79 | self.model.train() 80 | iterator = tqdm(loader, desc="train") 81 | total_loss = Counter() 82 | images = {} 83 | for i, batch in enumerate(iterator): 84 | x = batch["image"].to(self.device) 85 | ground_truth = None if "ground_truth" not in batch else batch["ground_truth"].to(self.device) 86 | mask = batch["mask"].to(self.device) 87 | self.optimizer.zero_grad() 88 | 89 | # todo gradient accumulation 90 | with autocast(enabled=self.amp): 91 | if self.info_padding: 92 | # provide full size image to avoid zero padding 93 | padding = [ 94 | (b, a) 95 | for b, a in zip( 96 | loader.dataset.tiler.margin_start, 97 | loader.dataset.tiler.margin_end, 98 | ) 99 | ] + [(0, 0)] 100 | full_size_image = np.pad(loader.dataset.image, padding) 101 | full_size_image = torch.from_numpy( 102 | np.moveaxis(full_size_image, -1, 0) 103 | ).to(self.device) 104 | out_mask, out_raw = self.model.forward( 105 | x, mask=mask, crops=batch["crop"], full_size_image=full_size_image 106 | ) 107 | else: 108 | try: 109 | out_mask, out_raw = self.model.forward(x, mask=mask) 110 | except RuntimeError as e: 111 | raise RuntimeError(f"Batch {x.shape} failed on device {x.device}") from e 112 | 113 | loss, loss_log = self.inner_model.compute_losses_from_output( 114 | x, mask, out_mask, out_raw, ground_truth 115 | ) 116 | 117 | reg_loss, reg_loss_log = self.inner_model.compute_regularization_loss( 118 | out_mask if out_raw is None else out_raw, 119 | mean=batch["mean"].to(self.device), 120 | std=batch["std"].to(self.device), 121 | max_value=x.max(), 122 | ) 123 | 124 | if reg_loss_log: 125 | loss += reg_loss 126 | loss_log.update(reg_loss_log) 127 | self.optimizer_scheduler_step(loss) 128 | total_loss += loss_log 129 | iterator.set_postfix({k: v / (i + 1) for k, v in total_loss.items()}) 130 | 131 | if self.check and i > 3: 132 | break 133 | 134 | # Log last batch of images 135 | if i == len(iterator) - 1 or (self.check and i == 3): 136 | images = { 137 | "input": x 138 | } 139 | 140 | if out_mask is not None: 141 | images["out_mask"] = out_mask["image"] 142 | 143 | if out_raw is not None: 144 | images["out_raw"] = out_raw["image"] 145 | 146 | if out_mask is not None and "deconv" in out_mask: 147 | images["out_mask_deconv"] = out_mask["deconv"] 148 | if out_raw is not None and "deconv" in out_raw: 149 | images["out_raw_deconv"] = out_raw["deconv"] 150 | 151 | images = detach_to_np(images, mean=batch["mean"], std=batch["std"]) 152 | images = normalize_zero_one_dict(images) 153 | 154 | total_loss = {k: v / len(loader) for k, v in total_loss.items()} 155 | if self.scheduler is not None: 156 | total_loss["lr"] = self.scheduler.get_last_lr()[0] 157 | return total_loss, images 158 | 159 | @torch.no_grad() 160 | def validate( 161 | self, loader: DataLoader 162 | ) -> Tuple[Dict[str, float], Dict[str, np.ndarray]]: 163 | self.model.eval() 164 | iterator = tqdm(loader, desc="valid") 165 | 166 | total_loss = 0 167 | val_mse_log = [] 168 | images = {} 169 | for i, batch in enumerate(iterator): 170 | x = batch["image"].to(self.device) 171 | 172 | with autocast(enabled=self.amp): 173 | out_raw = self.model(x)[1]["image"] 174 | 175 | if "ground_truth" in batch.keys(): 176 | val_mse = torch.mean(torch.square(out_raw - batch["ground_truth"].to(self.device))) 177 | val_mse_log.append(val_mse.item()) 178 | 179 | rec_mse = torch.mean(torch.square(out_raw - x)) 180 | total_loss += rec_mse.item() 181 | 182 | iterator.set_postfix({"val_rec_mse": total_loss / (i + 1)}) 183 | 184 | if self.check and i > 3: 185 | break 186 | 187 | if i == len(iterator) - 1 or (self.check and i == 3): 188 | images = { 189 | "val_input": x, 190 | "val_out_raw": out_raw, 191 | } 192 | images = detach_to_np(images, mean=batch["mean"], std=batch["std"]) 193 | images = normalize_zero_one_dict(images) 194 | if len(val_mse_log) > 0: 195 | return {"val_rec_mse": total_loss / len(loader), "val_mse": np.mean(val_mse_log)}, images 196 | return {"val_rec_mse": total_loss / len(loader)}, images 197 | 198 | def inference(self, *args: Any, **kwargs: Any): 199 | return self.evaluator.inference(*args, **kwargs) 200 | 201 | def inference_single_image_dataset(self, *args: Any, **kwargs: Any): 202 | return self.evaluator.inference_single_image_dataset(*args, **kwargs) 203 | 204 | def inference_single_image_tensor(self, *args: Any, **kwargs: Any): 205 | return self.evaluator.inference_single_image_tensor(*args, **kwargs) 206 | 207 | def fit( 208 | self, 209 | n_epochs: int, 210 | loader_train: DataLoader, 211 | loader_valid: Optional[DataLoader] = None, 212 | ) -> List[Dict[str, float]]: 213 | 214 | iterator = trange(n_epochs, position=0, leave=True) 215 | history = [] 216 | best_loss = np.inf 217 | 218 | # if self.wandb_log: 219 | # wandb.watch(self.model) 220 | 221 | try: 222 | for i in iterator: 223 | loss, images = self.one_epoch(loader_train) 224 | if loader_valid is not None: 225 | loss_valid, images_valid = self.validate(loader_valid) 226 | loss.update(loss_valid) 227 | images.update(images_valid) 228 | 229 | # Log training 230 | if self.wandb_log: 231 | images_wandb = { 232 | # limit the number of uploaded images 233 | # if image is 3d, reduce it 234 | k: [ 235 | wandb.Image(im.max(0) if self.inner_model.n_dim == 3 else im) 236 | for im in v[:4] 237 | ] 238 | for k, v in images.items() 239 | } 240 | wandb.log({**images_wandb, **loss}) 241 | 242 | # Show progress 243 | iterator.set_postfix(loss) 244 | history.append(loss) 245 | 246 | # Save last model 247 | self.save_model("model_last") 248 | 249 | if self.check and i > 3: 250 | break 251 | 252 | # Save best model 253 | if self.monitor not in loss: 254 | print( 255 | f"Nothing to monitor! {self.monitor} not in recorded losses {list(loss.keys())}" 256 | ) 257 | continue 258 | 259 | if loss[self.monitor] < best_loss: 260 | print( 261 | f"Saved best model by {self.monitor}: {loss[self.monitor]:.4e} < {best_loss:.4e}" 262 | ) 263 | self.save_model() 264 | best_loss = loss[self.monitor] 265 | 266 | except KeyboardInterrupt: 267 | print("Interrupted") 268 | 269 | return history 270 | 271 | def save_model(self, name: str = "model"): 272 | torch.save( 273 | { 274 | "model": self.inner_model.state_dict(), 275 | "optimizer": self.optimizer.state_dict(), 276 | "scheduler": self.scheduler.state_dict(), 277 | }, 278 | f"{self.checkpoint_path}/{name}.pth", 279 | ) 280 | 281 | def load_model(self, path: Optional[str] = None): 282 | if path is None: 283 | path = self.checkpoint_path / "model.pth" 284 | load_checkpoint_to_module(self, path) 285 | -------------------------------------------------------------------------------- /noise2same/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from functools import partial 4 | from typing import Any, Dict, Tuple, Union 5 | 6 | import cv2 7 | import numpy as np 8 | import torch 9 | from matplotlib import pyplot as plt 10 | from numpy import ndarray 11 | from numpy.linalg import norm 12 | from scipy.fft import dct 13 | from skimage.metrics import ( 14 | mean_squared_error, 15 | peak_signal_noise_ratio, 16 | structural_similarity, 17 | ) 18 | 19 | 20 | def clean_plot(ax: np.ndarray) -> None: 21 | """ 22 | Plot axes without ticks in tight layout 23 | :param ax: ndarray of matplotlib axes 24 | :return: 25 | """ 26 | plt.setp(ax, xticks=[], yticks=[]) 27 | plt.tight_layout() 28 | plt.show() 29 | 30 | 31 | def fix_seed(seed: int = 56) -> None: 32 | """ 33 | Fix all random seeds for reproducibility 34 | :param seed: 35 | :return: 36 | """ 37 | random.seed(seed) 38 | os.environ["PYTHONHASHSEED"] = str(seed) 39 | np.random.seed(seed) 40 | torch.manual_seed(seed) 41 | torch.cuda.manual_seed(seed) 42 | torch.backends.cudnn.deterministic = True 43 | torch.backends.cudnn.benchmark = False 44 | 45 | 46 | def crop_as(x: np.ndarray, gt: np.ndarray) -> np.ndarray: 47 | """ 48 | Crops x to gt shape evenly from each side 49 | (assumes even padding to remove) 50 | :param x: 51 | :param gt: 52 | :return: cropped x 53 | """ 54 | diff = np.array(x.shape) - np.array(gt.shape) 55 | assert np.all(diff >= 0) 56 | top_left = diff // 2 57 | bottom_right = diff - top_left 58 | sl = tuple(slice(tl, s - br) for tl, s, br in zip(top_left, x.shape, bottom_right)) 59 | crop = x[sl] 60 | assert crop.shape == gt.shape 61 | return crop 62 | 63 | 64 | def center_crop(x: np.ndarray, size: int = 63) -> np.ndarray: 65 | """ 66 | Crops a central part of an array 67 | (used for PSF) 68 | :param x: source 69 | :param size: to crop 70 | :return: cropped array 71 | """ 72 | h = size // 2 73 | return x[tuple(slice(max(0, d // 2 - h), min(d // 2 + h + 1, d)) for d in x.shape)] 74 | 75 | 76 | def calculate_scores( 77 | gt: np.ndarray, 78 | x: np.ndarray, 79 | data_range: float = 1.0, 80 | normalize_pairs: bool = False, 81 | scale: bool = False, 82 | clip: bool = False, 83 | multichannel: bool = False, 84 | prefix: str = "", 85 | calculate_mi: bool = False, 86 | **kwargs: Any, 87 | ) -> Dict[str, float]: 88 | """ 89 | Calculates image reconstruction metrics 90 | :param gt: ndarray, the ground truth image 91 | :param x: ndarray, prediction 92 | :param data_range: The data range of the input image, 1 by default (0-1 normalized images) 93 | :param normalize_pairs: bool, normalize and affinely scale pairs gt-x (needed for Planaria dataset) 94 | :param scale: bool, scale images by min and max (needed for Imagenet dataset) 95 | :param clip: bool, clip an image to [0, data_range] 96 | :param multichannel: If True, treat the last dimension of the array as channels for SSIM. Similarity 97 | calculations are done independently for each channel then averaged. 98 | :param prefix: str, prefix for metric names 99 | :param calculate_mi: bool, calculate mutual information and spectral mutual information 100 | :param kwargs: kwargs for SSIM 101 | :return: 102 | """ 103 | x_ = crop_as(x, gt) 104 | assert gt.shape == x_.shape, f"Different shapes {gt.shape}, {x_.shape}" 105 | if scale: 106 | x_ = normalize_zero_one(x_) * data_range 107 | if normalize_pairs: 108 | gt, x_ = normalize_min_mse(gt, x_) 109 | if clip: 110 | x_ = np.clip(x_, 0, data_range) 111 | 112 | if prefix: 113 | prefix += "." 114 | 115 | metrics = { 116 | prefix + "rmse": np.sqrt(mean_squared_error(gt, x_)), 117 | prefix + "psnr": peak_signal_noise_ratio(gt, x_, data_range=data_range), 118 | prefix 119 | + "ssim": structural_similarity( 120 | gt, x_, data_range=data_range, channel_axis=-1 if multichannel else None, **kwargs, 121 | ), 122 | } 123 | 124 | if calculate_mi: 125 | metrics.update( 126 | { 127 | prefix + "mi": mutual_information(gt, x_), 128 | prefix + "smi": spectral_mutual_information(gt, x_), 129 | } 130 | ) 131 | 132 | return metrics 133 | 134 | 135 | # Normalization utils from Noise2Void 136 | def normalize_mi_ma( 137 | x: np.ndarray, 138 | mi: Union[float, np.ndarray], 139 | ma: Union[float, np.ndarray], 140 | clip: bool = False, 141 | eps: float = 1e-20, 142 | dtype: type = np.float32, 143 | ): 144 | """ 145 | 146 | :param x: 147 | :param mi: 148 | :param ma: 149 | :param clip: 150 | :param eps: 151 | :param dtype: 152 | :return: 153 | """ 154 | if dtype is not None: 155 | x = x.astype(dtype, copy=False) 156 | mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype, copy=False) 157 | ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype, copy=False) 158 | eps = dtype(eps) 159 | try: 160 | import numexpr 161 | 162 | x = numexpr.evaluate("(x - mi) / ( ma - mi + eps )") 163 | except ImportError: 164 | x = (x - mi) / (ma - mi + eps) 165 | if clip: 166 | x = np.clip(x, 0, 1) 167 | return x 168 | 169 | 170 | def normalize_percentile( 171 | x, 172 | p_min: float = 2.0, 173 | p_max: float = 99.8, 174 | axis: Union[int, Tuple[int, ...]] = None, 175 | clip: bool = False, 176 | eps: float = 1e-20, 177 | dtype: type = np.float32, 178 | ): 179 | """ 180 | Percentile-based image normalization. 181 | :param x: 182 | :param p_min: 183 | :param p_max: 184 | :param axis: 185 | :param clip: 186 | :param eps: 187 | :param dtype: 188 | :return: 189 | """ 190 | 191 | mi = np.percentile(x, p_min, axis=axis, keepdims=True) 192 | ma = np.percentile(x, p_max, axis=axis, keepdims=True) 193 | return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype) 194 | 195 | 196 | normalize_zero_one = partial(normalize_percentile, p_min=0, p_max=100, clip=True) 197 | 198 | 199 | def normalize_zero_one_dict(images: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: 200 | """ 201 | Normalizes all images in the given dictionary to the range [0, 1]. 202 | """ 203 | return {k: normalize_zero_one(v) for k, v in images.items()} 204 | 205 | 206 | def normalize_min_mse(gt: np.ndarray, x: np.ndarray, normalize_gt: bool = True): 207 | """ 208 | Normalizes and affinely scales an image pair such that the MSE is minimized 209 | :param gt: ndarray, the ground truth image 210 | :param x: ndarray, the image that will be affinely scaled 211 | :param normalize_gt: bool, set to True of gt image should be normalized (default) 212 | :return: gt_scaled, x_scaled 213 | """ 214 | if normalize_gt: 215 | gt = normalize_percentile(gt, 0.1, 99.9, clip=False).astype( 216 | np.float32, copy=False 217 | ) 218 | x = x.astype(np.float32, copy=False) - np.mean(x) 219 | gt = gt.astype(np.float32, copy=False) - np.mean(gt) 220 | scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten()) 221 | return gt, scale * x 222 | 223 | 224 | def plot_3d(im: ndarray) -> None: 225 | """ 226 | Plot 3D image as three max projections 227 | :param im: image to plot 228 | :return: none 229 | """ 230 | fig = plt.figure(constrained_layout=False, figsize=(12, 7)) 231 | gs = fig.add_gridspec(nrows=3, ncols=5) 232 | 233 | ax_0 = fig.add_subplot(gs[:-1, :-1]) 234 | ax_0.imshow(np.max(im, 0)) 235 | 236 | ax_1 = fig.add_subplot(gs[-1, :-1]) 237 | ax_1.imshow(np.max(im, 1)) 238 | 239 | ax_2 = fig.add_subplot(gs[:-1, -1]) 240 | ax_2.imshow(np.rot90(np.max(im, 2))) 241 | 242 | plt.setp([ax_0, ax_1, ax_2], xticks=[], yticks=[]) 243 | plt.tight_layout() 244 | plt.show() 245 | 246 | 247 | def concat_projections(im: ndarray, axis: int = 1) -> ndarray: 248 | """ 249 | Do max projection of an image to all axes and concatenate them in 2D image 250 | Expects image to be a cube 251 | :param im: ND image 252 | :param axis: concatenate projections along it (0 - vertical concatenation, 1 - horizontal) 253 | :return: 2D concatenation of max projections 254 | """ 255 | projections = [] 256 | for i in range(im.ndim): 257 | p = np.max(im, axis=i) 258 | if i > 0 and axis > 0: 259 | p = np.rot90(p) 260 | projections.append(p) 261 | projections = np.concatenate(projections, axis=axis) 262 | return projections 263 | 264 | 265 | def concat_projections_3d(im: ndarray, projection_func: callable = np.max) -> ndarray: 266 | """ 267 | Do max projection of an image to all axes and concatenate them in 2D image 268 | :param im: ND image 269 | :param projection_func: function to make 2d from 3d, np.max by default 270 | :return: 2D concatenation of max projections 271 | """ 272 | projections = np.zeros((im.shape[0] + im.shape[1], im.shape[0] + im.shape[2])) 273 | shifts = [(0, 0), (im.shape[1], 0), (0, im.shape[2])] 274 | for i, s in enumerate(im.shape): 275 | p = projection_func(im, axis=i) 276 | if i == 2: 277 | p = np.rot90(p) 278 | ps = tuple(slice(0 + sh, d + sh) for d, sh in zip(p.shape, shifts[i])) 279 | projections[ps] = p 280 | return projections 281 | 282 | 283 | def plot_projections(im: ndarray, axis: int = 1) -> None: 284 | """ 285 | Plot batch projections from `concat_projections` 286 | :param im: ND image 287 | :param axis: concatenate projections along it (0 - vertical concatenation, 1 - horizontal) 288 | :return: 289 | """ 290 | projections = concat_projections(im, axis) 291 | fig, ax = plt.subplots() 292 | ax.imshow(projections) 293 | clean_plot(ax) 294 | 295 | 296 | def load_checkpoint_to_module(module, checkpoint_path: str): 297 | """ 298 | Loads PyTorch state checkpoint to module 299 | :param module: nn.Module 300 | :param checkpoint_path: str, path to checkpoint 301 | :return: 302 | """ 303 | checkpoint = torch.load(checkpoint_path) 304 | for attr, state_dict in checkpoint.items(): 305 | try: 306 | getattr(module, attr).load_state_dict(state_dict) 307 | except AttributeError: 308 | print( 309 | f"Attribute {attr} is present in the checkpoint but absent in the class, do not load" 310 | ) 311 | 312 | 313 | def detach_to_np( 314 | images: Dict[str, torch.Tensor], mean: torch.Tensor, std: torch.Tensor 315 | ) -> Dict[str, torch.Tensor]: 316 | """ 317 | Detaches and denormalizes all tensors in the given dictionary, then converts to np.array. 318 | """ 319 | return { 320 | k: np.moveaxis( 321 | (v.detach().cpu() * std + mean).numpy(), 322 | 1, 323 | -1, 324 | ) 325 | for k, v in images.items() 326 | } 327 | 328 | 329 | # Metrics from SSI 330 | def spectral_mutual_information(image_a, image_b, normalised=True): 331 | norm_image_a = image_a / norm(image_a.flatten(), 2) 332 | norm_image_b = image_b / norm(image_b.flatten(), 2) 333 | 334 | dct_norm_true_image = dct(dct(norm_image_a, axis=0), axis=1) 335 | dct_norm_test_image = dct(dct(norm_image_b, axis=0), axis=1) 336 | 337 | return mutual_information( 338 | dct_norm_true_image, dct_norm_test_image, normalised=normalised 339 | ) 340 | 341 | 342 | def mutual_information(image_a, image_b, bins=256, normalised=True): 343 | image_a = image_a.flatten() 344 | image_b = image_b.flatten() 345 | 346 | c_xy = np.histogram2d(image_a, image_b, bins)[0] 347 | mi = mutual_info_from_contingency(c_xy) 348 | mi = mi / joint_entropy_from_contingency(c_xy) if normalised else mi 349 | return mi 350 | 351 | 352 | def joint_entropy_from_contingency(contingency): 353 | # coordinates of non-zero entries in contingency table: 354 | nzx, nzy = np.nonzero(contingency) 355 | 356 | # non zero values: 357 | nz_val = contingency[nzx, nzy] 358 | 359 | # sum of all values in contingency table: 360 | contingency_sum = contingency.sum() 361 | 362 | # normalised contingency, i.e. probability: 363 | p = nz_val / contingency_sum 364 | 365 | # log contingency: 366 | log_p = np.log2(p) 367 | 368 | # Joint entropy: 369 | joint_entropy = -p * log_p 370 | 371 | return joint_entropy.sum() 372 | 373 | 374 | def mutual_info_from_contingency(contingency): 375 | # cordinates of non-zero entries in contingency table: 376 | nzx, nzy = np.nonzero(contingency) 377 | 378 | # non zero values: 379 | nz_val = contingency[nzx, nzy] 380 | 381 | # sum of all values in contingnecy table: 382 | contingency_sum = contingency.sum() 383 | 384 | # marginals: 385 | pi = np.ravel(contingency.sum(axis=1)) 386 | pj = np.ravel(contingency.sum(axis=0)) 387 | 388 | # 389 | log_contingency_nm = np.log2(nz_val) 390 | contingency_nm = nz_val / contingency_sum 391 | # Don't need to calculate the full outer product, just for non-zeroes 392 | outer = pi.take(nzx).astype(np.int64, copy=False) * pj.take(nzy).astype( 393 | np.int64, copy=False 394 | ) 395 | log_outer = -np.log2(outer) + np.log2(pi.sum()) + np.log2(pj.sum()) 396 | mi = ( 397 | contingency_nm * (log_contingency_nm - np.log2(contingency_sum)) 398 | + contingency_nm * log_outer 399 | ) 400 | return mi.sum() 401 | 402 | 403 | def ssim(prediction, target): 404 | """ 405 | https://github.com/TaoHuang2018/Neighbor2Neighbor/blob/2fff2978/train.py#L258 406 | """ 407 | C1 = (0.01 * 255) ** 2 408 | C2 = (0.03 * 255) ** 2 409 | img1 = prediction.astype(np.float64) 410 | img2 = target.astype(np.float64) 411 | kernel = cv2.getGaussianKernel(11, 1.5) 412 | window = np.outer(kernel, kernel.transpose()) 413 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 414 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 415 | mu1_sq = mu1 ** 2 416 | mu2_sq = mu2 ** 2 417 | mu1_mu2 = mu1 * mu2 418 | sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq 419 | sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq 420 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 421 | ssim_map = ((2 * mu1_mu2 + C1) * 422 | (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 423 | (sigma1_sq + sigma2_sq + C2)) 424 | return ssim_map.mean() 425 | 426 | 427 | def calculate_ssim(target, ref): 428 | """ 429 | Calculate SSIM 430 | the same outputs as MATLAB's 431 | https://github.com/TaoHuang2018/Neighbor2Neighbor/blob/2fff2978/train.py#L279 432 | img1, img2: [0, 255] 433 | """ 434 | img1 = np.array(target, dtype=np.float64) 435 | img2 = np.array(ref, dtype=np.float64) 436 | if not img1.shape == img2.shape: 437 | raise ValueError('Input images must have the same dimensions.') 438 | if img1.ndim == 2: 439 | return ssim(img1, img2) 440 | elif img1.ndim == 3: 441 | if img1.shape[2] == 3: 442 | ssims = [] 443 | for i in range(3): 444 | ssims.append(ssim(img1[:, :, i], img2[:, :, i])) 445 | return np.array(ssims).mean() 446 | elif img1.shape[2] == 1: 447 | return ssim(np.squeeze(img1), np.squeeze(img2)) 448 | else: 449 | raise ValueError('Wrong input image dimensions.') 450 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | timm @ git+git://github.com/rwightman/pytorch-image-models@9a1bd358c7e998799eed88b29842e3c9e5483e34 2 | segmentation_models_pytorch @ git+https://github.com/papkov/segmentation_models.pytorch 3 | pytorch_toolbelt @ git+https://github.com/papkov/pytorch-toolbelt 4 | omegaconf>=2.0.5 5 | hydra-core==1.1 6 | wandb>=0.10 7 | numpy~=1.19.2 8 | albumentations~=0.5.2 9 | scipy~=1.5.4 10 | scikit-image~=0.18.0 11 | torch~=1.7.1 12 | tifffile~=2020.12.8 13 | matplotlib~=3.3.1 14 | tqdm~=4.59.0 15 | h5py 16 | pandas 17 | einops 18 | -------------------------------------------------------------------------------- /scripts/bsd68_swin.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -J bsd68 4 | #SBATCH --output=slurm_outputs/slurm-%x.%j.out 5 | #SBATCH -N 1 6 | #SBATCH --ntasks-per-node=4 7 | #SBATCH --cpus-per-task=8 8 | #SBATCH -t 40:00:00 9 | #SBATCH --mem=50G 10 | #SBATCH --partition=gpu 11 | #SBATCH --gres=gpu:tesla:4 12 | #SBATCH --nodelist=falcon1 13 | 14 | 15 | module load any/python/3.8.3-conda 16 | 17 | conda activate n2s_env 18 | 19 | cd swinia 20 | 21 | python train.py +experiment=bsd68 +backbone=swinia project=bsd68 \ 22 | training.batch_size=64 training.crop=64 training.steps=40000 \ 23 | model.mode=noise2self -------------------------------------------------------------------------------- /scripts/bsd68_unet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -J bsd68 4 | #SBATCH --output=slurm_outputs/slurm-%x.%j.out 5 | #SBATCH -N 1 6 | #SBATCH --ntasks-per-node=1 7 | #SBATCH --cpus-per-task=8 8 | #SBATCH -t 10:00:00 9 | #SBATCH --mem=50G 10 | #SBATCH --partition=gpu 11 | #SBATCH --gres=gpu:a100-80g:1 12 | #SBATCH --exclude=falcon2,falcon3 13 | 14 | module load any/python/3.8.3-conda 15 | 16 | conda activate n2s_env 17 | 18 | cd noise2same/noise2same.pytorch 19 | 20 | python train.py +experiment=bsd68 +backbone=unet project=bsd68 model.mode=noise2same 21 | -------------------------------------------------------------------------------- /scripts/deconvolution/fmd_unet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -J cf_fish 4 | #SBATCH --output=slurm_outputs/slurm-%x.%j.out 5 | #SBATCH -N 1 6 | #SBATCH --ntasks-per-node=1 7 | #SBATCH --cpus-per-task=8 8 | #SBATCH -t 5:00:00 9 | #SBATCH --mem=50G 10 | #SBATCH --partition=gpu 11 | #SBATCH --gres=gpu:tesla:1 12 | #SBATCH --exclude=falcon2,falcon3 13 | 14 | module load any/python/3.8.3-conda 15 | 16 | conda activate n2s_env 17 | 18 | cd swinia 19 | 20 | python train.py +experiment=fmd_deconvolution +backbone=unet project=fmd_deconvolution data.part=cf_fish \ 21 | model.mode=noise2same model.lambda_inv=0 model.lambda_inv_deconv=4 \ 22 | training.amp=False 23 | # model.lambda_bound=0.1 model.regularization_key=deconv 24 | 25 | -------------------------------------------------------------------------------- /scripts/deconvolution/microtubules_inv_mse_before_psf_boundary.sh: -------------------------------------------------------------------------------- 1 | python train.py +experiment=microtubules_generated +backbone=unet project=noise2same-ssi-mt-gen \ 2 | model.lambda_inv=0 model.lambda_inv_deconv=2 \ 3 | model.regularization_key=deconv model.lambda_bound=0.1 \ 4 | device=5 -------------------------------------------------------------------------------- /scripts/deconvolution/ssi.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Move to the working directory 4 | if ${PWD##*/} == "scripts" ; then 5 | cd .. 6 | fi 7 | 8 | for image in ./data/ssi/*.png; do 9 | [ -e "$image" ] || continue 10 | image_name="${image##*/}" 11 | echo "$image_name" 12 | python train.py +experiment=ssi +backbone=unet project=noise2same-ssi-paper \ 13 | data.input_name="$image_name" model.lambda_inv=2 model.lambda_inv_deconv=0 \ 14 | device=2 # change device accordingly 15 | done -------------------------------------------------------------------------------- /scripts/deconvolution/ssi_boundary.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Move to the working directory 4 | if ${PWD##*/} == "scripts" ; then 5 | cd .. 6 | fi 7 | 8 | for image in ./data/ssi/*.png; do 9 | [ -e "$image" ] || continue 10 | image_name="${image##*/}" 11 | echo "$image_name" 12 | python train.py +experiment=ssi +backbone=unet project=noise2same-ssi-paper \ 13 | data.input_name="$image_name" model.lambda_inv=2 model.lambda_inv_deconv=0 \ 14 | model.lambda_bound=0.1 \ 15 | device=3 # change device accordingly 16 | done -------------------------------------------------------------------------------- /scripts/deconvolution/ssi_inv_mse_before_and_after_psf_boundary.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Move to the working directory 4 | if ${PWD##*/} == "scripts" ; then 5 | cd .. 6 | fi 7 | 8 | for image in ./data/ssi/*.png; do 9 | [ -e "$image" ] || continue 10 | image_name="${image##*/}" 11 | echo "$image_name" 12 | python train.py +experiment=ssi project=noise2same-ssi-paper-rerun data.input_name="$image_name" \ 13 | model.lambda_inv=1 model.lambda_inv_deconv=1 \ 14 | model.lambda_bound=0.1 model.regularization_key=deconv \ 15 | device=0 # change device accordingly 16 | done 17 | -------------------------------------------------------------------------------- /scripts/deconvolution/ssi_inv_mse_before_psf.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Move to the working directory 4 | if ${PWD##*/} == "scripts" ; then 5 | cd .. 6 | fi 7 | 8 | for image in ./data/ssi/*.png; do 9 | [ -e "$image" ] || continue 10 | image_name="${image##*/}" 11 | echo "$image_name" 12 | python train.py +experiment=ssi +backbone=unet project=noise2same-ssi-paper \ 13 | data.input_name="$image_name" model.lambda_inv=0 model.lambda_inv_deconv=2 \ 14 | device=2 # change device accordingly 15 | done 16 | -------------------------------------------------------------------------------- /scripts/deconvolution/ssi_inv_mse_before_psf_boundary.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Move to the working directory 4 | if ${PWD##*/} == "scripts" ; then 5 | cd .. 6 | fi 7 | 8 | for image in ./data/ssi/*.png; do 9 | [ -e "$image" ] || continue 10 | image_name="${image##*/}" 11 | echo "$image_name" 12 | python train.py +experiment=ssi +backbone=unet project=noise2same-ssi-paper \ 13 | data.input_name="$image_name" model.lambda_inv=0 model.lambda_inv_deconv=2 \ 14 | model.lambda_bound=0.1 model.regularization_key=deconv \ 15 | device=1 # change device accordingly 16 | done 17 | -------------------------------------------------------------------------------- /scripts/deconvolution/ssi_only_masked.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Move to the working directory 4 | if ${PWD##*/} == "scripts" ; then 5 | cd .. 6 | fi 7 | 8 | for image in ./data/ssi/*.png; do 9 | [ -e "$image" ] || continue 10 | image_name="${image##*/}" 11 | echo "$image_name" 12 | python train.py +experiment=ssi +backbone=unet project=noise2same-ssi-paper \ 13 | data.input_name="$image_name" device=0 model.only_masked=True \ 14 | training.monitor=bsp_mse 15 | done 16 | -------------------------------------------------------------------------------- /scripts/deconvolution/ssi_only_masked_boundary.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Move to the working directory 4 | if ${PWD##*/} == "scripts" ; then 5 | cd .. 6 | fi 7 | 8 | for image in ./data/ssi/*.png; do 9 | [ -e "$image" ] || continue 10 | image_name="${image##*/}" 11 | echo "$image_name" 12 | python train.py +experiment=ssi +backbone=unet project=noise2same-ssi-paper \ 13 | data.input_name="$image_name" device=0 model.only_masked=True \ 14 | training.monitor=bsp_mse model.lambda_bound=0.1 15 | done 16 | -------------------------------------------------------------------------------- /scripts/fmd_swinia.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -J cf_fish 4 | #SBATCH --output=slurm_outputs/slurm-%x.%j.out 5 | #SBATCH -N 1 6 | #SBATCH --ntasks-per-node=1 7 | #SBATCH --cpus-per-task=8 8 | #SBATCH -t 40:00:00 9 | #SBATCH --mem=50G 10 | #SBATCH --partition=gpu 11 | #SBATCH --gres=gpu:a100-80g:1 12 | #SBATCH --exclude=falcon2,falcon3 13 | 14 | 15 | module load any/python/3.8.3-conda 16 | 17 | conda activate n2s_env 18 | 19 | cd swinia 20 | 21 | python train.py +experiment=fmd +backbone=swinia project=fmd model.mode=noise2self data.part=cf_fish 22 | -------------------------------------------------------------------------------- /scripts/fmd_swinia_dp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -J cf_fish 4 | #SBATCH --output=slurm_outputs/slurm-%x.%j.out 5 | #SBATCH -N 1 6 | #SBATCH --ntasks-per-node=4 7 | #SBATCH --cpus-per-task=8 8 | #SBATCH -t 40:00:00 9 | #SBATCH --mem=50G 10 | #SBATCH --partition=gpu 11 | #SBATCH --gres=gpu:tesla:4 12 | #SBATCH --nodelist=falcon1 13 | 14 | 15 | module load any/python/3.8.3-conda 16 | module load cuda/11.3.1 17 | 18 | conda activate n2s_env 19 | 20 | cd swinia 21 | 22 | python train.py +experiment=fmd +backbone=swinia project=fmd model.mode=noise2self data.part=cf_fish -------------------------------------------------------------------------------- /scripts/fmd_unet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -J tp_mice 4 | #SBATCH --output=slurm_outputs/slurm-%x.%j.out 5 | #SBATCH -N 1 6 | #SBATCH --ntasks-per-node=1 7 | #SBATCH --cpus-per-task=8 8 | #SBATCH -t 5:00:00 9 | #SBATCH --mem=50G 10 | #SBATCH --partition=gpu 11 | #SBATCH --gres=gpu:tesla:1 12 | #SBATCH --nodelist=falcon1 13 | 14 | module load any/python/3.8.3-conda 15 | 16 | conda activate n2s_env 17 | 18 | cd swinia 19 | 20 | python train.py +experiment=fmd +backbone=unet project=fmd model.mode=noise2self data.part=tp_mice 21 | -------------------------------------------------------------------------------- /scripts/hanzi_swin.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -J hanzi 4 | #SBATCH --output=slurm_outputs/slurm-%x.%j.out 5 | #SBATCH -N 1 6 | #SBATCH --ntasks-per-node=4 7 | #SBATCH --cpus-per-task=8 8 | #SBATCH -t 80:00:00 9 | #SBATCH --mem=50G 10 | #SBATCH --partition=gpu 11 | #SBATCH --gres=gpu:tesla:4 12 | #SBATCH --exclude=falcon2,falcon3,falcon4 13 | 14 | module load any/python/3.8.3-conda 15 | 16 | conda activate n2s_env 17 | 18 | cd noise2same/noise2same.pytorch 19 | 20 | python train.py +experiment=hanzi +backbone=swinia project=hanzi \ 21 | training.batch_size=64 backbone.window_size=8 model.mode=noise2self data.noise_level=3 -------------------------------------------------------------------------------- /scripts/hanzi_unet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -J hanzi 4 | #SBATCH --output=slurm_outputs/slurm-%x.%j.out 5 | #SBATCH -N 1 6 | #SBATCH --ntasks-per-node=1 7 | #SBATCH --cpus-per-task=8 8 | #SBATCH -t 24:00:00 9 | #SBATCH --mem=50G 10 | #SBATCH --partition=gpu 11 | #SBATCH --gres=gpu:a100-80g:1 12 | #SBATCH --exclude=falcon3 13 | 14 | module load any/python/3.8.3-conda 15 | 16 | conda activate n2s_env 17 | 18 | cd noise2same/noise2same.pytorch 19 | 20 | python train.py +experiment=hanzi +backbone=unet model.mode=noise2self project=hanzi 21 | -------------------------------------------------------------------------------- /scripts/imagenet_swin.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -J inet 4 | #SBATCH --output=slurm_outputs/slurm-%x.%j.out 5 | #SBATCH -N 1 6 | #SBATCH --ntasks-per-node=4 7 | #SBATCH -t 80:00:00 8 | #SBATCH --mem=50G 9 | #SBATCH --partition=gpu 10 | #SBATCH --gres=gpu:tesla:4 11 | #SBATCH --exclude=falcon2,falcon3,falcon4,falcon5 12 | 13 | module load any/python/3.8.3-conda 14 | 15 | conda activate n2s_env 16 | 17 | cd noise2same/noise2same.pytorch 18 | 19 | python train.py +experiment=imagenet +backbone=swinia project=imagenet \ 20 | backbone.window_size=8 model.mode=noise2self training.steps=80000 -------------------------------------------------------------------------------- /scripts/imagenet_unet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -J inet 4 | #SBATCH --output=slurm_outputs/slurm-%x.%j.out 5 | #SBATCH -N 1 6 | #SBATCH --ntasks-per-node=1 7 | #SBATCH --cpus-per-task=8 8 | #SBATCH -t 10:00:00 9 | #SBATCH --mem=50G 10 | #SBATCH --partition=gpu 11 | #SBATCH --gres=gpu:a100-80g:1 12 | #SBATCH --exclude=falcon3 13 | 14 | module load any/python/3.8.3-conda 15 | 16 | conda activate n2s_env 17 | 18 | cd noise2same/noise2same.pytorch 19 | 20 | python train.py +experiment=imagenet +backbone=unet project=imagenet model.mode=noise2self 21 | -------------------------------------------------------------------------------- /scripts/sidd_swinia.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -J sidd 4 | #SBATCH --output=slurm_outputs/slurm-%x.%j.out 5 | #SBATCH -N 1 6 | #SBATCH --ntasks-per-node=1 7 | #SBATCH --cpus-per-task=8 8 | #SBATCH -t 10:00:00 9 | #SBATCH --mem=100G 10 | #SBATCH --partition=gpu 11 | #SBATCH --gres=gpu:a100-80g:1 12 | 13 | module load any/python/3.8.3-conda 14 | 15 | conda activate n2s_env 16 | 17 | cd noise2same/noise2same.pytorch 18 | 19 | python train.py +experiment=sidd +backbone=swinia project=sidd model.mode=noise2self -------------------------------------------------------------------------------- /scripts/sidd_unet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -J sidd 4 | #SBATCH --output=slurm_outputs/slurm-%x.%j.out 5 | #SBATCH -N 1 6 | #SBATCH --ntasks-per-node=1 7 | #SBATCH --cpus-per-task=8 8 | #SBATCH -t 10:00:00 9 | #SBATCH --mem=100G 10 | #SBATCH --partition=gpu 11 | #SBATCH --gres=gpu:a100-80g:1 12 | 13 | module load any/python/3.8.3-conda 14 | 15 | conda activate n2s_env 16 | 17 | cd noise2same/noise2same.pytorch 18 | 19 | python train.py +experiment=sidd +backbone=unet project=sidd model.mode=noise2self -------------------------------------------------------------------------------- /scripts/synthetic_grayscale_swinia.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -J ggauss50 4 | #SBATCH --output=slurm_outputs/slurm-%x.%j.out 5 | #SBATCH -N 1 6 | #SBATCH --ntasks-per-node=1 7 | #SBATCH --cpus-per-task=8 8 | #SBATCH -t 80:00:00 9 | #SBATCH --mem=50G 10 | #SBATCH --partition=gpu 11 | #SBATCH --gres=gpu:a100-80g:1 12 | #SBATCH --exclude=falcon2,falcon3 13 | 14 | module load any/python/3.8.3-conda 15 | module load cuda/11.3.1 16 | 17 | conda activate n2s_env 18 | 19 | cd swinia 20 | 21 | python train.py +experiment=synthetic_grayscale +backbone=swinia project=synthetic_grayscale \ 22 | backbone.window_size=8 model.mode=noise2self training.steps=50000 \ 23 | data.noise_type=gaussian data.noise_param=50 -------------------------------------------------------------------------------- /scripts/synthetic_grayscale_swinia_dp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -J fggauss50 4 | #SBATCH --output=slurm_outputs/slurm-%x.%j.out 5 | #SBATCH -N 1 6 | #SBATCH --ntasks-per-node=4 7 | #SBATCH --cpus-per-task=8 8 | #SBATCH -t 40:00:00 9 | #SBATCH --mem=50G 10 | #SBATCH --partition=gpu 11 | #SBATCH --gres=gpu:tesla:4 12 | #SBATCH --nodelist=falcon1 13 | 14 | module load any/python/3.8.3-conda 15 | module load cuda/11.3.1 16 | 17 | conda activate n2s_env 18 | 19 | cd swinia 20 | 21 | python train.py +experiment=synthetic_grayscale +backbone=swinia project=synthetic_grayscale \ 22 | backbone.window_size=8 model.mode=noise2self training.steps=50000 \ 23 | data.noise_type=gaussian data.noise_param=50 data.standardize=False -------------------------------------------------------------------------------- /scripts/synthetic_grayscale_unet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -J n2s-g50 4 | #SBATCH --output=slurm_outputs/slurm-%x.%j.out 5 | #SBATCH -N 1 6 | #SBATCH --ntasks-per-node=1 7 | #SBATCH --cpus-per-task=8 8 | #SBATCH -t 5:00:00 9 | #SBATCH --mem=50G 10 | #SBATCH --partition=gpu 11 | #SBATCH --gres=gpu:tesla:1 12 | #SBATCH --nodelist=falcon1 13 | 14 | module load any/python/3.8.3-conda 15 | module load cuda/11.3.1 16 | 17 | conda activate n2s_env 18 | 19 | cd swinia 20 | 21 | python train.py +experiment=synthetic_grayscale +backbone=unet project=synthetic_grayscale model.mode=noise2self \ 22 | data.noise_type=gaussian data.noise_param=50 \ 23 | -------------------------------------------------------------------------------- /scripts/synthetic_swinia.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -J poiss30 4 | #SBATCH --output=slurm_outputs/slurm-%x.%j.out 5 | #SBATCH -N 1 6 | #SBATCH --ntasks-per-node=1 7 | #SBATCH --cpus-per-task=8 8 | #SBATCH -t 80:00:00 9 | #SBATCH --mem=50G 10 | #SBATCH --partition=gpu 11 | #SBATCH --gres=gpu:a100-80g:1 12 | #SBATCH --exclude=falcon2,falcon3 13 | 14 | module load any/python/3.8.3-conda 15 | module load cuda/11.3.1 16 | 17 | conda activate n2s_env 18 | 19 | cd swinia 20 | 21 | python train.py +experiment=synthetic +backbone=swinia project=synthetic \ 22 | backbone.window_size=8 model.mode=noise2self training.steps=80000 \ 23 | data.noise_type=poisson data.noise_param=30 data.standardize=False -------------------------------------------------------------------------------- /scripts/synthetic_swinia_dp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -J fpoiss550 4 | #SBATCH --output=slurm_outputs/slurm-%x.%j.out 5 | #SBATCH -N 1 6 | #SBATCH --ntasks-per-node=4 7 | #SBATCH --cpus-per-task=8 8 | #SBATCH -t 40:00:00 9 | #SBATCH --mem=50G 10 | #SBATCH --partition=gpu 11 | #SBATCH --gres=gpu:tesla:4 12 | #SBATCH --nodelist=falcon4 13 | 14 | module load any/python/3.8.3-conda 15 | module load cuda/11.3.1 16 | 17 | conda activate n2s_env 18 | 19 | cd swinia 20 | 21 | python train.py +experiment=synthetic +backbone=swinia project=synthetic \ 22 | backbone.window_size=8 model.mode=noise2self training.steps=50000 training.val_batch_size=4 \ 23 | data.noise_type=poisson data.noise_param=[5,50] data.standardize=False -------------------------------------------------------------------------------- /scripts/synthetic_unet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -J n2s-p550 4 | #SBATCH --output=slurm_outputs/slurm-%x.%j.out 5 | #SBATCH -N 1 6 | #SBATCH --ntasks-per-node=1 7 | #SBATCH --cpus-per-task=8 8 | #SBATCH -t 15:00:00 9 | #SBATCH --mem=50G 10 | #SBATCH --partition=gpu 11 | #SBATCH --gres=gpu:tesla:1 12 | #SBATCH --exclude=falcon2,falcon3,falcon4,falcon5,falcon6 13 | 14 | module load any/python/3.8.3-conda 15 | module load cuda/11.3.1 16 | 17 | conda activate n2s_env 18 | 19 | cd swinia 20 | 21 | python train.py +experiment=synthetic +backbone=unet project=synthetic model.mode=noise2same \ 22 | data.noise_type=poisson data.noise_param=[5,50] \ 23 | data.standardize=False # important for poisson noise 24 | -------------------------------------------------------------------------------- /tests/test_contrast.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | from einops import rearrange 5 | 6 | from noise2same.contrast import PixelContrastLoss 7 | 8 | 9 | class MyTestCase(unittest.TestCase): 10 | def test_masking(self, bs: int = 2, n_emb: int = 8, h: int = 4, w: int = 4): 11 | 12 | t = torch.randint(0, 10, (bs, n_emb, h, w)) 13 | mask = torch.randn((bs, 1, h, w)).ge(0.5) 14 | mask[0, 0, 0, 0] = True 15 | mask[-1, -1, -1, -1] = True 16 | 17 | t_flat = rearrange(t, "b e h w -> (b h w) e", b=bs, h=h, w=w) 18 | mask = rearrange(mask, "b e h w -> (b h w e)", b=bs, h=h, w=w) 19 | print(t_flat.shape, mask.shape) 20 | t_masked = rearrange(t_flat[mask], "(b m) e -> b m e", b=bs) 21 | 22 | self.assertTrue(torch.all(torch.eq(t[0, :, 0, 0], t_masked[0, 0]))) 23 | self.assertTrue(torch.all(torch.eq(t[-1, :, -1, -1], t_masked[-1, -1]))) 24 | 25 | def test_loss(self, bs: int = 2, n_emb: int = 8, h: int = 4, w: int = 4): 26 | # mask = torch.randn(bs * h * w).ge(0.5) 27 | # out_raw = torch.randn(bs * n_emb * h * w) 28 | # out_mask = out_raw.clone() 29 | # out_mask[mask] = out_mask[mask] + torch.randn_like(out_mask)[mask] / 100 30 | # 31 | # out_raw = out_raw.reshape(bs, n_emb, h, w) 32 | # out_mask = out_mask.reshape(bs, n_emb, h, w) 33 | 34 | mask = torch.randn((bs, 1, h, w)).ge(0.5).float() 35 | out_raw = torch.randn((bs, n_emb, h, w)) 36 | out_mask = torch.randn((bs, n_emb, h, w)) 37 | 38 | loss = PixelContrastLoss() 39 | res = loss(out_raw, out_mask, mask) 40 | print(res.mean()) 41 | 42 | 43 | if __name__ == "__main__": 44 | unittest.main() 45 | -------------------------------------------------------------------------------- /tests/test_dataset.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | from albumentations import PadIfNeeded 5 | 6 | from noise2same.dataset.util import mask_like_image 7 | from noise2same.util import crop_as 8 | 9 | 10 | class TestDataset(unittest.TestCase): 11 | def test_crop_as(self): 12 | for divisor in (2, 4, 8, 16, 32, 64): 13 | pad = PadIfNeeded( 14 | min_height=None, 15 | min_width=None, 16 | pad_height_divisor=divisor, 17 | pad_width_divisor=divisor, 18 | ) 19 | 20 | image = np.random.uniform(size=(180, 180, 1)) 21 | padded = pad(image=image)["image"] 22 | cropped = crop_as(padded, image) 23 | print(padded.shape, cropped.shape) 24 | self.assertEqual(cropped.shape, image.shape) 25 | self.assertTrue(np.all(cropped == image)) 26 | 27 | def test_mask_2d(self, mask_percentage: float = 0.5): 28 | shape = (64, 64, 3) 29 | img = np.zeros(shape) 30 | mask = mask_like_image(img, mask_percentage=mask_percentage, channels_last=True) 31 | result = mask.mean() * 100 32 | self.assertAlmostEqual(mask_percentage, result, delta=0.1) 33 | 34 | def test_mask_3d(self, mask_percentage: float = 0.5): 35 | shape = (1, 16, 64, 64) 36 | img = np.zeros(shape) 37 | mask = mask_like_image( 38 | img, mask_percentage=mask_percentage, channels_last=False 39 | ) 40 | result = mask.mean() * 100 41 | self.assertAlmostEqual(mask_percentage, result, delta=0.1) 42 | 43 | 44 | if __name__ == "__main__": 45 | unittest.main() 46 | -------------------------------------------------------------------------------- /tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | from scipy.misc import ascent 5 | from skimage.metrics import ( 6 | mean_squared_error, 7 | peak_signal_noise_ratio, 8 | structural_similarity, 9 | ) 10 | 11 | from noise2same.util import normalize_min_mse 12 | 13 | 14 | class MetricsTestCase(unittest.TestCase): 15 | def test_mse(self): 16 | y, x1, x2 = self._get_test_images() 17 | mse1 = mean_squared_error(*normalize_min_mse(y, x1)) 18 | mse2 = mean_squared_error(*normalize_min_mse(y, x2)) 19 | 20 | self.assertAlmostEqual(mse1, mse2, delta=1e-6) 21 | 22 | def test_psnr(self): 23 | y, x1, x2 = self._get_test_images() 24 | psnr1 = peak_signal_noise_ratio(*normalize_min_mse(y, x1)) 25 | psnr2 = peak_signal_noise_ratio(*normalize_min_mse(y, x2)) 26 | 27 | self.assertAlmostEqual(psnr1, psnr2, delta=1e-6) 28 | 29 | def test_ssim(self): 30 | y, x1, x2 = self._get_test_images() 31 | ssim1 = structural_similarity(*normalize_min_mse(y, x1)) 32 | ssim2 = structural_similarity(*normalize_min_mse(y, x2)) 33 | 34 | self.assertAlmostEqual(ssim1, ssim2, delta=1e-6) 35 | 36 | def _get_test_images(self): 37 | # ground truth image 38 | y = ascent().astype(np.float32) 39 | # input image to compare to 40 | x1 = y + 30 * np.random.normal(0, 1, y.shape) 41 | # a scaled and shifted version of x1 42 | x2 = 2 * x1 + 100 43 | 44 | return y, x1, x2 45 | 46 | 47 | if __name__ == "__main__": 48 | unittest.main() 49 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | from torch.utils.data import DataLoader, RandomSampler 7 | 8 | from noise2same import model, trainer 9 | from noise2same.dataset.dummy import DummyDataset3DLarge 10 | 11 | 12 | class ModelTestCase(unittest.TestCase): 13 | def test_info_padding_forward_identity( 14 | self, psf_size: int = 9, n_dim: int = 3, device: str = "cuda" 15 | ): 16 | # Set up identity PSF 17 | psf = np.zeros((psf_size,) * n_dim) 18 | np.put(psf, psf_size ** n_dim // 2, 1) 19 | 20 | # Set up identity model and optimizer 21 | mdl = model.Noise2Same( 22 | n_dim=n_dim, 23 | in_channels=1, 24 | base_channels=4, 25 | psf=psf, 26 | psf_pad_mode="constant", 27 | backbone=torch.nn.Identity(), 28 | ) 29 | mdl.to(device) 30 | mdl.train() 31 | 32 | # Fetch dummy data batch 33 | dataset = DummyDataset3DLarge(n_dim=n_dim, image_size=230) 34 | loader = DataLoader( 35 | dataset, 36 | batch_size=4, 37 | num_workers=1, 38 | shuffle=True, 39 | pin_memory=True, 40 | drop_last=True, 41 | ) 42 | batch = next(iter(loader)) 43 | 44 | padding = [ 45 | (b, a) 46 | for b, a in zip( 47 | loader.dataset.tiler.margin_start, 48 | loader.dataset.tiler.margin_end, 49 | ) 50 | ] + [(0, 0)] 51 | print(padding) 52 | 53 | x = batch["image"].to(device) 54 | mask = batch["mask"].to(device) 55 | 56 | full_size_image = np.pad(loader.dataset.image, padding) 57 | full_size_image = torch.from_numpy(np.moveaxis(full_size_image, -1, 0)).to( 58 | device 59 | ) 60 | 61 | out_mask, out_raw = mdl.forward( 62 | x, mask, crops=batch["crop"], full_size_image=full_size_image 63 | ) 64 | loss, loss_log = mdl.compute_losses_from_output(x, mask, out_mask, out_raw) 65 | 66 | self.assertAlmostEqual(loss_log["rec_mse"], 0) 67 | 68 | def test_info_padding_forward_backpropagation( 69 | self, psf_size: int = 9, n_dim: int = 3, device: str = "cuda" 70 | ): 71 | torch.autograd.set_detect_anomaly(True) 72 | 73 | # Set up PSF 74 | shape = (psf_size,) * n_dim 75 | psf = np.random.rand(*shape) 76 | 77 | # Set up model and optimizer 78 | mdl = model.Noise2Same( 79 | n_dim=n_dim, 80 | in_channels=1, 81 | base_channels=4, 82 | psf=psf, 83 | psf_pad_mode="constant", 84 | ) 85 | mdl.to(device) 86 | mdl.train() 87 | 88 | optimizer = torch.optim.Adam(mdl.parameters()) 89 | optimizer.zero_grad() 90 | 91 | # Fetch dummy data batch 92 | dataset = DummyDataset3DLarge(n_dim=n_dim, image_size=256) 93 | loader = DataLoader( 94 | dataset, 95 | batch_size=4, 96 | num_workers=1, 97 | shuffle=True, 98 | pin_memory=True, 99 | drop_last=True, 100 | ) 101 | batch = next(iter(loader)) 102 | 103 | padding = [ 104 | (b, a) 105 | for b, a in zip( 106 | loader.dataset.tiler.margin_start, 107 | loader.dataset.tiler.margin_end, 108 | ) 109 | ] + [(0, 0)] 110 | 111 | x = batch["image"].to(device) 112 | mask = batch["mask"].to(device) 113 | 114 | full_size_image = np.pad(loader.dataset.image, padding) 115 | full_size_image = torch.from_numpy(np.moveaxis(full_size_image, -1, 0)).to( 116 | device 117 | ) 118 | 119 | out_mask, out_raw = mdl.forward( 120 | x, mask, crops=batch["crop"], full_size_image=full_size_image 121 | ) 122 | loss, loss_log = mdl.compute_losses_from_output(x, mask, out_mask, out_raw) 123 | 124 | loss.backward() 125 | optimizer.step() 126 | self.assertTrue((torch.isnan(loss).sum() == 0).detach().cpu().numpy()) 127 | -------------------------------------------------------------------------------- /tests/test_psf.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from noise2same.psf.psf_convolution import PSF, PSFParameter 7 | 8 | 9 | class PSFTestCase(unittest.TestCase): 10 | def test_psf_fft(self, s: int = 7): 11 | kernel = np.random.rand(s, s, s) 12 | patch = torch.rand(1, 1, 64, 64, 64) 13 | 14 | psf = PSF(kernel, fft=False) 15 | psf_fft = PSF(kernel, fft=True) 16 | 17 | psf_out = psf(patch) 18 | psf_fft_out = psf_fft(patch) 19 | 20 | self.assertTrue(patch.shape == psf_out.shape) 21 | self.assertTrue(torch.allclose(psf_out, psf_fft_out)) 22 | 23 | def test_psf_delta(self, s: int = 7): 24 | kernel = np.zeros((s, s, s)) 25 | kernel[s // 2, s // 2, s // 2] = 1 # delta function 26 | 27 | patch = torch.rand(1, 1, 64, 64, 64) 28 | 29 | psf = PSF(kernel, fft=False) 30 | psf_out = psf(patch) 31 | 32 | self.assertTrue(torch.allclose(psf_out, patch)) 33 | 34 | def test_psf_fft_delta(self, s: int = 7): 35 | kernel = np.zeros((s, s, s)) 36 | kernel[s // 2, s // 2, s // 2] = 1 # delta function 37 | 38 | patch = torch.rand(1, 1, 64, 64, 64) 39 | 40 | psf = PSF(kernel, fft=True) 41 | psf_out = psf(patch) 42 | 43 | # 1e-7 does not work for some reason 44 | self.assertTrue(torch.allclose(psf_out, patch, atol=1e-6)) 45 | 46 | def test_psf_parameter(self): 47 | kernel = np.random.rand(7, 7, 7) 48 | patch = torch.rand(1, 1, 64, 64, 64) 49 | 50 | psf = PSF(kernel, fft=True) 51 | psf_param = PSFParameter(kernel, fft=True) 52 | 53 | psf_out = psf(patch) 54 | psf_param_out = psf_param(patch) 55 | 56 | self.assertTrue(patch.shape == psf_out.shape) 57 | self.assertTrue(torch.allclose(psf_out, psf_param_out)) 58 | 59 | def test_large_psf(self): 60 | kernel = np.random.rand(128, 256, 512) 61 | patch = torch.rand(1, 1, 128, 128, 128) 62 | 63 | psf = PSFParameter(kernel, fft=True) 64 | psf_out = psf(patch) 65 | 66 | self.assertTrue(patch.shape == psf_out.shape, f"Output shape: {psf_out.shape}") 67 | 68 | def test_psf_auto_padding(self): 69 | kernel = np.random.rand(7, 7, 7) 70 | patch = torch.rand(1, 1, 64, 64, 64) 71 | 72 | psf = PSFParameter(kernel, fft=True, auto_padding=False) 73 | psf_auto = PSFParameter(kernel, fft=True, auto_padding=True) 74 | 75 | psf_out = psf(patch) 76 | psf_auto_out = psf_auto(patch) 77 | 78 | self.assertTrue(patch.shape == psf_auto_out.shape) 79 | self.assertTrue(torch.allclose(psf_out, psf_auto_out)) 80 | 81 | 82 | if __name__ == "__main__": 83 | unittest.main() 84 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import traceback 3 | from pathlib import Path 4 | from random import randint 5 | from time import sleep 6 | 7 | import hydra 8 | import torch 9 | import wandb 10 | from hydra.utils import get_original_cwd 11 | from omegaconf import DictConfig, OmegaConf 12 | from torch.utils.data import DataLoader, RandomSampler 13 | 14 | import evaluate 15 | import noise2same.trainer 16 | from noise2same import model, util 17 | from noise2same.dataset.getter import get_dataset, get_test_dataset_and_gt 18 | from noise2same.optimizers.esadam import ESAdam 19 | from utils import parametrize_backbone_and_head 20 | 21 | 22 | def exponential_decay( 23 | decay_rate: float = 0.5, decay_steps: int = 5e3, staircase: bool = True 24 | ): 25 | """ 26 | Lambda for torch.optimizers.lr_scheduler.LambdaLR mimicking tf.train.exponential_decay: 27 | decayed_learning_rate = learning_rate * 28 | decay_rate ^ (global_step / decay_steps) 29 | 30 | :param decay_rate: float, multiplication factor 31 | :param decay_steps: int, how many steps to make to multiply by decay_rate 32 | :param staircase: bool, integer division global_step / decay_steps 33 | :return: lambda(epoch) 34 | """ 35 | 36 | def _lambda(epoch: int): 37 | exp = epoch / decay_steps 38 | if staircase: 39 | exp = int(exp) 40 | return decay_rate ** exp 41 | 42 | return _lambda 43 | 44 | 45 | @hydra.main(config_path="config", config_name="config", version_base="1.1") 46 | def main(cfg: DictConfig) -> None: 47 | 48 | # trying to fix: unable to open shared memory object in read-write mode 49 | # torch.multiprocessing.set_sharing_strategy("file_system") 50 | 51 | # Prevent from writing from the same log folder 52 | sleep(randint(1, 5)) 53 | 54 | if "backbone_name" not in cfg.keys(): 55 | print("Please specify a backbone with `+backbone=name`") 56 | return 57 | 58 | if "experiment" not in cfg.keys(): 59 | print("Please specify an experiment with `+experiment=name`") 60 | return 61 | 62 | print(OmegaConf.to_yaml(cfg)) 63 | # os.environ["CUDA_VISIBLE_DEVICES"] = f"{cfg.device}" 64 | print(f"Run backbone {cfg.backbone_name} on experiment {cfg.experiment}, work in {os.getcwd()}") 65 | cwd = Path(get_original_cwd()) 66 | 67 | util.fix_seed(cfg.seed) 68 | 69 | # flatten 2-level config 70 | d_cfg = {} 71 | for group, group_dict in dict(cfg).items(): 72 | if isinstance(group_dict, DictConfig): 73 | for param, value in dict(group_dict).items(): 74 | d_cfg[f"{group}.{param}"] = value 75 | else: 76 | d_cfg[group] = group_dict 77 | 78 | if not cfg.check: 79 | wandb.init(project=cfg.project, config=d_cfg, settings=wandb.Settings(start_method="fork")) 80 | wandb.run.summary.update({'training_dir': os.getcwd()}) 81 | 82 | # Data 83 | dataset_train, dataset_valid = get_dataset(cfg, cwd) 84 | num_samples = cfg.training.batch_size * cfg.training.steps_per_epoch 85 | loader_train = DataLoader( 86 | dataset_train, 87 | batch_size=cfg.training.batch_size, 88 | num_workers=cfg.training.num_workers, 89 | sampler=RandomSampler(dataset_train, replacement=True, num_samples=num_samples), 90 | pin_memory=True, 91 | drop_last=True, 92 | ) 93 | 94 | loader_valid = None 95 | if cfg.training.validate: 96 | n_samples_val = int(cfg.training.val_partition * len(dataset_valid)) 97 | loader_valid = DataLoader( 98 | torch.utils.data.random_split( 99 | dataset_valid, 100 | [n_samples_val, len(dataset_valid) - n_samples_val], 101 | generator=torch.Generator().manual_seed(42) 102 | )[0], 103 | batch_size=cfg.training.val_batch_size, 104 | num_workers=cfg.training.num_workers, 105 | shuffle=False, 106 | pin_memory=True, 107 | drop_last=False, 108 | ) 109 | 110 | # Read PSF from dataset if available or by path 111 | psf = getattr(dataset_train, "psf", None) 112 | if psf is None and getattr(cfg, "psf", None) is not None: 113 | psf = cwd / cfg.psf.path 114 | print(f"Read PSF from {psf}") 115 | 116 | backbone, head = parametrize_backbone_and_head(cfg) 117 | 118 | # Model 119 | mdl = model.Noise2Same( 120 | n_dim=cfg.data.n_dim, 121 | in_channels=cfg.data.n_channels, 122 | psf=psf, 123 | psf_size=cfg.psf.psf_size if "psf" in cfg else None, 124 | psf_pad_mode=cfg.psf.psf_pad_mode if "psf" in cfg else None, 125 | psf_fft=cfg.psf.psf_fft if "psf" in cfg else None, 126 | backbone=backbone, 127 | head=head, 128 | **cfg.model, 129 | ) 130 | 131 | if torch.cuda.device_count() > 1: 132 | print(f'Using data parallel with {torch.cuda.device_count()} GPUs') 133 | mdl = torch.nn.DataParallel(mdl) 134 | 135 | # Optimization 136 | if cfg.optim.optimizer == "adam": 137 | optimizer = torch.optim.Adam(mdl.parameters(), lr=cfg.optim.lr, weight_decay=cfg.optim.weight_decay) 138 | elif cfg.optim.optimizer == "esadam": 139 | optimizer = ESAdam(mdl.parameters(), lr=cfg.optim.lr) 140 | else: 141 | raise ValueError(f"Unknown optimizer {cfg.optim.optimizer}") 142 | 143 | if cfg.optim.scheduler == "cosine": 144 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 145 | optimizer, 146 | cfg.training.steps, 147 | eta_min=cfg.optim.eta_min, 148 | ) 149 | else: 150 | scheduler = torch.optim.lr_scheduler.LambdaLR( 151 | optimizer, 152 | lr_lambda=exponential_decay( 153 | decay_rate=cfg.optim.decay_rate, 154 | decay_steps=cfg.optim.decay_steps, 155 | staircase=cfg.optim.staircase, 156 | ), 157 | ) 158 | 159 | # Trainer 160 | trainer = noise2same.trainer.Trainer( 161 | model=mdl, 162 | optimizer=optimizer, 163 | scheduler=scheduler, 164 | check=cfg.check, 165 | monitor=cfg.training.monitor, 166 | experiment=cfg.experiment, 167 | amp=cfg.training.amp, 168 | info_padding=cfg.training.info_padding 169 | ) 170 | 171 | n_epochs = cfg.training.steps // cfg.training.steps_per_epoch 172 | try: 173 | history = trainer.fit( 174 | n_epochs, loader_train, loader_valid if cfg.training.validate else None 175 | ) 176 | except KeyboardInterrupt: 177 | print("Training interrupted") 178 | except RuntimeError: 179 | if not cfg.check: 180 | wandb.run.summary["error"] = "RuntimeError" 181 | traceback.print_exc() 182 | 183 | if cfg.evaluate: 184 | test_dataset, ground_truth = get_test_dataset_and_gt(cfg, cwd) 185 | 186 | scores = evaluate.evaluate(trainer.evaluator, ground_truth, cfg.experiment, cwd, Path(os.getcwd()), 187 | dataset=test_dataset, half=cfg.training.amp, num_workers=cfg.training.num_workers) 188 | 189 | if not cfg.check: 190 | wandb.log(scores) 191 | wandb.run.summary.update(scores) 192 | 193 | if not cfg.check: 194 | wandb.finish() 195 | 196 | 197 | if __name__ == "__main__": 198 | main() 199 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn 2 | from torch.nn import Identity 3 | from omegaconf import DictConfig 4 | from typing import Tuple 5 | 6 | from noise2same.backbone import SwinIR, UNet, RegressionHead 7 | from noise2same.backbone.swinia import SwinIA 8 | from noise2same.backbone.bsp_swinir import BSpSwinIR 9 | from noise2same.dataset.getter import compute_pad_divisor 10 | 11 | 12 | def recalculate_img_size(cfg: DictConfig) -> int: 13 | """ 14 | Recalculate image size with respect to future padding 15 | :param cfg: DictConfig, training/evaluation configuration object 16 | :return: int 17 | """ 18 | pad_divisor = compute_pad_divisor(cfg) 19 | if cfg.training.crop % pad_divisor: 20 | return (cfg.training.crop // pad_divisor + 1) * pad_divisor 21 | else: 22 | return cfg.training.crop 23 | 24 | 25 | def parametrize_backbone_and_head(cfg: DictConfig) -> Tuple[torch.nn.Module, torch.nn.Module]: 26 | """ 27 | Create backbone and head according to the configuration 28 | :param cfg: DictConfig, training/evaluation configuration object 29 | :return: Tuple[torch.nn.Module, torch.nn.Module] 30 | """ 31 | head = Identity() 32 | if cfg.backbone_name == 'unet': 33 | backbone = UNet( 34 | in_channels=cfg.data.n_channels, 35 | **cfg.backbone 36 | ) 37 | head = RegressionHead( 38 | in_channels=cfg.backbone.base_channels, 39 | out_channels=cfg.data.n_channels, 40 | n_dim=cfg.data.n_dim 41 | ) 42 | elif cfg.backbone_name == 'swinir': 43 | assert cfg.data.n_dim == 2 44 | backbone = SwinIR( 45 | in_chans=cfg.data.n_channels, 46 | img_size=recalculate_img_size(cfg), 47 | **cfg.backbone 48 | ) 49 | elif cfg.backbone_name == 'bsp_swinir': 50 | assert cfg.data.n_dim == 2 51 | backbone = BSpSwinIR( 52 | in_chans=cfg.data.n_channels, 53 | img_size=recalculate_img_size(cfg), 54 | **cfg.backbone 55 | ) 56 | elif cfg.backbone_name == 'swinia': 57 | assert cfg.data.n_dim == 2 58 | backbone = SwinIA( 59 | in_chans=cfg.data.n_channels, 60 | input_size=recalculate_img_size(cfg), 61 | **cfg.backbone 62 | ) 63 | else: 64 | raise ValueError("Incorrect backbone name") 65 | return backbone, head 66 | -------------------------------------------------------------------------------- /weights/README.md: -------------------------------------------------------------------------------- 1 | # Pretrained weights 2 | 3 | You can download pretrained weights to this folder 4 | from [Drive](https://drive.google.com/drive/folders/1mh-Df_3dQ-kzfmZHQIzikv0A0Uc4fVGg?usp=sharing). 5 | --------------------------------------------------------------------------------