├── results ├── lpd_unet │ ├── README.md │ ├── __init__.py │ ├── normalisation.py │ ├── lpd.py │ ├── unet.py │ └── lpd_modules.py ├── 3D_final │ ├── results_analysis_3D.ipynb │ └── result_util_3D.py └── 2D_final │ ├── results_analysis_2D.ipynb │ └── result_util.py ├── configs ├── baseline │ ├── prior │ │ ├── quadratic.yaml │ │ └── relative_difference.yaml │ └── osem.yaml ├── sampling │ ├── dds_proj │ │ ├── osem.yaml │ │ ├── map.yaml │ │ └── anchor.yaml │ ├── dds_3D.yaml │ ├── dps.yaml │ ├── naive.yaml │ └── dds.yaml ├── score_based_model │ ├── vesde_image_scale.yaml │ ├── vpsde_image_scale.yaml │ ├── vpsde_image_scale_guided.yaml │ └── vesde_image_scale_guided.yaml ├── dataset │ ├── brainweb3D.yaml │ └── brainweb2Ddataset.yaml ├── baseline.yaml ├── test_reconstruction.yaml ├── final_reconstruction.yaml ├── 3D_reconstruction.yaml └── default_config.py ├── src ├── third_party_models │ ├── __init__.py │ └── openai_unet │ │ ├── __init__.py │ │ ├── README.md │ │ └── nn_utils.py ├── sirf │ ├── __init__.py │ ├── herman_meyer.py │ ├── utils.py │ ├── datasets.py │ └── dip.py ├── brainweb_2d │ ├── __init__.py │ ├── get_test_subset.py │ ├── tumor_generator.py │ ├── get_noisy_train_brainweb_2D.py │ ├── get_OOD_noisy_brainweb_2D.py │ ├── get_validation_subset.py │ ├── get_true_test_brainweb_2D.py │ ├── get_noisy_test_brainweb_2D.py │ ├── get_true_train_brainweb_2D.py │ ├── brainweb.py │ └── lpd_modules.py ├── utils │ ├── __init__.py │ ├── metrics.py │ ├── losses.py │ ├── ema.py │ ├── trainer.py │ ├── sde.py │ ├── nll.py │ └── exp_utils.py ├── samplers │ ├── __init__.py │ └── base_sampler.py └── __init__.py ├── diagram.png ├── modifications.png ├── scripts ├── postcreaterequirements.yml ├── postCreateCommand.sh └── req.txt ├── .devcontainer └── devcontainer.json ├── main_score_based_models_train.py ├── coordinators ├── 3D_dip_baseline.py ├── 3D_bsrem_rdp_rdpz_baseline.py ├── 3D_reconstruction.py ├── 2D_rdp_baseline.py ├── test_reconstruction.py └── final_reconstruction.py └── README.md /results/lpd_unet/README.md: -------------------------------------------------------------------------------- 1 | Taken from asdsadsd -------------------------------------------------------------------------------- /configs/baseline/prior/quadratic.yaml: -------------------------------------------------------------------------------- 1 | name: quadratic 2 | penalty: 100 -------------------------------------------------------------------------------- /configs/baseline/osem.yaml: -------------------------------------------------------------------------------- 1 | num_subsets: 1 2 | num_epochs: 1000 3 | num_img_log: 5 -------------------------------------------------------------------------------- /src/third_party_models/__init__.py: -------------------------------------------------------------------------------- 1 | from .openai_unet import OpenAiUNetModel 2 | -------------------------------------------------------------------------------- /src/third_party_models/openai_unet/__init__.py: -------------------------------------------------------------------------------- 1 | from .unet import OpenAiUNetModel -------------------------------------------------------------------------------- /configs/sampling/dds_proj/osem.yaml: -------------------------------------------------------------------------------- 1 | name: osem 2 | num_epochs: null 3 | num_subsets: 1 -------------------------------------------------------------------------------- /configs/sampling/dds_proj/map.yaml: -------------------------------------------------------------------------------- 1 | name: map 2 | num_epochs: null 3 | num_subsets: 1 4 | beta: null -------------------------------------------------------------------------------- /configs/baseline/prior/relative_difference.yaml: -------------------------------------------------------------------------------- 1 | name: relative_difference 2 | penalty: 100 3 | gamma: 2 -------------------------------------------------------------------------------- /configs/sampling/dds_proj/anchor.yaml: -------------------------------------------------------------------------------- 1 | name: anchor 2 | num_epochs: null 3 | num_subsets: 1 4 | beta: null -------------------------------------------------------------------------------- /src/sirf/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import * 2 | from .herman_meyer import * 3 | from .projection import * -------------------------------------------------------------------------------- /configs/score_based_model/vesde_image_scale.yaml: -------------------------------------------------------------------------------- 1 | path: path_to/vesde/version_2/ 2 | ema: True 3 | name: vesde_image_scale -------------------------------------------------------------------------------- /diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Imraj-Singh/Score-Based-Generative-Models-for-PET-Image-Reconstruction/HEAD/diagram.png -------------------------------------------------------------------------------- /configs/score_based_model/vpsde_image_scale.yaml: -------------------------------------------------------------------------------- 1 | path: path_to/vpsde_image_scale/version_7 2 | ema: True 3 | name: vpsde_image_scale -------------------------------------------------------------------------------- /configs/sampling/dds_3D.yaml: -------------------------------------------------------------------------------- 1 | name: dds_3D 2 | defaults: 3 | - dds 4 | num_iterations: 5 5 | num_subsets: 28 6 | lambd: null 7 | beta: null -------------------------------------------------------------------------------- /configs/score_based_model/vpsde_image_scale_guided.yaml: -------------------------------------------------------------------------------- 1 | path: path_to/guided/vpsde/version_4/ 2 | ema: True 3 | name: vpsde_image_scale_guided -------------------------------------------------------------------------------- /modifications.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Imraj-Singh/Score-Based-Generative-Models-for-PET-Image-Reconstruction/HEAD/modifications.png -------------------------------------------------------------------------------- /configs/score_based_model/vesde_image_scale_guided.yaml: -------------------------------------------------------------------------------- 1 | path: path_to/guided/vesde/version_3/ 2 | ema: True 3 | name: vpsde_image_scale_guided 4 | -------------------------------------------------------------------------------- /configs/dataset/brainweb3D.yaml: -------------------------------------------------------------------------------- 1 | name: brainweb3D 2 | count_level: low 3 | realisation: 0 4 | tracer: FDG 5 | base_path: /home/user/sirf/src/sirf/brainweb_3D 6 | -------------------------------------------------------------------------------- /results/lpd_unet/__init__.py: -------------------------------------------------------------------------------- 1 | from .unet import get_unet_model 2 | from .lpd import get_lpd_model 3 | from .normalisation import Normalisation 4 | from .lpd_modules import LPDForwardFunction2D -------------------------------------------------------------------------------- /src/brainweb_2d/__init__.py: -------------------------------------------------------------------------------- 1 | from .brainweb import BrainWebClean, BrainWebOSEM, BrainWebScoreTrain, BrainWebSupervisedTrain 2 | from .lpd_modules import LPDForwardFunction2D, LPDAdjointFunction2D -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .sde import * 2 | from .ema import * 3 | from .losses import * 4 | from .metrics import * 5 | from .trainer import * 6 | from .exp_utils import * 7 | from .nll import * -------------------------------------------------------------------------------- /configs/dataset/brainweb2Ddataset.yaml: -------------------------------------------------------------------------------- 1 | name: brainweb2D 2 | base_path: path_to/src/brainweb_2d/ 3 | part: subset_test #"validation" # subset_test_tumour 4 | poisson_scale: 2.5 5 | img_xy_dim: 128 6 | img_z_dim: 1 7 | num_samples: 45 -------------------------------------------------------------------------------- /configs/sampling/dps.yaml: -------------------------------------------------------------------------------- 1 | name: dps 2 | num_steps: 1000 3 | eps: 1e-04 4 | pct_chain_elapsed: 0 5 | num_img_log: 10 6 | log_freq: 10 7 | batch_size: 8 8 | add_corrector: False 9 | guidance_strength: null 10 | use_osem_nll: False 11 | penalty: null -------------------------------------------------------------------------------- /configs/sampling/naive.yaml: -------------------------------------------------------------------------------- 1 | name: naive 2 | num_steps: 1000 3 | eps: 1e-04 4 | pct_chain_elapsed: 0 5 | num_img_log: 50 6 | log_freq: 10 7 | batch_size: 8 8 | add_corrector: False 9 | guidance_strength: null 10 | use_osem_nll: False 11 | penalty: null -------------------------------------------------------------------------------- /configs/sampling/dds.yaml: -------------------------------------------------------------------------------- 1 | name: dds 2 | num_steps: 50 3 | stochasticity: 0.1 4 | eps: 0.1 5 | penalty: null 6 | pct_chain_elapsed: 0 7 | num_img_log: 10 8 | log_freq: 10 9 | batch_size: 8 10 | beta_unit: False 11 | add_corrector: False 12 | guidance_strength: null 13 | use_osem_nll: False -------------------------------------------------------------------------------- /src/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_sampler import BaseSampler 2 | from .utils import Euler_Maruyama_sde_predictor, Langevin_sde_corrector, chain_simple_init 3 | from .utils import soft_diffusion_sde_predictor, soft_diffusion_momentum_sde_predictor 4 | from .utils import decomposed_diffusion_sampling_sde_predictor -------------------------------------------------------------------------------- /scripts/postcreaterequirements.yml: -------------------------------------------------------------------------------- 1 | name: base 2 | channels: 3 | - anaconda 4 | - conda-forge 5 | dependencies: 6 | - scikit-image 7 | - parallelproj>=1.3.4 8 | - matplotlib>=3.2.1 9 | - pydantic>=1.10 10 | - scipy>=1.2 11 | - nibabel 12 | - h5py 13 | - pandas 14 | - seaborn 15 | - pip 16 | - pip: 17 | - tensorboard 18 | - ml-collections -------------------------------------------------------------------------------- /scripts/postCreateCommand.sh: -------------------------------------------------------------------------------- 1 | mamba env update --file scripts/postcreaterequirements.yml 2 | 3 | pip install cupy-cuda11x tensorboardX hydra-core --upgrade 4 | 5 | pip3 install torch torchvision torchaudio 6 | 7 | cd ~/ 8 | 9 | git clone https://github.com/gschramm/pyparallelproj.git 10 | 11 | echo "export PYTHONPATH="${PYTHONPATH}:/home/user/pyparallelproj/"" | sudo tee -a "/home/user/.bashrc" 12 | 13 | source "/home/user/.bashrc" -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "SIRF Container", 3 | "image": "imrajs/sirf_base:latest", 4 | "workspaceMount": "source=${localWorkspaceFolder},target=/home/user/sirf,type=bind", 5 | "workspaceFolder": "/home/user/sirf", 6 | "forwardPorts": [9002,9999,8890], 7 | "extensions": ["ms-python.python", "ms-toolsai.jupyter"], 8 | "runArgs": ["--gpus=all","--init","--network=host","--shm-size=5gb"], 9 | "postCreateCommand": "bash scripts/postCreateCommand.sh", 10 | "hostRequirements": {"cpus": 16, "memory": "32gb", "storage": "64gb"} 11 | } -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | #from .dataset_reconstruction import PET_2D_reconstruction_dict, pet_acq_model 2 | from .utils import loss_fn, ExponentialMovingAverage, score_model_simple_trainer, PSNR, SSIM, SDE, VPSDE, VESDE 3 | from .utils import HeatDiffusion, get_standard_score, get_standard_sde, get_standard_sampler 4 | from .utils import poisson_nll, osem_nll, get_osem, get_map, get_anchor, kl_div 5 | from .third_party_models import OpenAiUNetModel 6 | from .brainweb_2d import * 7 | from .samplers import BaseSampler, Euler_Maruyama_sde_predictor, Langevin_sde_corrector 8 | from .samplers import soft_diffusion_momentum_sde_predictor 9 | from .sirf import SIRF3DDataset, SIRF3DProjection, herman_meyer_order -------------------------------------------------------------------------------- /src/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from skimage.metrics import structural_similarity 4 | 5 | def PSNR(reconstruction, ground_truth, data_range=None): 6 | gt = np.asarray(ground_truth) 7 | mse = np.mean((np.asarray(reconstruction) - gt)**2) 8 | if mse == 0.: 9 | return float('inf') 10 | if data_range is None: 11 | data_range = np.max(gt) - np.min(gt) 12 | return 20*np.log10(data_range) - 10*np.log10(mse) 13 | 14 | def SSIM(reconstruction, ground_truth, data_range=None): 15 | gt = np.asarray(ground_truth) 16 | if data_range is None: 17 | data_range = np.max(gt) - np.min(gt) 18 | return structural_similarity(reconstruction, gt, data_range=data_range) 19 | -------------------------------------------------------------------------------- /configs/baseline.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: OSEM/${model.name}/${dataset.name}/tumour_${dataset.tumour}_scale_${dataset.poisson_scale}/${baseline.prior.name}/${hydra.job.override_dirname} 4 | sweep: 5 | dir: OSEM/${model.name}/${dataset.name}/tumour_${dataset.tumour}_scale_${dataset.poisson_scale}/${baseline.prior.name}/ 6 | subdir: ${hydra.job.override_dirname} 7 | job: 8 | config: 9 | override_dirname: 10 | exclude_keys: 11 | - dataset.name 12 | - dataset.tumour 13 | - dataset.poisson_scale 14 | - baseline.prior.name 15 | defaults: 16 | - _self_ 17 | - dataset: brainweb2Ddataset 18 | - baseline: osem 19 | - baseline/prior: quadratic 20 | - override hydra/job_logging: colorlog 21 | - override hydra/hydra_logging: colorlog 22 | seed: 1 23 | num_images: 1 24 | device: cuda -------------------------------------------------------------------------------- /configs/test_reconstruction.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: TEST_SINGLE/${dataset.part}_${dataset.poisson_scale}/${score_based_model.name}/${sampling.name}/${now:%Y-%m-%d-%H-%M-%S}/${hydra.job.override_dirname} 4 | sweep: 5 | dir: TEST/${dataset.part}_${dataset.poisson_scale}/${score_based_model.name}/${sampling.name}/${now:%Y-%m-%d-%H-%M-%S}/ 6 | subdir: ${hydra.job.override_dirname} 7 | job: 8 | config: 9 | override_dirname: 10 | exclude_keys: 11 | - dataset.name 12 | - dataset.part 13 | - dataset.poisson_scale 14 | - sampling.name 15 | - sampling 16 | - score_based_model.name 17 | - score_based_model 18 | 19 | defaults: 20 | - _self_ 21 | - dataset: brainweb2Ddataset 22 | - score_based_model: vpsde_image_scale 23 | - sampling: naive 24 | seed: 42 25 | num_images: 1 26 | device: cuda -------------------------------------------------------------------------------- /src/brainweb_2d/get_test_subset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | if __name__=="__main__": 4 | 5 | names = ["test", "test_tumour"] 6 | noise_levels = ["_2.5", "_5", "_7.5", "_10", "_50", "_100"] 7 | samples = [5,15,25,35,45,55,65,75] 8 | 9 | for name in names: 10 | for noise in noise_levels: 11 | data = torch.load(f"path_to/noisy/noisy_{name+noise}.pt") 12 | subset_data = {} 13 | for key in data.keys(): 14 | subset_data[key] = data[key][samples, ...] 15 | torch.save(subset_data, f"path_to/noisy/noisy_subset_{name+noise}.pt") 16 | 17 | for name in names: 18 | data = torch.load(f"path_to/clean/clean_{name}.pt") 19 | subset_data = {} 20 | for key in data.keys(): 21 | subset_data[key] = data[key][samples, ...] 22 | torch.save(subset_data, f"path_to/clean/clean_subset_{name}.pt") -------------------------------------------------------------------------------- /src/third_party_models/openai_unet/README.md: -------------------------------------------------------------------------------- 1 | ### Open AI guided diffusion model (https://github.com/openai/guided-diffusion). 2 | 3 | As used **Diffusion Models Beat GANs on Image Synthesis** (https://arxiv.org/pdf/2105.05233.pdf). 4 | 5 | I added output rescaling as used by Song so that the output is in the right range from the start of training. 6 | 7 | 8 | Defaults for the UNet 9 | res = dict( 10 | image_size=64, 11 | num_channels=128, 12 | num_res_blocks=2, 13 | num_heads=4, 14 | num_heads_upsample=-1, 15 | num_head_channels=-1, 16 | attention_resolutions="16,8", 17 | channel_mult="", 18 | dropout=0.0, 19 | class_cond=False, 20 | use_checkpoint=False, 21 | use_scale_shift_norm=True, 22 | resblock_updown=False, 23 | use_fp16=False, 24 | use_new_attention_order=False, 25 | ) -------------------------------------------------------------------------------- /configs/final_reconstruction.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: FINAL_SINGLE/${dataset.part}_${dataset.poisson_scale}/${score_based_model.name}/${sampling.name}/${now:%Y-%m-%d-%H-%M-%S} 4 | sweep: 5 | dir: FINAL_RERUN/${dataset.part}_${dataset.poisson_scale}/${score_based_model.name}/${sampling.name}/${now:%Y-%m-%d-%H-%M-%S} 6 | job: 7 | config: 8 | override_dirname: 9 | exclude_keys: 10 | - dataset.name 11 | - dataset.part 12 | - dataset.poisson_scale 13 | - sampling.name 14 | - sampling 15 | - score_based_model.name 16 | - score_based_model 17 | 18 | defaults: 19 | - _self_ 20 | - dataset: brainweb2Ddataset 21 | - score_based_model: vesde_image_scale 22 | - sampling: naive 23 | - override hydra/job_logging: colorlog 24 | - override hydra/hydra_logging: colorlog 25 | seed: 42 26 | num_images: 1 27 | device: cuda 28 | dump_path: path_to/dump 29 | -------------------------------------------------------------------------------- /results/3D_final/results_analysis_3D.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from result_util_3D import get_unique_swept_datafit_strengths, save_sweep_dicts, get_sweep_mean_results\n", 10 | "import torch\n", 11 | "\n", 12 | "unique_swept_datafit_strengths = get_unique_swept_datafit_strengths(base_path=\"/home/user/sirf/coordinators/BSREM/**/volume.pt\")\n", 13 | "\n", 14 | "save_sweep_dicts(unique_swept_datafit_strengths)" 15 | ] 16 | } 17 | ], 18 | "metadata": { 19 | "kernelspec": { 20 | "display_name": "pet", 21 | "language": "python", 22 | "name": "python3" 23 | }, 24 | "language_info": { 25 | "codemirror_mode": { 26 | "name": "ipython", 27 | "version": 3 28 | }, 29 | "file_extension": ".py", 30 | "mimetype": "text/x-python", 31 | "name": "python", 32 | "nbconvert_exporter": "python", 33 | "pygments_lexer": "ipython3", 34 | "version": "3.10.11" 35 | }, 36 | "orig_nbformat": 4 37 | }, 38 | "nbformat": 4, 39 | "nbformat_minor": 2 40 | } 41 | -------------------------------------------------------------------------------- /src/sirf/herman_meyer.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | def herman_meyer_order(n): 4 | # Assuming that the subsets are in geometrical order 5 | n_variable = n 6 | i = 2 7 | factors = [] 8 | while i * i <= n_variable: 9 | if n_variable % i: 10 | i += 1 11 | else: 12 | n_variable //= i 13 | factors.append(i) 14 | if n_variable > 1: 15 | factors.append(n_variable) 16 | n_factors = len(factors) 17 | order = [0 for _ in range(n)] 18 | value = 0 19 | for factor_n in range(n_factors): 20 | n_rep_value = 0 21 | if factor_n == 0: 22 | n_change_value = 1 23 | else: 24 | n_change_value = math.prod(factors[:factor_n]) 25 | for element in range(n): 26 | mapping = value 27 | n_rep_value += 1 28 | if n_rep_value >= n_change_value: 29 | value = value + 1 30 | n_rep_value = 0 31 | if value == factors[factor_n]: 32 | value = 0 33 | order[element] = order[element] + math.prod(factors[factor_n+1:]) * mapping 34 | return order 35 | -------------------------------------------------------------------------------- /configs/3D_reconstruction.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: FINAL_3D/${dataset.name}_${dataset.count_level}_${dataset.realisation}_${dataset.tracer}/${sampling.name}_iters_${sampling.num_iterations}_subsets_${sampling.num_subsets}_lambda_${sampling.lambd}_beta_${sampling.beta}/ 4 | sweep: 5 | dir: FINAL_3D/${dataset.name}_${dataset.count_level}_${dataset.realisation}_${dataset.tracer}/${sampling.name}_iters_${sampling.num_iterations}_subsets_${sampling.num_subsets}_lambda_${sampling.lambd}_beta_${sampling.beta}/ 6 | job: 7 | config: 8 | override_dirname: 9 | exclude_keys: 10 | - dataset 11 | - dataset.name 12 | - dataset.count_level 13 | - dataset.realisation 14 | - sampling.name 15 | - sampling 16 | - sampling.num_iterations 17 | - sampling.num_subsets 18 | - sampling.lambd 19 | - sampling.beta 20 | - score_based_model.name 21 | - score_based_model 22 | 23 | defaults: 24 | - _self_ 25 | - dataset: brainweb3D 26 | - score_based_model: vpsde_image_scale 27 | - sampling: dds_3D 28 | seed: 42 29 | num_images: 1 30 | device: cuda 31 | dump_path: path_to/coordinators/dump -------------------------------------------------------------------------------- /results/2D_final/results_analysis_2D.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "from result_util import *\n", 11 | "unique_swept_datafit_strengths = get_unique_swept_datafit_strengths(base_path=\"/home/user/sirf/coordinators/RDP/test_10/*.pt\")\n", 12 | "save_sweep_dicts(unique_swept_datafit_strengths=unique_swept_datafit_strengths)\n", 13 | "unique_swept_datafit_strengths = get_unique_swept_datafit_strengths(base_path=\"E:/projects/pet_score_model/coordinators/RDP/**/*.pt\")\n", 14 | "save_sweep_dicts(unique_swept_datafit_strengths=unique_swept_datafit_strengths)" 15 | ] 16 | } 17 | ], 18 | "metadata": { 19 | "kernelspec": { 20 | "display_name": "pet", 21 | "language": "python", 22 | "name": "python3" 23 | }, 24 | "language_info": { 25 | "codemirror_mode": { 26 | "name": "ipython", 27 | "version": 3 28 | }, 29 | "file_extension": ".py", 30 | "mimetype": "text/x-python", 31 | "name": "python", 32 | "nbconvert_exporter": "python", 33 | "pygments_lexer": "ipython3", 34 | "version": "3.10.11" 35 | }, 36 | "orig_nbformat": 4 37 | }, 38 | "nbformat": 4, 39 | "nbformat_minor": 2 40 | } 41 | -------------------------------------------------------------------------------- /results/lpd_unet/normalisation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class Normalisation: 4 | def __init__(self, type_norm): 5 | # type_norm: none,data_mean,data_corrected_mean,osem_mean,osem_max 6 | if type_norm == "none": 7 | self.norm = self.none 8 | elif type_norm == "data_mean": 9 | self.norm = self.data_mean 10 | elif type_norm == "data_corrected_mean": 11 | self.norm = self.data_corrected_mean 12 | elif type_norm == "osem_mean": 13 | self.norm = self.osem_mean 14 | elif type_norm == "osem_max": 15 | self.norm = self.osem_max 16 | else: 17 | raise ValueError("normalisation type not recognised") 18 | 19 | def __call__(self, osem, measurements, contamination_factor): 20 | return self.norm(osem, measurements, contamination_factor) 21 | 22 | def none(self, osem, measurements, contamination_factor): 23 | norm = torch.ones_like(contamination_factor[:,0]) 24 | return norm 25 | 26 | def data_mean(self, osem, measurements, contamination_factor): 27 | norm = (measurements).mean(dim=[1,2]) 28 | return norm 29 | 30 | def data_corrected_mean(self, osem, measurements, contamination_factor): 31 | norm = (measurements - contamination_factor[...,None]).mean(dim=[1,2]) 32 | return norm 33 | 34 | def osem_mean(self, osem, measurements, contamination_factor): 35 | norm = osem.mean(dim=[1,2,3]) 36 | return norm 37 | 38 | def osem_max(self, osem, measurements, contamination_factor): 39 | norm = osem.view(osem.shape[0], -1).max(dim=-1).values 40 | return norm 41 | -------------------------------------------------------------------------------- /src/utils/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | from: https://github.com/educating-dip/score_based_model_baselines/blob/main/src/utils/losses.py 3 | """ 4 | 5 | 6 | import torch 7 | 8 | from .sde import HeatDiffusion 9 | 10 | def loss_fn(model, x, sde, eps=1e-5): 11 | 12 | """ 13 | The loss function for training score-based generative models. 14 | Args: 15 | model: A PyTorch model instance that represents a 16 | time-dependent score-based model. 17 | x: A mini-batch of training data. 18 | sde: the forward sde 19 | eps: A tolerance value for numerical stability. 20 | """ 21 | guided = (x.shape[1] == 2) 22 | if guided: # guided 23 | x_mri = x[:, 1, :,:].unsqueeze(1) 24 | x = x[:, 0, :, :].unsqueeze(1) 25 | 26 | random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps 27 | z = torch.randn_like(x) 28 | 29 | mean, std = sde.marginal_prob(x, random_t) # for VESDE the mean is just x 30 | perturbed_x = mean + z * std[:, None, None, None] 31 | 32 | if guided: 33 | x_input = torch.cat([perturbed_x, x_mri], dim=1) 34 | else: 35 | x_input = perturbed_x 36 | 37 | score = model(x_input, random_t) 38 | 39 | if isinstance(sde, HeatDiffusion): 40 | """ 41 | The loss function for training score-based generative models. 42 | Using the soft diffusion target from 43 | Daras et al. (2022) [https://arxiv.org/pdf/2209.05442.pdf] 44 | """ 45 | r_t = x - perturbed_x 46 | mean_model, _ = sde.marginal_prob(score - r_t, random_t) 47 | loss = torch.mean(torch.sum((std[:,None,None,None].pow(-1)*mean_model)**2, dim=(1,2,3))) 48 | 49 | else: 50 | loss = torch.mean(torch.sum((score * std[:, None, None, None] + z)**2, dim=(1,2,3))) 51 | 52 | return loss 53 | -------------------------------------------------------------------------------- /src/sirf/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def normalize(x, inplace=False): 4 | # Exploding pixel at edge of FOV we need to ignore... 5 | #mask = np.zeros_like(x) 6 | #mask[:, 50:201, 50:201] = 1 7 | #x = mask*x 8 | if inplace: 9 | x -= x.min() 10 | x /= x.max() 11 | else: 12 | x = x - x.min() 13 | x = x / x.max() 14 | return x 15 | 16 | class ComputeImageMetrics: 17 | 18 | def __init__(self, ROIs_a, ROIs_b, emissions_a, emissions_b, names_a, names_b): 19 | 20 | self.emissions_a = emissions_a 21 | self.emissions_b = emissions_b 22 | 23 | # Only use voxels where mask value is 1 24 | def _threshold_ROI_mask(ROIs_mask): 25 | 26 | for i in range(len(ROIs_mask)): 27 | ROIs_mask[i][ROIs_mask[i] < 1] = 0 28 | return ROIs_mask 29 | 30 | self.ROIs_a = _threshold_ROI_mask(ROIs_a) 31 | self.ROIs_b = _threshold_ROI_mask(ROIs_b) 32 | self.names_a = names_a 33 | self.names_b = names_b 34 | 35 | def _compute_std(self, x): 36 | # STANDARD DEVIATION 37 | # abar = ROI average uptake 38 | # Ka = number of ROIs 39 | # bbar = background average uptake 40 | # Kb = number of background ROIs 41 | # CRC = 1/R \sum_{r=1}^{R} (abar/bbar - 1)/(atrue/btrue - 1) 42 | STDval = [] 43 | for i in range(len(self.ROIs_b)): 44 | STDval.append(np.std(x[np.nonzero( 45 | self.ROIs_b[i] 46 | )])) 47 | return STDval 48 | 49 | def _compute_crc(self, x): 50 | # CONTRAST RECOVERY COEFFICIENT 51 | # abar = ROI average uptake 52 | # Ka = number of ROIs 53 | # bbar = background average uptake 54 | # Kb = number of background ROIs 55 | # CRC = 1/R \sum_{r=1}^{R} (abar/bbar - 1)/(atrue/btrue - 1) 56 | CRCval = [] 57 | for i in range(len(self.ROIs_a)): 58 | abar = np.mean( 59 | x[np.nonzero(self.ROIs_a[i] 60 | )] 61 | ) 62 | bbar = np.mean(x[np.nonzero( 63 | self.ROIs_b[i] 64 | )] 65 | ) 66 | atrue = self.emissions_a[i] 67 | btrue = self.emissions_b[i] 68 | CRCval.append((abar / bbar - 1) / (atrue / btrue - 1)) 69 | return CRCval 70 | 71 | def get_all_metrics(self, x): 72 | 73 | return self._compute_crc(x), self._compute_std(x) 74 | -------------------------------------------------------------------------------- /src/sirf/datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sirf.STIR as pet 3 | import os 4 | 5 | class SIRF3DDataset(): 6 | def __init__(self, base_path, name, tracer, count_level, realisation, num_subsets): 7 | if "thorax3D" == name: 8 | self.thorax(base_path, tracer, count_level, realisation, num_subsets) 9 | elif "brainweb3D" == name: 10 | self.brainweb(base_path, tracer, count_level, realisation, num_subsets) 11 | else: 12 | raise NotImplementedError("Dataset not implemented") 13 | 14 | def brainweb(self, base_path, tracer, count_level, realisation, num_subsets): 15 | if count_level == "low": 16 | count = "4e+07" 17 | elif count_level == "high": 18 | count = "2e+08" 19 | assert tracer in ["FDG", "Amyloid"], "must be FDG or Amyloid" 20 | bin_eff = pet.AcquisitionData(base_path + f"/{tracer}_bin_eff_hr.hs") 21 | sensitivity_image_sirf = pet.ImageData(base_path + f"/{tracer}_sensitivity_image.hv") 22 | osem = pet.ImageData(base_path + f"/noisy/{tracer}_osem_{count}_{realisation}.hv") 23 | self.osem = torch.tensor(osem.as_array()).to("cuda").float().unsqueeze(1) 24 | noisy_measurements_name = base_path + \ 25 | f"/noisy/{tracer}_noisy_measurements_{count}_{realisation}.hs" 26 | 27 | measurements = pet.AcquisitionData(noisy_measurements_name) 28 | 29 | self.image_sirf = osem.get_uniform_copy(1.) 30 | views = measurements.shape[2] 31 | 32 | self.objectives_sirf = [] 33 | for i in range(num_subsets): 34 | subset_idxs = list(range(views))[i:][::num_subsets] 35 | # Get subset of data 36 | noisy_measurements_sirf = measurements.get_subset(subset_idxs) 37 | bin_eff_subset = bin_eff.get_subset(subset_idxs) 38 | 39 | # SET UP THE ACQUISITION MODEL 40 | sensitivity_factors = pet.AcquisitionSensitivityModel(bin_eff_subset) 41 | acquisition_model = pet.AcquisitionModelUsingParallelproj() 42 | acquisition_model.set_acquisition_sensitivity(sensitivity_factors) 43 | acquisition_model.set_up(noisy_measurements_sirf, osem) 44 | 45 | objective_sirf = pet.make_Poisson_loglikelihood(noisy_measurements_sirf, acq_model = acquisition_model) 46 | objective_sirf.set_up(self.image_sirf) 47 | self.objectives_sirf.append(objective_sirf) 48 | print("Objective function", i+1, "of", num_subsets, "set up") 49 | 50 | 51 | self.sensitivity_image = torch.tensor(sensitivity_image_sirf.clone().as_array()).to("cuda").float().unsqueeze(1) 52 | self.fov = torch.zeros_like(self.sensitivity_image) 53 | self.fov[self.sensitivity_image!=0] = 1. 54 | 55 | -------------------------------------------------------------------------------- /main_score_based_models_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from datetime import datetime 4 | import yaml 5 | from torch.utils.data import DataLoader 6 | 7 | from src import (get_standard_sde, get_standard_score, BrainWebScoreTrain, 8 | score_model_simple_trainer) 9 | 10 | 11 | 12 | #from configs.ellipses_configs import get_config 13 | from configs.default_config import get_default_configs 14 | 15 | def coordinator(): 16 | 17 | config = get_default_configs() 18 | 19 | if config.guided_p_uncond is not None: 20 | print("Train Guided Score Model") 21 | assert config.model.in_channels == 2, "input channels = 2 for guided model" 22 | 23 | sde = get_standard_sde(config) 24 | score_model = get_standard_score(config, sde, use_ema=True, load_model=False) 25 | 26 | brain_dataset = BrainWebScoreTrain(base_path="/localdata/AlexanderDenker/pet_data/" ,#"E:/projects/pet_score_model/src/brainweb_2d/", 27 | guided= True if config.guided_p_uncond is not None else False, 28 | normalisation = config.normalisation) 29 | 30 | train_dl = DataLoader(brain_dataset,batch_size=config.training.batch_size, num_workers=6) 31 | 32 | print(f" # Parameters: {sum([p.numel() for p in score_model.parameters()]) }") 33 | today = datetime.now() 34 | 35 | if config.guided_p_uncond is not None: 36 | log_dir = '/localdata/AlexanderDenker/pet_score_based/guided/' + config.sde.type 37 | 38 | else: 39 | log_dir = '/localdata/AlexanderDenker/pet_score_based/' + config.sde.type 40 | 41 | if not os.path.exists(log_dir): 42 | os.makedirs(log_dir) 43 | 44 | found_version = False 45 | version_num = 0 46 | while not found_version: 47 | if os.path.isdir(os.path.join(log_dir, "version_" + str(version_num))): 48 | version_num += 1 49 | else: 50 | found_version = True 51 | 52 | log_dir = os.path.join(log_dir, "version_" + str(version_num)) 53 | os.makedirs(log_dir) 54 | 55 | with open(os.path.join(log_dir,'report.yaml'), 'w') as file: 56 | yaml.dump(config, file) 57 | 58 | score_model_simple_trainer( 59 | score=score_model.to(config.device), 60 | sde=sde, 61 | train_dl=train_dl, 62 | optim_kwargs={ 63 | 'epochs': config.training.epochs, 64 | 'lr': config.training.lr, 65 | 'ema_warm_start_steps': config.training.ema_warm_start_steps, 66 | 'log_freq': config.training.log_freq, 67 | 'ema_decay': config.training.ema_decay 68 | }, 69 | val_kwargs={ 70 | 'batch_size': config.validation.batch_size, 71 | 'num_steps': config.validation.num_steps, 72 | 'snr': config.validation.snr, 73 | 'eps': config.validation.eps, 74 | 'sample_freq' : config.validation.sample_freq 75 | }, 76 | device=config.device, 77 | log_dir=log_dir, 78 | guided_p_uncond=config.guided_p_uncond 79 | ) 80 | 81 | 82 | if __name__ == '__main__': 83 | coordinator() -------------------------------------------------------------------------------- /coordinators/3D_dip_baseline.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch, sys, os, time 3 | import sirf.STIR as pet 4 | import matplotlib.pyplot as plt 5 | sys.path.append(os.path.dirname(os.getcwd())) 6 | from src import SIRF3DProjection, SIRF3DDataset 7 | pet.set_verbosity(0) 8 | pet.AcquisitionData.set_storage_scheme("memory") 9 | #pet.MessageRedirector(info=None, warn=None, errr=None) 10 | 11 | def check_folder_create(path, folder_name): 12 | CHECK_FOLDER = os.path.isdir(path+folder_name) 13 | if not CHECK_FOLDER: 14 | os.makedirs(path+folder_name) 15 | print("created folder : ", path+folder_name) 16 | 17 | 18 | # name: brainweb3D 19 | # count_level: low 20 | # realisation: 0 21 | # tracer: FDG 22 | # base_path: /home/user/sirf/src/sirf/brainweb_3D 23 | 24 | # name: thorax3D 25 | # count_level: low 26 | # realisation: 0 27 | # tracer: 0 28 | # base_path: /home/user/sirf/D690XCATnonTOF 29 | 30 | # GET THE DATA 31 | # measurements_subsets_sirf, acquisition_models_sirf, sensitivity_image, fov, image_sirf, osem, measurements_subsets 32 | 33 | name = "brainweb3D" 34 | count_level = "low" 35 | realisation = "0" 36 | tracer = "FDG" 37 | base_path = "path_to/src/sirf/brainweb_3D" 38 | num_subsets = 28 39 | num_iterations = 11200 40 | betas = [0.0,0.2,0.05,0.075,0.3] 41 | max_norm = 1000 42 | for tracer in ["FDG", "Amyloid"]: 43 | for realisation in range(5): 44 | dataset = SIRF3DDataset(base_path, name, tracer, count_level, realisation, num_subsets) 45 | dataset_name = f"{name}_{count_level}_{realisation}_{tracer}" 46 | check_folder_create("path_to/coordinators/DIP/", dataset_name) 47 | for beta in betas: 48 | projection = SIRF3DProjection(image_sirf = dataset.image_sirf.clone(), 49 | objectives_sirf = dataset.objectives_sirf, 50 | sensitivity_image = dataset.sensitivity_image.clone(), 51 | fov = dataset.fov.clone(), 52 | num_subsets = num_subsets, 53 | num_iterations = num_iterations) 54 | 55 | 56 | projection.set_beta(beta) 57 | print(f"beta: {projection.beta}") 58 | x_new = projection.get_DIP(path = f"path_to/coordinators/DIP/{dataset_name}/", 59 | x = dataset.osem.clone(), 60 | beta = beta, 61 | lr = 1e-3, 62 | max_norm = max_norm) 63 | 64 | check_folder_create(f"path_to/coordinators/DIP/{dataset_name}/", f"DIP_beta_{beta}") 65 | path = f"path_to/coordinators/DIP/{dataset_name}/DIP_beta_{beta}/" 66 | res_dict = {"objective_values": projection.objective_values, 67 | "times": projection.times} 68 | torch.save(res_dict,path+"dict.pt") -------------------------------------------------------------------------------- /configs/default_config.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def get_default_configs(): 5 | 6 | config = ml_collections.ConfigDict() 7 | config.device = "cuda" 8 | config.seed = 1 9 | 10 | 11 | config.guided_p_uncond = 0.1 # 0.1 # None 12 | 13 | config.normalisation = "image_scale" 14 | # sde configs 15 | config.sde = sde = ml_collections.ConfigDict() 16 | sde.type = "vesde" # "vpsde", "vesde" "heatdiffusion" 17 | 18 | # the largest noise scale sigma_max was choosen according to Technique 1 from [https://arxiv.org/pdf/2006.09011.pdf], 19 | if sde.type == "vesde": 20 | sde.sigma_min = 0.01 21 | sde.sigma_max = 40. #for 40 vpsde, 0.1 for heatidffusion 22 | if sde.type == "vpsde": 23 | # only for vpsde 24 | sde.beta_min = 0.1 25 | sde.beta_max = 10 26 | 27 | if sde.type == "heatdiffusion": 28 | # used for HeatDiffusion 29 | sde.T_max = 64 30 | 31 | # training configs 32 | config.training = training = ml_collections.ConfigDict() 33 | training.batch_size = 32 34 | training.epochs = 2000 35 | training.log_freq = 25 36 | training.lr = 1e-4 37 | training.ema_decay = 0.999 38 | training.ema_warm_start_steps = 50 # only start updating ema after this amount of steps 39 | 40 | # validation configs 41 | config.validation = validation = ml_collections.ConfigDict() 42 | validation.batch_size = 8 43 | validation.snr = 0.05 44 | validation.num_steps = 500 45 | validation.eps = 1e-4 46 | validation.sample_freq = 0 #10 47 | 48 | # sampling configs 49 | config.sampling = sampling = ml_collections.ConfigDict() 50 | sampling.batch_size = 1 51 | sampling.snr = 0.05 52 | sampling.num_steps = 1000 53 | sampling.eps = 1e-4 54 | sampling.sampling_strategy = "predictor_corrector" 55 | sampling.start_time_step = 0 56 | 57 | sampling.load_model_from_path = "/localdata/AlexanderDenker/pet_score_based/checkpoints/version_02" 58 | sampling.model_name = "model.pt" 59 | 60 | 61 | # data configs - specify in other configs 62 | config.data = ml_collections.ConfigDict() 63 | config.data.im_size = 128 64 | 65 | # forward operator config - specify in other configs 66 | config.forward_op = ml_collections.ConfigDict() 67 | 68 | # model configs 69 | config.model = model = ml_collections.ConfigDict() 70 | model.model_name = 'OpenAiUNetModel' 71 | if config.guided_p_uncond == None: 72 | model.in_channels = 1 73 | else: 74 | model.in_channels = 2 75 | model.model_channels = 64 76 | model.out_channels = 1 77 | model.num_res_blocks = 3 78 | model.attention_resolutions = [32, 16, 8] 79 | model.channel_mult = (1, 2, 2, 4, 4) 80 | model.conv_resample = True 81 | model.dims = 2 82 | model.num_heads = 4 83 | model.num_head_channels = -1 84 | model.num_heads_upsample = -1 85 | model.use_scale_shift_norm = True 86 | model.resblock_updown = False 87 | model.use_new_attention_order = False 88 | model.max_period = 0.005 89 | 90 | 91 | return config -------------------------------------------------------------------------------- /coordinators/3D_bsrem_rdp_rdpz_baseline.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch, sys, os, time 3 | import sirf.STIR as pet 4 | import matplotlib.pyplot as plt 5 | sys.path.append(os.path.dirname(os.getcwd())) 6 | from src import herman_meyer_order, SIRF3DProjection, SIRF3DDataset 7 | pet.set_verbosity(0) 8 | pet.AcquisitionData.set_storage_scheme("memory") 9 | #pet.MessageRedirector(info=None, warn=None, errr=None) 10 | 11 | def check_folder_create(path, folder_name): 12 | CHECK_FOLDER = os.path.isdir(path+folder_name) 13 | if not CHECK_FOLDER: 14 | os.makedirs(path+folder_name) 15 | print("created folder : ", path+folder_name) 16 | 17 | 18 | # name: brainweb3D 19 | # count_level: low 20 | # realisation: 0 21 | # tracer: FDG 22 | # base_path: /home/user/sirf/src/sirf/brainweb_3D 23 | 24 | # name: thorax3D 25 | # count_level: low 26 | # realisation: 0 27 | # tracer: 0 28 | # base_path: /home/user/sirf/D690XCATnonTOF 29 | 30 | # GET THE DATA 31 | # measurements_subsets_sirf, acquisition_models_sirf, sensitivity_image, fov, image_sirf, osem, measurements_subsets 32 | 33 | name = "brainweb3D" 34 | count_level = "low" 35 | realisation = "0" 36 | tracer = "FDG" 37 | base_path = "path_to/src/sirf/brainweb_3D" 38 | num_subsets = 28 39 | num_iterations = 2800 40 | 41 | for tracer in ["FDG", "Amyloid"]: 42 | for realisation in range(5): 43 | dataset = SIRF3DDataset(base_path, name, tracer, count_level, realisation, num_subsets) 44 | dataset_name = f"{name}_{count_level}_{realisation}_{tracer}" 45 | check_folder_create("path_to/coordinators/BSREM/", dataset_name) 46 | for prior in ["rdpz", "rdp"]: 47 | if prior == "rdp": 48 | betas = [0.5,0.767,1.18,1.81,2.77,4.25,6.52,10.] 49 | elif prior == "rdpz": 50 | betas = [10.,15.3,23.5,36.1,55.4,85.,130.,200.] 51 | for beta in betas: 52 | projection = SIRF3DProjection(image_sirf = dataset.image_sirf.clone(), 53 | objectives_sirf = dataset.objectives_sirf, 54 | sensitivity_image = dataset.sensitivity_image.clone(), 55 | fov = dataset.fov.clone(), 56 | num_subsets = num_subsets, 57 | num_iterations = num_iterations) 58 | prior_name = f"{prior}_beta_{beta}" 59 | check_folder_create(f"path_to/coordinators/BSREM/{dataset_name}/", prior_name) 60 | full_path = f"path_to/coordinators/BSREM/{dataset_name}/{prior_name}/" 61 | projection.set_beta(beta) 62 | print(f"beta: {projection.beta}") 63 | x_new = projection.get_bsrem(full_path, dataset.osem.clone(), eta = 0.1, prior = prior, image_diff_tol = 1e-5) 64 | plt.plot(projection.objective_values) 65 | plt.title("Objective value") 66 | plt.savefig(f"{full_path}/objective.png", dpi=300, bbox_inches="tight") 67 | plt.close() 68 | 69 | torch.save(x_new.squeeze().cpu(), f"{full_path}/volume.pt") -------------------------------------------------------------------------------- /src/brainweb_2d/tumor_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage.morphology import area_closing, isotropic_erosion, isotropic_dilation 3 | from skimage.draw import ellipse 4 | from skimage.filters import gaussian 5 | 6 | 7 | def Generate2DTumors(image): 8 | seg = np.zeros_like(image) 9 | foreground, background = 1, 0 10 | seg[image <= 0] = background 11 | seg[image > 0] = foreground 12 | seg = isotropic_erosion(area_closing(seg, 1e4), 20) 13 | 14 | n_ellipses = np.random.choice([1,1,1,2,2,3]) 15 | # (r, c, r_radius, c_radius, shape=None, rotation=0.0) 16 | shape = image.shape 17 | tumour_rois = [] 18 | background = np.zeros_like(image) 19 | for n_ellipse in range(n_ellipses): 20 | r_radius = max(np.random.poisson(10.),4) 21 | c_radius = max(np.random.poisson(10.),4) 22 | 23 | tmp_seg = isotropic_erosion(seg, max(r_radius, c_radius)) 24 | tmp_coords = np.where(tmp_seg == 1) 25 | if tmp_coords[0].shape[0] == 0: 26 | break 27 | tmp_coord_idx = np.random.randint(0, tmp_coords[0].shape[0]) 28 | rr_tmp, cc_tmp = ellipse(tmp_coords[0][tmp_coord_idx], 29 | tmp_coords[1][tmp_coord_idx], 30 | r_radius, 31 | c_radius, 32 | shape=shape, 33 | rotation=np.random.uniform(0, 2*np.pi)) 34 | 35 | tumour = np.zeros_like(image) 36 | tumour[rr_tmp, cc_tmp] = 1. 37 | tumour_rois.append(tumour) 38 | 39 | intensity_factor = image.max()*np.random.uniform(1.3, 1.8) 40 | tmp_tumour = tumour * intensity_factor 41 | tmp_tumour = gaussian(tmp_tumour, sigma=np.random.uniform(1., 3.0)) 42 | 43 | image = np.maximum(tmp_tumour, image) 44 | 45 | tmp_background = isotropic_dilation(tumour, 20)*1. 46 | tmp_background *= seg 47 | background += tmp_background 48 | while len(tumour_rois) < 3: 49 | tumour_rois.append(np.zeros_like(image)) 50 | background[background > 0] = 1. 51 | for n_ellipse in range(n_ellipses): 52 | tmp_rmv = isotropic_dilation(tumour_rois[n_ellipse], 8)*1. 53 | background *= -1*(tmp_rmv-1) 54 | return image, background, np.stack(tumour_rois) 55 | 56 | 57 | 58 | 59 | if __name__=="__main__": 60 | import torch 61 | import matplotlib.pyplot as plt 62 | clean = torch.load("path_to/examples/data/clean/test_subset_clean.pt") 63 | image_clean = clean['reference'][2].squeeze().detach().numpy() 64 | from skimage.transform import resize 65 | image_clean = resize(image_clean, (image_clean.shape[0] * 2, image_clean.shape[1] * 2), 66 | anti_aliasing=True) 67 | image, background, tumour_rois = Generate2DTumors(image_clean) 68 | print(tumour_rois.shape) 69 | fig, ax = plt.subplots(1, 3, figsize=(15, 5)) 70 | add = np.zeros_like(image) 71 | for tumour in tumour_rois: 72 | add += tumour 73 | fig.colorbar(ax[0].imshow(image), ax=ax[0]) 74 | fig.colorbar(ax[1].imshow(image_clean + add*3 + background*3), ax=ax[1]) 75 | fig.colorbar(ax[2].imshow(image - image_clean + image.max()*background), ax=ax[2]) 76 | plt.show() -------------------------------------------------------------------------------- /src/third_party_models/openai_unet/nn_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 12 | class SiLU(nn.Module): 13 | def forward(self, x): 14 | return x * torch.sigmoid(x) 15 | 16 | 17 | class GroupNorm32(nn.GroupNorm): 18 | def forward(self, x): 19 | return super().forward(x.float()).type(x.dtype) 20 | 21 | 22 | def conv_nd(dims, *args, **kwargs): 23 | """ 24 | Create a 1D, 2D, or 3D convolution module. 25 | """ 26 | if dims == 1: 27 | return nn.Conv1d(*args, **kwargs) 28 | elif dims == 2: 29 | return nn.Conv2d(*args, **kwargs) 30 | elif dims == 3: 31 | return nn.Conv3d(*args, **kwargs) 32 | raise ValueError(f"unsupported dimensions: {dims}") 33 | 34 | 35 | def linear(*args, **kwargs): 36 | """ 37 | Create a linear module. 38 | """ 39 | return nn.Linear(*args, **kwargs) 40 | 41 | 42 | def avg_pool_nd(dims, *args, **kwargs): 43 | """ 44 | Create a 1D, 2D, or 3D average pooling module. 45 | """ 46 | if dims == 1: 47 | return nn.AvgPool1d(*args, **kwargs) 48 | elif dims == 2: 49 | return nn.AvgPool2d(*args, **kwargs) 50 | elif dims == 3: 51 | return nn.AvgPool3d(*args, **kwargs) 52 | raise ValueError(f"unsupported dimensions: {dims}") 53 | 54 | 55 | def zero_module(module): 56 | """ 57 | Zero out the parameters of a module and return it. 58 | """ 59 | for p in module.parameters(): 60 | p.detach().zero_() 61 | return module 62 | 63 | 64 | def scale_module(module, scale): 65 | """ 66 | Scale the parameters of a module and return it. 67 | """ 68 | for p in module.parameters(): 69 | p.detach().mul_(scale) 70 | return module 71 | 72 | 73 | def mean_flat(tensor): 74 | """ 75 | Take the mean over all non-batch dimensions. 76 | """ 77 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 78 | 79 | 80 | def normalization(channels): 81 | """ 82 | Make a standard normalization layer. 83 | :param channels: number of input channels. 84 | :return: an nn.Module for normalization. 85 | """ 86 | return GroupNorm32(32, channels) 87 | 88 | 89 | def timestep_embedding(timesteps, dim, max_period=10000): 90 | """ 91 | Create sinusoidal timestep embeddings. 92 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 93 | These may be fractional. 94 | :param dim: the dimension of the output. 95 | :param max_period: controls the minimum frequency of the embeddings. 96 | :return: an [N x dim] Tensor of positional embeddings. 97 | """ 98 | half = dim // 2 99 | freqs = torch.exp( 100 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 101 | ).to(device=timesteps.device) 102 | args = timesteps[:, None].float() * freqs[None] 103 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 104 | if dim % 2: 105 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 106 | return embedding 107 | 108 | -------------------------------------------------------------------------------- /src/sirf/dip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | class Block(nn.Module): 6 | def __init__(self, ch_in, ch_out, stride): 7 | super(Block, self).__init__() 8 | self.block = nn.Sequential( 9 | nn.Conv3d(ch_in, ch_out, 3, stride=stride, padding=1), 10 | nn.BatchNorm3d(ch_out), 11 | nn.LeakyReLU(inplace=True)) 12 | def forward(self, x): 13 | return self.block(x) 14 | 15 | class PETUNet(nn.Module): 16 | def __init__(self, ch = 16, size = (90,128,128)): 17 | super(PETUNet, self).__init__() 18 | # Encoder 19 | self.block1 = Block(ch_in = 1, ch_out = ch, stride = 1) 20 | self.block2 = Block(ch_in = ch, ch_out = ch, stride = 1) 21 | self.stridedblock1 = Block(ch_in = ch, ch_out = ch, stride = 2) 22 | 23 | self.block3 = Block(ch_in = ch, ch_out = 2*ch, stride = 1) 24 | self.block4 = Block(ch_in = 2*ch, ch_out = 2*ch, stride = 1) 25 | self.stridedblock2 = Block(ch_in = 2*ch, ch_out = 2*ch, stride = 2) 26 | 27 | self.block5 = Block(ch_in = 2*ch, ch_out = 2*2*ch, stride = 1) 28 | self.block6 = Block(ch_in = 2*2*ch, ch_out = 2*2*ch, stride = 1) 29 | self.stridedblock3 = Block(ch_in = 2*2*ch, ch_out = 2*2*ch, stride = 2) 30 | 31 | self.block7 = Block(ch_in = 2*2*ch, ch_out = 2*2*2*ch, stride = 1) 32 | self.block8 = Block(ch_in = 2*2*2*ch, ch_out = 2*2*2*ch, stride = 1) 33 | 34 | # Decoder 35 | self.block9 = Block(ch_in = 2*2*2*ch, ch_out = 2*2*ch, stride = 1) 36 | self.upsample1 = nn.Upsample(size=tuple(math.ceil(si/4) for si in size), mode='trilinear', align_corners=True) 37 | 38 | self.block10 = Block(ch_in = 2*2*2*ch, ch_out = 2*2*ch, stride = 1) 39 | self.block11 = Block(ch_in = 2*2*ch, ch_out = 2*ch, stride = 1) 40 | self.upsample2 = nn.Upsample(size=tuple(math.ceil(si/2) for si in size), mode='trilinear', align_corners=True) 41 | 42 | self.block12 = Block(ch_in = 2*2*ch, ch_out = 2*ch, stride = 1) 43 | self.block13 = Block(ch_in = 2*ch, ch_out = ch, stride = 1) 44 | self.upsample3 = nn.Upsample(size=size, mode='trilinear', align_corners=True) 45 | 46 | self.block14 = Block(ch_in = 2*ch, ch_out = ch, stride = 1) 47 | self.block15 = Block(ch_in = ch, ch_out = 1, stride = 1) 48 | 49 | self.output = nn.ReLU() 50 | 51 | def forward(self, x): 52 | # Encoder 53 | x = self.block1(x) 54 | x_skip_1 = self.block2(x) 55 | x = self.stridedblock1(x_skip_1) 56 | 57 | x = self.block3(x) 58 | x_skip_2 = self.block4(x) 59 | x = self.stridedblock2(x_skip_2) 60 | 61 | x = self.block5(x) 62 | x_skip_3 = self.block6(x) 63 | x = self.stridedblock3(x_skip_3) 64 | 65 | x = self.block7(x) 66 | x = self.block8(x) 67 | 68 | # Decoder 69 | x = self.block9(x) 70 | x = self.upsample1(x) 71 | 72 | x = torch.cat([x, x_skip_3], dim=1) 73 | x = self.block10(x) 74 | x = self.block11(x) 75 | x = self.upsample2(x) 76 | 77 | x = torch.cat([x, x_skip_2], dim=1) 78 | x = self.block12(x) 79 | x = self.block13(x) 80 | x = self.upsample3(x) 81 | 82 | x = torch.cat([x, x_skip_1], dim=1) 83 | x = self.block14(x) 84 | x = self.block15(x) 85 | return self.output(x) 86 | 87 | if __name__ == "__main__": 88 | model = PETUNet() 89 | print(sum(p.numel() for p in model.parameters())) -------------------------------------------------------------------------------- /src/utils/ema.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | 5 | # Taken from https://github.com/yang-song/score_sde_pytorch/blob/cb1f359f4aadf0ff9a5e122fe8fffc9451fd6e44/models/ema.py#L10 6 | class ExponentialMovingAverage: 7 | """ 8 | Maintains (exponential) moving average of a set of parameters. 9 | """ 10 | 11 | def __init__(self, parameters, decay, use_num_updates=True): 12 | """ 13 | Args: 14 | parameters: Iterable of `torch.nn.Parameter`; usually the result of 15 | `model.parameters()`. 16 | decay: The exponential decay. 17 | use_num_updates: Whether to use number of updates when computing 18 | averages. 19 | """ 20 | if decay < 0.0 or decay > 1.0: 21 | raise ValueError('Decay must be between 0 and 1') 22 | self.decay = decay 23 | self.num_updates = 0 if use_num_updates else None 24 | self.shadow_params = [p.clone().detach() 25 | for p in parameters if p.requires_grad] 26 | self.collected_params = [] 27 | 28 | def update(self, parameters): 29 | """ 30 | Update currently maintained parameters. 31 | Call this every time the parameters are updated, such as the result of 32 | the `optimizer.step()` call. 33 | Args: 34 | parameters: Iterable of `torch.nn.Parameter`; usually the same set of 35 | parameters used to initialize this object. 36 | """ 37 | decay = self.decay 38 | if self.num_updates is not None: 39 | self.num_updates += 1 40 | decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates)) 41 | one_minus_decay = 1.0 - decay 42 | with torch.no_grad(): 43 | parameters = [p for p in parameters if p.requires_grad] 44 | for s_param, param in zip(self.shadow_params, parameters): 45 | s_param.sub_(one_minus_decay * (s_param - param)) 46 | 47 | def copy_to(self, parameters): 48 | """ 49 | Copy current parameters into given collection of parameters. 50 | Args: 51 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 52 | updated with the stored moving averages. 53 | """ 54 | parameters = [p for p in parameters if p.requires_grad] 55 | for s_param, param in zip(self.shadow_params, parameters): 56 | if param.requires_grad: 57 | param.data.copy_(s_param.data) 58 | 59 | def store(self, parameters): 60 | """ 61 | Save the current parameters for restoring later. 62 | Args: 63 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 64 | temporarily stored. 65 | """ 66 | self.collected_params = [param.clone() for param in parameters] 67 | 68 | def restore(self, parameters): 69 | """ 70 | Restore the parameters stored with the `store` method. 71 | Useful to validate the model with EMA parameters without affecting the 72 | original optimization process. Store the parameters before the 73 | `copy_to` method. After validation (or model saving), use this to 74 | restore the former parameters. 75 | Args: 76 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 77 | updated with the stored parameters. 78 | """ 79 | for c_param, param in zip(self.collected_params, parameters): 80 | param.data.copy_(c_param.data) 81 | 82 | def state_dict(self): 83 | return dict(decay=self.decay, num_updates=self.num_updates, 84 | shadow_params=self.shadow_params) 85 | 86 | def load_state_dict(self, state_dict): 87 | self.decay = state_dict['decay'] 88 | self.num_updates = state_dict['num_updates'] 89 | self.shadow_params = state_dict['shadow_params'] 90 | -------------------------------------------------------------------------------- /src/utils/trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from: https://github.com/educating-dip/score_based_model_baselines/blob/main/src/utils/trainer.py 3 | 4 | """ 5 | 6 | 7 | from typing import Optional, Any, Dict 8 | import os 9 | import torch 10 | import torchvision 11 | import numpy as np 12 | import functools 13 | 14 | from tqdm import tqdm 15 | from torch.utils.tensorboard import SummaryWriter 16 | from torch.optim import Adam 17 | from torch.utils.data import DataLoader 18 | from .losses import loss_fn 19 | from .ema import ExponentialMovingAverage 20 | from .sde import SDE 21 | 22 | from ..third_party_models import OpenAiUNetModel 23 | from ..samplers import BaseSampler, Euler_Maruyama_sde_predictor, Langevin_sde_corrector, soft_diffusion_momentum_sde_predictor 24 | 25 | 26 | def score_model_simple_trainer( 27 | score: OpenAiUNetModel, 28 | sde: SDE, 29 | train_dl: DataLoader, 30 | optim_kwargs: Dict, 31 | val_kwargs: Dict, 32 | device: Optional[Any] = None, 33 | log_dir: str ='./', 34 | guided_p_uncond: Optional[Any] = None, 35 | ) -> None: 36 | 37 | writer = SummaryWriter(log_dir=log_dir, comment='training-score-model') 38 | optimizer = Adam(score.parameters(), lr=optim_kwargs['lr']) 39 | for epoch in range(optim_kwargs['epochs']): 40 | avg_loss, num_items = 0, 0 41 | score.train() 42 | for idx, batch in tqdm(enumerate(train_dl), total = len(train_dl)): 43 | x = batch.to(device) 44 | if guided_p_uncond is not None: 45 | mask = torch.asarray(np.random.choice([0, 1], size=(len(x),), p=[guided_p_uncond, 1 - guided_p_uncond])).to(device) 46 | x[:,1,...] = x[:,1,...] * mask[:,None,None] 47 | loss = loss_fn(score, x, sde) 48 | optimizer.zero_grad() 49 | loss.backward() 50 | optimizer.step() 51 | 52 | avg_loss += loss.item() * x.shape[0] 53 | num_items += x.shape[0] 54 | if idx % optim_kwargs['log_freq'] == 0: 55 | writer.add_scalar('train/loss', loss.item(), epoch*len(train_dl) + idx) 56 | if epoch == 0 and idx == optim_kwargs['ema_warm_start_steps']: 57 | ema = ExponentialMovingAverage(score.parameters(), decay=optim_kwargs['ema_decay']) 58 | if idx > optim_kwargs['ema_warm_start_steps'] or epoch > 0: 59 | ema.update(score.parameters()) 60 | 61 | print('Average Loss: {:5f}'.format(avg_loss / num_items)) 62 | writer.add_scalar('train/mean_loss_per_epoch', avg_loss / num_items, epoch + 1) 63 | torch.save(score.state_dict(), os.path.join(log_dir,'model.pt')) 64 | torch.save(ema.state_dict(), os.path.join(log_dir, 'ema_model.pt')) 65 | if val_kwargs['sample_freq'] > 0: 66 | if epoch % val_kwargs['sample_freq']== 0: 67 | score.eval() 68 | 69 | predictor = functools.partial(Euler_Maruyama_sde_predictor, nloglik = None) 70 | corrector = functools.partial(Langevin_sde_corrector, nloglik = None) 71 | 72 | sample_kwargs={ 73 | 'num_steps': val_kwargs['num_steps'], 74 | 'start_time_step': 0, 75 | 'batch_size': val_kwargs['batch_size'] if guided_p_uncond is None else x.shape[0], 76 | 'im_shape': [1, *x.shape[2:]], 77 | 'eps': val_kwargs['eps'], 78 | 'predictor': {'aTweedy': False}, 79 | 'corrector': {'corrector_steps': 1} 80 | } 81 | 82 | if guided_p_uncond is not None: 83 | sample_kwargs['predictor'] = { 84 | "guidance_imgs": x[:,1,...].unsqueeze(1), 85 | "guidance_strength": 0.4 86 | } 87 | sample_kwargs['corrector'] = { 88 | "guidance_imgs": x[:,1,...].unsqueeze(1), 89 | "guidance_strength": 0.4 90 | } 91 | 92 | sampler = BaseSampler( 93 | score=score, 94 | sde=sde, 95 | predictor=predictor, 96 | corrector=corrector, 97 | init_chain_fn=None, 98 | sample_kwargs=sample_kwargs, 99 | device=device) 100 | x_mean, _ = sampler.sample(logging=False) 101 | 102 | if guided_p_uncond is not None: 103 | x_mean = torch.cat([x_mean[:,[0],...], x[:,[1],...]], dim=0) 104 | sample_grid = torchvision.utils.make_grid(x_mean, normalize=True, scale_each=True, nrow = x.shape[0]) 105 | writer.add_image('unconditional samples', sample_grid, global_step=epoch) 106 | else: 107 | sample_grid = torchvision.utils.make_grid(x_mean, normalize=True, scale_each=True) 108 | writer.add_image('unconditional samples', sample_grid, global_step=epoch) 109 | -------------------------------------------------------------------------------- /src/samplers/base_sampler.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Inspired to https://github.com/yang-song/score_sde_pytorch/blob/main/sampling.py 3 | ''' 4 | from typing import Optional, Any, Dict, Tuple 5 | 6 | import os 7 | import torchvision 8 | import numpy as np 9 | import torch 10 | import datetime 11 | 12 | from tqdm import tqdm 13 | from torch import Tensor 14 | from torch.utils.tensorboard import SummaryWriter 15 | 16 | from ..utils import SDE, PSNR, SSIM 17 | from ..third_party_models import OpenAiUNetModel 18 | 19 | class BaseSampler: 20 | def __init__(self, 21 | score: OpenAiUNetModel, 22 | sde: SDE, 23 | predictor: callable, 24 | sample_kwargs: Dict, 25 | init_chain_fn: Optional[callable] = None, 26 | corrector: Optional[callable] = None, 27 | device: Optional[Any] = None 28 | ) -> None: 29 | 30 | self.score = score 31 | self.sde = sde 32 | self.predictor = predictor 33 | self.init_chain_fn = init_chain_fn 34 | self.sample_kwargs = sample_kwargs 35 | self.corrector = corrector 36 | self.device = device 37 | 38 | def sample(self, 39 | logg_kwargs: Dict = {}, 40 | logging: bool = True 41 | ) -> Tensor: 42 | if logging: 43 | writer = SummaryWriter(log_dir=os.path.join(logg_kwargs['log_dir'], str(logg_kwargs['sample_num']))) 44 | 45 | time_steps = np.linspace(1., self.sample_kwargs['eps'], self.sample_kwargs['num_steps']) 46 | 47 | step_size = time_steps[0] - time_steps[1] 48 | if self.sample_kwargs['start_time_step'] == 0: 49 | t = torch.ones(self.sample_kwargs['batch_size'], device=self.device) 50 | 51 | init_x = self.sde.prior_sampling([self.sample_kwargs['batch_size'], *self.sample_kwargs['im_shape']]).to(self.device) 52 | 53 | else: 54 | init_x = self.init_chain_fn(time_steps=time_steps).to(self.device) 55 | 56 | if logging: 57 | writer.add_image('init_x', torchvision.utils.make_grid(init_x, 58 | normalize=True, scale_each=True), global_step=0) 59 | if logg_kwargs['ground_truth'] is not None: writer.add_image( 60 | 'ground_truth', torchvision.utils.make_grid(logg_kwargs['ground_truth'], 61 | normalize=True, scale_each=True), global_step=0) 62 | if logg_kwargs['osem'] is not None: writer.add_image( 63 | 'osem', torchvision.utils.make_grid(logg_kwargs['osem'], 64 | normalize=True, scale_each=True), global_step=0) 65 | 66 | x = init_x 67 | for i in tqdm(range(self.sample_kwargs['start_time_step'], self.sample_kwargs['num_steps'])): 68 | time_step = torch.ones(self.sample_kwargs['batch_size'], device=self.device) * time_steps[i] 69 | x, x_mean, norm_factors = self.predictor( 70 | score=self.score, 71 | sde=self.sde, 72 | x=x, 73 | time_step=time_step, 74 | step_size=step_size, 75 | datafitscale=i/self.sample_kwargs['num_steps'], 76 | **self.sample_kwargs['predictor'] 77 | ) 78 | 79 | if self.corrector is not None: 80 | x = self.corrector( 81 | x=x, 82 | score=self.score, 83 | sde=self.sde, 84 | time_step=time_step, 85 | datafitscale=i/self.sample_kwargs['num_steps'], 86 | **self.sample_kwargs['corrector'] 87 | ) 88 | 89 | if logging: 90 | if (i - self.sample_kwargs['start_time_step']) % logg_kwargs['num_img_in_log'] == 0: 91 | writer.add_image('reco', torchvision.utils.make_grid(x_mean, normalize=True, scale_each=True), i) 92 | writer.add_scalar('PSNR', PSNR(x_mean[0, 0].cpu().numpy()*norm_factors[0,0].cpu().numpy(), logg_kwargs['ground_truth'][0, 0].cpu().numpy()), i) 93 | writer.add_scalar('SSIM', SSIM(x_mean[0, 0].cpu().numpy()*norm_factors[0,0].cpu().numpy(), logg_kwargs['ground_truth'][0, 0].cpu().numpy()), i) 94 | if logging: 95 | return x_mean, writer 96 | else: 97 | return x_mean, None 98 | -------------------------------------------------------------------------------- /src/utils/sde.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Based on the variance exploding (VE) and variance presenrving (VP) SDE. 3 | The derivations are given in [https://arxiv.org/pdf/2011.13456.pdf] Appendix C. 4 | Based on: https://github.com/yang-song/score_sde_pytorch/blob/main/sde_lib.py 5 | 6 | Adapted from: https://github.com/educating-dip/score_based_model_baselines/blob/main/src/utils/sde.py 7 | ''' 8 | from typing import Any, Optional 9 | import torch 10 | import numpy as np 11 | import abc 12 | #import torch_dct as dct 13 | 14 | class SDE(abc.ABC): 15 | """ 16 | SDE abstract class. Functions are designed for a mini-batch of inputs. 17 | """ 18 | def __init__(self): 19 | """ 20 | Construct an SDE. 21 | """ 22 | super().__init__() 23 | 24 | def diffusion_coeff(self, t): 25 | """ 26 | Outputs f 27 | """ 28 | pass 29 | 30 | def sde(self, x, t): 31 | """ 32 | Outputs f and G 33 | """ 34 | pass 35 | 36 | def marginal_prob(self, x, t): 37 | """ 38 | Parameters to determine the marginal distribution of the SDE, $p_{0t}(x(t)|x(0))$. 39 | """ 40 | pass 41 | 42 | def marginal_prob_std(self, t): 43 | pass 44 | 45 | def marginal_prob_mean(self, t): 46 | """ 47 | Outputs the scaling factor of mean of p_{0t}(x(t)|x(0)) (for VE-SDE and VP-SDE the mean is a scaled x(0)) 48 | """ 49 | pass 50 | 51 | def prior_sampling(self, shape): 52 | """ 53 | Generate one sample from the prior distribution, $p_T(x)$. 54 | """ 55 | pass 56 | 57 | 58 | class VESDE(SDE): 59 | def __init__(self, sigma_min: float = 0.01, sigma_max: float = 50): 60 | """ 61 | Construct a Variance Exploding SDE. 62 | 63 | Args: 64 | sigma_min: smallest sigma. 65 | sigma_max: largest sigma. 66 | """ 67 | super().__init__() 68 | self.sigma_min = sigma_min 69 | self.sigma_max = sigma_max 70 | 71 | def diffusion_coeff(self, t): 72 | sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t 73 | diffusion = sigma * torch.sqrt( 74 | torch.tensor(2 * (np.log(self.sigma_max) - np.log(self.sigma_min)), device=t.device)) 75 | return diffusion 76 | 77 | def sde(self, x, t): 78 | 79 | drift = torch.zeros_like(x) 80 | diffusion = self.diffusion_coeff(t) 81 | return drift, diffusion 82 | 83 | def marginal_prob(self, x, t): 84 | 85 | """ 86 | mean and standard deviation of p_{0t}(x(t) | x(0)) 87 | """ 88 | std = self.marginal_prob_std(t) 89 | mean = x 90 | return mean, std 91 | 92 | def marginal_prob_std(self, t): 93 | """ 94 | standard deviation of p_{0t}(x(t) | x(0)) is used: 95 | - in the UNET as a scaling of the output 96 | """ 97 | std = self.sigma_min * (self.sigma_max / self.sigma_min) ** t 98 | return std 99 | 100 | def marginal_prob_mean(self, t): 101 | mean = torch.ones_like(t) 102 | return mean 103 | 104 | def prior_sampling(self, shape): 105 | return torch.randn(*shape) * self.sigma_max 106 | 107 | 108 | class VPSDE(SDE): 109 | def __init__(self, beta_min: float = 0.1, beta_max: float = 20): 110 | """ 111 | Construct a Variance Preserving SDE. 112 | 113 | Args: 114 | beta_min: value of beta(0) 115 | beta_max: value of beta(1) 116 | """ 117 | 118 | super().__init__() 119 | self.beta_min = beta_min 120 | self.beta_max = beta_max 121 | 122 | def diffusion_coeff(self, t): 123 | beta_t = self.beta_min + t*(self.beta_max - self.beta_min) 124 | return torch.sqrt(beta_t) 125 | 126 | def sde(self, x, t): 127 | beta_t = self.beta_min + t*(self.beta_max - self.beta_min) 128 | drift = -0.5 * beta_t[:, None, None, None] * x 129 | 130 | diffusion = self.diffusion_coeff(t) 131 | return drift, diffusion 132 | 133 | def marginal_prob(self, x, t): 134 | """ 135 | mean and standard deviation of p_{0t}(x(t) | x(0)) 136 | """ 137 | std = self.marginal_prob_std(t) 138 | log_mean_coeff = -0.25 * t ** 2 * (self.beta_max - self.beta_min) - 0.5 * t * self.beta_min 139 | mean = torch.exp(log_mean_coeff[:, None, None, None]) * x 140 | return mean, std 141 | 142 | def marginal_prob_std(self, t): 143 | """ 144 | standard deviation of p_{0t}(x(t) | x(0)) is used: 145 | - in the UNET as a scaling of the output 146 | """ 147 | log_mean_coeff = -0.25 * t ** 2 * (self.beta_max - self.beta_min) - 0.5 * t * self.beta_min 148 | std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff)) 149 | return std 150 | 151 | def marginal_prob_mean(self, t): 152 | log_mean_coeff = -0.25 * t ** 2 * (self.beta_max - self.beta_min) - 0.5 * t * self.beta_min 153 | mean = torch.exp(log_mean_coeff) 154 | 155 | return mean 156 | 157 | def prior_sampling(self, shape): 158 | return torch.randn(*shape) -------------------------------------------------------------------------------- /results/lpd_unet/lpd.py: -------------------------------------------------------------------------------- 1 | """ 2 | Taken from Dival: 3 | https://jleuschn.github.io/docs.dival/dival.reconstructors.learnedpd_reconstructor.html 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import numpy as np 9 | 10 | def get_lpd_model(n_iter, op, op_adj): 11 | return PrimalDualNet(n_iter = n_iter, op = op, op_adj = op_adj, op_init=None, 12 | n_primal=5, n_dual=5, use_sigmoid=False, n_layer=4, 13 | internal_ch=32, kernel_size=3, 14 | batch_norm=True, prelu=False, lrelu_coeff=0.2) 15 | 16 | 17 | # UNetself in_ch, out_ch, channels, skip_channels, use_sigmoid=True, use_norm=True 18 | 19 | class IterativeBlock(nn.Module): 20 | def __init__(self, n_in=3, n_out=1, n_memory=5, n_layer=3, internal_ch=32, 21 | kernel_size=3, batch_norm=True, prelu=False, lrelu_coeff=0.2): 22 | super(IterativeBlock, self).__init__() 23 | assert kernel_size % 2 == 1 24 | padding = (kernel_size - 1) // 2 25 | modules = [] 26 | if batch_norm: 27 | modules.append(nn.BatchNorm2d(n_in + n_memory)) 28 | for i in range(n_layer-1): 29 | input_ch = (n_in + n_memory) if i == 0 else internal_ch 30 | modules.append(nn.Conv2d(input_ch, internal_ch, 31 | kernel_size=kernel_size, padding=padding)) 32 | if batch_norm: 33 | modules.append(nn.BatchNorm2d(internal_ch)) 34 | if prelu: 35 | modules.append(nn.PReLU(internal_ch, init=0.0)) 36 | else: 37 | modules.append(nn.LeakyReLU(lrelu_coeff, inplace=True)) 38 | modules.append(nn.Conv2d(internal_ch, n_out + n_memory, 39 | kernel_size=kernel_size, padding=padding)) 40 | self.block = nn.Sequential(*modules) 41 | self.relu = nn.LeakyReLU(lrelu_coeff, inplace=True) # remove? 42 | 43 | def forward(self, x): 44 | upd = self.block(x) 45 | return upd 46 | 47 | class PrimalDualNet(nn.Module): 48 | def __init__(self, n_iter, op, op_adj, op_init=None, n_primal=5, n_dual=5, 49 | use_sigmoid=False, n_layer=4, internal_ch=32, kernel_size=3, 50 | batch_norm=True, prelu=False, lrelu_coeff=0.2): 51 | super(PrimalDualNet, self).__init__() 52 | self.n_iter = n_iter 53 | self.op = op 54 | self.op_adj = op_adj 55 | self.op_init = op_init 56 | self.n_primal = n_primal 57 | self.n_dual = n_dual 58 | self.use_sigmoid = use_sigmoid 59 | 60 | self.primal_blocks = nn.ModuleList() 61 | self.dual_blocks = nn.ModuleList() 62 | for it in range(n_iter): 63 | self.dual_blocks.append(IterativeBlock( 64 | n_in=3, n_out=1, n_memory=self.n_dual-1, n_layer=n_layer, 65 | internal_ch=internal_ch, kernel_size=kernel_size, 66 | batch_norm=batch_norm, prelu=prelu, lrelu_coeff=lrelu_coeff)) 67 | self.primal_blocks.append(IterativeBlock( 68 | n_in=2, n_out=1, n_memory=self.n_primal-1, n_layer=n_layer, 69 | internal_ch=internal_ch, kernel_size=kernel_size, 70 | batch_norm=batch_norm, prelu=prelu, lrelu_coeff=lrelu_coeff)) 71 | 72 | def forward(self, osem, y, projector, attn_factors, norm, contamination_factor): 73 | primal_cur = osem.repeat(1, self.n_primal, 1, 1)/norm[...,None,None,None] 74 | dual_cur = torch.ones(y.shape[0], self.n_dual, 75 | *projector._coincidence_descriptor.sinogram_spatial_shape[:-1], 76 | device=y.device) 77 | y = (y/norm[...,None,None]).view(*dual_cur[:,[0],:,:].shape) 78 | for i in range(self.n_iter): 79 | # A (x_unorm) + b 80 | primal_evalop = self.op.apply(primal_cur[:, 1:2, ...]*norm[:,None,None,None], projector, attn_factors) + contamination_factor[...,None] 81 | # y_norm 82 | primal_evalop = (primal_evalop/norm[...,None,None]).view(*dual_cur[:,[0],:,:].shape) 83 | dual_update = torch.cat([dual_cur, primal_evalop, y], dim=1) 84 | dual_update = self.dual_blocks[i](dual_update) 85 | dual_cur = dual_cur + dual_update 86 | # NB: currently only linear op supported 87 | # for non-linear op: [d/dx self.op(primal_cur[0:1, ...])]* 88 | # A* (y_unorm - b) 89 | dual_evalop = self.op_adj.apply(dual_cur[:, 0:1, ...].view(y.shape[0], 1, -1)*norm[...,None,None] - contamination_factor[...,None], projector, attn_factors) 90 | dual_evalop = dual_evalop/norm[...,None,None,None] 91 | primal_update = torch.cat([primal_cur, dual_evalop], dim=1) 92 | primal_update = self.primal_blocks[i](primal_update) 93 | primal_cur = primal_cur + primal_update 94 | 95 | x = primal_cur[:, 0:1, ...] 96 | if self.use_sigmoid: 97 | x = torch.sigmoid(x) 98 | return x*norm[...,None,None,None] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Score-Based Generative Models for PET Image Reconstruction 2 | [![arXiv](https://img.shields.io/badge/arXiv-2308.14190-b31b1b.svg)](https://arxiv.org/abs/2308.14190) 3 | 4 | Official code for [Score-Based Generative Models (SGM) for PET Image Reconstruction](https://arxiv.org/abs/2308.14190) (MELBA, accepted) by [Imraj RD Singh](https://www.imraj.dev/), [Alexander Denker](http://www.math.uni-bremen.de/zetem/cms/detail.php?template=parse_title&person=AlexanderDenker), [Riccardo Barbano](https://scholar.google.com/citations?user=6jYGiC0AAAAJ), [Željko Kereta](http://www0.cs.ucl.ac.uk/people/Z.Kereta.html), [Bangti Jin](https://www.math.cuhk.edu.hk/people/academic-staff/btjin), [Kris Thielemans](https://iris.ucl.ac.uk/iris/browse/profile?upi=KTHIE60), [Peter Maass](https://user.math.uni-bremen.de/pmaass/), [Simon Arridge](https://iris.ucl.ac.uk/iris/browse/profile?upi=SRARR14). 5 | 6 | I. Singh, A. Denker and R. Barbano have equal contribution. 7 | 8 | In this work we address PET-specific challenges such as; non-negativity of measurements/images, varying dynamic range of underlying radio-tracer distributions, and low-count Poisson noise on measurements requiring a Poisson Log-Likelihood (PLL). Further, we develop methods for 3D reconstruction, propose a guided variant with a Magnetic Resonance (MR) image, and accelerate the method using subsets. 9 | 10 | Our modifications can be summarised with the following diagram: 11 | ![Alt text](/modifications.png) 12 | Where the sections pertain to those in the [paper](tbd). The most appropriate reconstruction proposed, PET-variant of Decomposed Diffusion Sampling (PET-DDS; where DDS is proposed for MRI and CT [here](https://doi.org/10.48550/arXiv.2303.05754)), was extended to 3D and the reconstruction steps are illustrated below: 13 | 14 | ![Alt text](/diagram.png) 15 | 16 | ## Use of open-source repositories 17 | 18 | The work presented develops and adopts code from various repositories, where specific contributions are indicated at the top of sources. The most important repositories include: 19 | * [SGM sampling methods for inverse problems](https://github.com/educating-dip/score_based_model_baselines) 20 | * [pyParallelProj for 2D experiments data generation](https://github.com/gschramm/pyparallelproj) 21 | * [SIRF-exercises for 3D experiments data generation](https://github.com/SyneRBI/SIRF-Exercises) 22 | * [Normalised supervised PET baselines](https://github.com/Imraj-Singh/pet_supervised_normalisation) 23 | * [DIVal for supervised deep learning architectures and training scripts](https://github.com/jleuschn/dival) 24 | * [Guided diffusion repository for the diffusion model architecture](https://github.com/openai/guided-diffusion) 25 | * [Deep image prior comparison](https://github.com/educating-dip/pet_deep_image_prior) 26 | 27 | We thank the authors of the aforementioned repositories for their open-source development and contributions. 28 | 29 | ## Datasets and Reproducibility 30 | 31 | The results of this work are *in-silico* simulations of the [BrainWeb dataset](https://brainweb.bic.mni.mcgill.ca/), and all datasets are freely available for download/generation. For 2D work, and training the score-model, we use the dataset available [here](https://zenodo.org/records/10509379), which can be downloaded through [pyParalellProj](https://github.com/gschramm/pyparallelproj). For 3D work we use the dataset available here [here](https://github.com/casperdcl/brainweb). 32 | 33 | Files for the generation of 2D data can be found in [src/brainweb_2d/](src/brainweb_2d/). For 3D data generation we provide a juypter notebook [src/sirf/brainweb_3D.ipynb](src/sirf/brainweb_3D.ipynb). 34 | 35 | Training of the score-model requires running script [main_score_based_models_train.py](main_score_based_models_train.py). All experiments with reconstruction techniques can be found in [coordinators/](coordinators/), and all results can be processed with files in [results/](results/). 36 | 37 | For reproducibility we provide a devcontainer utilising docker to containerise the development environment required for this work. The files are located in [.devcontainer/](.devcontainer/), these files use scripts to setup up conda environments where the environment is defined with files in [scripts/](scripts/), we provided full list of static dependencies in [req.txt](scripts/req.txt). Please note that this project requires [SIRF](https://github.com/SyneRBI/SIRF) for 3D work. 38 | 39 | ## Citation 40 | Arxiv bibtex: 41 | ``` 42 | @article{melba:2024:001:singh, 43 | title = "Score-Based Generative Models for PET Image Reconstruction", 44 | author = "Singh, Imraj RD and Denker, Alexander and Barbano, Riccardo and Kereta, Željko and Jin, Bangti and Thielemans, Kris and Maass, Peter and Arridge, Simon", 45 | journal = "Machine Learning for Biomedical Imaging", 46 | volume = "2", 47 | issue = "Special Issue for Generative Models", 48 | year = "2024", 49 | pages = "547--585", 50 | issn = "2766-905X", 51 | doi = "https://doi.org/10.59275/j.melba.2024-5d51", 52 | url = "https://melba-journal.org/2024:001" 53 | } 54 | ``` 55 | 56 | -------------------------------------------------------------------------------- /coordinators/3D_reconstruction.py: -------------------------------------------------------------------------------- 1 | import hydra, torch, yaml, sys, os, functools, time 2 | import numpy as np 3 | sys.path.append(os.path.dirname(os.getcwd())) 4 | from src import (get_standard_sampler, get_standard_score, get_standard_sde, SIRF3DProjection, SIRF3DDataset) 5 | from omegaconf import DictConfig, OmegaConf 6 | import sirf.STIR as pet 7 | pet.set_verbosity(0) 8 | pet.AcquisitionData.set_storage_scheme("memory") 9 | pet.MessageRedirector(info=None, warn=None, errr=None) 10 | import gc 11 | # python 3D_reconstruction.py --multirun dataset=thorax3Dmedium sampling.lambd=0.1 12 | 13 | @hydra.main(config_path='../configs', config_name='3D_reconstruction') 14 | def reconstruction(config : DictConfig) -> None: 15 | print(OmegaConf.to_yaml(config)) 16 | timestr = time.strftime("%Y%m%d_%H%M%S_") 17 | dump_name = config.dump_path + "/"+ timestr + ".tmp" 18 | with open(dump_name, "xt") as f: 19 | f.write(os.getcwd()) 20 | f.close() 21 | ###### SET SEED ###### 22 | if config.seed is not None: 23 | torch.manual_seed(config.seed) 24 | np.random.seed(config.seed) 25 | 26 | ###### GET SCORE MODEL ###### 27 | # open the yaml config file 28 | with open(os.path.join(config.score_based_model.path, "report.yaml"), "r") as stream: 29 | ml_collection = yaml.load(stream, Loader=yaml.UnsafeLoader) 30 | guided = False if ml_collection.guided_p_uncond is None else True 31 | if guided: raise NotImplementedError 32 | # get the sde 33 | sde = get_standard_sde(ml_collection) 34 | # get the score model 35 | score_model = get_standard_score(ml_collection, sde, 36 | use_ema = config.score_based_model.ema, 37 | load_path = config.score_based_model.path) 38 | score_model.eval() 39 | score_model.to(config.device) 40 | 41 | ###### SOLVING REVERSE SDE ###### 42 | num_subsets = config.sampling.num_subsets 43 | num_iterations = config.sampling.num_iterations 44 | count_level = config.dataset.count_level 45 | 46 | # GET THE DATA 47 | # measurements_subsets_sirf, acquisition_models_sirf, sensitivity_image, fov, image_sirf, osem, measurements_subsets 48 | 49 | 50 | dataset = SIRF3DDataset(config.dataset.base_path, 51 | config.dataset.name, 52 | config.dataset.tracer, 53 | count_level, 54 | config.dataset.realisation, 55 | num_subsets) 56 | 57 | image = dataset.osem 58 | config.sampling.batch_size = len(image) 59 | # estimate scaling factors from osem 60 | scale_factors = [] 61 | for i in range(config.sampling.batch_size): 62 | emission_volume = torch.where(image[i] > 0.01*image[i].max(), 1, 0).sum() 63 | scale_factor = image[i].sum()/emission_volume 64 | # Less than 100 voxels in emmision volume then set scale factor to 0 65 | if emission_volume < 100: 66 | scale_factor = 0 67 | scale_factors.append(scale_factor) 68 | scale_factors = torch.tensor(scale_factors).to(config.device) 69 | # Remove outliers from scale factors 70 | scale_factors[scale_factors < scale_factors.mean()*0.05] = 0 71 | scale_factors = scale_factors.unsqueeze(1).unsqueeze(2).unsqueeze(3) 72 | 73 | 74 | if config.sampling.name == "dds_3D": 75 | projection = SIRF3DProjection(image_sirf = dataset.image_sirf.clone(), 76 | objectives_sirf = dataset.objectives_sirf, 77 | sensitivity_image = dataset.sensitivity_image.clone(), 78 | fov = dataset.fov.clone(), 79 | num_subsets = num_subsets, 80 | num_iterations = num_iterations) 81 | if config.sampling.lambd != None and config.sampling.beta == None: 82 | print(f"Using DDS with anchor with lambda {config.sampling.lambd}") 83 | projection.set_lambd(config.sampling.lambd) 84 | nll_partial = functools.partial(projection.get_anchor, 85 | scale_factor=scale_factors) 86 | elif config.sampling.lambd != None and config.sampling.beta != None: 87 | print(f"Using DDS with anchor and rdpz with lambda {config.sampling.lambd} and beta {config.sampling.beta}") 88 | projection.set_lambd(config.sampling.lambd) 89 | projection.set_beta(config.sampling.beta) 90 | nll_partial = functools.partial(projection.get_anchor_rdpz, 91 | scale_factor=scale_factors) 92 | else: raise NotImplementedError("3D only DDS with anchor or anchor+rdpz") 93 | else: raise NotImplementedError("3D only tested with DDS") 94 | 95 | 96 | logg_kwargs = {'log_dir': "./tb", 'num_img_in_log': None, 97 | 'sample_num': None, 'ground_truth': None, 'osem': None} 98 | sampler = get_standard_sampler( 99 | config=config, 100 | score=score_model, 101 | sde=sde, 102 | nll=nll_partial, 103 | im_shape=(1,128,128), 104 | guidance_imgs= None, 105 | device=config.device) 106 | t0 = time.time() 107 | recon, _ = sampler.sample(logg_kwargs=logg_kwargs, logging=False) 108 | t1 = time.time() 109 | recon = torch.clamp(recon, min=0) 110 | recon = recon*scale_factors.to(config.device) 111 | torch.save(recon.detach().cpu().squeeze().swapaxes(1,2).flip(1,2), "volume.pt") 112 | recon_dict = {"pll_values": projection.pll_values, 113 | "objective_values": projection.objective_values, 114 | "pll_values_last": projection.pll_values_last, 115 | "objective_values_last": projection.objective_values_last, 116 | "time": t1-t0} 117 | torch.save(recon_dict, "recon_dict.pt") 118 | os.remove(dump_name) 119 | gc.collect() 120 | if __name__ == '__main__': 121 | reconstruction() -------------------------------------------------------------------------------- /src/brainweb_2d/get_noisy_train_brainweb_2D.py: -------------------------------------------------------------------------------- 1 | import pyparallelproj.coincidences as coincidences 2 | import pyparallelproj.subsets as subsets 3 | import pyparallelproj.petprojectors as petprojectors 4 | import pyparallelproj.resolution_models as resolution_models 5 | import pyparallelproj.algorithms as algorithms 6 | import cupy as xp 7 | import cupyx.scipy.ndimage as ndi 8 | from brainweb import BrainWebClean 9 | import torch, os 10 | from tqdm import tqdm 11 | 12 | # Adapted from https://github.com/gschramm/pyparallelproj/blob/main/examples/00_projections_and_reconstruction/02_osem.py 13 | 14 | if __name__ == "__main__": 15 | coincidence_descriptor = coincidences.GEDiscoveryMICoincidenceDescriptor( 16 | num_rings=1, 17 | sinogram_spatial_axis_order=coincidences. 18 | SinogramSpatialAxisOrder['RVP'], 19 | xp=xp) 20 | 21 | mu_projector = petprojectors.PETJosephProjector(coincidence_descriptor, 22 | (128, 128, 1), (-127, -127, 0), 23 | (2, 2, 2)) 24 | projector = petprojectors.PETJosephProjector(coincidence_descriptor, 25 | (128, 128, 1), (-127, -127, 0), 26 | (2, 2, 2)) 27 | res_model = resolution_models.GaussianImageBasedResolutionModel( 28 | (128, 128, 1), tuple(4.5 / (2.35 * x) for x in (2, 2, 2)), xp, ndi) 29 | 30 | projector.image_based_resolution_model = res_model 31 | subsetter = subsets.SingoramViewSubsetter(coincidence_descriptor, 34) 32 | projector.subsetter = subsetter 33 | xp.random.seed(42) 34 | tumour = True 35 | trues_per_volumes = [5, 10, 50] 36 | if tumour: 37 | dataset = BrainWebClean(path_to_files="path_to/clean/clean_train_tumour.pt") 38 | else: 39 | dataset = BrainWebClean(path_to_files="path_to/clean/clean_train.pt") 40 | 41 | for trues_per_volume in trues_per_volumes: 42 | osem_pts = [] 43 | scaling_factor_pts = [] 44 | noisy_data_pts = [] 45 | contamination_pts = [] 46 | attenuation_pts = [] 47 | print(f"Trues per volume {trues_per_volume}") 48 | for idx in tqdm(range(len(dataset))): 49 | y, mu, gt = dataset[idx] 50 | gt = xp.from_dlpack(gt.cuda().squeeze().unsqueeze(-1)) 51 | mu = xp.from_dlpack(mu.cuda().squeeze().unsqueeze(-1)) 52 | y = xp.from_dlpack(y.cuda().squeeze()) 53 | # simulate the attenuation factors (exp(-fwd(attenuation_image))) 54 | attenuation_factors = xp.exp(-mu_projector.forward(mu)) 55 | projector.multiplicative_corrections = attenuation_factors * 1. / 30 56 | 57 | # scale the image such that we get a certain true count per emission voxel value 58 | emission_volume = xp.where(gt > 0)[0].shape[0] * 8 59 | current_trues_per_volume = float(y.sum() / emission_volume) 60 | 61 | scaling_factor = (trues_per_volume / current_trues_per_volume) 62 | 63 | image_fwd_scaled = y*scaling_factor 64 | 65 | # simulate a constant background contamination 66 | contamination_scale = image_fwd_scaled.mean() 67 | contamination = xp.full(projector.output_shape, 68 | contamination_scale, 69 | dtype=xp.float32) 70 | 71 | # generate noisy data 72 | data = xp.random.poisson(image_fwd_scaled + contamination).astype(xp.uint16).astype(xp.float32) 73 | 74 | 75 | reconstructor = algorithms.OSEM(data, contamination, projector, verbose=False) 76 | reconstructor.run(1, evaluate_cost=False) 77 | 78 | osem_x = reconstructor.x 79 | 80 | osem_pts.append(torch.from_dlpack(osem_x[:,:,0])[None][None].float().cpu()) 81 | scaling_factor_pts.append(torch.tensor(scaling_factor)[None][None].float().cpu()) 82 | noisy_data_pts.append(torch.from_dlpack(data)[None][None].float().cpu()) 83 | contamination_pts.append(torch.tensor(contamination_scale)[None][None].float().cpu()) 84 | attenuation_pts.append(torch.from_dlpack(attenuation_factors)[None][None].float().cpu()) 85 | 86 | osem_reconstruction = torch.cat(osem_pts) 87 | scaling_factor = torch.cat(scaling_factor_pts) 88 | noisy_data = torch.cat(noisy_data_pts) 89 | contamination_scales = torch.cat(contamination_pts) 90 | attenuation_factors = torch.cat(attenuation_pts) 91 | 92 | save_dict = {'osem': osem_reconstruction, 93 | 'scale_factor': scaling_factor, 94 | 'measurements': noisy_data, 95 | 'contamination_factor': contamination_scales, 96 | 'attn_factors': attenuation_factors} 97 | 98 | if tumour: 99 | torch.save(save_dict, f"path_to/noisy/noisy_train_tumour_{trues_per_volume}.pt") 100 | else: 101 | torch.save(save_dict, f"path_to/noisy/noisy_train_{trues_per_volume}.pt") 102 | del osem_pts, scaling_factor_pts, noisy_data_pts, contamination_pts, attenuation_pts 103 | del osem_reconstruction, scaling_factor, noisy_data, contamination_scales, attenuation_factors 104 | del save_dict 105 | -------------------------------------------------------------------------------- /src/brainweb_2d/get_OOD_noisy_brainweb_2D.py: -------------------------------------------------------------------------------- 1 | import pyparallelproj.coincidences as coincidences 2 | import pyparallelproj.subsets as subsets 3 | import pyparallelproj.petprojectors as petprojectors 4 | import pyparallelproj.resolution_models as resolution_models 5 | import pyparallelproj.algorithms as algorithms 6 | import os 7 | import cupy as xp 8 | import cupyx.scipy.ndimage as ndi 9 | from brainweb import BrainWebClean 10 | import torch 11 | from tqdm import tqdm 12 | 13 | # Adapted from https://github.com/gschramm/pyparallelproj/blob/main/examples/00_projections_and_reconstruction/02_osem.py 14 | 15 | if __name__ == "__main__": 16 | coincidence_descriptor = coincidences.GEDiscoveryMICoincidenceDescriptor( 17 | num_rings=1, 18 | sinogram_spatial_axis_order=coincidences. 19 | SinogramSpatialAxisOrder['RVP'], 20 | xp=xp) 21 | 22 | mu_projector = petprojectors.PETJosephProjector(coincidence_descriptor, 23 | (128, 128, 1), (-127, -127, 0), 24 | (2, 2,2 )) 25 | projector = petprojectors.PETJosephProjector(coincidence_descriptor, 26 | (128, 128, 1), (-127, -127, 0), 27 | (2, 2, 2)) 28 | res_model = resolution_models.GaussianImageBasedResolutionModel( 29 | (128, 128, 1), tuple(4.5 / (2.35 * x) for x in (2, 2, 2)), xp, ndi) 30 | 31 | projector.image_based_resolution_model = res_model 32 | datasets = ["test","test_tumour"] 33 | xp.random.seed(0) 34 | file_directory = "path_to/" 35 | for name in datasets: 36 | 37 | path_to_files= name+"_clean.pt" 38 | dataset = BrainWebClean(path_to_files=path_to_files) 39 | trues_per_volumes = [5, 7.5, 100] 40 | for trues_per_volume in trues_per_volumes: 41 | osem_pts = [] 42 | scaling_factor_pts = [] 43 | noisy_data_pts = [] 44 | contamination_pts = [] 45 | attenuation_pts = [] 46 | 47 | print(f"Dataset {name} trues per volume {trues_per_volume}") 48 | for idx in tqdm(range(len(dataset))): 49 | y, mu, gt = dataset[idx] 50 | gt = xp.asarray(gt.squeeze().unsqueeze(-1).numpy()) 51 | mu = xp.asarray(mu.squeeze().unsqueeze(-1).numpy()) 52 | y = xp.asarray(y.numpy().squeeze()) 53 | # simulate the attenuation factors (exp(-fwd(attenuation_image))) 54 | attenuation_factors = xp.exp(-mu_projector.forward(mu)) 55 | projector.multiplicative_corrections = attenuation_factors * 1. / 30 56 | 57 | # scale the image such that we get a certain true count per emission voxel value 58 | emission_volume = xp.where(gt > 0)[0].shape[0] * 8 59 | current_trues_per_volume = float(y.sum() / emission_volume) 60 | 61 | scaling_factor = (trues_per_volume / current_trues_per_volume) 62 | 63 | image_fwd_scaled = y*scaling_factor 64 | 65 | # simulate a constant background contamination 66 | contamination_scale = image_fwd_scaled.mean() 67 | contamination = xp.full(projector.output_shape, 68 | contamination_scale, 69 | dtype=xp.float32) 70 | 71 | # generate noisy data 72 | data = xp.random.poisson(image_fwd_scaled + contamination).astype(xp.uint16).astype(xp.float32) 73 | 74 | subsetter = subsets.SingoramViewSubsetter(coincidence_descriptor, 34) 75 | projector.subsetter = subsetter 76 | 77 | reconstructor = algorithms.OSEM(data, contamination, projector, verbose=False) 78 | reconstructor.run(1, evaluate_cost=False) 79 | 80 | osem_x = reconstructor.x 81 | osem_pts.append(torch.from_dlpack(osem_x[:,:,0])[None,None].float().cuda()) 82 | scaling_factor_pts.append(torch.tensor(scaling_factor)[None,None].float().cuda()) 83 | noisy_data_pts.append(torch.from_dlpack(data)[None,None].float().cuda()) 84 | contamination_pts.append(torch.tensor(contamination_scale)[None,None].float().cuda()) 85 | attenuation_pts.append(torch.from_dlpack(attenuation_factors)[None,None].float().cuda()) 86 | 87 | 88 | osem_reconstruction = torch.cat(osem_pts) 89 | scaling_factor = torch.cat(scaling_factor_pts) 90 | noisy_data = torch.cat(noisy_data_pts) 91 | contamination_scales = torch.cat(contamination_pts) 92 | attenuation_factors = torch.cat(attenuation_pts) 93 | 94 | save_dict = {'osem': osem_reconstruction, 95 | 'scale_factor': scaling_factor, 96 | 'measurements': noisy_data, 97 | 'contamination_factor': contamination_scales, 98 | 'attn_factors': attenuation_factors} 99 | torch.save(save_dict, file_directory + f"/noisy/{name}_noisy_{trues_per_volume}.pt") 100 | del osem_pts, scaling_factor_pts, noisy_data_pts, contamination_pts, attenuation_pts 101 | del osem_reconstruction, scaling_factor, noisy_data, contamination_scales, attenuation_factors 102 | del save_dict 103 | -------------------------------------------------------------------------------- /src/brainweb_2d/get_validation_subset.py: -------------------------------------------------------------------------------- 1 | import pyparallelproj.coincidences as coincidences 2 | import pyparallelproj.subsets as subsets 3 | import pyparallelproj.petprojectors as petprojectors 4 | import pyparallelproj.resolution_models as resolution_models 5 | import pyparallelproj.algorithms as algorithms 6 | import cupy as xp 7 | import cupyx.scipy.ndimage as ndi 8 | from brainweb import BrainWebClean 9 | import torch, os 10 | from tqdm import tqdm 11 | 12 | # Adapted from https://github.com/gschramm/pyparallelproj/blob/main/examples/00_projections_and_reconstruction/02_osem.py 13 | 14 | if __name__ == "__main__": 15 | coincidence_descriptor = coincidences.GEDiscoveryMICoincidenceDescriptor( 16 | num_rings=1, 17 | sinogram_spatial_axis_order=coincidences. 18 | SinogramSpatialAxisOrder['RVP'], 19 | xp=xp) 20 | 21 | mu_projector = petprojectors.PETJosephProjector(coincidence_descriptor, 22 | (128, 128, 1), (-127, -127, 0), 23 | (2, 2, 2)) 24 | projector = petprojectors.PETJosephProjector(coincidence_descriptor, 25 | (128, 128, 1), (-127, -127, 0), 26 | (2, 2, 2)) 27 | res_model = resolution_models.GaussianImageBasedResolutionModel( 28 | (128, 128, 1), tuple(4.5 / (2.35 * x) for x in (2, 2, 2)), xp, ndi) 29 | 30 | projector.image_based_resolution_model = res_model 31 | subsetter = subsets.SingoramViewSubsetter(coincidence_descriptor, 34) 32 | projector.subsetter = subsetter 33 | xp.random.seed(42) 34 | dataset = BrainWebClean(path_to_files="path_to/clean/clean_train.pt", mri=True) 35 | n_validation = 8 36 | validation_indices = torch.randperm(len(dataset))[:n_validation] 37 | dataset = torch.utils.data.Subset(dataset, validation_indices) 38 | trues_per_volumes = [2.5, 5, 7.5, 10, 50, 100] 39 | 40 | for trues_per_volume in trues_per_volumes: 41 | osem_pts = [] 42 | scaling_factor_pts = [] 43 | noisy_data_pts = [] 44 | contamination_pts = [] 45 | attenuation_pts = [] 46 | print(f"Trues per volume {trues_per_volume}") 47 | for idx in tqdm(range(len(dataset))): 48 | y, mu, gt, _ = dataset[idx] 49 | gt = xp.from_dlpack(gt.cuda().squeeze().unsqueeze(-1)) 50 | mu = xp.from_dlpack(mu.cuda().squeeze().unsqueeze(-1)) 51 | y = xp.from_dlpack(y.cuda().squeeze()) 52 | # simulate the attenuation factors (exp(-fwd(attenuation_image))) 53 | attenuation_factors = xp.exp(-mu_projector.forward(mu)) 54 | projector.multiplicative_corrections = attenuation_factors * 1. / 30 55 | 56 | # scale the image such that we get a certain true count per emission voxel value 57 | emission_volume = xp.where(gt > 0)[0].shape[0] * 8 58 | current_trues_per_volume = float(y.sum() / emission_volume) 59 | 60 | scaling_factor = (trues_per_volume / current_trues_per_volume) 61 | 62 | image_fwd_scaled = y*scaling_factor 63 | 64 | # simulate a constant background contamination 65 | contamination_scale = image_fwd_scaled.mean() 66 | contamination = xp.full(projector.output_shape, 67 | contamination_scale, 68 | dtype=xp.float32) 69 | 70 | # generate noisy data 71 | data = xp.random.poisson(image_fwd_scaled + contamination).astype(xp.uint16).astype(xp.float32) 72 | 73 | 74 | reconstructor = algorithms.OSEM(data, contamination, projector, verbose=False) 75 | reconstructor.run(1, evaluate_cost=False) 76 | 77 | osem_x = reconstructor.x 78 | 79 | osem_pts.append(torch.from_dlpack(osem_x[:,:,0])[None][None].float().cpu()) 80 | scaling_factor_pts.append(torch.tensor(scaling_factor)[None][None].float().cpu()) 81 | noisy_data_pts.append(torch.from_dlpack(data)[None][None].float().cpu()) 82 | contamination_pts.append(torch.tensor(contamination_scale)[None][None].float().cpu()) 83 | attenuation_pts.append(torch.from_dlpack(attenuation_factors)[None][None].float().cpu()) 84 | 85 | osem_reconstruction = torch.cat(osem_pts) 86 | scaling_factor = torch.cat(scaling_factor_pts) 87 | noisy_data = torch.cat(noisy_data_pts) 88 | contamination_scales = torch.cat(contamination_pts) 89 | attenuation_factors = torch.cat(attenuation_pts) 90 | 91 | save_dict = {'osem': osem_reconstruction, 92 | 'scale_factor': scaling_factor, 93 | 'measurements': noisy_data, 94 | 'contamination_factor': contamination_scales, 95 | 'attn_factors': attenuation_factors} 96 | torch.save(save_dict, f"path_to/noisy/noisy_validation_{trues_per_volume}.pt") 97 | del osem_pts, scaling_factor_pts, noisy_data_pts, contamination_pts, attenuation_pts 98 | del osem_reconstruction, scaling_factor, noisy_data, contamination_scales, attenuation_factors 99 | del save_dict 100 | 101 | ys = [] 102 | mus = [] 103 | gts = [] 104 | mris = [] 105 | for i in range(n_validation): 106 | y, mu, gt, mri = dataset[i] 107 | ys.append(y) 108 | mus.append(mu) 109 | gts.append(gt) 110 | mris.append(mri) 111 | ys = torch.stack(ys) 112 | mus = torch.stack(mus) 113 | gts = torch.stack(gts) 114 | mris = torch.stack(mris) 115 | print(gts.shape) 116 | print(validation_indices) 117 | save_dict = {'clean_measurements': ys, 'mu': mus, 'reference': gts, 'mri': mris} 118 | import matplotlib.pyplot as plt 119 | fig, ax = plt.subplots(1, 8) 120 | for i in range(n_validation): 121 | ax[i].imshow(gts[i].squeeze()) 122 | plt.show() 123 | torch.save(save_dict, f"path_to/clean/clean_validation.pt") 124 | -------------------------------------------------------------------------------- /src/utils/nll.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import pyparallelproj.algorithms as algorithms 4 | import pyparallelproj.subsets as subsets 5 | 6 | import cupy as xp 7 | import time 8 | from ..sirf.herman_meyer import herman_meyer_order 9 | 10 | from ..brainweb_2d import LPDForwardFunction2D, LPDAdjointFunction2D 11 | 12 | 13 | def kl_div(x, acq_model, attn_factors, contamination, measurements, scale_factor): 14 | fwd_proj = LPDForwardFunction2D.apply(torch.clamp(x, 0) * scale_factor, acq_model, attn_factors/30.) 15 | fwd_proj = fwd_proj + contamination 16 | kl = torch.sum(measurements * torch.log(measurements/fwd_proj + 1e-6) - measurements + fwd_proj, axis=[-1]) 17 | return kl, scale_factor 18 | 19 | def poisson_nll(x, acq_model, attn_factors, contamination, measurements, scale_factor, loss): 20 | """ print(scale_factor.squeeze().tolist()) 21 | x = torch.clamp(x, 0) 22 | sens_img = LPDAdjointFunction2D.apply(torch.ones_like(measurements), acq_model, attn_factors).detach() 23 | scale_factor = (measurements-torch.ones_like(measurements)*contamination).sum(dim=(1,2))/(x*sens_img).sum(dim=(1,2,3)) 24 | scale_factor = scale_factor.detach()[:,None,None,None] 25 | scale_factor = torch.clamp(scale_factor, 1e-6, 1000) 26 | print(scale_factor.squeeze().tolist()) """ 27 | fwd_proj = LPDForwardFunction2D.apply(torch.clamp(x, 0) * scale_factor, acq_model, attn_factors/30.) 28 | fwd_proj = fwd_proj + contamination[:,None] 29 | #grad = sens_img - LPDAdjointFunction2D.apply((measurements) / (fwd_proj + 1e-9), acq_model, attn_factors/30.) 30 | loss_vals = loss(fwd_proj, measurements) 31 | #grad = torch.zeros_like(x) 32 | return loss_vals, scale_factor 33 | 34 | def osem_nll(x, scale_factor, osem): 35 | loss_vals = (x * scale_factor - osem)**2 36 | return torch.sum(loss_vals, axis=[-1, -2]), scale_factor 37 | 38 | # TODO: Instead of OSEM, do gradient descent on min datafit(Ax,y) + beta*||x-x0hat||^2 39 | def get_osem(x, acq_model, attn_factors, contamination, measurements, scale_factor, num_subsets, num_epochs): 40 | # SET THE SUBSETTER 41 | subsetter = subsets.SingoramViewSubsetter(acq_model._coincidence_descriptor, num_subsets) 42 | acq_model.subsetter = subsetter 43 | x_mean = [] 44 | for sample in range(x.shape[0]): 45 | # UPDATE THE MULTIPLICATIVE CORRECTIONS 46 | a_f = xp.asarray(attn_factors[sample,0,:]) 47 | acq_model.multiplicative_corrections = a_f/30. 48 | c = xp.asarray(contamination[sample,0]) 49 | m = xp.asarray(measurements[sample,0,:]) 50 | reconstructor = algorithms.OSEM(data = m, 51 | contamination = c*xp.ones_like(m), 52 | data_operator = acq_model, 53 | verbose=False) 54 | reconstructor.setup(xp.asarray(x[sample,0,:,:].unsqueeze(-1)*scale_factor[sample, ...])) 55 | reconstructor.run(num_epochs, evaluate_cost=False) 56 | x_mean.append(torch.from_dlpack(reconstructor.x)) 57 | x_mean = torch.stack(x_mean).squeeze().unsqueeze(1)/scale_factor 58 | return x_mean, scale_factor 59 | 60 | def pll_gradient(x, measurements, acq_model, attn_factor, contamination_factor): 61 | tmp = measurements / (LPDForwardFunction2D.apply(x, acq_model, attn_factor) + contamination_factor) - torch.ones_like(measurements) 62 | return LPDAdjointFunction2D.apply(tmp, acq_model, attn_factor) 63 | 64 | def rdp_rolled_components(x): 65 | rows = [1,1,1,0,0,0,-1,-1,-1] 66 | columns = [1,0,-1,1,0,-1,1,0,-1] 67 | x_neighbours = x.clone().repeat(1,9,1,1) 68 | for i in range(9): 69 | x_neighbours[:,[i]] = torch.roll(x, shifts=(rows[i], columns[i]), dims=(-2, -1)) 70 | return x_neighbours 71 | 72 | def get_preconditioner(x, x_neighbours, sens_img, beta): 73 | first = torch.clamp(sens_img,1e-9)/(x+1e-9) 74 | x = x.repeat(1,9,1,1) 75 | second = (16*(x_neighbours**2))/torch.clamp((x + x_neighbours + 2 * torch.abs(x-x_neighbours))**3,1e-9) 76 | return 1/(first - beta*second.sum(dim=1, keepdim=True)) 77 | 78 | def rdp_gradient(x, x_neighbours): 79 | x = x.repeat(1,9,1,1) 80 | numerator = (x - x_neighbours)*(2 * torch.abs(x-x_neighbours) + x + 3 * x_neighbours) 81 | denominator = torch.clamp((x + x_neighbours + 2 * torch.abs(x - x_neighbours))**2,1e-9) 82 | return - (numerator/denominator).sum(dim=1, keepdim=True) 83 | 84 | def get_map(x, acq_model, attn_factors, contamination, measurements, scale_factor, num_subsets, num_epochs, beta): 85 | # from A Concave Prior Penalizing Relative Differences 86 | # for Maximum-a-Posteriori Reconstruction 87 | # in Emission Tomography Eq. 15 88 | sens_img = LPDAdjointFunction2D.apply(torch.ones_like(measurements), acq_model, attn_factors/30.).detach() 89 | x_old = torch.clamp(scale_factor*x.clone().detach(),0) 90 | for _ in range(num_epochs): 91 | x_neighbours = rdp_rolled_components(x_old) 92 | preconditioner = get_preconditioner(x_old, x_neighbours, sens_img, beta).detach() 93 | gradient_1 = pll_gradient(x_old, measurements, acq_model, attn_factors/30., contamination).detach() 94 | gradient_2 = beta * rdp_gradient(x_old, x_neighbours).detach() 95 | x_new = torch.clamp(x_old.detach() + \ 96 | preconditioner * (gradient_1 + gradient_2), 0) 97 | x_old = x_new 98 | return x_new/scale_factor, scale_factor 99 | 100 | def get_anchor(x, acq_model, attn_factors, contamination, measurements, scale_factor, num_subsets, num_epochs, beta): 101 | x_anchor = scale_factor*x.clone().detach() 102 | sens_img = LPDAdjointFunction2D.apply(torch.ones_like(measurements), acq_model, attn_factors/30.).detach() 103 | x_old = scale_factor*x.clone().detach() 104 | for _ in range(num_epochs): 105 | preconditioner = (x_old + 1e-9)/torch.clamp(sens_img,1e-9) 106 | gradient_1 = pll_gradient(x_old, measurements, acq_model, attn_factors/30., contamination).detach() 107 | gradient_2 = beta * (x_anchor-x_old).detach() 108 | x_new = torch.clamp(x_old.detach() + \ 109 | preconditioner * (gradient_1 + gradient_2), 1e-9) 110 | 111 | x_old = x_new 112 | 113 | return x_new/scale_factor, scale_factor -------------------------------------------------------------------------------- /src/brainweb_2d/get_true_test_brainweb_2D.py: -------------------------------------------------------------------------------- 1 | import nibabel as nib 2 | import pyparallelproj.coincidences as coincidences 3 | import pyparallelproj.petprojectors as petprojectors 4 | import pyparallelproj.resolution_models as resolution_models 5 | from tqdm import tqdm 6 | import cupy as xp 7 | import torch, os 8 | import cupyx.scipy.ndimage as ndi 9 | from tumor_generator import Generate2DTumors 10 | 11 | 12 | # Adapted from https://github.com/gschramm/pyparallelproj/blob/main/examples/00_projections_and_reconstruction/02_osem.py 13 | 14 | if __name__=="__main__": 15 | coincidence_descriptor = coincidences.GEDiscoveryMICoincidenceDescriptor( 16 | num_rings=1, 17 | sinogram_spatial_axis_order=coincidences. 18 | SinogramSpatialAxisOrder['RVP'], 19 | xp=xp) 20 | 21 | mu_projector = petprojectors.PETJosephProjector(coincidence_descriptor, 22 | (256,256,1), 23 | (-127.5, -127.5, 0), 24 | (1,1,2)) 25 | 26 | true_projector = petprojectors.PETJosephProjector(coincidence_descriptor, 27 | (256,256,1), 28 | (-127.5, -127.5, 0), 29 | (1,1,2)) 30 | 31 | res_model = resolution_models.GaussianImageBasedResolutionModel((256,256,1), 32 | tuple(4.5 / (2.35 * x) for x in (1,1,2)), xp, ndi) 33 | 34 | true_projector.image_based_resolution_model = res_model 35 | 36 | bool_tumour = [True, False] 37 | 38 | for tumour in bool_tumour: 39 | 40 | clean_data_pts = [] 41 | mu_ref_pts = [] 42 | image_ref_pts = [] 43 | mri_ref_pts = [] 44 | if tumour: 45 | background_pts = [] 46 | tumour_rois_pts = [] 47 | 48 | nii_pet = nib.as_closest_canonical(nib.load(f'path_to/examples/data/brainweb_petmr/subject04/sim_0/true_pet.nii.gz')) 49 | nii_mu = nib.as_closest_canonical(nib.load(f'path_to/examples/data/brainweb_petmr/subject04/mu.nii.gz')) 50 | nii_mri = nib.as_closest_canonical(nib.load(f'path_to/examples/data/brainweb_petmr/subject04/t1.nii.gz')) 51 | 52 | # pet image resolution [1,1,2] mm 53 | image_gt = xp.array(nii_pet.get_fdata(), dtype=xp.float32) 54 | image_gt = (image_gt[:, :, ::2] + image_gt[:, :, 1::2])/2 55 | 56 | # pet image resolution [2,2,2] mm 57 | image_ref = (image_gt[::2, :, :] + image_gt[1::2, :, :])/2 58 | image_ref = (image_ref[:, ::2, :] + image_ref[:, 1::2, :])/2 59 | 60 | # mu image resolution [1,1,2] mm 61 | mu_gt = xp.array(nii_mu.get_fdata(), dtype=xp.float32) 62 | mu_gt = (mu_gt[:, :, ::2] + mu_gt[:, :, 1::2]) /2 63 | 64 | # mu image resolution [2,2,2] mm 65 | mu_ref = (mu_gt[::2, :, :] + mu_gt[1::2, :, :])/2 66 | mu_ref = (mu_ref[:, ::2, :] + mu_ref[:, 1::2, :])/2 67 | 68 | # mri image resolution [1,1,2] mm 69 | mri_gt = xp.array(nii_mri.get_fdata(), dtype=xp.float32) 70 | mri_gt = (mri_gt[:, :, ::2] + mri_gt[:, :, 1::2]) /2 71 | 72 | # mri image resolution [2,2,2] mm 73 | mri_ref = (mri_gt[::2, :, :] + mri_gt[1::2, :, :])/2 74 | mri_ref = (mri_ref[:, ::2, :] + mri_ref[:, 1::2, :])/2 75 | 76 | for slice_number in range(image_ref.shape[-1]): 77 | # ENSURE THERE ARE AT LEAST 2000 NON-ZERO PIXELS IN SLICE 78 | if len(xp.nonzero(image_ref[:, :, [slice_number]])[0]) > 2000: 79 | attenuation_factors = xp.exp(-mu_projector.forward(mu_gt[:, :, [slice_number]])) 80 | true_projector.multiplicative_corrections = attenuation_factors * 1./30 81 | if tumour: 82 | image_gt_slice, background, tumour_rois = Generate2DTumors(xp.asnumpy(image_gt[:, :, slice_number])) 83 | image_gt_slice = xp.expand_dims(xp.asarray(image_gt_slice),-1) 84 | image_ref_slice = (image_gt_slice[::2, :, :] + image_gt_slice[1::2, :, :])/2 85 | image_ref_slice = (image_ref_slice[:, ::2, :] + image_ref_slice[:, 1::2, :])/2 86 | tumour_rois = (tumour_rois[:, ::2, :] + tumour_rois[:, 1::2, :])/2 87 | tumour_rois = (tumour_rois[:, :, ::2] + tumour_rois[:, :, 1::2])/2 88 | tumour_rois[tumour_rois < 1.] = 0. 89 | background = (background[::2, :] + background[1::2, :])/2 90 | background = (background[:, ::2] + background[:, 1::2])/2 91 | background[background < 1.] = 0. 92 | background_pts.append(torch.from_numpy(background)[None][None].float().cuda()) 93 | tumour_rois_pts.append(torch.from_numpy(tumour_rois)[None].float().cuda()) 94 | #print("Has tumour") 95 | else: 96 | image_gt_slice = image_gt[..., [slice_number]] 97 | image_ref_slice = image_ref[..., [slice_number]] 98 | #print("Has no tumour") 99 | 100 | clean_data = true_projector.forward(image_gt_slice) 101 | clean_data_pts.append(torch.from_dlpack(clean_data[None][None]).float().cuda()) 102 | mu_ref_pts.append(torch.from_dlpack(mu_ref[:, :, slice_number][None][None]).float().cuda()) 103 | image_ref_pts.append(torch.from_dlpack(image_ref_slice[:, :, 0][None][None]).float().cuda()) 104 | mri_ref_pts.append(torch.from_dlpack(mri_ref[:, :, slice_number][None][None]).float().cuda()) 105 | 106 | clean_data_pts = torch.cat(clean_data_pts) 107 | mu_ref_pts = torch.cat(mu_ref_pts) 108 | image_ref_pts = torch.cat(image_ref_pts) 109 | mri_ref_pts = torch.cat(mri_ref_pts) 110 | if tumour: 111 | background_pts = torch.cat(background_pts) 112 | tumour_rois_pts = torch.cat(tumour_rois_pts) 113 | recon_dict = {'clean_measurements': clean_data_pts, 'mu': mu_ref_pts, 'reference': image_ref_pts, 'background': background_pts, 'tumour_rois': tumour_rois_pts, 'mri': mri_ref_pts} 114 | torch.save(recon_dict, "path_to/clean/clean_test_tumour.pt") 115 | else: 116 | recon_dict = {'clean_measurements': clean_data_pts, 'mu': mu_ref_pts, 'reference': image_ref_pts, 'mri': mri_ref_pts} 117 | torch.save(recon_dict, "path_to/clean/clean_test.pt") -------------------------------------------------------------------------------- /src/brainweb_2d/get_noisy_test_brainweb_2D.py: -------------------------------------------------------------------------------- 1 | import pyparallelproj.coincidences as coincidences 2 | import pyparallelproj.subsets as subsets 3 | import pyparallelproj.petprojectors as petprojectors 4 | import pyparallelproj.resolution_models as resolution_models 5 | import pyparallelproj.algorithms as algorithms 6 | import cupy as xp 7 | import cupyx.scipy.ndimage as ndi 8 | from brainweb import BrainWebClean 9 | import torch, os 10 | from tqdm import tqdm 11 | 12 | # Adapted from https://github.com/gschramm/pyparallelproj/blob/main/examples/00_projections_and_reconstruction/02_osem.py 13 | 14 | if __name__ == "__main__": 15 | coincidence_descriptor = coincidences.GEDiscoveryMICoincidenceDescriptor( 16 | num_rings=1, 17 | sinogram_spatial_axis_order=coincidences. 18 | SinogramSpatialAxisOrder['RVP'], 19 | xp=xp) 20 | 21 | mu_projector = petprojectors.PETJosephProjector(coincidence_descriptor, 22 | (128, 128, 1), (-127, -127, 0), 23 | (2, 2, 2)) 24 | projector = petprojectors.PETJosephProjector(coincidence_descriptor, 25 | (128, 128, 1), (-127, -127, 0), 26 | (2, 2, 2)) 27 | res_model = resolution_models.GaussianImageBasedResolutionModel( 28 | (128, 128, 1), tuple(4.5 / (2.35 * x) for x in (2, 2, 2)), xp, ndi) 29 | 30 | projector.image_based_resolution_model = res_model 31 | 32 | subsetter = subsets.SingoramViewSubsetter(coincidence_descriptor, 34) 33 | projector.subsetter = subsetter 34 | xp.random.seed(42) 35 | n_realisations = 10 36 | bool_tumour = [True, False] 37 | trues_per_volumes = [100, 50, 10, 7.5, 5, 2.5] 38 | for tumour in bool_tumour: 39 | 40 | if tumour: 41 | dataset = BrainWebClean(path_to_files="path_to/clean/clean_test_tumour.pt") 42 | else: 43 | dataset = BrainWebClean(path_to_files="path_to/clean/clean_test.pt") 44 | 45 | for trues_per_volume in trues_per_volumes: 46 | osem_pts = [] 47 | scaling_factor_pts = [] 48 | noisy_data_pts = [] 49 | contamination_pts = [] 50 | attenuation_pts = [] 51 | print(f"Trues per volume {trues_per_volume}, tumour {tumour}") 52 | for idx in tqdm(range(len(dataset))): 53 | osem_tmps = [] 54 | scaling_factor_tmps = [] 55 | noisy_data_tmps = [] 56 | contamination_tmps = [] 57 | attenuation_tmps = [] 58 | y, mu, gt = dataset[idx] 59 | gt = xp.from_dlpack(gt.squeeze().unsqueeze(-1).to("cuda")) 60 | mu = xp.from_dlpack(mu.squeeze().unsqueeze(-1).to("cuda")) 61 | y = xp.from_dlpack(y.squeeze().to("cuda")) 62 | for _ in range(n_realisations): 63 | # simulate the attenuation factors (exp(-fwd(attenuation_image))) 64 | attenuation_factors = xp.exp(-mu_projector.forward(mu)) 65 | projector.multiplicative_corrections = attenuation_factors * 1. / 30 66 | 67 | # scale the image such that we get a certain true count per emission voxel value 68 | emission_volume = xp.where(gt > 0)[0].shape[0] * 8 69 | current_trues_per_volume = float(y.sum() / emission_volume) 70 | 71 | scaling_factor = (trues_per_volume / current_trues_per_volume) 72 | 73 | image_fwd_scaled = y*scaling_factor 74 | 75 | # simulate a constant background contamination 76 | contamination_scale = image_fwd_scaled.mean() 77 | contamination = xp.full(projector.output_shape, 78 | contamination_scale, 79 | dtype=xp.float32) 80 | 81 | # generate noisy data 82 | data = xp.random.poisson(image_fwd_scaled + contamination).astype(xp.uint16).astype(xp.float32) 83 | 84 | 85 | reconstructor = algorithms.OSEM(data, contamination, projector, verbose=False) 86 | reconstructor.run(1, evaluate_cost=False) 87 | 88 | osem_x = reconstructor.x 89 | 90 | osem_tmps.append(torch.from_dlpack(osem_x[:,:,0])[None].float().cuda()) 91 | scaling_factor_tmps.append(torch.tensor(scaling_factor)[None].float().cuda()) 92 | noisy_data_tmps.append(torch.from_dlpack(data)[None].float().cuda()) 93 | contamination_tmps.append(torch.tensor(contamination_scale)[None].float().cuda()) 94 | attenuation_tmps.append(torch.from_dlpack(attenuation_factors)[None].float().cuda()) 95 | 96 | """ import matplotlib.pyplot as plt 97 | fig, ax = plt.subplots(1, 3, figsize=(15, 5)) 98 | ax[0].imshow(osem_x[:,:,0].get()) 99 | ax[1].imshow(mu[:,:,0].get()) 100 | ax[2].imshow(gt[:,:,0].get()) 101 | plt.show() 102 | exit() """ 103 | 104 | osem_pts.append(torch.cat(osem_tmps)[None]) 105 | scaling_factor_pts.append(torch.cat(scaling_factor_tmps)[None]) 106 | noisy_data_pts.append(torch.cat(noisy_data_tmps)[None]) 107 | contamination_pts.append(torch.cat(contamination_tmps)[None]) 108 | attenuation_pts.append(torch.cat(attenuation_tmps)[None]) 109 | 110 | osem_reconstruction = torch.cat(osem_pts) 111 | scaling_factor = torch.cat(scaling_factor_pts) 112 | noisy_data = torch.cat(noisy_data_pts) 113 | contamination_scales = torch.cat(contamination_pts) 114 | attenuation_factors = torch.cat(attenuation_pts) 115 | 116 | save_dict = {'osem': osem_reconstruction, 117 | 'scale_factor': scaling_factor, 118 | 'measurements': noisy_data, 119 | 'contamination_factor': contamination_scales, 120 | 'attn_factors': attenuation_factors} 121 | if tumour: 122 | torch.save(save_dict, f"E:/projects/pet_score_model/src/brainweb_2d/noisy/noisy_test_tumour_{trues_per_volume}.pt") 123 | else: 124 | torch.save(save_dict, f"E:/projects/pet_score_model/src/brainweb_2d/noisy/noisy_test_{trues_per_volume}.pt") 125 | 126 | del osem_pts, scaling_factor_pts, noisy_data_pts, contamination_pts, attenuation_pts 127 | del osem_reconstruction, scaling_factor, noisy_data, contamination_scales, attenuation_factors 128 | del save_dict 129 | -------------------------------------------------------------------------------- /src/brainweb_2d/get_true_train_brainweb_2D.py: -------------------------------------------------------------------------------- 1 | import nibabel as nib 2 | import pyparallelproj.coincidences as coincidences 3 | import pyparallelproj.petprojectors as petprojectors 4 | import pyparallelproj.resolution_models as resolution_models 5 | from tqdm import tqdm 6 | import cupy as xp 7 | import torch, os 8 | import cupyx.scipy.ndimage as ndi 9 | from tumor_generator import Generate2DTumors 10 | 11 | # Adapted from https://github.com/gschramm/pyparallelproj/blob/main/examples/00_projections_and_reconstruction/02_osem.py 12 | 13 | if __name__=="__main__": 14 | coincidence_descriptor = coincidences.GEDiscoveryMICoincidenceDescriptor( 15 | num_rings=1, 16 | sinogram_spatial_axis_order=coincidences. 17 | SinogramSpatialAxisOrder['RVP'], 18 | xp=xp) 19 | 20 | mu_projector = petprojectors.PETJosephProjector(coincidence_descriptor, 21 | (256,256,1), 22 | (-127.5, -127.5, 0), 23 | (1,1,2)) 24 | true_projector = petprojectors.PETJosephProjector(coincidence_descriptor, 25 | (256,256,1), 26 | (-127.5, -127.5, 0), 27 | (1,1,2)) 28 | 29 | res_model = resolution_models.GaussianImageBasedResolutionModel((256,256,1), 30 | tuple(4.5 / (2.35 * x) for x in (1,1,2)), xp, ndi) 31 | 32 | true_projector.image_based_resolution_model = res_model 33 | 34 | subjects = [5, 6, 18, 20, 38, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54] 35 | tumour = True 36 | if tumour: 37 | print('Getting true data with tumours...') 38 | clean_data_pts = [] 39 | mu_ref_pts = [] 40 | image_ref_pts = [] 41 | mri_ref_pts = [] 42 | if tumour: 43 | background_pts = [] 44 | tumour_rois_pts = [] 45 | for subject_number in tqdm(subjects): 46 | for sim_number in range(3): 47 | nii_pet = nib.as_closest_canonical(nib.load(f'path_to/examples/data/brainweb_petmr/subject{subject_number:02}/sim_{sim_number}/true_pet.nii.gz')) 48 | nii_mu = nib.as_closest_canonical(nib.load(f'path_to/examples/data/brainweb_petmr/subject{subject_number:02}/mu.nii.gz')) 49 | nii_mri = nib.as_closest_canonical(nib.load(f'path_to/examples/data/brainweb_petmr/subject{subject_number:02}/t1.nii.gz')) 50 | # pet image resolution [1,1,2] mm 51 | image_gt = xp.array(nii_pet.get_fdata(), dtype=xp.float32) 52 | image_gt = (image_gt[:, :, ::2] + image_gt[:, :, 1::2])/2 53 | 54 | # pet image resolution [2,2,2] mm 55 | image_ref = (image_gt[::2, :, :] + image_gt[1::2, :, :])/2 56 | image_ref = (image_ref[:, ::2, :] + image_ref[:, 1::2, :])/2 57 | 58 | # mu image resolution [1,1,2] mm 59 | mu_gt = xp.array(nii_mu.get_fdata(), dtype=xp.float32) 60 | mu_gt = (mu_gt[:, :, ::2] + mu_gt[:, :, 1::2]) /2 61 | 62 | # mu image resolution [2,2,2] mm 63 | mu_ref = (mu_gt[::2, :, :] + mu_gt[1::2, :, :])/2 64 | mu_ref = (mu_ref[:, ::2, :] + mu_ref[:, 1::2, :])/2 65 | 66 | # mri image resolution [1,1,2] mm 67 | mri_gt = xp.array(nii_mri.get_fdata(), dtype=xp.float32) 68 | mri_gt = (mri_gt[:, :, ::2] + mri_gt[:, :, 1::2]) /2 69 | 70 | # mri image resolution [2,2,2] mm 71 | mri_ref = (mri_gt[::2, :, :] + mri_gt[1::2, :, :])/2 72 | mri_ref = (mri_ref[:, ::2, :] + mri_ref[:, 1::2, :])/2 73 | 74 | for slice_number in range(image_ref.shape[-1]): 75 | # ENSURE THERE ARE AT LEAST 2000 NON-ZERO PIXELS IN SLICE 76 | if len(xp.nonzero(image_ref[:, :, slice_number])[0]) > 2000: 77 | attenuation_factors = xp.exp(-mu_projector.forward(mu_gt[:, :, [slice_number]])) 78 | true_projector.multiplicative_corrections = attenuation_factors * 1./30 79 | if tumour: 80 | image_gt_slice, background, tumour_rois = Generate2DTumors(xp.asnumpy(image_gt[:, :, slice_number])) 81 | image_gt_slice = xp.expand_dims(xp.asarray(image_gt_slice),-1) 82 | image_ref_slice = (image_gt_slice[::2, :, :] + image_gt_slice[1::2, :, :])/2 83 | image_ref_slice = (image_ref_slice[:, ::2, :] + image_ref_slice[:, 1::2, :])/2 84 | tumour_rois = (tumour_rois[:, ::2, :] + tumour_rois[:, 1::2, :])/2 85 | tumour_rois = (tumour_rois[:, :, ::2] + tumour_rois[:, :, 1::2])/2 86 | tumour_rois[tumour_rois < 1.] = 0. 87 | background = (background[::2, :] + background[1::2, :])/2 88 | background = (background[:, ::2] + background[:, 1::2])/2 89 | background[background < 1.] = 0. 90 | background_pts.append(torch.from_numpy(background)[None][None].float().cuda()) 91 | tumour_rois_pts.append(torch.from_numpy(tumour_rois)[None].float().cuda()) 92 | else: 93 | image_gt_slice = image_gt[..., [slice_number]] 94 | image_ref_slice = image_ref[..., [slice_number]] 95 | 96 | clean_data = true_projector.forward(image_gt[:, :, [slice_number]]) 97 | clean_data_pts.append(torch.from_dlpack(clean_data[None][None]).float().cuda()) 98 | mu_ref_pts.append(torch.from_dlpack(mu_ref[:, :, slice_number][None][None]).float().cuda()) 99 | image_ref_pts.append(torch.from_dlpack(image_ref[:, :, slice_number][None][None]).float().cuda()) 100 | mri_ref_pts.append(torch.from_dlpack(mri_ref[:, :, slice_number][None][None]).float().cuda()) 101 | 102 | clean_data_pts = torch.cat(clean_data_pts) 103 | mu_ref_pts = torch.cat(mu_ref_pts) 104 | image_ref_pts = torch.cat(image_ref_pts) 105 | mri_ref_pts = torch.cat(mri_ref_pts) 106 | 107 | if tumour: 108 | background_pts = torch.cat(background_pts) 109 | tumour_rois_pts = torch.cat(tumour_rois_pts) 110 | recon_dict = {'clean_measurements': clean_data_pts, 'mu': mu_ref_pts, 'reference': image_ref_pts, 'background': background_pts, 'tumour_rois': tumour_rois_pts, 'mri': mri_ref_pts} 111 | torch.save(recon_dict, "path_to/clean/clean_train_tumour.pt") 112 | else: 113 | recon_dict = {'clean_measurements': clean_data_pts, 'mu': mu_ref_pts, 'reference': image_ref_pts, 'mri': mri_ref_pts} 114 | torch.save(recon_dict, "path_to/clean/clean_train.pt") 115 | 116 | 117 | 118 | 119 | -------------------------------------------------------------------------------- /src/utils/exp_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from: https://github.com/educating-dip/score_based_model_baselines/blob/main/src/utils/exp_utils.py 3 | 4 | """ 5 | 6 | import os 7 | import time 8 | import torch 9 | import functools 10 | from math import ceil 11 | from pathlib import Path 12 | 13 | from .sde import VESDE, VPSDE, HeatDiffusion 14 | from .ema import ExponentialMovingAverage 15 | 16 | from ..third_party_models import OpenAiUNetModel 17 | from ..samplers import (BaseSampler, Euler_Maruyama_sde_predictor, Langevin_sde_corrector, 18 | chain_simple_init, decomposed_diffusion_sampling_sde_predictor) 19 | 20 | def get_standard_score(config, sde, use_ema, load_path = None, load_model=True): 21 | if load_model: 22 | assert load_path is not None, "set load path" 23 | 24 | if str(config.model.model_name).lower() == 'OpenAiUNetModel'.lower(): 25 | score = OpenAiUNetModel( 26 | image_size=config.data.im_size, 27 | in_channels=config.model.in_channels, 28 | model_channels=config.model.model_channels, 29 | out_channels=config.model.out_channels, 30 | num_res_blocks=config.model.num_res_blocks, 31 | attention_resolutions=config.model.attention_resolutions, 32 | marginal_prob_std=None if isinstance(sde,HeatDiffusion) else sde.marginal_prob_std, 33 | channel_mult=config.model.channel_mult, 34 | conv_resample=config.model.conv_resample, 35 | dims=config.model.dims, 36 | num_heads=config.model.num_heads, 37 | num_head_channels=config.model.num_head_channels, 38 | num_heads_upsample=config.model.num_heads_upsample, 39 | use_scale_shift_norm=config.model.use_scale_shift_norm, 40 | resblock_updown=config.model.resblock_updown, 41 | use_new_attention_order=config.model.use_new_attention_order, 42 | max_period=config.model.max_period 43 | ) 44 | else: 45 | raise NotImplementedError 46 | 47 | if load_model: 48 | print(f'load score model from path: {load_path}') 49 | if use_ema: 50 | ema = ExponentialMovingAverage(score.parameters(), decay=0.999) 51 | ema.load_state_dict(torch.load(os.path.join(load_path,'ema_model.pt'))) 52 | ema.copy_to(score.parameters()) 53 | else: 54 | score.load_state_dict(torch.load(os.path.join(load_path, config.sampling.model_name))) 55 | 56 | return score 57 | 58 | def get_standard_sde(config): 59 | 60 | if config.sde.type.lower() == 'vesde': 61 | sde = VESDE( 62 | sigma_min=config.sde.sigma_min, 63 | sigma_max=config.sde.sigma_max 64 | ) 65 | elif config.sde.type.lower() == 'vpsde': 66 | sde = VPSDE( 67 | beta_min=config.sde.beta_min, 68 | beta_max=config.sde.beta_max 69 | ) 70 | elif config.sde.type.lower() == "heatdiffusion": 71 | sde = HeatDiffusion( 72 | sigma_min=config.sde.sigma_min, 73 | sigma_max=config.sde.sigma_max, 74 | T_max=config.sde.T_max 75 | ) 76 | 77 | else: 78 | raise NotImplementedError 79 | 80 | return sde 81 | 82 | def get_standard_sampler(config, score, sde, nll, im_shape, observation=None, 83 | osem=None, guidance_imgs=None, device=None): 84 | """ 85 | nll should be a function of x, i.e. a functools.partial with fixed norm_factors, attn_factors, contamination, measurements 86 | 87 | """ 88 | if config.sampling.name.lower() == 'naive': 89 | predictor = functools.partial( 90 | Euler_Maruyama_sde_predictor, 91 | nloglik = nll) 92 | sample_kwargs = { 93 | 'num_steps': int(config.sampling.num_steps), 94 | 'start_time_step': ceil(float(config.sampling.pct_chain_elapsed) * int(config.sampling.num_steps)), 95 | 'batch_size': config.sampling.batch_size, 96 | 'im_shape': im_shape, 97 | 'eps': config.sampling.eps, 98 | 'predictor': {'aTweedy': False, 'penalty': float(config.sampling.penalty), "guidance_imgs": guidance_imgs, "guidance_strength": config.sampling.guidance_strength}, 99 | 'corrector': {} 100 | } 101 | elif config.sampling.name.lower() == 'dps': 102 | predictor = functools.partial( 103 | Euler_Maruyama_sde_predictor, 104 | nloglik = nll) 105 | sample_kwargs = { 106 | 'num_steps': int(config.sampling.num_steps), 107 | 'batch_size': config.sampling.batch_size, 108 | 'start_time_step': ceil(float(config.sampling.pct_chain_elapsed) * int(config.sampling.num_steps)), 109 | 'im_shape': im_shape, 110 | 'eps': config.sampling.eps, 111 | 'predictor': {'aTweedy': True, 'penalty': float(config.sampling.penalty), "guidance_imgs": guidance_imgs, "guidance_strength": config.sampling.guidance_strength}, 112 | 'corrector': {}, 113 | } 114 | elif config.sampling.name.lower() == 'dds' or config.sampling.name.lower() == 'dds_3d': 115 | predictor = functools.partial( 116 | decomposed_diffusion_sampling_sde_predictor, 117 | nloglik = nll) 118 | sample_kwargs = { 119 | 'num_steps': int(config.sampling.num_steps), 120 | 'batch_size': config.sampling.batch_size, 121 | 'start_time_step': ceil(float(config.sampling.pct_chain_elapsed) * int(config.sampling.num_steps)), 122 | 'im_shape': im_shape, 123 | 'eps': config.sampling.eps, 124 | 'predictor': {"guidance_imgs": guidance_imgs, 125 | "guidance_strength": config.sampling.guidance_strength, 126 | 'use_simplified_eqn': True, 127 | 'eta': config.sampling.stochasticity}, 128 | 'corrector': {}, 129 | } 130 | else: 131 | raise NotImplementedError 132 | 133 | corrector = None 134 | if config.sampling.add_corrector: 135 | corrector = functools.partial(Langevin_sde_corrector, 136 | nloglik = nll ) 137 | sample_kwargs['corrector']['corrector_steps'] = 1 138 | sample_kwargs['corrector']['penalty'] = float(config.sampling.penalty) 139 | 140 | init_chain_fn = None 141 | if sample_kwargs['start_time_step'] > 0: 142 | init_chain_fn = functools.partial( 143 | chain_simple_init, 144 | sde=sde, 145 | osem=osem, 146 | start_time_step=sample_kwargs['start_time_step'], 147 | im_shape=im_shape, 148 | batch_size=sample_kwargs['batch_size'], 149 | device=device 150 | ) 151 | 152 | sampler = BaseSampler( 153 | score=score, 154 | sde=sde, 155 | predictor=predictor, 156 | corrector=corrector, 157 | init_chain_fn=init_chain_fn, 158 | sample_kwargs=sample_kwargs, 159 | device=config.device, 160 | ) 161 | 162 | return sampler -------------------------------------------------------------------------------- /results/lpd_unet/unet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Taken from Dival: 3 | https://jleuschn.github.io/docs.dival/dival.reconstructors.networks.unet.html 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import numpy as np 9 | 10 | 11 | def get_unet_model(in_ch=1, out_ch=1, scales=5, skip=4, 12 | channels=(32, 32, 64, 64, 128, 128), use_sigmoid=False, 13 | use_norm=True): 14 | assert (1 <= scales <= 6) 15 | skip_channels = [skip] * (scales) 16 | return UNet(in_ch=in_ch, out_ch=out_ch, channels=channels[:scales], 17 | skip_channels=skip_channels, use_sigmoid=use_sigmoid, 18 | use_norm=use_norm) 19 | 20 | 21 | class UNet(nn.Module): 22 | 23 | def __init__(self, in_ch, out_ch, channels, skip_channels, 24 | use_sigmoid=True, use_norm=True): 25 | super(UNet, self).__init__() 26 | assert (len(channels) == len(skip_channels)) 27 | self.scales = len(channels) 28 | self.use_sigmoid = use_sigmoid 29 | self.down = nn.ModuleList() 30 | self.up = nn.ModuleList() 31 | self.inc = InBlock(in_ch, channels[0], use_norm=use_norm) 32 | for i in range(1, self.scales): 33 | self.down.append(DownBlock(in_ch=channels[i - 1], 34 | out_ch=channels[i], 35 | use_norm=use_norm)) 36 | for i in range(1, self.scales): 37 | self.up.append(UpBlock(in_ch=channels[-i], 38 | out_ch=channels[-i - 1], 39 | skip_ch=skip_channels[-i], 40 | use_norm=use_norm)) 41 | self.outc = OutBlock(in_ch=channels[0], 42 | out_ch=out_ch) 43 | 44 | def forward(self, x0, norm): 45 | xs = [self.inc(x0/norm[:,None,None,None]), ] 46 | for i in range(self.scales - 1): 47 | xs.append(self.down[i](xs[-1])) 48 | x = xs[-1] 49 | for i in range(self.scales - 1): 50 | x = self.up[i](x, xs[-2 - i]) 51 | return self.outc(x)*norm[:,None,None,None] 52 | 53 | 54 | class DownBlock(nn.Module): 55 | 56 | def __init__(self, in_ch, out_ch, kernel_size=3, use_norm=True): 57 | super(DownBlock, self).__init__() 58 | to_pad = int((kernel_size - 1) / 2) 59 | if use_norm: 60 | self.conv = nn.Sequential( 61 | nn.Conv2d(in_ch, out_ch, kernel_size, 62 | stride=2, padding=to_pad), 63 | nn.BatchNorm2d(out_ch), 64 | nn.LeakyReLU(0.2, inplace=True), 65 | nn.Conv2d(out_ch, out_ch, kernel_size, 66 | stride=1, padding=to_pad), 67 | nn.BatchNorm2d(out_ch), 68 | nn.LeakyReLU(0.2, inplace=True)) 69 | else: 70 | self.conv = nn.Sequential( 71 | nn.Conv2d(in_ch, out_ch, kernel_size, 72 | stride=2, padding=to_pad), 73 | nn.LeakyReLU(0.2, inplace=True), 74 | nn.Conv2d(out_ch, out_ch, kernel_size, 75 | stride=1, padding=to_pad), 76 | nn.LeakyReLU(0.2, inplace=True)) 77 | 78 | def forward(self, x): 79 | x = self.conv(x) 80 | return x 81 | 82 | 83 | 84 | class InBlock(nn.Module): 85 | 86 | def __init__(self, in_ch, out_ch, kernel_size=3, use_norm=True): 87 | super(InBlock, self).__init__() 88 | to_pad = int((kernel_size - 1) / 2) 89 | if use_norm: 90 | self.conv = nn.Sequential( 91 | nn.Conv2d(in_ch, out_ch, kernel_size, 92 | stride=1, padding=to_pad), 93 | nn.BatchNorm2d(out_ch), 94 | nn.LeakyReLU(0.2, inplace=True)) 95 | else: 96 | self.conv = nn.Sequential( 97 | nn.Conv2d(in_ch, out_ch, kernel_size, 98 | stride=1, padding=to_pad), 99 | nn.LeakyReLU(0.2, inplace=True)) 100 | 101 | 102 | def forward(self, x): 103 | x = self.conv(x) 104 | return x 105 | 106 | 107 | 108 | 109 | class UpBlock(nn.Module): 110 | 111 | def __init__(self, in_ch, out_ch, skip_ch=4, kernel_size=3, use_norm=True): 112 | super(UpBlock, self).__init__() 113 | to_pad = int((kernel_size - 1) / 2) 114 | self.skip = skip_ch > 0 115 | if skip_ch == 0: 116 | skip_ch = 1 117 | if use_norm: 118 | self.conv = nn.Sequential( 119 | nn.BatchNorm2d(in_ch + skip_ch), 120 | nn.Conv2d(in_ch + skip_ch, out_ch, kernel_size, stride=1, 121 | padding=to_pad), 122 | nn.BatchNorm2d(out_ch), 123 | nn.LeakyReLU(0.2, inplace=True), 124 | nn.Conv2d(out_ch, out_ch, kernel_size, 125 | stride=1, padding=to_pad), 126 | nn.BatchNorm2d(out_ch), 127 | nn.LeakyReLU(0.2, inplace=True)) 128 | else: 129 | self.conv = nn.Sequential( 130 | nn.Conv2d(in_ch + skip_ch, out_ch, kernel_size, stride=1, 131 | padding=to_pad), 132 | nn.LeakyReLU(0.2, inplace=True), 133 | nn.Conv2d(out_ch, out_ch, kernel_size, 134 | stride=1, padding=to_pad), 135 | nn.LeakyReLU(0.2, inplace=True)) 136 | 137 | if use_norm: 138 | self.skip_conv = nn.Sequential( 139 | nn.Conv2d(out_ch, skip_ch, kernel_size=1, stride=1), 140 | nn.BatchNorm2d(skip_ch), 141 | nn.LeakyReLU(0.2, inplace=True)) 142 | else: 143 | self.skip_conv = nn.Sequential( 144 | nn.Conv2d(out_ch, skip_ch, kernel_size=1, stride=1), 145 | nn.LeakyReLU(0.2, inplace=True)) 146 | 147 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', 148 | align_corners=True) 149 | self.concat = Concat() 150 | 151 | def forward(self, x1, x2): 152 | x1 = self.up(x1) 153 | x2 = self.skip_conv(x2) 154 | if not self.skip: 155 | x2 = x2 * 0 156 | x = self.concat(x1, x2) 157 | x = self.conv(x) 158 | return x 159 | 160 | 161 | class Concat(nn.Module): 162 | 163 | def __init__(self): 164 | super(Concat, self).__init__() 165 | 166 | def forward(self, *inputs): 167 | inputs_shapes2 = [x.shape[2] for x in inputs] 168 | inputs_shapes3 = [x.shape[3] for x in inputs] 169 | 170 | if (np.all(np.array(inputs_shapes2) == min(inputs_shapes2)) and 171 | np.all(np.array(inputs_shapes3) == min(inputs_shapes3))): 172 | inputs_ = inputs 173 | else: 174 | target_shape2 = min(inputs_shapes2) 175 | target_shape3 = min(inputs_shapes3) 176 | 177 | inputs_ = [] 178 | for inp in inputs: 179 | diff2 = (inp.size(2) - target_shape2) // 2 180 | diff3 = (inp.size(3) - target_shape3) // 2 181 | inputs_.append(inp[:, :, diff2: diff2 + target_shape2, 182 | diff3:diff3 + target_shape3]) 183 | return torch.cat(inputs_, dim=1) 184 | 185 | 186 | class OutBlock(nn.Module): 187 | def __init__(self, in_ch, out_ch): 188 | super(OutBlock, self).__init__() 189 | self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=1) 190 | 191 | def forward(self, x): 192 | x = self.conv(x) 193 | return x 194 | 195 | 196 | def __len__(self): 197 | return len(self._modules) -------------------------------------------------------------------------------- /src/brainweb_2d/brainweb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BrainWebClean(torch.utils.data.Dataset): 5 | def __init__(self, path_to_files="path_to/test_dict.pt", mri=False): 6 | 7 | self.path_to_files = path_to_files 8 | self.data = torch.load(path_to_files, map_location=torch.device('cpu')) 9 | self.mri = mri 10 | 11 | def __len__(self): 12 | return self.data['clean_measurements'].shape[0] 13 | 14 | def __getitem__(self, idx): 15 | y = self.data["clean_measurements"][idx, ...] 16 | mu = self.data["mu"][idx, ...] 17 | gt = self.data["reference"][idx, ...] 18 | if self.mri: 19 | mri = self.data["mri"][idx, ...] 20 | return y, mu, gt, mri 21 | return y, mu, gt 22 | 23 | class BrainWebOSEM(torch.utils.data.Dataset): 24 | def __init__(self, part, noise_level, base_path="path_to/src/brainweb_2d/", static_path = None, device="cpu", guided=False): 25 | assert noise_level in [2.5, 5, 7.5, 10, 50, 100, "2.5", "5", "7.5", "10", "50", "100"], "noise level has to be 2.5, 5, 7.5, 10, 50, 100" 26 | assert part in ["train", "test", "test_tumour", "subset_test_tumour", "subset_test", "validation"], 'part has to be "train", "test", "test_tumour", "subset_test_tumour", "subset_test", "validation"' 27 | 28 | self.part = part 29 | self.noise_level = noise_level 30 | self.guided = guided 31 | 32 | self.base_path = base_path 33 | # dict_keys(['osem', 'scale_factor', 'measurements', 'contamination_factor', 'attn_factors']) 34 | self.noisy = torch.load(base_path+"noisy/noisy_"+ self.part + "_" + str(noise_level)+".pt", map_location=torch.device(device)) 35 | # dict_keys(['clean_measurements', 'mu', 'reference']) 36 | self.clean = torch.load(base_path+"clean/clean_"+part+".pt", map_location=torch.device(device)) 37 | if static_path is not None: 38 | # dict_keys(['osem', 'scale_factor', 'measurements', 'contamination_factor', 'attn_factors']) 39 | self.noisy = torch.load(static_path, map_location=torch.device(device)) 40 | # dict_keys(['clean_measurements', 'mu', 'reference']) 41 | self.clean = torch.load(base_path+"clean/"+part+"_clean.pt", map_location=torch.device(device)) 42 | if "tumour" in part: 43 | self.tumour = True 44 | else: 45 | self.tumour = False 46 | def __len__(self): 47 | return self.clean["reference"].shape[0] 48 | 49 | def __getitem__(self, idx): 50 | 51 | reference = self.clean["reference"][idx, ...].float() 52 | scale_factor = self.noisy["scale_factor"][idx] 53 | 54 | reference = reference*scale_factor[[0]] 55 | 56 | if self.guided: 57 | reference = torch.cat((reference, self.clean["mri"][idx, ...].float()), dim=0) 58 | osem = self.noisy["osem"][idx, ...].float() 59 | 60 | 61 | norm = 1 62 | 63 | measurements = self.noisy["measurements"][idx, ...].float() 64 | contamination_factor = self.noisy["contamination_factor"][idx] 65 | attn_factors = self.noisy["attn_factors"][idx, ...].float() 66 | 67 | if self.part == "subset_test": 68 | measurements = measurements 69 | contamination_factor = contamination_factor 70 | attn_factors = attn_factors 71 | osem = osem 72 | 73 | if norm == 0: 74 | norm = torch.ones_like(norm) 75 | osem = torch.zeros_like(osem) 76 | reference = torch.zeros_like(reference) 77 | attn_factors = torch.ones_like(attn_factors) 78 | 79 | if self.tumour: 80 | background = self.clean["background"][idx, ...].float() 81 | tumour_rois = self.clean["tumour_rois"][idx, ...].float() 82 | return reference, scale_factor, osem, norm, measurements, contamination_factor, attn_factors, background, tumour_rois 83 | return reference, scale_factor, osem, norm, measurements, contamination_factor, attn_factors 84 | 85 | 86 | class BrainWebSupervisedTrain(torch.utils.data.Dataset): 87 | def __init__(self, noise_level, base_path="path_to/pyparallelproj/examples/data/", device="cpu", guided=False): 88 | assert noise_level in [5, 10, 50, "5", "10", "50"], "noise level has to be 5, 10, 50" 89 | self.base_path = base_path 90 | 91 | # dict_keys(['clean_measurements', 'mu', 'reference', 'mri]) 92 | clean = torch.load(base_path+"clean/clean_train.pt", map_location=torch.device(device)) 93 | # dict_keys(['osem', 'scale_factor', 'measurements', 'contamination_factor', 'attn_factors']) 94 | self.noisy = torch.load(base_path+"noisy/noisy_train_"+str(noise_level)+".pt", map_location=torch.device(device)) 95 | self.reference = clean["reference"] 96 | self.guided = guided 97 | if self.guided: 98 | self.mri = clean["mri"] 99 | 100 | def __len__(self): 101 | return self.reference 102 | 103 | def __getitem__(self, idx): 104 | reference = self.reference[idx,...].float()*self.noisy["scale_factor"][idx] 105 | 106 | osem = self.noisy["osem"][idx, ...].float() 107 | 108 | measurements = self.noisy["measurements"][idx, ...].float() 109 | 110 | contamination_factor = self.noisy["contamination_factor"][idx] 111 | 112 | attn_factors = self.noisy["attn_factors"][idx, ...].float() 113 | 114 | if self.guided: 115 | mri = self.mri[idx,...].float() 116 | return reference, mri, osem, measurements, contamination_factor, attn_factors 117 | 118 | return reference, osem, measurements, contamination_factor, attn_factors 119 | 120 | 121 | 122 | class BrainWebScoreTrain(torch.utils.data.Dataset): 123 | def __init__(self, base_path="path_to/pyparallelproj/examples/data/", device="cpu", guided=False, normalisation="data_scale"): 124 | 125 | self.base_path = base_path 126 | # dict_keys(['clean_measurements', 'mu', 'reference', 'mri]) 127 | clean = torch.load(base_path+"clean/clean_train.pt", map_location=torch.device(device)) 128 | print(clean.keys()) 129 | self.reference = clean["reference"] 130 | self.clean_measurements = clean["clean_measurements"] 131 | self.mri = clean["mri"] 132 | self.guided = guided 133 | 134 | self.normalisation = normalisation 135 | 136 | def __len__(self): 137 | return self.reference.shape[0] 138 | 139 | def __getitem__(self, idx): 140 | reference = self.reference[idx, ...].float() 141 | 142 | if self.normalisation == "data_scale": 143 | emission_volume = torch.where(reference > 0)[0].shape[0] * 8 # 2 x 2 x 2 144 | current_trues_per_volume = float(self.clean_measurements[idx].sum() / emission_volume) 145 | elif self.normalisation == "image_scale": 146 | emission_volume = torch.where(reference > 0)[0].shape[0] 147 | current_trues_per_volume = float(reference.sum() / emission_volume) 148 | else: 149 | raise NotImplementedError 150 | 151 | reference = reference/current_trues_per_volume 152 | 153 | reference = reference* (0.5 + torch.rand(1)) 154 | 155 | mri = self.mri[idx, ...].float() 156 | if self.guided: 157 | return torch.cat((reference, mri), dim=0) 158 | return reference 159 | 160 | 161 | if __name__ == "__main__": 162 | 163 | dataset = BrainWebScoreTrain(base_path="path_to/", normalisation="image_scale") 164 | import matplotlib.pyplot as plt 165 | import numpy as np 166 | 167 | for i in range(10): 168 | batch = dataset[i] 169 | print(batch.min(), batch.max()) -------------------------------------------------------------------------------- /results/3D_final/result_util_3D.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from glob import glob 3 | import os, sys 4 | sys.path.append("/home/user/sirf/") 5 | from src import PSNR, SSIM 6 | import sirf.STIR as pet 7 | 8 | def get_crc_std_psnr_ssim(vols, tumours, backgrounds, ref, img_idx): 9 | crcs = [] 10 | std_background = backgrounds.sum(dim=0, keepdim=True).repeat(len(vols),1,1,1) 11 | stds = vols[torch.nonzero(std_background, as_tuple=True)].std(dim = 0).mean().item() 12 | 13 | for i in range(len(tumours)): 14 | tumour_roi_idx = torch.nonzero(tumours[i], as_tuple=True) 15 | background_roi_idx = torch.nonzero(backgrounds[i], as_tuple=True) 16 | cr_refs = ref[tumour_roi_idx].mean() / ref[background_roi_idx].mean() 17 | crc_r = [] 18 | for i in range(len(vols)): 19 | cr_est = vols[i][tumour_roi_idx].mean() / vols[i][background_roi_idx].mean() 20 | crc_r.append(((cr_est-1)/(cr_refs-1)).item()) 21 | crcs.append(sum(crc_r)/len(crc_r)) 22 | 23 | psnrs_r = [] 24 | ssims_r = [] 25 | for i in range(len(vols)): 26 | psnrs_r.append(PSNR((vols[i]).numpy(), (ref).numpy())) 27 | ssims_r.append(SSIM((vols[i]).numpy(), (ref).numpy())) 28 | return crcs, stds, sum(psnrs_r)/len(psnrs_r), sum(ssims_r)/len(ssims_r), vols[0,img_idx] 29 | 30 | def get_sweep_mean_results(path, img_id=3): 31 | result = torch.load(path) 32 | datafit_strengths = [] 33 | psnrs = [] 34 | ssims = [] 35 | crcs = [] 36 | stds = [] 37 | show_images = [] 38 | for datafit_strength in result.keys(): 39 | datafit_strengths.append(float(datafit_strength)) 40 | psnrs.append(result[datafit_strength]["psnr"]) 41 | ssims.append(result[datafit_strength]["ssim"]) 42 | crcs.append(result[datafit_strength]["crc"]) 43 | stds.append(result[datafit_strength]["std"]) 44 | show_images.append(result[datafit_strength]["show_images"]) 45 | psnrs = [x for _, x in sorted(zip(datafit_strengths, psnrs))] 46 | ssims = [x for _, x in sorted(zip(datafit_strengths, ssims))] 47 | crcs = [x for _, x in sorted(zip(datafit_strengths, crcs))] 48 | stds = [x for _, x in sorted(zip(datafit_strengths, stds))] 49 | datafit_strengths = sorted(datafit_strengths) 50 | return {"datafit_strengths": datafit_strengths, "psnr": psnrs, "ssim": ssims, "crc": crcs, "std": stds, "show_images": show_images} 51 | 52 | 53 | def save_sweep_dicts(unique_swept_datafit_strengths): 54 | for name in unique_swept_datafit_strengths.keys(): 55 | if "brainweb3D" in name: 56 | tumours = torch.load("/home/user/sirf/src/sirf/brainweb_3D/tumours.pt") 57 | backgrounds = torch.load("/home/user/sirf/src/sirf/brainweb_3D/backgrounds.pt") 58 | if "FDG" in name: 59 | ref = torch.from_numpy(pet.ImageData(f"/home/user/sirf/src/sirf/brainweb_3D/FDG_PET_lr.hv").as_array()) 60 | # Scaled reference as linear model and scaled data 61 | ref = ref * 0.0018871473105862091 / 4.0024639524408965 62 | elif "Amyloid" in name: 63 | ref = torch.from_numpy(pet.ImageData(f"/home/user/sirf/src/sirf/brainweb_3D/Amyloid_PET_lr.hv").as_array()) 64 | # Scaled reference as linear model and scaled data 65 | ref = ref * 0.0024918709984937458 / 4.004922937685068 66 | else: raise NotImplementedError("Not a valid tracer") 67 | img_idx = [12,31,47,55,71] 68 | sweep_dict = {} 69 | folder_name = name.split("_")[0] + "_" + name.split("_")[1] + "_" + name.split("_")[2] 70 | file_name = "" 71 | for i in name.split("_")[3:]: 72 | file_name += i + "_" 73 | file_name = file_name[:-1] 74 | if "dds_3D" in name: 75 | file_name += "_beta_" + name.split("_")[-1] 76 | check_folder_create("3D_dicts/", folder_name) 77 | while len(unique_swept_datafit_strengths[name]) > 0: 78 | result_path = unique_swept_datafit_strengths[name][0] 79 | if "dds" in name: 80 | # is lambda not beta in this case 81 | beta = result_path.split("/")[-3].split("_")[-3] 82 | elif "DIP" in name: 83 | beta = result_path.split("/")[-2].split("_")[-1] 84 | else: 85 | _, _, beta = result_path.split("/")[-2].split("_") 86 | sweep_dict[str(beta)] = {} 87 | beta_paths = [] 88 | # Get all the realisations 89 | for beta_path in unique_swept_datafit_strengths[name]: 90 | if "dds" in name: 91 | # is lambda not beta in this case 92 | b = beta_path.split("/")[-3].split("_")[-3] 93 | elif "DIP" in name: 94 | b = beta_path.split("/")[-2].split("_")[-1] 95 | else: 96 | _, _, b = beta_path.split("/")[-2].split("_") 97 | if b == beta: 98 | beta_paths.append(beta_path) 99 | vols = [] 100 | for b_p in beta_paths: 101 | unique_swept_datafit_strengths[name].remove(b_p) 102 | vols.append(torch.load(b_p)) 103 | vols = torch.stack(vols) 104 | crc, std, psnr, ssim, show_images = get_crc_std_psnr_ssim(vols, tumours, backgrounds, ref, img_idx) 105 | sweep_dict[str(beta)]["crc"] = crc 106 | sweep_dict[str(beta)]["std"] = std 107 | sweep_dict[str(beta)]["psnr"] = psnr 108 | sweep_dict[str(beta)]["ssim"] = ssim 109 | sweep_dict[str(beta)]["show_images"] = show_images 110 | torch.save(sweep_dict, f"3D_dicts/{folder_name}/{file_name}.pt") 111 | 112 | 113 | def check_folder_create(path, folder_name): 114 | CHECK_FOLDER = os.path.isdir(path+folder_name) 115 | if not CHECK_FOLDER: 116 | os.makedirs(path+folder_name) 117 | print("created folder : ", path+folder_name) 118 | 119 | def get_unique_swept_datafit_strengths(base_path="/home/user/sirf/coordinators/FINAL_3D/**/volume.pt"): 120 | result_paths = glob(base_path, recursive=True) 121 | num_results = len(result_paths) 122 | # FIND ALL THE UNIQUE DATAFIT STRENGTH SWEEPS 123 | unique_swept_datafit_strengths = {} 124 | while len(result_paths) > 0: 125 | result_path = result_paths[0] 126 | if "dds_3D" in result_path: 127 | dataset = result_path.split("/")[-4] 128 | prior = result_path.split("/")[-3] 129 | data_name, count_level, _, tracer = dataset.split("_") 130 | prior_name = prior.split("_")[0] 131 | prior_beta = prior.split("_")[-1] 132 | name = f"{data_name}_{count_level}_{tracer}_{prior_name}_{prior_beta}" 133 | elif "DIP" in result_path: 134 | dataset = result_path.split("/")[-3] 135 | prior = result_path.split("/")[-2] 136 | prior_beta = prior.split("_")[2] 137 | data_name, count_level, _, tracer = dataset.split("_") 138 | prior_name = prior.split("_")[0] 139 | name = f"{data_name}_{count_level}_{tracer}_{prior_name}_{prior_beta}" 140 | else: 141 | dataset = result_path.split("/")[-3] 142 | prior = result_path.split("/")[-2] 143 | data_name, count_level, _, tracer = dataset.split("_") 144 | prior_name = prior.split("_")[0] 145 | name = f"{data_name}_{count_level}_{tracer}_{prior_name}" 146 | if name in unique_swept_datafit_strengths.keys(): 147 | unique_swept_datafit_strengths[name].append(result_path) 148 | else: 149 | unique_swept_datafit_strengths[name] = [] 150 | unique_swept_datafit_strengths[name].append(result_path) 151 | result_paths.remove(result_path) 152 | print(f"Altogether we have {num_results} results, and {len(unique_swept_datafit_strengths.keys())} individual sweeps.") 153 | return unique_swept_datafit_strengths 154 | 155 | 156 | -------------------------------------------------------------------------------- /coordinators/2D_rdp_baseline.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch, sys, os 3 | import pyparallelproj.coincidences as coincidences 4 | import pyparallelproj.petprojectors as petprojectors 5 | import pyparallelproj.resolution_models as resolution_models 6 | import cupy as xp 7 | import cupyx.scipy.ndimage as ndi 8 | from tqdm import tqdm 9 | sys.path.append(os.path.dirname(os.getcwd())) 10 | from src import BrainWebOSEM, LPDForwardFunction2D, LPDAdjointFunction2D 11 | 12 | 13 | detector_efficiency = 1./30 14 | coincidence_descriptor = coincidences.GEDiscoveryMICoincidenceDescriptor( 15 | num_rings=1, 16 | sinogram_spatial_axis_order=coincidences.SinogramSpatialAxisOrder['RVP'],xp=xp) 17 | acq_model = petprojectors.PETJosephProjector(coincidence_descriptor, 18 | (128, 128, 1), (-127.0, -127.0, 0.0), 19 | (2., 2., 2.)) 20 | res_model = resolution_models.GaussianImageBasedResolutionModel( 21 | (128, 128, 1), tuple(4.5 / (2.35 * x) for x in (2., 2., 2.)), xp, ndi) 22 | acq_model.image_based_resolution_model = res_model 23 | 24 | 25 | 26 | def pnll_gradient(x, measurements, acq_model, attn_factor, contamination_factor): 27 | tmp = measurements / (LPDForwardFunction2D.apply(x, acq_model, attn_factor) + contamination_factor) - torch.ones_like(measurements) 28 | return LPDAdjointFunction2D.apply(tmp, acq_model, attn_factor) 29 | 30 | def rdp_rolled_components(x): 31 | rows = [1,1,1,0,0,0,-1,-1,-1] 32 | columns = [1,0,-1,1,0,-1,1,0,-1] 33 | x_neighbours = x.clone().repeat(1,9,1,1) 34 | for i in range(9): 35 | x_neighbours[:,[i]] = torch.roll(x, shifts=(rows[i], columns[i]), dims=(-2, -1)) 36 | return x_neighbours 37 | 38 | def get_preconditioner(x, x_neighbours, sens_img, beta): 39 | # from A Concave Prior Penalizing Relative Differences 40 | # for Maximum-a-Posteriori Reconstruction 41 | # in Emission Tomography Eq. 15 42 | first = sens_img/x 43 | x = x.repeat(1,9,1,1) 44 | second = (16*(x_neighbours**2))/torch.clamp(((x + x_neighbours + 2 * torch.abs(x-x_neighbours))**3),0) 45 | return 1/(first - beta*second.sum(dim=1, keepdim=True)) 46 | 47 | def rdp_gradient(x, x_neighbours): 48 | x = x.repeat(1,9,1,1) 49 | numerator = (x - x_neighbours)*(2 * torch.abs(x-x_neighbours) + x + 3 * x_neighbours) 50 | denominator = (x + x_neighbours + 2 * torch.abs(x - x_neighbours))**2 51 | return - (numerator/denominator).sum(dim=1, keepdim=True) 52 | 53 | def obj_value(x, x_neighbours, measurements, acq_model, attn_factors, contamination_factor, beta): 54 | y_pred = LPDForwardFunction2D.apply(x, acq_model, attn_factors) + contamination_factor 55 | kl = - (measurements*torch.log(measurements/y_pred+1e-9) + (y_pred-measurements)).sum(-1) 56 | x = x.repeat(1,9,1,1) 57 | numerator = (x - x_neighbours)**2 58 | denominator = (x + x_neighbours + 2 * torch.abs(x - x_neighbours)) 59 | rdp = - beta*(numerator/denominator).sum(dim=1, keepdim=True).sum(-1).sum(-1) 60 | return (kl + rdp).mean() 61 | 62 | def compute_kl_div(recons, measurements, acq_model, attn_factors, contamination_factor): 63 | kldiv_r = [] 64 | for r in range(len(recons)): 65 | y_pred = LPDForwardFunction2D.apply(recons[[r]], acq_model, attn_factors[[r]]) + contamination_factor[[r],..., None] 66 | kl = (measurements[[r]]*torch.log(measurements[[r]]/y_pred+1e-9)+ (y_pred-measurements[[r]])).sum() 67 | if kl.isnan(): 68 | print("KL is nan") 69 | kldiv_r.append(kl) 70 | return torch.asarray(kldiv_r).cpu() 71 | 72 | 73 | parts = ["test", "test_tumour"] 74 | noises = [10] 75 | for noise in noises: 76 | for part in parts: 77 | dataset = BrainWebOSEM(part=part, 78 | noise_level=noise, 79 | base_path="path_tof/src/brainweb_2d/") 80 | subset = list(range(2, len(dataset), 4)) 81 | dataset = torch.utils.data.Subset(dataset, subset) 82 | betas = [0.11,0.1,0.09,0.075,0.05,0.025,0.01,0.001] 83 | for beta in betas: 84 | save_recon = [] 85 | save_ref = [] 86 | save_kldivs = [] 87 | save_lesion_rois = [] 88 | save_background_rois = [] 89 | print(f"beta: {beta}") 90 | idx = 0 91 | for batch in dataset: 92 | idx += 1 93 | # [0] reference, [1] scale_factor, [2] osem, [3] norm, [4] measurements, 94 | # [5] contamination_factor, [6] attn_factors 95 | gt = batch[0].to("cuda:0").unsqueeze(1)[...] 96 | osem = batch[2].to("cuda:0").unsqueeze(1)[...] 97 | measurements = batch[4].to("cuda:0").unsqueeze(1)[...] 98 | contamination_factor = batch[5].to("cuda:0")[:,None,None] 99 | attn_factors = batch[6].to("cuda:0").unsqueeze(1)[...]*detector_efficiency 100 | sens_img = LPDAdjointFunction2D.apply(torch.ones_like(measurements), acq_model, attn_factors).detach() 101 | x_old = osem.clone().detach() 102 | grad_norm = [] 103 | objective_values = [] 104 | x_neighbours = rdp_rolled_components(x_old) 105 | prev_obj = obj_value(x_old, x_neighbours, measurements, acq_model, attn_factors, contamination_factor, beta) 106 | for i in tqdm(range(1000)): 107 | x_neighbours = rdp_rolled_components(x_old) 108 | preconditioner = get_preconditioner(x_old, x_neighbours, sens_img, beta).detach() 109 | gradient = (pnll_gradient(x_old, measurements, acq_model, attn_factors, contamination_factor) 110 | + beta * rdp_gradient(x_old, x_neighbours)).detach() 111 | 112 | x_new = torch.clamp(x_old.detach() + \ 113 | preconditioner * gradient, 0) 114 | x_old = x_new 115 | new_obj = obj_value(x_new, x_neighbours, measurements, acq_model, attn_factors, contamination_factor, beta) 116 | if (new_obj - prev_obj)**2 < 1e-6: 117 | break 118 | grad_norm.append(torch.norm(gradient).item()) 119 | objective_values.append(new_obj.item()) 120 | prev_obj = new_obj.clone() 121 | kldiv_r = compute_kl_div(recons = x_new, 122 | measurements=measurements, 123 | acq_model=acq_model, 124 | attn_factors=attn_factors, 125 | contamination_factor=contamination_factor) 126 | 127 | lesion_roi = batch[-1].to("cuda:0") 128 | background_roi = batch[-2].to("cuda:0") 129 | 130 | save_recon.append(x_new.squeeze().cpu()) 131 | save_ref.append(gt.squeeze().cpu()) 132 | save_kldivs.append(kldiv_r) 133 | save_lesion_rois.append(lesion_roi.squeeze().cpu()) 134 | save_background_rois.append(background_roi.squeeze().cpu()) 135 | 136 | print("iterations: ", i) 137 | fig, ax = plt.subplots(1,5, figsize=(25,5)) 138 | fig.colorbar(ax[0].imshow(gt[0,0].cpu().numpy())) 139 | ax[0].set_title("Ground truth") 140 | ax[0].axis("off") 141 | fig.colorbar(ax[1].imshow(sens_img[0,0].cpu().numpy())) 142 | ax[1].set_title("Sensitivity image") 143 | ax[1].axis("off") 144 | fig.colorbar(ax[2].imshow(osem[0,0].cpu().numpy())) 145 | ax[2].set_title("OSEM") 146 | ax[2].axis("off") 147 | fig.colorbar(ax[3].imshow(x_new[0,0].cpu().numpy())) 148 | ax[3].set_title(f"Penalised MAP beta: {beta}") 149 | ax[3].axis("off") 150 | ax[4].plot(objective_values) 151 | ax[4].set_title("Objective value") 152 | plt.savefig(f"path_to/coordinators/RDP/{part}_{noise}/rdp_baseline_image_{idx}_beta_{beta}.png", dpi=300, bbox_inches="tight") 153 | plt.close() 154 | #plt.show() 155 | 156 | results = {"images": torch.stack(save_recon).cpu(), 157 | "ref": torch.stack(save_ref).cpu(), 158 | "kldiv": torch.stack(save_kldivs).cpu(), 159 | "lesion_rois": torch.stack(save_lesion_rois).cpu(), 160 | "background_rois": torch.stack(save_background_rois).cpu(), 161 | "beta": beta} 162 | 163 | name = f"rdp_baseline_beta_{beta}" 164 | torch.save(results, f"path_to/coordinators/RDP/{part}_{noise}/{name}.pt") -------------------------------------------------------------------------------- /coordinators/test_reconstruction.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import torch 3 | import functools 4 | import numpy as np 5 | import yaml 6 | import sys, os 7 | sys.path.append(os.path.dirname(os.getcwd())) 8 | from src import (BrainWebOSEM, get_standard_score, get_standard_sde, 9 | LPDForwardFunction2D, get_standard_sampler, osem_nll, 10 | get_osem, get_map, get_anchor, kl_div) 11 | from omegaconf import DictConfig, OmegaConf 12 | import torchvision 13 | import matplotlib.pyplot as plt 14 | from src import PSNR, SSIM 15 | 16 | import cupy as xp 17 | 18 | # not used in this script 19 | #detector_efficiency = 1./30 20 | 21 | 22 | def get_acq_model(): 23 | import pyparallelproj.coincidences as coincidences 24 | import pyparallelproj.petprojectors as petprojectors 25 | import pyparallelproj.resolution_models as resolution_models 26 | import cupyx.scipy.ndimage as ndi 27 | 28 | """ 29 | create forward operator 30 | """ 31 | coincidence_descriptor = coincidences.GEDiscoveryMICoincidenceDescriptor( 32 | num_rings=1, 33 | sinogram_spatial_axis_order=coincidences.SinogramSpatialAxisOrder['RVP'],xp=xp) 34 | acq_model = petprojectors.PETJosephProjector(coincidence_descriptor, 35 | (128, 128, 1), (-127.0, -127.0, 0.0), (2., 2., 2.)) 36 | res_model = resolution_models.GaussianImageBasedResolutionModel( 37 | (128, 128, 1), tuple(4.5 / (2.35 * x) for x in (2., 2., 2.)), xp, ndi) 38 | acq_model.image_based_resolution_model = res_model 39 | return acq_model 40 | 41 | 42 | def estimate_scale_factor(osem, measurements, contamination, normalisation_type): 43 | scale_factors = [] 44 | for i in range(osem.shape[0]): 45 | if normalisation_type == "data_scale": 46 | emission_volume = torch.where(osem[i] > 0.01*osem[i].max(), 1, 0).sum() * 8 47 | scale_factor = (measurements[i] - contamination[i]).sum()/emission_volume 48 | scale_factors.append(scale_factor) 49 | elif normalisation_type == "image_scale": 50 | emission_volume = torch.where(osem[i] > 0.01*osem[i].max(), 1, 0).sum() 51 | scale_factor = osem[i].sum()/emission_volume 52 | scale_factors.append(scale_factor) 53 | else: 54 | raise NotImplementedError 55 | return torch.tensor(scale_factors) 56 | 57 | def compute_kl_div(recons, measurements, acq_model, attn_factors, contamination_factor): 58 | kldiv_r = [] 59 | for r in range(len(recons)): 60 | y_pred = LPDForwardFunction2D.apply(recons[[r]], acq_model, attn_factors[[r]]) + contamination_factor[[r]] 61 | kl = (measurements[[r]]*torch.log(measurements[[r]]/y_pred+1e-9)+ (y_pred-measurements[[r]])).sum() 62 | if kl.isnan(): 63 | print(torch.log(measurements[[r]]/y_pred[[r]]+1e-9).sum()) 64 | print("KL is nan") 65 | kldiv_r.append(kl) 66 | return torch.asarray(kldiv_r).cpu() 67 | 68 | @hydra.main(config_path='../configs', config_name='test_reconstruction') 69 | def reconstruction(config : DictConfig) -> None: 70 | print(OmegaConf.to_yaml(config)) 71 | 72 | ###### SET SEED ###### 73 | if config.seed is not None: 74 | torch.manual_seed(config.seed) 75 | np.random.seed(config.seed) 76 | 77 | ###### GET SCORE MODEL ###### 78 | # open the yaml config file 79 | with open(os.path.join(config.score_based_model.path, "report.yaml"), "r") as stream: 80 | ml_collection = yaml.load(stream, Loader=yaml.UnsafeLoader) 81 | guided = False if ml_collection.guided_p_uncond is None else True 82 | # get the sde 83 | sde = get_standard_sde(ml_collection) 84 | # get the score model 85 | score_model = get_standard_score(ml_collection, sde, 86 | use_ema = config.score_based_model.ema, 87 | load_path = config.score_based_model.path) 88 | score_model.eval() 89 | score_model.to(config.device) 90 | 91 | ###### GET ACQUISITION MODEL AND DATA ###### 92 | # get the acquisition model 93 | acq_model = get_acq_model() 94 | # get the data 95 | dataset = BrainWebOSEM(part=config.dataset.part, 96 | noise_level=config.dataset.poisson_scale, 97 | base_path=config.dataset.base_path, 98 | guided=guided) 99 | test_loader = torch.utils.data.DataLoader(dataset, 100 | batch_size=8, shuffle=False) 101 | 102 | config.sampling.batch_size = 8 103 | 104 | ###### SOLVING REVERSE SDE ###### 105 | img_shape = (config.dataset.img_z_dim, 106 | config.dataset.img_xy_dim, config.dataset.img_xy_dim) 107 | 108 | for idx, batch in enumerate(test_loader): 109 | # [0] reference, [1] scale_factor, [2] osem, [3] norm, [4] measurements, [5] contamination_factor, [6] attn_factors 110 | if guided: 111 | gt = batch[0][:, [0], ...] 112 | guided_img = batch[0][:, [1], ...].to(config.device) 113 | else: 114 | gt = batch[0][:, [0], ...] 115 | guided_img = None 116 | 117 | print("Normalisation type: ", ml_collection.normalisation) 118 | attn_factors=batch[6][:,[0],...].to(config.device) 119 | contamination_factor=batch[5][:,[0],None].to(config.device) 120 | measurements=batch[4][:,[0],...].to(config.device) 121 | osem=batch[2][:,[0],...] 122 | gt=batch[0][:, [0], ...] 123 | 124 | # estimate scaling factors from measurements 125 | scale_factor = estimate_scale_factor(osem=osem, 126 | measurements=measurements, contamination=contamination_factor, 127 | normalisation_type=ml_collection.normalisation)[:, None, None, None].to(config.device) 128 | 129 | if config.sampling.use_osem_nll: 130 | nll_partial = functools.partial(osem_nll, 131 | scale_factor=scale_factor, 132 | osem=osem.to(config.device)) 133 | elif config.sampling.name == "dds": 134 | if config.sampling.dds_proj.name == "osem": 135 | nll_partial = functools.partial(get_osem, 136 | acq_model=acq_model, 137 | attn_factors=attn_factors, 138 | contamination=contamination_factor, 139 | measurements=measurements, 140 | scale_factor=scale_factor, 141 | num_subsets=config.sampling.dds_proj.num_subsets, 142 | num_epochs=config.sampling.dds_proj.num_epochs) 143 | elif config.sampling.dds_proj.name == "map": 144 | nll_partial = functools.partial(get_map, 145 | acq_model=acq_model, 146 | attn_factors=attn_factors, 147 | contamination=contamination_factor, 148 | measurements=measurements, 149 | scale_factor=scale_factor, 150 | num_subsets=config.sampling.dds_proj.num_subsets, 151 | num_epochs=config.sampling.dds_proj.num_epochs, 152 | beta = config.sampling.dds_proj.beta) 153 | elif config.sampling.dds_proj.name == "anchor": 154 | nll_partial = functools.partial(get_anchor, 155 | acq_model=acq_model, 156 | attn_factors=attn_factors, 157 | contamination=contamination_factor, 158 | measurements=measurements, 159 | scale_factor=scale_factor, 160 | num_subsets=config.sampling.dds_proj.num_subsets, 161 | num_epochs=config.sampling.dds_proj.num_epochs, 162 | beta = config.sampling.dds_proj.beta) 163 | else: 164 | raise NotImplementedError 165 | else: 166 | nll_partial = functools.partial(kl_div, 167 | acq_model=acq_model, 168 | attn_factors=attn_factors, 169 | contamination=contamination_factor, 170 | measurements=measurements, 171 | scale_factor=scale_factor) 172 | 173 | 174 | logg_kwargs = {'log_dir': "./tb", 'num_img_in_log': 10, 175 | 'sample_num':idx, 'ground_truth': gt, 'osem': osem} 176 | 177 | sampler = get_standard_sampler( 178 | config=config, 179 | score=score_model, 180 | sde=sde, 181 | nll=nll_partial, 182 | im_shape=img_shape, 183 | guidance_imgs=guided_img if guided else None, 184 | device=config.device) 185 | 186 | recon, writer = sampler.sample(logg_kwargs=logg_kwargs) 187 | recon = torch.clamp(recon, min=0) 188 | recon = recon*scale_factor.to(config.device) 189 | 190 | fig, axes = plt.subplots(1,recon.shape[0]) 191 | 192 | for idx, ax in enumerate(axes.ravel()): 193 | ax.imshow(recon[idx,0,:,:].cpu().numpy()) 194 | 195 | plt.show() 196 | 197 | fig, (ax1, ax2, ax3) = plt.subplots(1,3) 198 | psnr = PSNR(recon[4].squeeze().cpu().numpy(), gt[4].squeeze().cpu().numpy()) 199 | ax1.imshow(gt[4,0,:,:].cpu().numpy()) 200 | ax2.imshow(recon[4,0,:,:].cpu().numpy()) 201 | ax2.set_title(str(np.round(float(psnr), 4))) 202 | psnr = PSNR(osem[4].squeeze().cpu().numpy(), gt[4].squeeze().cpu().numpy()) 203 | ax3.imshow(osem[4,0,:,:].cpu().numpy()) 204 | ax3.set_title(str(np.round(float(psnr), 4))) 205 | 206 | plt.show() 207 | 208 | kldiv_r = kl_div(x = recon, 209 | acq_model=acq_model, 210 | attn_factors=attn_factors, 211 | contamination=contamination_factor, 212 | measurements=measurements, 213 | scale_factor=1.)[0] 214 | 215 | writer.add_image( 216 | 'final_reco', torchvision.utils.make_grid(recon, 217 | normalize=True, scale_each=True), global_step=0) 218 | for i in range(config.sampling.batch_size): 219 | writer.add_scalar('PSNR_per_validation_img', PSNR(recon[i].squeeze().cpu().numpy(), gt[i].squeeze().cpu().numpy()), global_step=i) 220 | writer.add_scalar('SSIM_per_validation_img', SSIM(recon[i].squeeze().cpu().numpy(), gt[i].squeeze().cpu().numpy()), global_step=i) 221 | writer.add_scalar('kldiv_r', kldiv_r[i].mean(), global_step=i) 222 | writer.close() 223 | 224 | if __name__ == '__main__': 225 | reconstruction() -------------------------------------------------------------------------------- /results/2D_final/result_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from glob import glob 3 | import os, sys 4 | sys.path.append("/home/user/sirf/") 5 | from src import PSNR, SSIM 6 | import numpy as np 7 | 8 | def get_std_crc_ssim_psnr(result): 9 | images = result["images"] 10 | refs = result["ref"].unsqueeze(1) 11 | if "lesion_rois" in result.keys(): 12 | if result["lesion_rois"].shape[-1] == result["ref"].shape[-1]: 13 | lesion_rois = result["lesion_rois"] 14 | background_rois = result["background_rois"] 15 | else: 16 | lesion_rois = torch.zeros_like(refs) 17 | background_rois = torch.zeros_like(refs) 18 | lesion_rois[refs!=0] = 1 19 | background_rois[refs!=0] = 1 20 | else: 21 | lesion_rois = torch.zeros_like(refs) 22 | background_rois = torch.zeros_like(refs) 23 | lesion_rois[refs!=0] = 1 24 | background_rois[refs!=0] = 1 25 | psnrs = [] 26 | ssims = [] 27 | crcs = [] 28 | stds = [] 29 | for img_idx in range(images.shape[0]): 30 | image = images[img_idx].cpu().numpy() 31 | ref = refs[img_idx].squeeze().cpu().numpy() 32 | lesion_roi = lesion_rois[img_idx].cpu().numpy() 33 | background_roi = background_rois[img_idx].squeeze().cpu().numpy() 34 | psnr_r = [] 35 | ssim_r = [] 36 | crc_r = [] 37 | b_bar_r = [] 38 | for realisation in image: 39 | psnr_r.append(torch.asarray(PSNR(realisation,ref))) 40 | ssim_r.append(torch.asarray(SSIM(realisation, ref))) 41 | if background_roi.sum() != 0: 42 | background_idx = np.nonzero(background_roi) 43 | b_bar = realisation[background_idx] 44 | b_t = ref[background_idx] 45 | crc_t = [] 46 | for i in range(len(lesion_roi)): 47 | if lesion_roi[i,:,:].sum() != 0: 48 | tumour_roi_idx = np.nonzero(lesion_roi[i,:,:]) 49 | a_bar = realisation[tumour_roi_idx] 50 | a_t = ref[tumour_roi_idx] 51 | if a_bar.mean() == 0 and b_bar.mean() == 0: 52 | crc_t.append(np.array([0.0])) 53 | else: 54 | crc_t.append((a_bar.mean()/b_bar.mean() - 1) / (a_t.mean()/b_t.mean() - 1)) 55 | crc_r.append(torch.asarray(crc_t).mean()) 56 | b_bar_r.append(torch.asarray(b_bar)) 57 | std = (torch.std(torch.stack(b_bar_r), dim=0)/torch.clamp(torch.stack(b_bar_r).mean(0),1e-9)).mean() 58 | psnrs.append(torch.asarray(psnr_r)) 59 | ssims.append(torch.asarray(ssim_r)) 60 | crcs.append(torch.asarray(crc_r)) 61 | stds.append(torch.asarray(std)) 62 | return torch.stack(psnrs), torch.stack(ssims), torch.stack(crcs), torch.stack(stds) 63 | 64 | def get_sweep_mean_results(path, img_id=3): 65 | result = torch.load(path) 66 | datafit_strengths = [] 67 | kldivs = [] 68 | psnrs = [] 69 | ssims = [] 70 | crcs = [] 71 | stds = [] 72 | show_images = [] 73 | for datafit_strength in result.keys(): 74 | kldiv = [] 75 | psnr = [] 76 | ssim = [] 77 | crc = [] 78 | std = [] 79 | images = [] 80 | datafit_strengths.append(float(datafit_strength)) 81 | for image_num in result[datafit_strength].keys(): 82 | kldiv.append(result[datafit_strength][image_num]["kldiv"]) 83 | psnr.append(result[datafit_strength][image_num]["psnr"]) 84 | ssim.append(result[datafit_strength][image_num]["ssim"]) 85 | crc.append(result[datafit_strength][image_num]["crc"]) 86 | std.append(result[datafit_strength][image_num]["std"]) 87 | images.append(result[datafit_strength][image_num]["images"]) 88 | kldivs.append(sum(kldiv) / len(kldiv)) 89 | psnrs.append(sum(psnr) / len(psnr)) 90 | ssims.append(sum(ssim) / len(ssim)) 91 | crcs.append(sum(crc) / len(crc)) 92 | stds.append(sum(std) / len(std)) 93 | show_images.append(images[img_id]) 94 | kldivs = [x for _, x in sorted(zip(datafit_strengths, kldivs))] 95 | psnrs = [x for _, x in sorted(zip(datafit_strengths, psnrs))] 96 | ssims = [x for _, x in sorted(zip(datafit_strengths, ssims))] 97 | crcs = [x for _, x in sorted(zip(datafit_strengths, crcs))] 98 | stds = [x for _, x in sorted(zip(datafit_strengths, stds))] 99 | datafit_strengths = sorted(datafit_strengths) 100 | return {"datafit_strengths": datafit_strengths, "kldivs": kldivs, "psnrs": psnrs, "ssims": ssims, "crcs": crcs, "stds": stds, "show_images": show_images} 101 | 102 | 103 | def get_individual_dict(result): 104 | psnr, ssim, crc, std = get_std_crc_ssim_psnr(result) 105 | n_images = len(result["ref"]) 106 | individual_dict = {} 107 | for i in range(n_images): 108 | individual_dict[str(i)] = {} 109 | # Mean accross realisations 110 | individual_dict[str(i)]["kldiv"] = result["kldiv"][i].mean() 111 | individual_dict[str(i)]["psnr"] = psnr[i].mean() 112 | individual_dict[str(i)]["ssim"] = ssim[i].mean() 113 | individual_dict[str(i)]["crc"] = crc[i].mean() 114 | # Std accross realisations 115 | individual_dict[str(i)]["std"] = std[i] 116 | # save the first realisation 117 | individual_dict[str(i)]["images"] = result["images"][i][0] 118 | return individual_dict 119 | 120 | def save_sweep_dicts(unique_swept_datafit_strengths): 121 | for name in unique_swept_datafit_strengths.keys(): 122 | sweep_dict = {} 123 | for result_path in unique_swept_datafit_strengths[name]: 124 | result = torch.load(result_path) 125 | if "naive" in result_path or "dps" in result_path: 126 | datafit_strength = result["penalty"] 127 | elif "dds" in result_path: 128 | if "osem_num_epochs" in result_path: 129 | datafit_strength = result["num_epochs"] 130 | elif "anchor_num_epochs" in result_path: 131 | datafit_strength = result["beta"] 132 | elif "rdp" in result_path: 133 | datafit_strength = result["beta"] 134 | else: 135 | raise NotImplementedError 136 | individual_dict = get_individual_dict(result) 137 | sweep_dict[str(datafit_strength)] = individual_dict 138 | if "tumour" in name: 139 | name = name.replace("tumour_","") 140 | torch.save(sweep_dict, f"tumour/{name}.pt") 141 | else: 142 | torch.save(sweep_dict, f"non_tumour/{name}.pt") 143 | 144 | def get_unique_swept_datafit_strengths(base_path="E:/projects/pet_score_model/coordinators/SBM_2/**/*.pt"): 145 | result_paths = glob(base_path, recursive=True) 146 | num_results = len(result_paths) 147 | # FIND ALL THE UNIQUE DATAFIT STRENGTH SWEEPS 148 | unique_swept_datafit_strengths = {} 149 | while len(result_paths) > 0: 150 | result_path = result_paths[0] 151 | result = torch.load(result_path) 152 | identifers = [] 153 | if "tumour" in result_path: 154 | identifers.append("tumour") 155 | if "vesde" in result_path: 156 | identifers.append("vesde") 157 | if "vpsde" in result_path: 158 | identifers.append("vpsde") 159 | if "OSEMNLL_" in result_path: 160 | identifers.append("OSEMNLL") 161 | if "dps" in result_path: 162 | identifers.append("dps") 163 | if "naive" in result_path: 164 | identifers.append("naive") 165 | if "dds" in result_path: 166 | identifers.append("dds") 167 | if "osem_num_epochs" in result_path: 168 | identifers.append("osem") 169 | elif "anchor_num_epochs" in result_path: 170 | identifers.append("anchor") 171 | identifers.append("epoch_" + str(result["num_epochs"])) 172 | else: raise NotImplementedError("DDS needs to be osem or anchor") 173 | if "guided" in result_path: 174 | identifers.append("guided") 175 | identifers.append("gstrength_" + str(result["gstrength"])) 176 | if "rdp" in result_path: 177 | identifers.append("rdp") 178 | name = "" 179 | for identifer in identifers: 180 | name += identifer + "_" 181 | if name in unique_swept_datafit_strengths.keys(): 182 | unique_swept_datafit_strengths[name].append(result_path) 183 | else: 184 | unique_swept_datafit_strengths[name] = [] 185 | unique_swept_datafit_strengths[name].append(result_path) 186 | result_paths.remove(result_path) 187 | print(f"Altogether we have {num_results} results, and {len(unique_swept_datafit_strengths.keys())} individual sweeps.") 188 | return unique_swept_datafit_strengths 189 | 190 | 191 | -------------------------------------------------------------------------------- /results/lpd_unet/lpd_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import cupy as xp 4 | 5 | 6 | class LPDForwardFunction2D(torch.autograd.Function): 7 | """PET forward projection function 8 | 9 | We can implement our own custom autograd Functions by subclassing 10 | torch.autograd.Function and implementing the forward and backward passes 11 | which operate on Tensors. 12 | """ 13 | 14 | @staticmethod 15 | def forward(ctx, x, projector, attn_factors): 16 | """ 17 | In the forward pass we receive a Tensor containing the input and return 18 | a Tensor containing the output. ctx is a context object that can be used 19 | to stash information for backward computation. 20 | """ 21 | 22 | ctx.set_materialize_grads(False) 23 | ctx.projector = projector 24 | ctx.attn_factors = attn_factors 25 | 26 | x_inp = x.detach() 27 | # convert pytorch input tensor into cupy array 28 | cp_x = xp.ascontiguousarray(xp.from_dlpack(x_inp)) 29 | 30 | # a custom function that maps from cupy array to cupy array 31 | batch, channels = cp_x.shape[:2] 32 | b_y = [] 33 | for sample in range(batch): 34 | projector.multiplicative_corrections = xp.from_dlpack(attn_factors[sample, 0, ...]) 35 | c_y = [] 36 | for channel in range(channels): 37 | c_y.append(projector.forward(cp_x[sample, channel, :, :][:, :, None])) 38 | b_y.append(xp.stack(c_y)) 39 | # convert torch array to cupy array 40 | return torch.from_dlpack(xp.stack(b_y)) 41 | 42 | @staticmethod 43 | def backward(ctx, grad_output: torch.Tensor): 44 | """ 45 | In the backward pass we receive a Tensor containing the gradient of the loss 46 | with respect to the output, and we need to compute the gradient of the loss 47 | with respect to the input. 48 | 49 | For details on how to implement the backward pass, see 50 | https://pytorch.org/docs/stable/notes/extending.html#how-to-use 51 | """ 52 | 53 | if grad_output is None: 54 | return None, None, None 55 | else: 56 | projector = ctx.projector 57 | attn_factors = ctx.attn_factors 58 | 59 | cp_y = xp.from_dlpack(grad_output.detach()) 60 | 61 | # a custom function that maps from cupy array to cupy array 62 | batch, channels = cp_y.shape[:2] 63 | b_x = [] 64 | for sample in range(batch): 65 | projector.multiplicative_corrections = xp.from_dlpack(attn_factors[sample, 0, ...]) 66 | c_x = [] 67 | for channel in range(channels): 68 | c_x.append(projector.adjoint(cp_y[sample, channel, :])[..., 0]) 69 | b_x.append(xp.stack(c_x)) 70 | b_x = xp.stack(b_x) 71 | # convert torch array to cupy array 72 | return torch.from_dlpack(b_x), None, None 73 | 74 | 75 | 76 | class LPDAdjointFunction2D(torch.autograd.Function): 77 | """PET forward projection function 78 | 79 | We can implement our own custom autograd Functions by subclassing 80 | torch.autograd.Function and implementing the forward and backward passes 81 | which operate on Tensors. 82 | """ 83 | 84 | @staticmethod 85 | def forward(ctx, y, projector, attn_factors): 86 | """ 87 | In the forward pass we receive a Tensor containing the input and return 88 | a Tensor containing the output. ctx is a context object that can be used 89 | to stash information for backward computation. 90 | """ 91 | 92 | ctx.set_materialize_grads(False) 93 | ctx.projector = projector 94 | ctx.attn_factors = attn_factors 95 | 96 | 97 | # convert pytorch input tensor into cupy array 98 | cp_y = xp.ascontiguousarray(xp.from_dlpack(y.detach())) 99 | 100 | # a custom function that maps from cupy array to cupy array 101 | batch, channels = cp_y.shape[:2] 102 | b_x = [] 103 | for sample in range(batch): 104 | projector.multiplicative_corrections = xp.from_dlpack(attn_factors[sample, 0, ...]) 105 | c_x = [] 106 | for channel in range(channels): 107 | c_x.append(projector.adjoint(cp_y[sample, channel, :])[..., 0]) 108 | b_x.append(xp.stack(c_x)) 109 | b_x = xp.stack(b_x) 110 | # convert torch array to cupy array 111 | return torch.from_dlpack(b_x) 112 | 113 | @staticmethod 114 | def backward(ctx, grad_output: torch.Tensor): 115 | """ 116 | In the backward pass we receive a Tensor containing the gradient of the loss 117 | with respect to the output, and we need to compute the gradient of the loss 118 | with respect to the input. 119 | 120 | For details on how to implement the backward pass, see 121 | https://pytorch.org/docs/stable/notes/extending.html#how-to-use 122 | """ 123 | 124 | if grad_output is None: 125 | return None, None, None 126 | else: 127 | projector = ctx.projector 128 | attn_factors = ctx.attn_factors 129 | 130 | cp_x = xp.from_dlpack(grad_output.detach()) 131 | 132 | # a custom function that maps from cupy array to cupy array 133 | batch, channels = cp_x.shape[:2] 134 | b_y = [] 135 | for sample in range(batch): 136 | projector.multiplicative_corrections = xp.from_dlpack(attn_factors[sample, 0, ...]) 137 | c_y = [] 138 | for channel in range(channels): 139 | c_y.append(projector.forward(cp_x[sample, channel, :, :][:, :, None])) 140 | b_y.append(xp.stack(c_y)) 141 | b_y = xp.stack(b_y) 142 | # convert torch array to cupy array 143 | return torch.from_dlpack(b_y), None, None 144 | 145 | 146 | 147 | if __name__ == "__main__": 148 | import os 149 | from pyparallelproj import petprojectors, coincidences, resolution_models 150 | import cupyx.scipy.ndimage as ndi 151 | from brainweb import BrainWebOSEM 152 | from tqdm import tqdm 153 | import matplotlib.pyplot as plt 154 | os.environ['KMP_DUPLICATE_LIB_OK']='True' 155 | 156 | detector_efficiency = 1./30 157 | coincidence_descriptor = coincidences.GEDiscoveryMICoincidenceDescriptor( 158 | num_rings=1, 159 | sinogram_spatial_axis_order=coincidences. 160 | SinogramSpatialAxisOrder['RVP'], 161 | xp=xp) 162 | 163 | projector = petprojectors.PETJosephProjector(coincidence_descriptor, 164 | (128, 128, 1), (-127, -127, 0), 165 | (2., 2., 2.)) 166 | 167 | res_model = resolution_models.GaussianImageBasedResolutionModel( 168 | (128, 128, 1), tuple(4.5 / (2.35 * x) for x in (2., 2., 2.)), xp, ndi) 169 | 170 | projector.image_based_resolution_model = res_model 171 | 172 | dataset_list = [] 173 | dataset = BrainWebOSEM(part="test", noise_level=5) 174 | 175 | train_size = int(0.9 * len(dataset)) 176 | val_size = len(dataset) - train_size 177 | 178 | train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42)) 179 | batch_size = 10 180 | train_dl = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=0) 181 | 182 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 183 | 184 | 185 | sinogram = torch.nn.Parameter(torch.ones(batch_size, 1, 112880, 186 | device=device)) 187 | image = torch.nn.Parameter(torch.ones(batch_size, 1, 188 | *projector._image_shape[:-1], 189 | device=device)) 190 | optimizer = torch.optim.Adam([sinogram, image], lr=1e0) 191 | 192 | batch = next(iter(train_dl)) 193 | reference = batch[0] 194 | reference = reference.to(device) 195 | 196 | scale_factor = batch[1] 197 | scale_factor = scale_factor.to(device) 198 | 199 | osem = batch[2] 200 | osem = osem.to(device) 201 | 202 | norm = batch[3] 203 | norm = norm.to(device) 204 | 205 | measurements = batch[4] 206 | measurements = measurements.to(device) 207 | contamination_factor = batch[5] 208 | contamination_factor = contamination_factor.to(device) 209 | 210 | attn_factors = batch[6] 211 | attn_factors = attn_factors.to(device)*detector_efficiency 212 | 213 | for epoch in tqdm(range(100)): 214 | optimizer.zero_grad() 215 | 216 | y_bar = LPDForwardFunction2D.apply(image, projector, attn_factors) 217 | loss = ((y_bar - measurements)**2).sum() 218 | loss.backward() 219 | optimizer.step() 220 | if epoch==0: 221 | print(loss.item()) 222 | print(loss.item()) 223 | fig, ax = plt.subplots(1, 2) 224 | print(image.max(), image.min()) 225 | ax[0].imshow(reference[0, 0, ...].detach().cpu().numpy()) 226 | ax[1].imshow(y_bar[0, 0, ...].detach().cpu().numpy()) 227 | plt.show() 228 | 229 | 230 | 231 | 232 | 233 | -------------------------------------------------------------------------------- /src/brainweb_2d/lpd_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import cupy as xp 4 | 5 | # Loosely based on https://github.com/educating-dip/pet_deep_image_prior/blob/main/src/deep_image_prior/torch_wrapper.py 6 | 7 | class LPDForwardFunction2D(torch.autograd.Function): 8 | """PET forward projection function 9 | 10 | We can implement our own custom autograd Functions by subclassing 11 | torch.autograd.Function and implementing the forward and backward passes 12 | which operate on Tensors. 13 | """ 14 | 15 | @staticmethod 16 | def forward(ctx, x, projector, attn_factors): 17 | """ 18 | In the forward pass we receive a Tensor containing the input and return 19 | a Tensor containing the output. ctx is a context object that can be used 20 | to stash information for backward computation. 21 | """ 22 | 23 | ctx.set_materialize_grads(False) 24 | ctx.projector = projector 25 | ctx.attn_factors = attn_factors 26 | 27 | x_inp = x.detach() 28 | # convert pytorch input tensor into cupy array 29 | cp_x = xp.ascontiguousarray(xp.from_dlpack(x_inp)) 30 | 31 | # a custom function that maps from cupy array to cupy array 32 | batch, channels = cp_x.shape[:2] 33 | b_y = [] 34 | for sample in range(batch): 35 | projector.multiplicative_corrections = xp.from_dlpack(attn_factors[sample, 0, ...]) 36 | c_y = [] 37 | for channel in range(channels): 38 | c_y.append(projector.forward(cp_x[sample, channel, :, :][:, :, None])) 39 | b_y.append(xp.stack(c_y)) 40 | # convert torch array to cupy array 41 | return torch.from_dlpack(xp.stack(b_y)) 42 | 43 | @staticmethod 44 | def backward(ctx, grad_output: torch.Tensor): 45 | """ 46 | In the backward pass we receive a Tensor containing the gradient of the loss 47 | with respect to the output, and we need to compute the gradient of the loss 48 | with respect to the input. 49 | 50 | For details on how to implement the backward pass, see 51 | https://pytorch.org/docs/stable/notes/extending.html#how-to-use 52 | """ 53 | 54 | if grad_output is None: 55 | return None, None, None 56 | else: 57 | projector = ctx.projector 58 | attn_factors = ctx.attn_factors 59 | 60 | cp_y = xp.from_dlpack(grad_output.detach()) 61 | 62 | # a custom function that maps from cupy array to cupy array 63 | batch, channels = cp_y.shape[:2] 64 | b_x = [] 65 | for sample in range(batch): 66 | projector.multiplicative_corrections = xp.from_dlpack(attn_factors[sample, 0, ...]) 67 | c_x = [] 68 | for channel in range(channels): 69 | c_x.append(projector.adjoint(cp_y[sample, channel, :])[..., 0]) 70 | b_x.append(xp.stack(c_x)) 71 | b_x = xp.stack(b_x) 72 | # convert torch array to cupy array 73 | return torch.from_dlpack(b_x), None, None 74 | 75 | 76 | 77 | class LPDAdjointFunction2D(torch.autograd.Function): 78 | """PET forward projection function 79 | 80 | We can implement our own custom autograd Functions by subclassing 81 | torch.autograd.Function and implementing the forward and backward passes 82 | which operate on Tensors. 83 | """ 84 | 85 | @staticmethod 86 | def forward(ctx, y, projector, attn_factors): 87 | """ 88 | In the forward pass we receive a Tensor containing the input and return 89 | a Tensor containing the output. ctx is a context object that can be used 90 | to stash information for backward computation. 91 | """ 92 | 93 | ctx.set_materialize_grads(False) 94 | ctx.projector = projector 95 | ctx.attn_factors = attn_factors 96 | 97 | 98 | # convert pytorch input tensor into cupy array 99 | cp_y = xp.ascontiguousarray(xp.from_dlpack(y.detach())) 100 | 101 | # a custom function that maps from cupy array to cupy array 102 | batch, channels = cp_y.shape[:2] 103 | b_x = [] 104 | for sample in range(batch): 105 | projector.multiplicative_corrections = xp.from_dlpack(attn_factors[sample, 0, ...]) 106 | c_x = [] 107 | for channel in range(channels): 108 | c_x.append(projector.adjoint(cp_y[sample, channel, :])[..., 0]) 109 | b_x.append(xp.stack(c_x)) 110 | b_x = xp.stack(b_x) 111 | # convert torch array to cupy array 112 | return torch.from_dlpack(b_x) 113 | 114 | @staticmethod 115 | def backward(ctx, grad_output: torch.Tensor): 116 | """ 117 | In the backward pass we receive a Tensor containing the gradient of the loss 118 | with respect to the output, and we need to compute the gradient of the loss 119 | with respect to the input. 120 | 121 | For details on how to implement the backward pass, see 122 | https://pytorch.org/docs/stable/notes/extending.html#how-to-use 123 | """ 124 | 125 | if grad_output is None: 126 | return None, None, None 127 | else: 128 | projector = ctx.projector 129 | attn_factors = ctx.attn_factors 130 | 131 | cp_x = xp.from_dlpack(grad_output.detach()) 132 | 133 | # a custom function that maps from cupy array to cupy array 134 | batch, channels = cp_x.shape[:2] 135 | b_y = [] 136 | for sample in range(batch): 137 | projector.multiplicative_corrections = xp.from_dlpack(attn_factors[sample, 0, ...]) 138 | c_y = [] 139 | for channel in range(channels): 140 | c_y.append(projector.forward(cp_x[sample, channel, :, :][:, :, None])) 141 | b_y.append(xp.stack(c_y)) 142 | b_y = xp.stack(b_y) 143 | # convert torch array to cupy array 144 | return torch.from_dlpack(b_y), None, None 145 | 146 | 147 | 148 | if __name__ == "__main__": 149 | import os 150 | from pyparallelproj import petprojectors, coincidences, resolution_models 151 | import cupyx.scipy.ndimage as ndi 152 | from brainweb import BrainWebOSEM 153 | from tqdm import tqdm 154 | import matplotlib.pyplot as plt 155 | os.environ['KMP_DUPLICATE_LIB_OK']='True' 156 | 157 | detector_efficiency = 1./30 158 | coincidence_descriptor = coincidences.GEDiscoveryMICoincidenceDescriptor( 159 | num_rings=1, 160 | sinogram_spatial_axis_order=coincidences. 161 | SinogramSpatialAxisOrder['RVP'], 162 | xp=xp) 163 | 164 | projector = petprojectors.PETJosephProjector(coincidence_descriptor, 165 | (128, 128, 1), (-127, -127, 0), 166 | (2., 2., 2.)) 167 | 168 | res_model = resolution_models.GaussianImageBasedResolutionModel( 169 | (128, 128, 1), tuple(4.5 / (2.35 * x) for x in (2., 2., 2.)), xp, ndi) 170 | 171 | projector.image_based_resolution_model = res_model 172 | 173 | dataset_list = [] 174 | dataset = BrainWebOSEM(part="test", noise_level=5) 175 | 176 | train_size = int(0.9 * len(dataset)) 177 | val_size = len(dataset) - train_size 178 | 179 | train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42)) 180 | batch_size = 10 181 | train_dl = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=0) 182 | 183 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 184 | 185 | 186 | sinogram = torch.nn.Parameter(torch.ones(batch_size, 1, 112880, 187 | device=device)) 188 | image = torch.nn.Parameter(torch.ones(batch_size, 1, 189 | *projector._image_shape[:-1], 190 | device=device)) 191 | optimizer = torch.optim.Adam([sinogram, image], lr=1e0) 192 | 193 | batch = next(iter(train_dl)) 194 | reference = batch[0] 195 | reference = reference.to(device) 196 | 197 | scale_factor = batch[1] 198 | scale_factor = scale_factor.to(device) 199 | 200 | osem = batch[2] 201 | osem = osem.to(device) 202 | 203 | norm = batch[3] 204 | norm = norm.to(device) 205 | 206 | measurements = batch[4] 207 | measurements = measurements.to(device) 208 | contamination_factor = batch[5] 209 | contamination_factor = contamination_factor.to(device) 210 | 211 | attn_factors = batch[6] 212 | attn_factors = attn_factors.to(device)*detector_efficiency 213 | 214 | for epoch in tqdm(range(100)): 215 | optimizer.zero_grad() 216 | 217 | y_bar = LPDForwardFunction2D.apply(image, projector, attn_factors) 218 | loss = ((y_bar - measurements)**2).sum() 219 | loss.backward() 220 | optimizer.step() 221 | if epoch==0: 222 | print(loss.item()) 223 | print(loss.item()) 224 | fig, ax = plt.subplots(1, 2) 225 | print(image.max(), image.min()) 226 | ax[0].imshow(reference[0, 0, ...].detach().cpu().numpy()) 227 | ax[1].imshow(y_bar[0, 0, ...].detach().cpu().numpy()) 228 | plt.show() 229 | 230 | 231 | 232 | 233 | 234 | -------------------------------------------------------------------------------- /coordinators/final_reconstruction.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import torch 3 | import functools 4 | import numpy as np 5 | import yaml 6 | import sys, os 7 | sys.path.append(os.path.dirname(os.getcwd())) 8 | from src import (BrainWebOSEM, get_standard_score, get_standard_sde, 9 | get_standard_sampler, osem_nll, get_osem, get_map, get_anchor, kl_div) 10 | from omegaconf import DictConfig, OmegaConf 11 | os.environ['KMP_DUPLICATE_LIB_OK']='True' 12 | import time 13 | import matplotlib.pyplot as plt 14 | 15 | import cupy as xp 16 | 17 | detector_efficiency = 1./30 18 | 19 | def get_acq_model(): 20 | import pyparallelproj.coincidences as coincidences 21 | import pyparallelproj.petprojectors as petprojectors 22 | import pyparallelproj.resolution_models as resolution_models 23 | import cupyx.scipy.ndimage as ndi 24 | """ 25 | create forward operator 26 | """ 27 | coincidence_descriptor = coincidences.GEDiscoveryMICoincidenceDescriptor( 28 | num_rings=1, 29 | sinogram_spatial_axis_order=coincidences.SinogramSpatialAxisOrder['RVP'],xp=xp) 30 | acq_model = petprojectors.PETJosephProjector(coincidence_descriptor, 31 | (128, 128, 1), (-127.0, -127.0, 0.0), (2., 2., 2.)) 32 | res_model = resolution_models.GaussianImageBasedResolutionModel( 33 | (128, 128, 1), tuple(4.5 / (2.35 * x) for x in (2., 2., 2.)), xp, ndi) 34 | acq_model.image_based_resolution_model = res_model 35 | return acq_model 36 | 37 | def estimate_scale_factor(osem, measurements, contamination, normalisation_type): 38 | scale_factors = [] 39 | for i in range(osem.shape[0]): 40 | if normalisation_type == "data_scale": 41 | emission_volume = torch.where(osem[i] > 0.01*osem[i].max(), 1, 0).sum() * 8 42 | scale_factor = (measurements[i] - contamination[i]).sum()/emission_volume 43 | scale_factors.append(scale_factor) 44 | elif normalisation_type == "image_scale": 45 | emission_volume = torch.where(osem[i] > 0.01*osem[i].max(), 1, 0).sum() 46 | scale_factor = osem[i].sum()/emission_volume 47 | scale_factors.append(scale_factor) 48 | else: 49 | raise NotImplementedError 50 | return torch.tensor(scale_factors) 51 | 52 | @hydra.main(config_path='../configs', config_name='final_reconstruction') 53 | def reconstruction(config : DictConfig) -> None: 54 | print(OmegaConf.to_yaml(config)) 55 | 56 | # Generate a unique filename, the file directories specify the SDE and data 57 | name = "" 58 | results = {} 59 | if config.sampling.use_osem_nll: 60 | name = "OSEMNLL_" 61 | if config.sampling.name == "dps" or config.sampling.name == "naive": 62 | name = name + "penalty_" + str(config.sampling.penalty) + "_" 63 | results["penalty"] = config.sampling.penalty 64 | if config.sampling.name == "dds": 65 | if config.sampling.dds_proj.name == "osem": 66 | name = name + "osem_num_epochs_" + str(config.sampling.dds_proj.num_epochs) + "_" 67 | results["num_epochs"] = config.sampling.dds_proj.num_epochs 68 | if config.sampling.dds_proj.name == "anchor": 69 | name = name + "anchor_num_epochs_" + str(config.sampling.dds_proj.num_epochs) + "_" 70 | results["num_epochs"] = config.sampling.dds_proj.num_epochs 71 | name = name + "_beta_" + str(config.sampling.dds_proj.beta) + "_" 72 | results["beta"] = config.sampling.dds_proj.beta 73 | if "guided" in config.score_based_model.name: 74 | name = name + "_gstrength_" + str(config.sampling.guidance_strength) + "_" 75 | results["gstrength"] = config.sampling.guidance_strength 76 | 77 | timestr = time.strftime("%Y%m%d_%H%M%S_") 78 | dump_name = config.dump_path + "/"+ timestr + name + ".tmp" 79 | with open(dump_name, "xt") as f: 80 | f.write(os.getcwd()) 81 | f.close() 82 | 83 | ###### SET SEED ###### 84 | if config.seed is not None: 85 | torch.manual_seed(config.seed) 86 | np.random.seed(config.seed) 87 | 88 | ###### GET SCORE MODEL ###### 89 | # open the yaml config file 90 | with open(os.path.join(config.score_based_model.path, "report.yaml"), "r") as stream: 91 | ml_collection = yaml.load(stream, Loader=yaml.UnsafeLoader) 92 | guided = False if ml_collection.guided_p_uncond is None else True 93 | # get the sde 94 | sde = get_standard_sde(ml_collection) 95 | # get the score model 96 | score_model = get_standard_score(ml_collection, sde, 97 | use_ema = config.score_based_model.ema, 98 | load_path = config.score_based_model.path) 99 | score_model.eval() 100 | score_model.to(config.device) 101 | 102 | ###### GET ACQUISITION MODEL AND DATA ###### 103 | # get the acquisition model 104 | acq_model = get_acq_model() 105 | # get the data 106 | dataset = BrainWebOSEM(part=config.dataset.part, 107 | noise_level=config.dataset.poisson_scale, 108 | base_path=config.dataset.base_path, 109 | guided=guided) 110 | subset = list(range(2, len(dataset), 4)) 111 | dataset = torch.utils.data.Subset(dataset, subset) 112 | test_loader = torch.utils.data.DataLoader(dataset, 113 | batch_size=1, shuffle=False) 114 | # as there are 10 realisations then batch = 10 115 | config.sampling.batch_size = 10 116 | 117 | ###### SOLVING REVERSE SDE ###### 118 | img_shape = (config.dataset.img_z_dim, 119 | config.dataset.img_xy_dim, config.dataset.img_xy_dim) 120 | 121 | save_recon = [] 122 | save_ref = [] 123 | save_kldivs = [] 124 | if "tumour" in config.dataset.part: 125 | save_lesion_rois = [] 126 | save_background_rois = [] 127 | if guided: 128 | save_guided = [] 129 | print("Normalisation type: ", ml_collection.normalisation) 130 | print("Length of test loader: ", len(test_loader)) 131 | for idx, batch in enumerate(test_loader): 132 | # [0] reference, [1] scale_factor, [2] osem, [3] norm, [4] measurements, 133 | # [5] contamination_factor, [6] attn_factors 134 | # FIRST STEP 135 | # swap axis so realisation are a batch 136 | osem = torch.swapaxes(batch[2], 0, 1).to(config.device) 137 | measurements = torch.swapaxes(batch[4], 0, 1).to(config.device) 138 | contamination_factor = torch.swapaxes(batch[5], 0, 1)[:,[0],None].to(config.device) 139 | attn_factors = torch.swapaxes(batch[6], 0, 1).to(config.device) 140 | 141 | gt = batch[0][:, [0], ...] 142 | if guided: 143 | guided_img = batch[0][:, [1], ...].repeat(config.sampling.batch_size, 1, 1, 1).to(config.device) 144 | 145 | # estimate scaling factors from measurements 146 | scale_factor = estimate_scale_factor(osem=osem, 147 | measurements=measurements, contamination=contamination_factor, 148 | normalisation_type=ml_collection.normalisation)[:, None, None, None].to(config.device) 149 | 150 | if config.sampling.use_osem_nll: 151 | nll_partial = functools.partial(osem_nll, 152 | scale_factor=scale_factor, 153 | osem=osem.to(config.device)) 154 | elif config.sampling.name == "dds": 155 | if config.sampling.dds_proj.name == "osem": 156 | nll_partial = functools.partial(get_osem, 157 | acq_model=acq_model, 158 | attn_factors=attn_factors, 159 | contamination=contamination_factor, 160 | measurements=measurements, 161 | scale_factor=scale_factor, 162 | num_subsets=config.sampling.dds_proj.num_subsets, 163 | num_epochs=config.sampling.dds_proj.num_epochs) 164 | elif config.sampling.dds_proj.name == "map": 165 | nll_partial = functools.partial(get_map, 166 | acq_model=acq_model, 167 | attn_factors=attn_factors, 168 | contamination=contamination_factor, 169 | measurements=measurements, 170 | scale_factor=scale_factor, 171 | num_subsets=config.sampling.dds_proj.num_subsets, 172 | num_epochs=config.sampling.dds_proj.num_epochs, 173 | beta = config.sampling.dds_proj.beta) 174 | elif config.sampling.dds_proj.name == "anchor": 175 | nll_partial = functools.partial(get_anchor, 176 | acq_model=acq_model, 177 | attn_factors=attn_factors, 178 | contamination=contamination_factor, 179 | measurements=measurements, 180 | scale_factor=scale_factor, 181 | num_subsets=config.sampling.dds_proj.num_subsets, 182 | num_epochs=config.sampling.dds_proj.num_epochs, 183 | beta = config.sampling.dds_proj.beta) 184 | else: 185 | raise NotImplementedError 186 | else: 187 | nll_partial = functools.partial(kl_div, 188 | acq_model=acq_model, 189 | attn_factors=attn_factors, 190 | contamination=contamination_factor, 191 | measurements=measurements, 192 | scale_factor=scale_factor) 193 | 194 | logg_kwargs = {'log_dir': "./tb", 195 | 'num_img_in_log': config.sampling.batch_size, 'sample_num':idx, 196 | 'ground_truth': gt, 'osem': None} 197 | 198 | sampler = get_standard_sampler( 199 | config=config, 200 | score=score_model, 201 | sde=sde, 202 | nll=nll_partial, 203 | im_shape=img_shape, 204 | guidance_imgs=guided_img if guided else None, 205 | device=config.device) 206 | 207 | recon, _ = sampler.sample(logg_kwargs=logg_kwargs, logging=False) 208 | recon = torch.clamp(recon, min=0) 209 | recon = recon*scale_factor 210 | 211 | kldiv_r = kl_div(x = recon, 212 | acq_model=acq_model, 213 | attn_factors=attn_factors, 214 | contamination=contamination_factor, 215 | measurements=measurements, 216 | scale_factor=1.)[0].squeeze() 217 | 218 | if "tumour" in config.dataset.part: 219 | lesion_roi = batch[-1].to(config.device) 220 | background_roi = batch[-2].to(config.device) 221 | save_lesion_rois.append(lesion_roi.squeeze().cpu()) 222 | save_background_rois.append(background_roi.squeeze().cpu()) 223 | 224 | save_recon.append(recon.squeeze().cpu()) 225 | save_ref.append(gt.squeeze().cpu()) 226 | save_kldivs.append(kldiv_r) 227 | if guided: 228 | save_guided.append(guided_img.squeeze()[0].cpu()) 229 | 230 | if "tumour" in config.dataset.part: 231 | results["images"] = torch.stack(save_recon).cpu() 232 | results["ref"] = torch.stack(save_ref).cpu() 233 | results["kldiv"] = torch.stack(save_kldivs).cpu() 234 | results["lesion_rois"] = torch.stack(save_lesion_rois).cpu() 235 | results["background_rois"] = torch.stack(save_background_rois).cpu() 236 | else: 237 | results["images"] = torch.stack(save_recon).cpu() 238 | results["ref"] = torch.stack(save_ref).cpu() 239 | results["kldiv"] = torch.stack(save_kldivs).cpu() 240 | 241 | torch.save(results, name+".pt") 242 | os.remove(dump_name) 243 | 244 | if __name__ == '__main__': 245 | reconstruction() -------------------------------------------------------------------------------- /scripts/req.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _libgcc_mutex=0.1=conda_forge 5 | _openmp_mutex=4.5=2_kmp_llvm 6 | absl-py=1.4.0=pypi_0 7 | alsa-lib=1.2.8=h166bdaf_0 8 | antlr4-python3-runtime=4.9.3=pypi_0 9 | anyio=3.7.0=pyhd8ed1ab_1 10 | aom=3.5.0=h27087fc_0 11 | appdirs=1.4.4=pyhd3eb1b0_0 12 | argon2-cffi=21.3.0=pyhd8ed1ab_0 13 | argon2-cffi-bindings=21.2.0=py310h5764c6d_3 14 | array-api-compat=1.3=pyhd8ed1ab_0 15 | asttokens=2.2.1=pyhd8ed1ab_0 16 | attr=2.5.1=h166bdaf_1 17 | attrs=23.1.0=pyh71513ae_1 18 | backcall=0.2.0=pyh9f0ad1d_0 19 | backports=1.0=pyhd8ed1ab_3 20 | backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0 21 | beautifulsoup4=4.12.2=pyha770c72_0 22 | blas=1.0=openblas 23 | bleach=6.0.0=pyhd8ed1ab_0 24 | blosc=1.21.4=h0f2a231_0 25 | boltons=23.0.0=py310h06a4308_0 26 | brotli=1.0.9=h166bdaf_8 27 | brotli-bin=1.0.9=h166bdaf_8 28 | brotlipy=0.7.0=py310h7f8727e_1002 29 | brunsli=0.1=h9c3ff4c_0 30 | bzip2=1.0.8=h7b6447c_0 31 | c-ares=1.19.1=hd590300_0 32 | c-blosc2=2.9.2=hb4ffafa_0 33 | ca-certificates=2023.5.7=hbcca054_0 34 | cachetools=5.3.1=pypi_0 35 | cairo=1.16.0=hbbf8b49_1016 36 | certifi=2023.5.7=pyhd8ed1ab_0 37 | cffi=1.15.1=py310h5eee18b_3 38 | cfitsio=4.2.0=hd9d235c_0 39 | charls=2.4.2=h59595ed_0 40 | charset-normalizer=2.0.4=pyhd3eb1b0_0 41 | cloudpickle=2.0.0=pyhd3eb1b0_0 42 | cmake=3.27.2=pypi_0 43 | colorama=0.4.6=pyhd8ed1ab_0 44 | comm=0.1.3=pyhd8ed1ab_0 45 | conda=23.3.1=py310hff52083_0 46 | conda-content-trust=0.1.3=py310h06a4308_0 47 | conda-package-handling=2.0.2=py310h06a4308_0 48 | conda-package-streaming=0.7.0=py310h06a4308_0 49 | contextlib2=21.6.0=pypi_0 50 | contourpy=1.0.7=py310hdf3cbec_0 51 | coverage=7.2.7=py310h2372a71_0 52 | cryptography=38.0.4=py310h600f1e7_0 53 | cupy-cuda11x=12.2.0=pypi_0 54 | cycler=0.11.0=pyhd8ed1ab_0 55 | cytoolz=0.12.0=py310h5eee18b_0 56 | dask-core=2022.7.0=py310h06a4308_0 57 | dav1d=1.2.1=hd590300_0 58 | dbus=1.13.6=h5008d03_3 59 | debugpy=1.6.7=py310heca2aa9_0 60 | decorator=5.1.1=pyhd8ed1ab_0 61 | defusedxml=0.7.1=pyhd8ed1ab_0 62 | deprecation=2.1.0=pyh9f0ad1d_0 63 | docopt=0.6.2=py_1 64 | entrypoints=0.4=pyhd8ed1ab_0 65 | exceptiongroup=1.1.1=pyhd8ed1ab_0 66 | executing=1.2.0=pyhd8ed1ab_0 67 | expat=2.5.0=hcb278e6_1 68 | fastrlock=0.8.1=pypi_0 69 | fftw=3.3.9=h27cfd23_1 70 | filelock=3.12.2=pypi_0 71 | flit-core=3.9.0=pyhd8ed1ab_0 72 | fmt=9.1.0=h924138e_0 73 | font-ttf-dejavu-sans-mono=2.37=hab24e00_0 74 | font-ttf-inconsolata=3.000=h77eed37_0 75 | font-ttf-source-code-pro=2.038=h77eed37_0 76 | font-ttf-ubuntu=0.83=hab24e00_0 77 | fontconfig=2.14.2=h14ed4e7_0 78 | fonts-conda-ecosystem=1=0 79 | fonts-conda-forge=1=0 80 | fonttools=4.39.4=py310h2372a71_0 81 | freetype=2.12.1=hca18f0e_1 82 | fsspec=2022.11.0=py310h06a4308_0 83 | gettext=0.21.1=h27087fc_0 84 | giflib=5.2.1=h0b41bf4_3 85 | glib=2.76.3=hfc55251_0 86 | glib-tools=2.76.3=hfc55251_0 87 | google-auth=2.22.0=pypi_0 88 | google-auth-oauthlib=1.0.0=pypi_0 89 | graphite2=1.3.13=h58526e2_1001 90 | grpcio=1.57.0=pypi_0 91 | gst-plugins-base=1.22.3=h938bd60_1 92 | gstreamer=1.22.3=h977cf35_1 93 | h5py=3.7.0=py310he06866b_0 94 | harfbuzz=7.3.0=hdb3a94d_0 95 | hdf5=1.10.6=h3ffc7dd_1 96 | hydra-core=1.3.2=pypi_0 97 | icu=72.1=hcb278e6_0 98 | idna=3.4=py310h06a4308_0 99 | imagecodecs=2023.1.23=py310h241fb82_2 100 | imageio=2.19.3=py310h06a4308_0 101 | importlib-metadata=6.6.0=pyha770c72_0 102 | importlib_metadata=6.6.0=hd8ed1ab_0 103 | importlib_resources=5.12.0=pyhd8ed1ab_0 104 | iniconfig=2.0.0=pyhd8ed1ab_0 105 | ipykernel=6.23.1=pyh210e3f2_0 106 | ipython=8.14.0=pyh41d4057_0 107 | ipython_genutils=0.2.0=py_1 108 | jedi=0.18.2=pyhd8ed1ab_0 109 | jinja2=3.1.2=pyhd8ed1ab_1 110 | jsonpatch=1.32=pyhd3eb1b0_0 111 | jsonpointer=2.1=pyhd3eb1b0_0 112 | jsonschema=4.17.3=pyhd8ed1ab_0 113 | jupyter_client=8.2.0=pyhd8ed1ab_0 114 | jupyter_core=5.3.0=py310hff52083_0 115 | jupyter_events=0.6.3=pyhd8ed1ab_0 116 | jupyter_server=2.6.0=pyhd8ed1ab_0 117 | jupyter_server_terminals=0.4.4=pyhd8ed1ab_1 118 | jupyterlab_pygments=0.2.2=pyhd8ed1ab_0 119 | jxrlib=1.1=h7f98852_2 120 | keyutils=1.6.1=h166bdaf_0 121 | kiwisolver=1.4.4=py310hbf28c38_1 122 | krb5=1.20.1=h81ceb04_0 123 | lame=3.100=h166bdaf_1003 124 | lcms2=2.15=haa2dc70_1 125 | ld_impl_linux-64=2.38=h1181459_1 126 | lerc=4.0.0=h27087fc_0 127 | libaec=1.0.6=hcb278e6_1 128 | libarchive=3.6.2=h039dbb9_1 129 | libavif=0.11.1=h8182462_2 130 | libblas=3.9.0=17_linux64_openblas 131 | libbrotlicommon=1.0.9=h166bdaf_8 132 | libbrotlidec=1.0.9=h166bdaf_8 133 | libbrotlienc=1.0.9=h166bdaf_8 134 | libcap=2.67=he9d0100_0 135 | libcblas=3.9.0=17_linux64_openblas 136 | libclang=15.0.7=default_h7634d5b_2 137 | libclang13=15.0.7=default_h9986a30_2 138 | libcups=2.3.3=h36d4200_3 139 | libcurl=8.1.2=h409715c_0 140 | libdeflate=1.18=h0b41bf4_0 141 | libedit=3.1.20191231=he28a2e2_2 142 | libev=4.33=h516909a_1 143 | libevent=2.1.12=hf998b51_1 144 | libexpat=2.5.0=hcb278e6_1 145 | libffi=3.4.2=h6a678d5_6 146 | libflac=1.4.2=h27087fc_0 147 | libgcc-ng=13.1.0=he5830b7_0 148 | libgcrypt=1.10.1=h166bdaf_0 149 | libgfortran-ng=13.1.0=h69a702a_0 150 | libgfortran5=13.1.0=h15d22d2_0 151 | libglib=2.76.3=hebfc3b9_0 152 | libgpg-error=1.46=h620e276_0 153 | libiconv=1.17=h166bdaf_0 154 | libjpeg-turbo=2.1.5.1=h0b41bf4_0 155 | liblapack=3.9.0=17_linux64_openblas 156 | libllvm15=15.0.7=h5cf9203_2 157 | libmamba=1.4.2=hcea66bb_0 158 | libmambapy=1.4.2=py310h1428755_0 159 | libnghttp2=1.52.0=h61bc06f_0 160 | libnsl=2.0.0=h7f98852_0 161 | libogg=1.3.4=h7f98852_1 162 | libopenblas=0.3.23=pthreads_h80387f5_0 163 | libopus=1.3.1=h7f98852_1 164 | libparallelproj=1.5.0=cuda112_hd8b62ff_200 165 | libpng=1.6.39=h753d276_0 166 | libpq=15.3=hbcd7760_1 167 | libsndfile=1.2.0=hb75c966_0 168 | libsodium=1.0.18=h36c2ea0_1 169 | libsolv=0.7.24=h3eb15da_0 170 | libsqlite=3.42.0=h2797004_0 171 | libssh2=1.11.0=h0841786_0 172 | libstdcxx-ng=13.1.0=hfd8a6a1_0 173 | libsystemd0=253=h8c4010b_1 174 | libtiff=4.5.0=ha587672_6 175 | libuuid=2.38.1=h0b41bf4_0 176 | libvorbis=1.3.7=h9c3ff4c_0 177 | libwebp-base=1.3.0=h0b41bf4_0 178 | libxcb=1.15=h0b41bf4_0 179 | libxkbcommon=1.5.0=h5d7e998_3 180 | libxml2=2.11.4=h0d562d8_0 181 | libzlib=1.2.13=h166bdaf_4 182 | libzopfli=1.0.3=h9c3ff4c_0 183 | lit=16.0.6=pypi_0 184 | llvm-openmp=16.0.5=h4dfa4b3_0 185 | locket=1.0.0=py310h06a4308_0 186 | lz4-c=1.9.4=hcb278e6_0 187 | lzo=2.10=h516909a_1000 188 | mamba=1.4.2=py310h51d5547_0 189 | markdown=3.4.4=pypi_0 190 | markupsafe=2.1.3=py310h2372a71_0 191 | matplotlib=3.7.1=py310hff52083_0 192 | matplotlib-base=3.7.1=py310he60537e_0 193 | matplotlib-inline=0.1.6=pyhd8ed1ab_0 194 | mistune=2.0.5=pyhd8ed1ab_0 195 | ml-collections=0.1.1=pypi_0 196 | mpg123=1.31.3=hcb278e6_0 197 | mpmath=1.3.0=pypi_0 198 | munkres=1.1.4=pyh9f0ad1d_0 199 | mysql-common=8.0.32=hf1915f5_2 200 | mysql-libs=8.0.32=hca2cd23_2 201 | nbclassic=1.0.0=pyhb4ecaf3_1 202 | nbclient=0.8.0=pyhd8ed1ab_0 203 | nbconvert=7.4.0=pyhd8ed1ab_0 204 | nbconvert-core=7.4.0=pyhd8ed1ab_0 205 | nbconvert-pandoc=7.4.0=pyhd8ed1ab_0 206 | nbformat=5.9.0=pyhd8ed1ab_0 207 | ncurses=6.4=h6a678d5_0 208 | nest-asyncio=1.5.6=pyhd8ed1ab_0 209 | networkx=2.8.4=py310h06a4308_0 210 | nibabel=5.1.0=py310hff52083_2 211 | nose=1.3.7=py_1006 212 | notebook=6.5.4=pyha770c72_0 213 | notebook-shim=0.2.3=pyhd8ed1ab_0 214 | nspr=4.35=h27087fc_0 215 | nss=3.89=he45b914_0 216 | numpy=1.24.3=py310ha4c1d20_0 217 | nvidia-cublas-cu11=11.10.3.66=pypi_0 218 | nvidia-cuda-cupti-cu11=11.7.101=pypi_0 219 | nvidia-cuda-nvrtc-cu11=11.7.99=pypi_0 220 | nvidia-cuda-runtime-cu11=11.7.99=pypi_0 221 | nvidia-cudnn-cu11=8.5.0.96=pypi_0 222 | nvidia-cufft-cu11=10.9.0.58=pypi_0 223 | nvidia-curand-cu11=10.2.10.91=pypi_0 224 | nvidia-cusolver-cu11=11.4.0.1=pypi_0 225 | nvidia-cusparse-cu11=11.7.4.91=pypi_0 226 | nvidia-nccl-cu11=2.14.3=pypi_0 227 | nvidia-nvtx-cu11=11.7.91=pypi_0 228 | oauthlib=3.2.2=pypi_0 229 | omegaconf=2.3.0=pypi_0 230 | openjpeg=2.5.0=hfec8fc6_2 231 | openssl=3.1.1=hd590300_1 232 | overrides=7.3.1=pyhd8ed1ab_0 233 | packaging=23.0=py310h06a4308_0 234 | pandas=2.0.2=py310h7cbd5c2_0 235 | pandoc=2.19.2=h32600fe_2 236 | pandocfilters=1.5.0=pyhd8ed1ab_0 237 | parallelproj=1.5.0=pyha770c72_200 238 | parso=0.8.3=pyhd8ed1ab_0 239 | partd=1.2.0=pyhd3eb1b0_1 240 | pcre2=10.40=hc3806b6_0 241 | pexpect=4.8.0=pyh1a96a4e_2 242 | pickleshare=0.7.5=py_1003 243 | pillow=9.5.0=py310h582fbeb_1 244 | pip=23.0.1=py310h06a4308_0 245 | pixman=0.40.0=h36c2ea0_0 246 | pkgutil-resolve-name=1.3.10=pyhd8ed1ab_0 247 | platformdirs=3.5.3=pyhd8ed1ab_0 248 | pluggy=1.0.0=py310h06a4308_1 249 | ply=3.11=py_1 250 | pooch=1.4.0=pyhd3eb1b0_0 251 | prometheus_client=0.17.0=pyhd8ed1ab_0 252 | prompt-toolkit=3.0.38=pyha770c72_0 253 | prompt_toolkit=3.0.38=hd8ed1ab_0 254 | protobuf=4.24.0=pypi_0 255 | psutil=5.9.5=py310h1fa729e_0 256 | pthread-stubs=0.4=h36c2ea0_1001 257 | ptyprocess=0.7.0=pyhd3deb0d_0 258 | pulseaudio-client=16.1=hb77b528_4 259 | pure_eval=0.2.2=pyhd8ed1ab_0 260 | pyasn1=0.5.0=pypi_0 261 | pyasn1-modules=0.3.0=pypi_0 262 | pybind11-abi=4=hd8ed1ab_3 263 | pycosat=0.6.4=py310h5eee18b_0 264 | pycparser=2.21=pyhd3eb1b0_0 265 | pydantic=1.10.2=py310h5eee18b_0 266 | pygments=2.15.1=pyhd8ed1ab_0 267 | pyopenssl=23.0.0=py310h06a4308_0 268 | pyparsing=3.0.9=pyhd8ed1ab_0 269 | pyqt=5.15.7=py310hab646b1_3 270 | pyqt5-sip=12.11.0=py310heca2aa9_3 271 | pyrsistent=0.19.3=py310h1fa729e_0 272 | pysocks=1.7.1=py310h06a4308_0 273 | pytest=7.3.1=pyhd8ed1ab_0 274 | pytest-cov=4.1.0=pyhd8ed1ab_0 275 | python=3.10.11=he550d4f_0_cpython 276 | python-dateutil=2.8.2=pyhd8ed1ab_0 277 | python-fastjsonschema=2.17.1=pyhd8ed1ab_0 278 | python-json-logger=2.0.7=pyhd8ed1ab_0 279 | python-tzdata=2023.3=pyhd8ed1ab_0 280 | python_abi=3.10=3_cp310 281 | pytz=2023.3=pyhd8ed1ab_0 282 | pywavelets=1.4.1=py310h5eee18b_0 283 | pyyaml=6.0=py310h5764c6d_5 284 | pyzmq=25.1.0=py310h5bbb5d0_0 285 | qt-main=5.15.8=h01ceb2d_13 286 | readline=8.2=h5eee18b_0 287 | reproc=14.2.4=h0b41bf4_0 288 | reproc-cpp=14.2.4=hcb278e6_0 289 | requests=2.28.1=py310h06a4308_1 290 | requests-oauthlib=1.3.1=pypi_0 291 | rfc3339-validator=0.1.4=pyhd8ed1ab_0 292 | rfc3986-validator=0.1.1=pyh9f0ad1d_0 293 | rsa=4.9=pypi_0 294 | ruamel.yaml=0.17.21=py310h5eee18b_0 295 | ruamel.yaml.clib=0.2.6=py310h5eee18b_1 296 | scikit-image=0.19.3=py310h6a678d5_1 297 | scipy=1.10.0=py310heeff2f4_0 298 | seaborn=0.12.2=py310h06a4308_0 299 | send2trash=1.8.2=pyh41d4057_0 300 | setuptools=67.7.2=pyhd8ed1ab_0 301 | sip=6.7.9=py310hc6cd4ac_0 302 | six=1.16.0=pyhd3eb1b0_1 303 | snappy=1.1.10=h9fff704_0 304 | sniffio=1.3.0=pyhd8ed1ab_0 305 | soupsieve=2.3.2.post1=pyhd8ed1ab_0 306 | sqlite=3.41.1=h5eee18b_0 307 | stack_data=0.6.2=pyhd8ed1ab_0 308 | sympy=1.12=pypi_0 309 | tensorboard=2.14.0=pypi_0 310 | tensorboard-data-server=0.7.1=pypi_0 311 | tensorboardx=2.6.2=pypi_0 312 | terminado=0.17.1=pyh41d4057_0 313 | tifffile=2023.4.12=pyhd8ed1ab_0 314 | tinycss2=1.2.1=pyhd8ed1ab_0 315 | tk=8.6.12=h1ccaba5_0 316 | toml=0.10.2=pyhd8ed1ab_0 317 | tomli=2.0.1=pyhd8ed1ab_0 318 | toolz=0.12.0=py310h06a4308_0 319 | torch=2.0.1=pypi_0 320 | torchaudio=2.0.2=pypi_0 321 | torchvision=0.15.2=pypi_0 322 | tornado=6.3.2=py310h2372a71_0 323 | tqdm=4.65.0=pyhd8ed1ab_1 324 | traitlets=5.9.0=pyhd8ed1ab_0 325 | triton=2.0.0=pypi_0 326 | typing-extensions=4.6.3=hd8ed1ab_0 327 | typing_extensions=4.6.3=pyha770c72_0 328 | typing_utils=0.1.0=pyhd8ed1ab_0 329 | tzdata=2023c=h04d1e81_0 330 | unicodedata2=15.0.0=py310h5764c6d_0 331 | urllib3=1.26.15=py310h06a4308_0 332 | wcwidth=0.2.6=pyhd8ed1ab_0 333 | webencodings=0.5.1=py_1 334 | websocket-client=1.5.3=pyhd8ed1ab_0 335 | werkzeug=2.3.7=pypi_0 336 | wheel=0.40.0=pyhd8ed1ab_0 337 | xcb-util=0.4.0=hd590300_1 338 | xcb-util-image=0.4.0=h8ee46fc_1 339 | xcb-util-keysyms=0.4.0=h8ee46fc_1 340 | xcb-util-renderutil=0.3.9=hd590300_1 341 | xcb-util-wm=0.4.1=h8ee46fc_1 342 | xkeyboard-config=2.39=hd590300_0 343 | xorg-kbproto=1.0.7=h7f98852_1002 344 | xorg-libice=1.1.1=hd590300_0 345 | xorg-libsm=1.2.4=h7391055_0 346 | xorg-libx11=1.8.5=h8ee46fc_0 347 | xorg-libxau=1.0.11=hd590300_0 348 | xorg-libxdmcp=1.1.3=h7f98852_0 349 | xorg-libxext=1.3.4=h0b41bf4_2 350 | xorg-libxrender=0.9.10=h7f98852_1003 351 | xorg-renderproto=0.11.1=h7f98852_1002 352 | xorg-xextproto=7.3.0=h0b41bf4_1003 353 | xorg-xf86vidmodeproto=2.3.1=h7f98852_1002 354 | xorg-xproto=7.0.31=h7f98852_1007 355 | xz=5.2.10=h5eee18b_1 356 | yaml=0.2.5=h7f98852_2 357 | yaml-cpp=0.7.0=h27087fc_2 358 | zeromq=4.3.4=h9c3ff4c_1 359 | zfp=1.0.0=h27087fc_3 360 | zipp=3.15.0=pyhd8ed1ab_0 361 | zlib=1.2.13=h166bdaf_4 362 | zlib-ng=2.0.7=h0b41bf4_0 363 | zstandard=0.19.0=py310h5eee18b_0 364 | zstd=1.5.2=h3eb15da_6 365 | --------------------------------------------------------------------------------