├── 1.png ├── 2.png ├── CrackSegDiff ├── guided_diffusion │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── custom_dataset_loader.cpython-310.pyc │ │ ├── custom_dataset_loader.cpython-38.pyc │ │ ├── dist_util.cpython-310.pyc │ │ ├── dist_util.cpython-38.pyc │ │ ├── dpm_solver.cpython-310.pyc │ │ ├── dpm_solver.cpython-38.pyc │ │ ├── fp16_util.cpython-310.pyc │ │ ├── fp16_util.cpython-38.pyc │ │ ├── gaussian_diffusion.cpython-310.pyc │ │ ├── gaussian_diffusion.cpython-38.pyc │ │ ├── logger.cpython-310.pyc │ │ ├── logger.cpython-38.pyc │ │ ├── losses.cpython-310.pyc │ │ ├── losses.cpython-38.pyc │ │ ├── nn.cpython-310.pyc │ │ ├── nn.cpython-38.pyc │ │ ├── resample.cpython-310.pyc │ │ ├── resample.cpython-38.pyc │ │ ├── respace.cpython-310.pyc │ │ ├── respace.cpython-38.pyc │ │ ├── script_util.cpython-310.pyc │ │ ├── script_util.cpython-38.pyc │ │ ├── train_util.cpython-310.pyc │ │ ├── train_util.cpython-38.pyc │ │ ├── unet.cpython-310.pyc │ │ ├── unet.cpython-38.pyc │ │ ├── utils.cpython-310.pyc │ │ ├── utils.cpython-38.pyc │ │ ├── vmamba.cpython-310.pyc │ │ └── vmunet.cpython-310.pyc │ ├── custom_dataset_loader.py │ ├── dist_util.py │ ├── dpm_solver.py │ ├── fp16_util.py │ ├── gaussian_diffusion.py │ ├── logger.py │ ├── losses.py │ ├── nn.py │ ├── resample.py │ ├── respace.py │ ├── script_util.py │ ├── train_util.py │ ├── unet.py │ ├── utils.py │ ├── vmamba.py │ └── vmunet.py ├── segmentation_env.py ├── segmentation_sample.py └── segmentation_train.py ├── README.md └── requirement.txt /1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky-visionX/CrackSegDiff/ff54633e929d22d977d3e9dbbb36039a0f3aabad/1.png -------------------------------------------------------------------------------- /2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky-visionX/CrackSegDiff/ff54633e929d22d977d3e9dbbb36039a0f3aabad/2.png -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Codebase for " Diffusion Models for Implicit Image Segmentation Ensembles". 3 | """ 4 | -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky-visionX/CrackSegDiff/ff54633e929d22d977d3e9dbbb36039a0f3aabad/CrackSegDiff/guided_diffusion/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky-visionX/CrackSegDiff/ff54633e929d22d977d3e9dbbb36039a0f3aabad/CrackSegDiff/guided_diffusion/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/__pycache__/custom_dataset_loader.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky-visionX/CrackSegDiff/ff54633e929d22d977d3e9dbbb36039a0f3aabad/CrackSegDiff/guided_diffusion/__pycache__/custom_dataset_loader.cpython-310.pyc -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/__pycache__/custom_dataset_loader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky-visionX/CrackSegDiff/ff54633e929d22d977d3e9dbbb36039a0f3aabad/CrackSegDiff/guided_diffusion/__pycache__/custom_dataset_loader.cpython-38.pyc -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/__pycache__/dist_util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky-visionX/CrackSegDiff/ff54633e929d22d977d3e9dbbb36039a0f3aabad/CrackSegDiff/guided_diffusion/__pycache__/dist_util.cpython-310.pyc -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/__pycache__/dist_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky-visionX/CrackSegDiff/ff54633e929d22d977d3e9dbbb36039a0f3aabad/CrackSegDiff/guided_diffusion/__pycache__/dist_util.cpython-38.pyc -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/__pycache__/dpm_solver.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky-visionX/CrackSegDiff/ff54633e929d22d977d3e9dbbb36039a0f3aabad/CrackSegDiff/guided_diffusion/__pycache__/dpm_solver.cpython-310.pyc -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/__pycache__/dpm_solver.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky-visionX/CrackSegDiff/ff54633e929d22d977d3e9dbbb36039a0f3aabad/CrackSegDiff/guided_diffusion/__pycache__/dpm_solver.cpython-38.pyc -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/__pycache__/fp16_util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky-visionX/CrackSegDiff/ff54633e929d22d977d3e9dbbb36039a0f3aabad/CrackSegDiff/guided_diffusion/__pycache__/fp16_util.cpython-310.pyc -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/__pycache__/fp16_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky-visionX/CrackSegDiff/ff54633e929d22d977d3e9dbbb36039a0f3aabad/CrackSegDiff/guided_diffusion/__pycache__/fp16_util.cpython-38.pyc -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/__pycache__/gaussian_diffusion.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky-visionX/CrackSegDiff/ff54633e929d22d977d3e9dbbb36039a0f3aabad/CrackSegDiff/guided_diffusion/__pycache__/gaussian_diffusion.cpython-310.pyc -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/__pycache__/gaussian_diffusion.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky-visionX/CrackSegDiff/ff54633e929d22d977d3e9dbbb36039a0f3aabad/CrackSegDiff/guided_diffusion/__pycache__/gaussian_diffusion.cpython-38.pyc -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/__pycache__/logger.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky-visionX/CrackSegDiff/ff54633e929d22d977d3e9dbbb36039a0f3aabad/CrackSegDiff/guided_diffusion/__pycache__/logger.cpython-310.pyc -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/__pycache__/logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky-visionX/CrackSegDiff/ff54633e929d22d977d3e9dbbb36039a0f3aabad/CrackSegDiff/guided_diffusion/__pycache__/logger.cpython-38.pyc -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/__pycache__/losses.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky-visionX/CrackSegDiff/ff54633e929d22d977d3e9dbbb36039a0f3aabad/CrackSegDiff/guided_diffusion/__pycache__/losses.cpython-310.pyc -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/__pycache__/losses.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky-visionX/CrackSegDiff/ff54633e929d22d977d3e9dbbb36039a0f3aabad/CrackSegDiff/guided_diffusion/__pycache__/losses.cpython-38.pyc -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/__pycache__/nn.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky-visionX/CrackSegDiff/ff54633e929d22d977d3e9dbbb36039a0f3aabad/CrackSegDiff/guided_diffusion/__pycache__/nn.cpython-310.pyc -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/__pycache__/nn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky-visionX/CrackSegDiff/ff54633e929d22d977d3e9dbbb36039a0f3aabad/CrackSegDiff/guided_diffusion/__pycache__/nn.cpython-38.pyc -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/__pycache__/resample.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky-visionX/CrackSegDiff/ff54633e929d22d977d3e9dbbb36039a0f3aabad/CrackSegDiff/guided_diffusion/__pycache__/resample.cpython-310.pyc -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/__pycache__/resample.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky-visionX/CrackSegDiff/ff54633e929d22d977d3e9dbbb36039a0f3aabad/CrackSegDiff/guided_diffusion/__pycache__/resample.cpython-38.pyc -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/__pycache__/respace.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky-visionX/CrackSegDiff/ff54633e929d22d977d3e9dbbb36039a0f3aabad/CrackSegDiff/guided_diffusion/__pycache__/respace.cpython-310.pyc -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/__pycache__/respace.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky-visionX/CrackSegDiff/ff54633e929d22d977d3e9dbbb36039a0f3aabad/CrackSegDiff/guided_diffusion/__pycache__/respace.cpython-38.pyc -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/__pycache__/script_util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky-visionX/CrackSegDiff/ff54633e929d22d977d3e9dbbb36039a0f3aabad/CrackSegDiff/guided_diffusion/__pycache__/script_util.cpython-310.pyc -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/__pycache__/script_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky-visionX/CrackSegDiff/ff54633e929d22d977d3e9dbbb36039a0f3aabad/CrackSegDiff/guided_diffusion/__pycache__/script_util.cpython-38.pyc -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/__pycache__/train_util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky-visionX/CrackSegDiff/ff54633e929d22d977d3e9dbbb36039a0f3aabad/CrackSegDiff/guided_diffusion/__pycache__/train_util.cpython-310.pyc -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/__pycache__/train_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky-visionX/CrackSegDiff/ff54633e929d22d977d3e9dbbb36039a0f3aabad/CrackSegDiff/guided_diffusion/__pycache__/train_util.cpython-38.pyc -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/__pycache__/unet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky-visionX/CrackSegDiff/ff54633e929d22d977d3e9dbbb36039a0f3aabad/CrackSegDiff/guided_diffusion/__pycache__/unet.cpython-310.pyc -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/__pycache__/unet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky-visionX/CrackSegDiff/ff54633e929d22d977d3e9dbbb36039a0f3aabad/CrackSegDiff/guided_diffusion/__pycache__/unet.cpython-38.pyc -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky-visionX/CrackSegDiff/ff54633e929d22d977d3e9dbbb36039a0f3aabad/CrackSegDiff/guided_diffusion/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky-visionX/CrackSegDiff/ff54633e929d22d977d3e9dbbb36039a0f3aabad/CrackSegDiff/guided_diffusion/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/__pycache__/vmamba.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky-visionX/CrackSegDiff/ff54633e929d22d977d3e9dbbb36039a0f3aabad/CrackSegDiff/guided_diffusion/__pycache__/vmamba.cpython-310.pyc -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/__pycache__/vmunet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky-visionX/CrackSegDiff/ff54633e929d22d977d3e9dbbb36039a0f3aabad/CrackSegDiff/guided_diffusion/__pycache__/vmunet.cpython-310.pyc -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/custom_dataset_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.data import Dataset 3 | from PIL import Image 4 | from glob import glob 5 | import tifffile as tiff 6 | import cv2 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torchvision.transforms as transforms 11 | from torchvision.models import squeezenet1_1 12 | class CustomDataset(Dataset): 13 | def __init__(self, args, data_path, transform=None, mode='Training'): 14 | 15 | print("loading data from the directory :",data_path) 16 | path = data_path 17 | images = sorted(glob(os.path.join(path, "5d/*.png"))) 18 | masks = sorted(glob(os.path.join(path, "mask/*.bmp"))) 19 | 20 | self.name_list = images 21 | self.label_list = masks 22 | self.data_path = path 23 | self.mode = mode 24 | 25 | self.transform = transform 26 | 27 | def __len__(self): 28 | return len(self.name_list) 29 | 30 | def __getitem__(self, index): 31 | """Get the images""" 32 | name = self.name_list[index] 33 | img_path = os.path.join(name) 34 | 35 | mask_name = self.label_list[index] 36 | msk_path = os.path.join(mask_name) 37 | img = tiff.imread(img_path) 38 | # img = Image.open(img_path) 39 | # img = Image.open(img_path).convert('RGB') 40 | mask = Image.open(msk_path).convert('L') 41 | if self.transform: 42 | state = torch.get_rng_state() 43 | img = self.transform(img) 44 | torch.set_rng_state(state) 45 | mask = self.transform(mask) 46 | 47 | if self.mode == 'Training': 48 | return (img, mask, name) 49 | else: 50 | return (img, mask, name) -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/dist_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for distributed training. 3 | """ 4 | 5 | import io 6 | import os 7 | import socket 8 | 9 | import blobfile as bf 10 | #from mpi4py import MPI 11 | import torch as th 12 | import torch.distributed as dist 13 | 14 | # Change this to reflect your cluster layout. 15 | # The GPU for a given rank is (rank % GPUS_PER_NODE). 16 | GPUS_PER_NODE = 8 17 | 18 | SETUP_RETRY_COUNT = 3 19 | 20 | 21 | def setup_dist(args): 22 | """ 23 | Setup a distributed process group. 24 | """ 25 | if dist.is_initialized(): 26 | return 27 | if not args.multi_gpu: 28 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_dev 29 | 30 | backend = "gloo" if not th.cuda.is_available() else "nccl" 31 | 32 | if backend == "gloo": 33 | hostname = "localhost" 34 | else: 35 | hostname = socket.gethostbyname(socket.getfqdn()) 36 | os.environ["MASTER_ADDR"] = '127.0.1.1'#comm.bcast(hostname, root=0) 37 | os.environ["RANK"] = '0'#str(comm.rank) 38 | os.environ["WORLD_SIZE"] = '1'#str(comm.size) 39 | 40 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 41 | s.bind(("", 0)) 42 | s.listen(1) 43 | port = s.getsockname()[1] 44 | s.close() 45 | os.environ["MASTER_PORT"] = str(port) 46 | dist.init_process_group(backend=backend, init_method="env://") 47 | 48 | 49 | def dev(): 50 | """ 51 | Get the device to use for torch.distributed. 52 | """ 53 | if th.cuda.is_available(): 54 | return th.device(f"cuda") 55 | return th.device("cpu") 56 | 57 | 58 | def load_state_dict(path, **kwargs): 59 | """ 60 | Load a PyTorch file without redundant fetches across MPI ranks. 61 | """ 62 | mpigetrank=0 63 | if mpigetrank==0: 64 | with bf.BlobFile(path, "rb") as f: 65 | data = f.read() 66 | else: 67 | data = None 68 | 69 | return th.load(io.BytesIO(data), **kwargs) 70 | 71 | 72 | def sync_params(params): 73 | """ 74 | Synchronize a sequence of Tensors across ranks from rank 0. 75 | """ 76 | for p in params: 77 | with th.no_grad(): 78 | dist.broadcast(p, 0) 79 | 80 | 81 | def _find_free_port(): 82 | try: 83 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 84 | s.bind(("", 0)) 85 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 86 | return s.getsockname()[1] 87 | finally: 88 | s.close() 89 | -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 9 | 10 | from . import logger 11 | 12 | INITIAL_LOG_LOSS_SCALE = 20.0 13 | 14 | 15 | def convert_module_to_f16(l): 16 | """ 17 | Convert primitive modules to float16. 18 | """ 19 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 20 | l.weight.data = l.weight.data.half() 21 | if l.bias is not None: 22 | l.bias.data = l.bias.data.half() 23 | 24 | 25 | def convert_module_to_f32(l): 26 | """ 27 | Convert primitive modules to float32, undoing convert_module_to_f16(). 28 | """ 29 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 30 | l.weight.data = l.weight.data.float() 31 | if l.bias is not None: 32 | l.bias.data = l.bias.data.float() 33 | 34 | 35 | def make_master_params(param_groups_and_shapes): 36 | """ 37 | Copy model parameters into a (differently-shaped) list of full-precision 38 | parameters. 39 | """ 40 | master_params = [] 41 | for param_group, shape in param_groups_and_shapes: 42 | master_param = nn.Parameter( 43 | _flatten_dense_tensors( 44 | [param.detach().float() for (_, param) in param_group] 45 | ).view(shape) 46 | ) 47 | master_param.requires_grad = True 48 | master_params.append(master_param) 49 | return master_params 50 | 51 | 52 | def model_grads_to_master_grads(param_groups_and_shapes, master_params): 53 | """ 54 | Copy the gradients from the model parameters into the master parameters 55 | from make_master_params(). 56 | """ 57 | for master_param, (param_group, shape) in zip( 58 | master_params, param_groups_and_shapes 59 | ): 60 | master_param.grad = _flatten_dense_tensors( 61 | [param_grad_or_zeros(param) for (_, param) in param_group] 62 | ).view(shape) 63 | 64 | 65 | def master_params_to_model_params(param_groups_and_shapes, master_params): 66 | """ 67 | Copy the master parameter data back into the model parameters. 68 | """ 69 | # Without copying to a list, if a generator is passed, this will 70 | # silently not copy any parameters. 71 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): 72 | for (_, param), unflat_master_param in zip( 73 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 74 | ): 75 | param.detach().copy_(unflat_master_param) 76 | 77 | 78 | def unflatten_master_params(param_group, master_param): 79 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group]) 80 | 81 | 82 | def get_param_groups_and_shapes(named_model_params): 83 | named_model_params = list(named_model_params) 84 | scalar_vector_named_params = ( 85 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1], 86 | (-1), 87 | ) 88 | matrix_named_params = ( 89 | [(n, p) for (n, p) in named_model_params if p.ndim > 1], 90 | (1, -1), 91 | ) 92 | return [scalar_vector_named_params, matrix_named_params] 93 | 94 | 95 | def master_params_to_state_dict( 96 | model, param_groups_and_shapes, master_params, use_fp16 97 | ): 98 | if use_fp16: 99 | state_dict = model.state_dict() 100 | for master_param, (param_group, _) in zip( 101 | master_params, param_groups_and_shapes 102 | ): 103 | for (name, _), unflat_master_param in zip( 104 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 105 | ): 106 | assert name in state_dict 107 | state_dict[name] = unflat_master_param 108 | else: 109 | state_dict = model.state_dict() 110 | for i, (name, _value) in enumerate(model.named_parameters()): 111 | assert name in state_dict 112 | state_dict[name] = master_params[i] 113 | return state_dict 114 | 115 | 116 | def state_dict_to_master_params(model, state_dict, use_fp16): 117 | if use_fp16: 118 | named_model_params = [ 119 | (name, state_dict[name]) for name, _ in model.named_parameters() 120 | ] 121 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params) 122 | master_params = make_master_params(param_groups_and_shapes) 123 | else: 124 | master_params = [state_dict[name] for name, _ in model.named_parameters()] 125 | return master_params 126 | 127 | 128 | def zero_master_grads(master_params): 129 | for param in master_params: 130 | param.grad = None 131 | 132 | 133 | def zero_grad(model_params): 134 | for param in model_params: 135 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 136 | if param.grad is not None: 137 | param.grad.detach_() 138 | param.grad.zero_() 139 | 140 | 141 | def param_grad_or_zeros(param): 142 | if param.grad is not None: 143 | return param.grad.data.detach() 144 | else: 145 | return th.zeros_like(param) 146 | 147 | 148 | class MixedPrecisionTrainer: 149 | def __init__( 150 | self, 151 | *, 152 | model, 153 | use_fp16=False, 154 | fp16_scale_growth=1e-3, 155 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, 156 | ): 157 | self.model = model 158 | self.use_fp16 = use_fp16 159 | self.fp16_scale_growth = fp16_scale_growth 160 | 161 | self.model_params = list(self.model.parameters()) 162 | self.master_params = self.model_params 163 | self.param_groups_and_shapes = None 164 | self.lg_loss_scale = initial_lg_loss_scale 165 | 166 | if self.use_fp16: 167 | self.param_groups_and_shapes = get_param_groups_and_shapes( 168 | self.model.named_parameters() 169 | ) 170 | self.master_params = make_master_params(self.param_groups_and_shapes) 171 | self.model.convert_to_fp16() 172 | 173 | def zero_grad(self): 174 | zero_grad(self.model_params) 175 | 176 | def backward(self, loss: th.Tensor): 177 | if self.use_fp16: 178 | loss_scale = 2 ** self.lg_loss_scale 179 | (loss * loss_scale).backward() 180 | else: 181 | loss.backward() 182 | 183 | def optimize(self, opt: th.optim.Optimizer): 184 | if self.use_fp16: 185 | return self._optimize_fp16(opt) 186 | else: 187 | return self._optimize_normal(opt) 188 | 189 | def _optimize_fp16(self, opt: th.optim.Optimizer): 190 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale) 191 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) 192 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale) 193 | if check_overflow(grad_norm): 194 | self.lg_loss_scale -= 1 195 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") 196 | zero_master_grads(self.master_params) 197 | return False 198 | 199 | logger.logkv_mean("grad_norm", grad_norm) 200 | logger.logkv_mean("param_norm", param_norm) 201 | 202 | self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale)) 203 | opt.step() 204 | zero_master_grads(self.master_params) 205 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params) 206 | self.lg_loss_scale += self.fp16_scale_growth 207 | return True 208 | 209 | def _optimize_normal(self, opt: th.optim.Optimizer): 210 | grad_norm, param_norm = self._compute_norms() 211 | logger.logkv_mean("grad_norm", grad_norm) 212 | logger.logkv_mean("param_norm", param_norm) 213 | opt.step() 214 | return True 215 | 216 | def _compute_norms(self, grad_scale=1.0): 217 | grad_norm = 0.0 218 | param_norm = 0.0 219 | for p in self.master_params: 220 | with th.no_grad(): 221 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2 222 | if p.grad is not None: 223 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2 224 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) 225 | 226 | def master_params_to_state_dict(self, master_params): 227 | return master_params_to_state_dict( 228 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16 229 | ) 230 | 231 | def state_dict_to_master_params(self, state_dict): 232 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16) 233 | 234 | 235 | def check_overflow(value): 236 | return (value == float("inf")) or (value == -float("inf")) or (value != value) 237 | -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/gaussian_diffusion.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code started out as a PyTorch port of Ho et al's diffusion models: 3 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py 4 | Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules. 5 | """ 6 | from torch.autograd import Variable 7 | import enum 8 | import torch.nn.functional as F 9 | from torchvision.utils import save_image 10 | import torch 11 | import math 12 | import os 13 | # from visdom import Visdom 14 | # viz = Visdom(port=8850) 15 | import numpy as np 16 | import torch as th 17 | import torch.nn as nn 18 | from .train_util import visualize 19 | from .nn import mean_flat 20 | from .losses import normal_kl, discretized_gaussian_log_likelihood 21 | from scipy import ndimage 22 | from torchvision import transforms 23 | from .utils import staple, dice_score, norm 24 | import torchvision.utils as vutils 25 | from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver 26 | import string 27 | import random 28 | 29 | def standardize(img): 30 | mean = th.mean(img) 31 | std = th.std(img) 32 | img = (img - mean) / std 33 | return img 34 | 35 | 36 | def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): 37 | """ 38 | Get a pre-defined beta schedule for the given name. 39 | The beta schedule library consists of beta schedules which remain similar 40 | in the limit of num_diffusion_timesteps. 41 | Beta schedules may be added, but should not be removed or changed once 42 | they are committed to maintain backwards compatibility. 43 | """ 44 | if schedule_name == "linear": 45 | # Linear schedule from Ho et al, extended to work for any number of 46 | # diffusion steps. 47 | scale = 1000 / num_diffusion_timesteps 48 | beta_start = scale * 0.0001 49 | beta_end = scale * 0.02 50 | return np.linspace( 51 | beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 52 | ) 53 | elif schedule_name == "cosine": 54 | return betas_for_alpha_bar( 55 | num_diffusion_timesteps, 56 | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, 57 | ) 58 | else: 59 | raise NotImplementedError(f"unknown beta schedule: {schedule_name}") 60 | 61 | 62 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 63 | """ 64 | Create a beta schedule that discretizes the given alpha_t_bar function, 65 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 66 | :param num_diffusion_timesteps: the number of betas to produce. 67 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 68 | produces the cumulative product of (1-beta) up to that 69 | part of the diffusion process. 70 | :param max_beta: the maximum beta to use; use values lower than 1 to 71 | prevent singularities. 72 | """ 73 | betas = [] 74 | for i in range(num_diffusion_timesteps): 75 | t1 = i / num_diffusion_timesteps 76 | t2 = (i + 1) / num_diffusion_timesteps 77 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 78 | return np.array(betas) 79 | 80 | 81 | class ModelMeanType(enum.Enum): 82 | """ 83 | Which type of output the model predicts. 84 | """ 85 | 86 | PREVIOUS_X = enum.auto() # the model predicts x_{t-1} 87 | START_X = enum.auto() # the model predicts x_0 88 | EPSILON = enum.auto() # the model predicts epsilon 89 | 90 | 91 | class ModelVarType(enum.Enum): 92 | """ 93 | What is used as the model's output variance. 94 | The LEARNED_RANGE option has been added to allow the model to predict 95 | values between FIXED_SMALL and FIXED_LARGE, making its job easier. 96 | """ 97 | 98 | LEARNED = enum.auto() 99 | FIXED_SMALL = enum.auto() 100 | FIXED_LARGE = enum.auto() 101 | LEARNED_RANGE = enum.auto() 102 | 103 | 104 | class LossType(enum.Enum): 105 | MSE = enum.auto() # use raw MSE loss (and KL when learning variances) 106 | RESCALED_MSE = ( 107 | enum.auto() 108 | ) # use raw MSE loss (with RESCALED_KL when learning variances) 109 | KL = enum.auto() # use the variational lower-bound 110 | RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB 111 | BCE_DICE = enum.auto() 112 | 113 | def is_vb(self): 114 | return self == LossType.KL or self == LossType.RESCALED_KL 115 | 116 | 117 | class GaussianDiffusion: 118 | """ 119 | Utilities for training and sampling diffusion models. 120 | Ported directly from here, and then adapted over time to further experimentation. 121 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 122 | :param betas: a 1-D numpy array of betas for each diffusion timestep, 123 | starting at T and going to 1. 124 | :param model_mean_type: a ModelMeanType determining what the model outputs. 125 | :param model_var_type: a ModelVarType determining how variance is output. 126 | :param loss_type: a LossType determining the loss function to use. 127 | :param rescale_timesteps: if True, pass floating point timesteps into the 128 | model so that they are always scaled like in the 129 | original paper (0 to 1000). 130 | """ 131 | 132 | def __init__( 133 | self, 134 | *, 135 | betas, 136 | model_mean_type, 137 | model_var_type, 138 | loss_type, 139 | dpm_solver, 140 | rescale_timesteps=False, 141 | ): 142 | self.BceDiceLoss_diff = BceDiceLoss() 143 | self.BceDiceLoss_unet = BceDiceLoss() 144 | self.model_mean_type = model_mean_type 145 | self.model_var_type = model_var_type 146 | self.loss_type = loss_type 147 | self.rescale_timesteps = rescale_timesteps 148 | self.dpm_solver = dpm_solver 149 | 150 | # Use float64 for accuracy. 151 | betas = np.array(betas, dtype=np.float64) 152 | self.betas = betas 153 | assert len(betas.shape) == 1, "betas must be 1-D" 154 | assert (betas > 0).all() and (betas <= 1).all() 155 | 156 | self.num_timesteps = int(betas.shape[0]) 157 | 158 | alphas = 1.0 - betas 159 | self.alphas_cumprod = np.cumprod(alphas, axis=0) 160 | self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) 161 | self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) 162 | assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) 163 | 164 | # calculations for diffusion q(x_t | x_{t-1}) and others 165 | self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) 166 | self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) 167 | self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) 168 | self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) 169 | self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) 170 | 171 | # calculations for posterior q(x_{t-1} | x_t, x_0) 172 | self.posterior_variance = ( 173 | betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 174 | ) 175 | # log calculation clipped because the posterior variance is 0 at the 176 | # beginning of the diffusion chain. 177 | self.posterior_log_variance_clipped = np.log( 178 | np.append(self.posterior_variance[1], self.posterior_variance[1:]) 179 | ) 180 | self.posterior_mean_coef1 = ( 181 | betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 182 | ) 183 | self.posterior_mean_coef2 = ( 184 | (1.0 - self.alphas_cumprod_prev) 185 | * np.sqrt(alphas) 186 | / (1.0 - self.alphas_cumprod) 187 | ) 188 | 189 | def q_mean_variance(self, x_start, t): 190 | """ 191 | Get the distribution q(x_t | x_0). 192 | :param x_start: the [N x C x ...] tensor of noiseless inputs. 193 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 194 | :return: A tuple (mean, variance, log_variance), all of x_start's shape. 195 | """ 196 | mean = ( 197 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 198 | ) 199 | variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) 200 | log_variance = _extract_into_tensor( 201 | self.log_one_minus_alphas_cumprod, t, x_start.shape 202 | ) 203 | return mean, variance, log_variance 204 | 205 | def q_sample(self, x_start, t, noise=None): 206 | """ 207 | Diffuse the data for a given number of diffusion steps. 208 | In other words, sample from q(x_t | x_0). 209 | :param x_start: the initial data batch. 210 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 211 | :param noise: if specified, the split-out normal noise. 212 | :return: A noisy version of x_start. 213 | """ 214 | if noise is None: 215 | noise = th.randn_like(x_start) 216 | assert noise.shape == x_start.shape 217 | return ( 218 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 219 | + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) 220 | * noise 221 | ) 222 | 223 | def q_posterior_mean_variance(self, x_start, x_t, t): 224 | """ 225 | Compute the mean and variance of the diffusion posterior: 226 | q(x_{t-1} | x_t, x_0) 227 | """ 228 | assert x_start.shape == x_t.shape 229 | posterior_mean = ( 230 | _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start 231 | + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t 232 | ) 233 | posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) 234 | posterior_log_variance_clipped = _extract_into_tensor( 235 | self.posterior_log_variance_clipped, t, x_t.shape 236 | ) 237 | assert ( 238 | posterior_mean.shape[0] 239 | == posterior_variance.shape[0] 240 | == posterior_log_variance_clipped.shape[0] 241 | == x_start.shape[0] 242 | ) 243 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 244 | 245 | 246 | def p_mean_variance( 247 | self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None 248 | ): 249 | """ 250 | Apply the model to get p(x_{t-1} | x_t), as well as a prediction of 251 | the initial x, x_0. 252 | :param model: the model, which takes a signal and a batch of timesteps 253 | as input. 254 | :param x: the [N x C x ...] tensor at time t. 255 | :param t: a 1-D Tensor of timesteps. 256 | :param clip_denoised: if True, clip the denoised signal into [-1, 1]. 257 | :param denoised_fn: if not None, a function which applies to the 258 | x_start prediction before it is used to sample. Applies before 259 | clip_denoised. 260 | :param model_kwargs: if not None, a dict of extra keyword arguments to 261 | pass to the model. This can be used for conditioning. 262 | :return: a dict with the following keys: 263 | - 'mean': the model mean output. 264 | - 'variance': the model variance output. 265 | - 'log_variance': the log of 'variance'. 266 | - 'pred_xstart': the prediction for x_0. 267 | """ 268 | if model_kwargs is None: 269 | model_kwargs = {} 270 | B, C = x.shape[:2] 271 | C=1 272 | cal = 0 273 | assert t.shape == (B,) 274 | model_output = model(x, self._scale_timesteps(t), **model_kwargs) 275 | if isinstance(model_output, tuple): 276 | model_output, cal = model_output 277 | x = x[:, -1:, ...] #loss is only calculated on the last channel, not on the input brain MR image 278 | if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: 279 | assert model_output.shape == (B, C * 2, *x.shape[2:]) 280 | model_output, model_var_values = th.split(model_output, C, dim=1) 281 | if self.model_var_type == ModelVarType.LEARNED: 282 | model_log_variance = model_var_values 283 | model_variance = th.exp(model_log_variance) 284 | else: 285 | min_log = _extract_into_tensor( 286 | self.posterior_log_variance_clipped, t, x.shape 287 | ) 288 | max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) 289 | # The model_var_values is [-1, 1] for [min_var, max_var]. 290 | frac = (model_var_values + 1) / 2 291 | model_log_variance = frac * max_log + (1 - frac) * min_log 292 | model_variance = th.exp(model_log_variance) 293 | else: 294 | model_variance, model_log_variance = { 295 | # for fixedlarge, we set the initial (log-)variance like so 296 | # to get a better decoder log likelihood. 297 | ModelVarType.FIXED_LARGE: ( 298 | np.append(self.posterior_variance[1], self.betas[1:]), 299 | np.log(np.append(self.posterior_variance[1], self.betas[1:])), 300 | ), 301 | ModelVarType.FIXED_SMALL: ( 302 | self.posterior_variance, 303 | self.posterior_log_variance_clipped, 304 | ), 305 | }[self.model_var_type] 306 | model_variance = _extract_into_tensor(model_variance, t, x.shape) 307 | model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) 308 | 309 | def process_xstart(x): 310 | if denoised_fn is not None: 311 | x = denoised_fn(x) 312 | if clip_denoised: 313 | return x.clamp(-1, 1) 314 | return x 315 | 316 | if self.model_mean_type == ModelMeanType.PREVIOUS_X: 317 | pred_xstart = process_xstart( 318 | self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output) 319 | ) 320 | model_mean = model_output 321 | elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: 322 | if self.model_mean_type == ModelMeanType.START_X: 323 | pred_xstart = process_xstart(model_output) 324 | else: 325 | pred_xstart = process_xstart( 326 | self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) 327 | ) 328 | model_mean, _, _ = self.q_posterior_mean_variance( 329 | x_start=pred_xstart, x_t=x, t=t 330 | ) 331 | else: 332 | raise NotImplementedError(self.model_mean_type) 333 | 334 | assert ( 335 | model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape 336 | ) 337 | return { 338 | "mean": model_mean, 339 | "variance": model_variance, 340 | "log_variance": model_log_variance, 341 | "pred_xstart": pred_xstart, 342 | 'cal': cal, 343 | } 344 | 345 | 346 | 347 | def _predict_xstart_from_eps(self, x_t, t, eps): 348 | assert x_t.shape == eps.shape 349 | return ( 350 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 351 | - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps 352 | ) 353 | 354 | def _predict_xstart_from_xprev(self, x_t, t, xprev): 355 | assert x_t.shape == xprev.shape 356 | return ( # (xprev - coef2*x_t) / coef1 357 | _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev 358 | - _extract_into_tensor( 359 | self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape 360 | ) 361 | * x_t 362 | ) 363 | 364 | def _predict_eps_from_xstart(self, x_t, t, pred_xstart): 365 | return ( 366 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 367 | - pred_xstart 368 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) 369 | 370 | def _scale_timesteps(self, t): 371 | if self.rescale_timesteps: 372 | 373 | return t.float() * (1000.0 / self.num_timesteps) 374 | return t 375 | 376 | def condition_mean(self, cond_fn, p_mean_var, x, t, org, model_kwargs=None): 377 | """ 378 | Compute the mean for the previous step, given a function cond_fn that 379 | computes the gradient of a conditional log probability with respect to 380 | x. In particular, cond_fn computes grad(log(p(y|x))), and we want to 381 | condition on y. 382 | This uses the conditioning strategy from Sohl-Dickstein et al. (2015). 383 | """ 384 | a, gradient = cond_fn(x, self._scale_timesteps(t),org, **model_kwargs) 385 | 386 | 387 | new_mean = ( 388 | p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() 389 | ) 390 | return a, new_mean 391 | 392 | def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): 393 | """ 394 | Compute what the p_mean_variance output would have been, should the 395 | model's score function be conditioned by cond_fn. 396 | See condition_mean() for details on cond_fn. 397 | Unlike condition_mean(), this instead uses the conditioning strategy 398 | from Song et al (2020). 399 | """ 400 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) 401 | 402 | eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) 403 | 404 | eps = eps.detach() - (1 - alpha_bar).sqrt() *p_mean_var["update"]*0 405 | 406 | out = p_mean_var.copy() 407 | out["pred_xstart"] = self._predict_xstart_from_eps(x.detach(), t.detach(), eps) 408 | out["mean"], _, _ = self.q_posterior_mean_variance( 409 | x_start=out["pred_xstart"], x_t=x, t=t 410 | ) 411 | return out, eps 412 | 413 | 414 | # def sample_known(self, img, batch_size = 1): 415 | # image_size = self.image_size 416 | # channels = self.channels 417 | # return self.p_sample_loop_known(model,(batch_size, channels, image_size, image_size), img) 418 | 419 | def p_sample( 420 | self, 421 | model, 422 | x, 423 | t, 424 | clip_denoised=True, 425 | denoised_fn=None, 426 | model_kwargs=None, 427 | ): 428 | """ 429 | Sample x_{t-1} from the model at the given timestep. 430 | :param model: the model to sample from. 431 | :param x: the current tensor at x_{t-1}. 432 | :param t: the value of t, starting at 0 for the first diffusion step. 433 | :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. 434 | :param denoised_fn: if not None, a function which applies to the 435 | x_start prediction before it is used to sample. 436 | :param cond_fn: if not None, this is a gradient function that acts 437 | similarly to the model. 438 | :param model_kwargs: if not None, a dict of extra keyword arguments to 439 | pass to the model. This can be used for conditioning. 440 | :return: a dict containing the following keys: 441 | - 'sample': a random sample from the model. 442 | - 'pred_xstart': a prediction of x_0. 443 | """ 444 | out = self.p_mean_variance( 445 | model, 446 | x, 447 | t, 448 | clip_denoised=clip_denoised, 449 | denoised_fn=denoised_fn, 450 | model_kwargs=model_kwargs, 451 | ) 452 | noise = th.randn_like(x[:, -1:,...]) 453 | nonzero_mask = ( 454 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) 455 | ) 456 | sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise 457 | 458 | return {"sample": sample, "pred_xstart": out["pred_xstart"], "cal": out["cal"]} 459 | 460 | def p_sample_loop( 461 | self, 462 | model, 463 | shape, 464 | noise=None, 465 | clip_denoised=True, 466 | denoised_fn=None, 467 | cond_fn=None, 468 | model_kwargs=None, 469 | device=None, 470 | progress=False, 471 | 472 | ): 473 | """ 474 | Generate samples from the model. 475 | :param model: the model module. 476 | :param shape: the shape of the samples, (N, C, H, W). 477 | :param noise: if specified, the noise from the encoder to sample. 478 | Should be of the same shape as `shape`. 479 | :param clip_denoised: if True, clip x_start predictions to [-1, 1]. 480 | :param denoised_fn: if not None, a function which applies to the 481 | x_start prediction before it is used to sample. 482 | :param cond_fn: if not None, this is a gradient function that acts 483 | similarly to the model. 484 | :param model_kwargs: if not None, a dict of extra keyword arguments to 485 | pass to the model. This can be used for conditioning. 486 | :param device: if specified, the device to create the samples on. 487 | If not specified, use a model parameter's device. 488 | :param progress: if True, show a tqdm progress bar. 489 | :return: a non-differentiable batch of samples. 490 | """ 491 | final = None 492 | for sample in self.p_sample_loop_progressive( 493 | model, 494 | shape, 495 | noise=noise, 496 | clip_denoised=clip_denoised, 497 | denoised_fn=denoised_fn, 498 | cond_fn=cond_fn, 499 | model_kwargs=model_kwargs, 500 | device=device, 501 | progress=progress, 502 | ): 503 | final = sample 504 | return final["sample"] 505 | 506 | 507 | def p_sample_loop_known( 508 | self, 509 | model, 510 | shape, 511 | img, 512 | step = 1000, 513 | org=None, 514 | noise=None, 515 | clip_denoised=True, 516 | denoised_fn=None, 517 | cond_fn=None, 518 | model_kwargs=None, 519 | device=None, 520 | progress=False, 521 | conditioner = None, 522 | classifier=None 523 | ): 524 | if device is None: 525 | device = next(model.parameters()).device 526 | assert isinstance(shape, (tuple, list)) 527 | img = img.to(device) 528 | noise = th.randn_like(img[:, :1, ...]).to(device) 529 | x_noisy = torch.cat((img[:, :-1, ...], noise), dim=1) #add noise as the last channel 530 | img = img.to(device) 531 | 532 | if self.dpm_solver: 533 | final = {} 534 | noise_schedule = NoiseScheduleVP(schedule='discrete', betas= th.from_numpy(self.betas)) 535 | 536 | model_fn = model_wrapper( 537 | model, 538 | noise_schedule, 539 | model_type="noise", # or "x_start" or "v" or "score" 540 | model_kwargs=model_kwargs, 541 | ) 542 | 543 | dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++", 544 | correcting_x0_fn="dynamic_thresholding", img = img[:, :-1, ...]) 545 | 546 | ## Steps in [20, 30] can generate quite good samples. 547 | sample, cal = dpm_solver.sample( 548 | noise.to(dtype=th.float), 549 | steps= step, 550 | order=2, 551 | skip_type="time_uniform", 552 | method="multistep", 553 | ) 554 | sample = sample.detach() ### MODIFIED: for DPM-Solver OOM issue 555 | sample[:,-1,:,:] = norm(sample[:,-1,:,:]) 556 | final["sample"] = sample 557 | final["cal"] = cal 558 | 559 | cal_out = torch.clamp(final["cal"] + 0.25 * final["sample"][:,-1,:,:].unsqueeze(1), 0, 1) 560 | else: 561 | # print('no dpm-solver') 562 | i = 0 563 | letters = string.ascii_lowercase 564 | name = ''.join(random.choice(letters) for i in range(10)) 565 | for sample in self.p_sample_loop_progressive( 566 | model, 567 | shape, 568 | time = step, 569 | noise=x_noisy, 570 | clip_denoised=clip_denoised, 571 | denoised_fn=denoised_fn, 572 | cond_fn=cond_fn, 573 | org=org, 574 | model_kwargs=model_kwargs, 575 | device=device, 576 | progress=progress, 577 | ): 578 | final = sample 579 | # i += 1 580 | # '''vis each step sample''' 581 | # if i % 5 == 0: 582 | 583 | # o1 = th.tensor(img)[:,0,:,:].unsqueeze(1) 584 | # o2 = th.tensor(img)[:,1,:,:].unsqueeze(1) 585 | # o3 = th.tensor(img)[:,2,:,:].unsqueeze(1) 586 | # o4 = th.tensor(img)[:,3,:,:].unsqueeze(1) 587 | # s = th.tensor(final["sample"])[:,-1,:,:].unsqueeze(1) 588 | # tup = (o1/o1.max(),o2/o2.max(),o3/o3.max(),o4/o4.max(),s) 589 | # compose = th.cat(tup,0) 590 | # vutils.save_image(s, fp = os.path.join('../res_temp_norm_6000_100', name+str(i)+".jpg"), nrow = 1, padding = 10) 591 | if dice_score(final["sample"][:, -1, :, :].unsqueeze(1), final["cal"]) < 0.65: 592 | cal_out = torch.clamp(final["cal"] + 0.25 * final["sample"][:,-1,:,:].unsqueeze(1), 0, 1) 593 | else: 594 | cal_out = torch.clamp(final["cal"] * 0.5 + 0.5 * final["sample"][:,-1,:,:].unsqueeze(1), 0, 1) 595 | # cal_out = torch.clamp(final["cal"] * 0.5 + 0.5 * final["sample"][:, -1, :, :].unsqueeze(1), 0, 1) 596 | 597 | return final["sample"], x_noisy, img, final["cal"], cal_out 598 | 599 | def p_sample_loop_progressive( 600 | self, 601 | model, 602 | shape, 603 | time=1000, 604 | noise=None, 605 | clip_denoised=True, 606 | denoised_fn=None, 607 | cond_fn=None, 608 | org=None, 609 | model_kwargs=None, 610 | device=None, 611 | progress=False, 612 | ): 613 | """ 614 | Generate samples from the model and yield intermediate samples from 615 | each timestep of diffusion. 616 | Arguments are the same as p_sample_loop(). 617 | Returns a generator over dicts, where each dict is the return value of 618 | p_sample(). 619 | """ 620 | 621 | if device is None: 622 | device = next(model.parameters()).device 623 | assert isinstance(shape, (tuple, list)) 624 | if noise is not None: 625 | img = noise 626 | else: 627 | img = th.randn(*shape, device=device) 628 | indices = list(range(time))[::-1] 629 | org_c = img.size(1) 630 | org_MRI = img[:, :-1, ...] #original brain MR image 631 | if progress: 632 | # Lazy import so that we don't depend on tqdm. 633 | from tqdm.auto import tqdm 634 | 635 | indices = tqdm(indices) 636 | 637 | else: 638 | for i in indices: 639 | t = th.tensor([i] * shape[0], device=device) 640 | # if i%100==0: 641 | # print('sampling step', i) 642 | # viz.image(visualize(img.cpu()[0, -1,...]), opts=dict(caption="sample"+ str(i) )) 643 | 644 | with th.no_grad(): 645 | # print('img bef size',img.size()) 646 | if img.size(1) != org_c: 647 | img = torch.cat((org_MRI,img), dim=1) #in every step, make sure to concatenate the original image to the sampled segmentation mask 648 | 649 | out = self.p_sample( 650 | model, 651 | img.float(), 652 | t, 653 | clip_denoised=clip_denoised, 654 | denoised_fn=denoised_fn, 655 | model_kwargs=model_kwargs, 656 | ) 657 | yield out 658 | img = out["sample"] 659 | 660 | def ddim_sample( 661 | self, 662 | model, 663 | x, 664 | t, 665 | clip_denoised=True, 666 | denoised_fn=None, 667 | cond_fn=None, 668 | model_kwargs=None, 669 | eta=0.0, 670 | ): 671 | """ 672 | Sample x_{t-1} from the model using DDIM. 673 | Same usage as p_sample(). 674 | """ 675 | out = self.p_mean_variance( 676 | model, 677 | x, 678 | t, 679 | clip_denoised=clip_denoised, 680 | denoised_fn=denoised_fn, 681 | model_kwargs=model_kwargs, 682 | ) 683 | 684 | 685 | if cond_fn is not None: 686 | out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) 687 | 688 | # Usually our model outputs epsilon, but we re-derive it 689 | # in case we used x_start or x_prev prediction. 690 | eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) 691 | 692 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) 693 | alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) 694 | sigma = ( 695 | eta 696 | * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) 697 | * th.sqrt(1 - alpha_bar / alpha_bar_prev) 698 | ) 699 | # Equation 12. 700 | noise = th.randn_like(x[:, -1:, ...]) 701 | 702 | mean_pred = ( 703 | out["pred_xstart"] * th.sqrt(alpha_bar_prev) 704 | + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps 705 | ) 706 | nonzero_mask = ( 707 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) 708 | ) # no noise when t == 0 709 | sample = mean_pred + nonzero_mask * sigma * noise 710 | return {"sample": sample, "pred_xstart": out["pred_xstart"]} 711 | 712 | 713 | def ddim_reverse_sample( 714 | self, 715 | model, 716 | x, 717 | t, 718 | clip_denoised=True, 719 | denoised_fn=None, 720 | model_kwargs=None, 721 | eta=0.0, 722 | ): 723 | """ 724 | Sample x_{t+1} from the model using DDIM reverse ODE. 725 | """ 726 | assert eta == 0.0, "Reverse ODE only for deterministic path" 727 | out = self.p_mean_variance( 728 | model, 729 | x, 730 | t, 731 | clip_denoised=clip_denoised, 732 | denoised_fn=denoised_fn, 733 | model_kwargs=model_kwargs, 734 | ) 735 | # Usually our model outputs epsilon, but we re-derive it 736 | # in case we used x_start or x_prev prediction. 737 | eps = ( 738 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x 739 | - out["pred_xstart"] 740 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) 741 | alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) 742 | 743 | # Equation 12. reversed 744 | mean_pred = ( 745 | out["pred_xstart"] * th.sqrt(alpha_bar_next) 746 | + th.sqrt(1 - alpha_bar_next) * eps 747 | ) 748 | 749 | return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} 750 | 751 | 752 | 753 | def ddim_sample_loop_interpolation( 754 | self, 755 | model, 756 | shape, 757 | img1, 758 | img2, 759 | lambdaint, 760 | noise=None, 761 | clip_denoised=True, 762 | denoised_fn=None, 763 | cond_fn=None, 764 | model_kwargs=None, 765 | device=None, 766 | progress=False, 767 | ): 768 | if device is None: 769 | device = next(model.parameters()).device 770 | assert isinstance(shape, (tuple, list)) 771 | b = shape[0] 772 | t = th.randint(499,500, (b,), device=device).long().to(device) 773 | 774 | img1=torch.tensor(img1).to(device) 775 | img2 = torch.tensor(img2).to(device) 776 | 777 | noise = th.randn_like(img1).to(device) 778 | x_noisy1 = self.q_sample(x_start=img1, t=t, noise=noise).to(device) 779 | x_noisy2 = self.q_sample(x_start=img2, t=t, noise=noise).to(device) 780 | interpol=lambdaint*x_noisy1+(1-lambdaint)*x_noisy2 781 | 782 | for sample in self.ddim_sample_loop_progressive( 783 | model, 784 | shape, 785 | time=t, 786 | noise=interpol, 787 | clip_denoised=clip_denoised, 788 | denoised_fn=denoised_fn, 789 | cond_fn=cond_fn, 790 | model_kwargs=model_kwargs, 791 | device=device, 792 | progress=progress, 793 | ): 794 | final = sample 795 | return final["sample"], interpol, img1, img2 796 | 797 | def ddim_sample_loop( 798 | self, 799 | model, 800 | shape, 801 | noise=None, 802 | clip_denoised=True, 803 | denoised_fn=None, 804 | cond_fn=None, 805 | model_kwargs=None, 806 | device=None, 807 | progress=False, 808 | eta=0.0, 809 | ): 810 | """ 811 | Generate samples from the model using DDIM. 812 | Same usage as p_sample_loop(). 813 | """ 814 | final = None 815 | if device is None: 816 | device = next(model.parameters()).device 817 | assert isinstance(shape, (tuple, list)) 818 | b = shape[0] 819 | t = th.randint(99, 100, (b,), device=device).long().to(device) 820 | 821 | for sample in self.ddim_sample_loop_progressive( 822 | model, 823 | shape, 824 | time=t, 825 | noise=noise, 826 | clip_denoised=clip_denoised, 827 | denoised_fn=denoised_fn, 828 | cond_fn=cond_fn, 829 | model_kwargs=model_kwargs, 830 | device=device, 831 | progress=progress, 832 | eta=eta, 833 | ): 834 | 835 | final = sample 836 | # viz.image(visualize(final["sample"].cpu()[0, ...]), opts=dict(caption="sample"+ str(10) )) 837 | return final["sample"] 838 | 839 | 840 | 841 | def ddim_sample_loop_known( 842 | self, 843 | model, 844 | shape, 845 | img, 846 | clip_denoised=True, 847 | denoised_fn=None, 848 | cond_fn=None, 849 | model_kwargs=None, 850 | device=None, 851 | progress=False, 852 | eta = 0.0 853 | ): 854 | if device is None: 855 | device = next(model.parameters()).device 856 | assert isinstance(shape, (tuple, list)) 857 | b = shape[0] 858 | 859 | img = img.to(device) 860 | 861 | t = th.randint(499,500, (b,), device=device).long().to(device) 862 | noise = th.randn_like(img[:, :1, ...]).to(device) 863 | 864 | x_noisy = torch.cat((img[:, :-1, ...], noise), dim=1).float() 865 | img = img.to(device) 866 | 867 | final = None 868 | for sample in self.ddim_sample_loop_progressive( 869 | model, 870 | shape, 871 | time=t, 872 | noise=x_noisy, 873 | clip_denoised=clip_denoised, 874 | denoised_fn=denoised_fn, 875 | cond_fn=cond_fn, 876 | model_kwargs=model_kwargs, 877 | device=device, 878 | progress=progress, 879 | eta=eta, 880 | ): 881 | final = sample 882 | 883 | return final["sample"], x_noisy, img 884 | 885 | 886 | def ddim_sample_loop_progressive( 887 | self, 888 | model, 889 | shape, 890 | time=1000, 891 | noise=None, 892 | clip_denoised=True, 893 | denoised_fn=None, 894 | cond_fn=None, 895 | model_kwargs=None, 896 | device=None, 897 | progress=False, 898 | eta=0.0, 899 | ): 900 | """ 901 | Use DDIM to sample from the model and yield intermediate samples from 902 | each timestep of DDIM. 903 | Same usage as p_sample_loop_progressive(). 904 | """ 905 | if device is None: 906 | device = next(model.parameters()).device 907 | assert isinstance(shape, (tuple, list)) 908 | if noise is not None: 909 | img = noise 910 | else: 911 | img = th.randn(*shape, device=device) 912 | indices = list(range(time-1))[::-1] 913 | orghigh = img[:, :-1, ...] 914 | 915 | 916 | if progress: 917 | # Lazy import so that we don't depend on tqdm. 918 | from tqdm.auto import tqdm 919 | 920 | indices = tqdm(indices) 921 | 922 | for i in indices: 923 | t = th.tensor([i] * shape[0], device=device) 924 | with th.no_grad(): 925 | if img.shape != (1, 5, 224, 224): 926 | img = torch.cat((orghigh,img), dim=1).float() 927 | 928 | out = self.ddim_sample( 929 | model, 930 | img, 931 | t, 932 | clip_denoised=clip_denoised, 933 | denoised_fn=denoised_fn, 934 | cond_fn=cond_fn, 935 | model_kwargs=model_kwargs, 936 | eta=eta, 937 | ) 938 | yield out 939 | img = out["sample"] 940 | 941 | def _vb_terms_bpd( 942 | self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None 943 | ): 944 | """ 945 | Get a term for the variational lower-bound. 946 | The resulting units are bits (rather than nats, as one might expect). 947 | This allows for comparison to other papers. 948 | :return: a dict with the following keys: 949 | - 'output': a shape [N] tensor of NLLs or KLs. 950 | - 'pred_xstart': the x_0 predictions. 951 | """ 952 | true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( 953 | x_start=x_start, x_t=x_t, t=t 954 | ) 955 | out = self.p_mean_variance( 956 | model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs 957 | ) 958 | kl = normal_kl( 959 | true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] 960 | ) 961 | kl = mean_flat(kl) / np.log(2.0) 962 | 963 | decoder_nll = -discretized_gaussian_log_likelihood( 964 | x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] 965 | ) 966 | assert decoder_nll.shape == x_start.shape 967 | decoder_nll = mean_flat(decoder_nll) / np.log(2.0) 968 | 969 | # At the first timestep return the decoder NLL, 970 | # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) 971 | output = th.where((t == 0), decoder_nll, kl) 972 | return {"output": output, "pred_xstart": out["pred_xstart"]} 973 | 974 | 975 | 976 | def training_losses_segmentation(self, model, classifier, x_start, t, model_kwargs=None, noise=None): 977 | """ 978 | Compute training losses for a single timestep. 979 | :param model: the model to evaluate loss on. 980 | :param x_start: the [N x C x ...] tensor of inputs. 981 | :param t: a batch of timestep indices. 982 | :param model_kwargs: if not None, a dict of extra keyword arguments to 983 | pass to the model. This can be used for conditioning. 984 | :param noise: if specified, the specific Gaussian noise to try to remove. 985 | :return: a dict with the key "loss" containing a tensor of shape [N]. 986 | Some mean or variance settings may also have other keys. 987 | """ 988 | if model_kwargs is None: 989 | model_kwargs = {} 990 | if noise is None: 991 | noise = th.randn_like(x_start[:, -1:, ...]) 992 | 993 | 994 | mask = x_start[:, -1:, ...] 995 | res = torch.where(mask > 0, 1, 0) #merge all tumor classes into one to get a binary segmentation mask 996 | 997 | res_t = self.q_sample(res, t, noise=noise) #add noise to the segmentation channel 998 | x_t=x_start.float() 999 | x_t[:, -1:, ...]=res_t.float() 1000 | terms = {} 1001 | 1002 | 1003 | if self.loss_type == LossType.MSE or self.loss_type == LossType.BCE_DICE or self.loss_type == LossType.RESCALED_MSE: 1004 | 1005 | model_output, cal = model(x_t, self._scale_timesteps(t), **model_kwargs) 1006 | if self.model_var_type in [ 1007 | ModelVarType.LEARNED, 1008 | ModelVarType.LEARNED_RANGE, 1009 | ]: 1010 | B, C= x_t.shape[:2] 1011 | C = 1 1012 | assert model_output.shape == (B, C * 2, *x_t.shape[2:]) 1013 | model_output, model_var_values = th.split(model_output, C, dim=1) 1014 | # Learn the variance using the variational bound, but don't let 1015 | # it affect our mean prediction. 1016 | frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) 1017 | terms["vb"] = self._vb_terms_bpd( 1018 | model=lambda *args, r=frozen_out: r, 1019 | x_start=res, 1020 | x_t=res_t, 1021 | t=t, 1022 | clip_denoised=False, 1023 | )["output"] 1024 | if self.loss_type == LossType.RESCALED_MSE: 1025 | # Divide by 1000 for equivalence with initial implementation. 1026 | # Without a factor of 1/1000, the VB term hurts the MSE term. 1027 | terms["vb"] *= self.num_timesteps / 1000.0 1028 | 1029 | target = { 1030 | ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( 1031 | x_start=res, x_t=res_t, t=t 1032 | )[0], 1033 | ModelMeanType.START_X: res, 1034 | ModelMeanType.EPSILON: noise, 1035 | }[self.model_mean_type] 1036 | 1037 | # model_output = (cal > 0.5) * (model_output >0.5) * model_output if 2. * (cal*model_output).sum() / (cal+model_output).sum() < 0.75 else model_output 1038 | # terms["loss_diff"] = self.bceloss(model_output, target) 1039 | terms["loss_diff"] = mean_flat((target - model_output) ** 2 ) 1040 | terms["loss_cal"] = self.BceDiceLoss_diff(cal.type(th.float), res.type(th.float)) 1041 | # terms["loss_unet"] = self.BceDiceLoss_unet(model_var_values.type(th.float), res.type(th.float)) 1042 | # terms["loss_cal"] = self.DiceBCELoss(cal.type(th.float), res.type(th.float)) 1043 | # terms["loss_cal"] = mean_flat((res - cal) ** 2) 1044 | # terms["loss_cal"] = self.bceloss(cal.type(th.float), res.type(th.float)) 1045 | # terms["mse"] = (terms["mse_diff"] + terms["mse_cal"]) / 2. 1046 | if "vb" in terms: 1047 | terms["loss"] = terms["loss_diff"] + terms["vb"] 1048 | 1049 | else: 1050 | terms["loss"] = terms["loss_diff"] 1051 | 1052 | else: 1053 | raise NotImplementedError(self.loss_type) 1054 | 1055 | return (terms, model_output) 1056 | 1057 | 1058 | def _prior_bpd(self, x_start): 1059 | """ 1060 | Get the prior KL term for the variational lower-bound, measured in 1061 | bits-per-dim. 1062 | This term can't be optimized, as it only depends on the encoder. 1063 | :param x_start: the [N x C x ...] tensor of inputs. 1064 | :return: a batch of [N] KL values (in bits), one per batch element. 1065 | """ 1066 | batch_size = x_start.shape[0] 1067 | t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) 1068 | qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) 1069 | kl_prior = normal_kl( 1070 | mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 1071 | ) 1072 | return mean_flat(kl_prior) / np.log(2.0) 1073 | 1074 | def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): 1075 | """ 1076 | Compute the entire variational lower-bound, measured in bits-per-dim, 1077 | as well as other related quantities. 1078 | :param model: the model to evaluate loss on. 1079 | :param x_start: the [N x C x ...] tensor of inputs. 1080 | :param clip_denoised: if True, clip denoised samples. 1081 | :param model_kwargs: if not None, a dict of extra keyword arguments to 1082 | pass to the model. This can be used for conditioning. 1083 | :return: a dict containing the following keys: 1084 | - total_bpd: the total variational lower-bound, per batch element. 1085 | - prior_bpd: the prior term in the lower-bound. 1086 | - vb: an [N x T] tensor of terms in the lower-bound. 1087 | - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. 1088 | - mse: an [N x T] tensor of epsilon MSEs for each timestep. 1089 | """ 1090 | device = x_start.device 1091 | batch_size = x_start.shape[0] 1092 | 1093 | vb = [] 1094 | xstart_mse = [] 1095 | mse = [] 1096 | for t in list(range(self.num_timesteps))[::-1]: 1097 | t_batch = th.tensor([t] * batch_size, device=device) 1098 | noise = th.randn_like(x_start) 1099 | x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) 1100 | 1101 | # Calculate VLB term at the current timestep 1102 | with th.no_grad(): 1103 | out = self._vb_terms_bptimestepsd( 1104 | model, 1105 | x_start=x_start, 1106 | x_t=x_t, 1107 | t=t_batch, 1108 | clip_denoised=clip_denoised, 1109 | model_kwargs=model_kwargs, 1110 | ) 1111 | vb.append(out["output"]) 1112 | xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) 1113 | eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) 1114 | mse.append(mean_flat((eps - noise) ** 2)) 1115 | 1116 | vb = th.stack(vb, dim=1) 1117 | xstart_mse = th.stack(xstart_mse, dim=1) 1118 | mse = th.stack(mse, dim=1) 1119 | 1120 | prior_bpd = self._prior_bpd(x_start) 1121 | total_bpd = vb.sum(dim=1) + prior_bpd 1122 | return { 1123 | "total_bpd": total_bpd, 1124 | "prior_bpd": prior_bpd, 1125 | "vb": vb, 1126 | "xstart_mse": xstart_mse, 1127 | "mse": mse, 1128 | } 1129 | 1130 | 1131 | def _extract_into_tensor(arr, timesteps, broadcast_shape): 1132 | """ 1133 | Extract values from a 1-D numpy array for a batch of indices. 1134 | :param arr: the 1-D numpy array. 1135 | :param timesteps: a tensor of indices into the array to extract. 1136 | :param broadcast_shape: a larger shape of K dimensions with the batch 1137 | dimension equal to the length of timesteps. 1138 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. 1139 | """ 1140 | res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() 1141 | while len(res.shape) < len(broadcast_shape): 1142 | res = res[..., None] 1143 | return res.expand(broadcast_shape) 1144 | 1145 | class DiceLoss(nn.Module): 1146 | def __init__(self): 1147 | super(DiceLoss, self).__init__() 1148 | 1149 | def forward(self, pred, target): 1150 | smooth = 1 1151 | size = pred.size(0) 1152 | 1153 | pred_ = pred.view(size, -1) 1154 | target_ = target.view(size, -1) 1155 | intersection = pred_ * target_ 1156 | dice_score = (2 * intersection.sum(1) + smooth)/(pred_.sum(1) + target_.sum(1) + smooth) 1157 | dice_loss = 1 - dice_score.sum()/size 1158 | 1159 | return dice_loss 1160 | 1161 | class BCELoss(nn.Module): 1162 | def __init__(self): 1163 | super(BCELoss, self).__init__() 1164 | self.bceloss = nn.BCELoss() 1165 | 1166 | def forward(self, pred, target): 1167 | size = pred.size(0) 1168 | pred_ = pred.view(size, -1) 1169 | target_ = target.view(size, -1) 1170 | 1171 | return self.bceloss(pred_, target_) 1172 | 1173 | class BceDiceLoss(nn.Module): 1174 | def __init__(self, wb=1, wd=1): 1175 | super(BceDiceLoss, self).__init__() 1176 | self.bce = BCELoss() 1177 | self.dice = DiceLoss() 1178 | self.wb = wb 1179 | self.wd = wd 1180 | 1181 | def forward(self, pred, target): 1182 | bceloss = self.bce(pred, target) 1183 | diceloss = self.dice(pred, target) 1184 | 1185 | loss = self.wd * diceloss + self.wb * bceloss 1186 | return loss -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies: 3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py 4 | """ 5 | 6 | import os 7 | import sys 8 | import shutil 9 | import os.path as osp 10 | import json 11 | import time 12 | import datetime 13 | import tempfile 14 | import warnings 15 | from collections import defaultdict 16 | from contextlib import contextmanager 17 | 18 | DEBUG = 10 19 | INFO = 20 20 | WARN = 30 21 | ERROR = 40 22 | 23 | DISABLED = 50 24 | 25 | 26 | class KVWriter(object): 27 | def writekvs(self, kvs): 28 | raise NotImplementedError 29 | 30 | 31 | class SeqWriter(object): 32 | def writeseq(self, seq): 33 | raise NotImplementedError 34 | 35 | 36 | class HumanOutputFormat(KVWriter, SeqWriter): 37 | def __init__(self, filename_or_file): 38 | if isinstance(filename_or_file, str): 39 | self.file = open(filename_or_file, "wt") 40 | self.own_file = True 41 | else: 42 | assert hasattr(filename_or_file, "read"), ( 43 | "expected file or str, got %s" % filename_or_file 44 | ) 45 | self.file = filename_or_file 46 | self.own_file = False 47 | 48 | def writekvs(self, kvs): 49 | # Create strings for printing 50 | key2str = {} 51 | for (key, val) in sorted(kvs.items()): 52 | if hasattr(val, "__float__"): 53 | valstr = "%-8.3g" % val 54 | else: 55 | valstr = str(val) 56 | key2str[self._truncate(key)] = self._truncate(valstr) 57 | 58 | # Find max widths 59 | if len(key2str) == 0: 60 | print("WARNING: tried to write empty key-value dict") 61 | return 62 | else: 63 | keywidth = max(map(len, key2str.keys())) 64 | valwidth = max(map(len, key2str.values())) 65 | 66 | # Write out the data 67 | dashes = "-" * (keywidth + valwidth + 7) 68 | lines = [dashes] 69 | for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()): 70 | lines.append( 71 | "| %s%s | %s%s |" 72 | % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val))) 73 | ) 74 | lines.append(dashes) 75 | self.file.write("\n".join(lines) + "\n") 76 | 77 | # Flush the output to the file 78 | self.file.flush() 79 | 80 | def _truncate(self, s): 81 | maxlen = 30 82 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s 83 | 84 | def writeseq(self, seq): 85 | seq = list(seq) 86 | for (i, elem) in enumerate(seq): 87 | self.file.write(elem) 88 | if i < len(seq) - 1: # add space unless this is the last one 89 | self.file.write(" ") 90 | self.file.write("\n") 91 | self.file.flush() 92 | 93 | def close(self): 94 | if self.own_file: 95 | self.file.close() 96 | 97 | 98 | class JSONOutputFormat(KVWriter): 99 | def __init__(self, filename): 100 | self.file = open(filename, "wt") 101 | 102 | def writekvs(self, kvs): 103 | for k, v in sorted(kvs.items()): 104 | if hasattr(v, "dtype"): 105 | kvs[k] = float(v) 106 | self.file.write(json.dumps(kvs) + "\n") 107 | self.file.flush() 108 | 109 | def close(self): 110 | self.file.close() 111 | 112 | 113 | class CSVOutputFormat(KVWriter): 114 | def __init__(self, filename): 115 | self.file = open(filename, "w+t") 116 | self.keys = [] 117 | self.sep = "," 118 | 119 | def writekvs(self, kvs): 120 | # Add our current row to the history 121 | extra_keys = list(kvs.keys() - self.keys) 122 | extra_keys.sort() 123 | if extra_keys: 124 | self.keys.extend(extra_keys) 125 | self.file.seek(0) 126 | lines = self.file.readlines() 127 | self.file.seek(0) 128 | for (i, k) in enumerate(self.keys): 129 | if i > 0: 130 | self.file.write(",") 131 | self.file.write(k) 132 | self.file.write("\n") 133 | for line in lines[1:]: 134 | self.file.write(line[:-1]) 135 | self.file.write(self.sep * len(extra_keys)) 136 | self.file.write("\n") 137 | for (i, k) in enumerate(self.keys): 138 | if i > 0: 139 | self.file.write(",") 140 | v = kvs.get(k) 141 | if v is not None: 142 | self.file.write(str(v)) 143 | self.file.write("\n") 144 | self.file.flush() 145 | 146 | def close(self): 147 | self.file.close() 148 | 149 | 150 | # class TensorBoardOutputFormat(KVWriter): 151 | # """ 152 | # Dumps key/value pairs into TensorBoard's numeric format. 153 | # """ 154 | # 155 | # def __init__(self, dir): 156 | # os.makedirs(dir, exist_ok=True) 157 | # self.dir = dir 158 | # self.step = 1 159 | # prefix = "events" 160 | # path = osp.join(osp.abspath(dir), prefix) 161 | # import tensorflow as tf 162 | # from tensorflow.python import pywrap_tensorflow 163 | # from tensorflow.core.util import event_pb2 164 | # from tensorflow.python.util import compat 165 | # 166 | # self.tf = tf 167 | # self.event_pb2 = event_pb2 168 | # self.pywrap_tensorflow = pywrap_tensorflow 169 | # self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) 170 | # 171 | # def writekvs(self, kvs): 172 | # def summary_val(k, v): 173 | # kwargs = {"tag": k, "simple_value": float(v)} 174 | # return self.tf.Summary.Value(**kwargs) 175 | # 176 | # summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) 177 | # event = self.event_pb2.Event(wall_time=time.time(), summary=summary) 178 | # event.step = ( 179 | # self.step 180 | # ) # is there any reason why you'd want to specify the step? 181 | # self.writer.WriteEvent(event) 182 | # self.writer.Flush() 183 | # self.step += 1 184 | # 185 | # def close(self): 186 | # if self.writer: 187 | # self.writer.Close() 188 | # self.writer = None 189 | 190 | 191 | def make_output_format(format, ev_dir, log_suffix=""): 192 | os.makedirs(ev_dir, exist_ok=True) 193 | if format == "stdout": 194 | return HumanOutputFormat(sys.stdout) 195 | elif format == "log": 196 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix)) 197 | elif format == "json": 198 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix)) 199 | elif format == "csv": 200 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix)) 201 | # elif format == "tensorboard": 202 | # return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix)) 203 | else: 204 | raise ValueError("Unknown format specified: %s" % (format,)) 205 | 206 | 207 | # ================================================================ 208 | # API 209 | # ================================================================ 210 | 211 | 212 | def logkv(key, val): 213 | """ 214 | Log a value of some diagnostic 215 | Call this once for each diagnostic quantity, each iteration 216 | If called many times, last value will be used. 217 | """ 218 | get_current().logkv(key, val) 219 | 220 | 221 | def logkv_mean(key, val): 222 | """ 223 | The same as logkv(), but if called many times, values averaged. 224 | """ 225 | get_current().logkv_mean(key, val) 226 | 227 | 228 | def logkvs(d): 229 | """ 230 | Log a dictionary of key-value pairs 231 | """ 232 | for (k, v) in d.items(): 233 | logkv(k, v) 234 | 235 | 236 | def dumpkvs(): 237 | """ 238 | Write all of the diagnostics from the current iteration 239 | """ 240 | return get_current().dumpkvs() 241 | 242 | 243 | def getkvs(): 244 | return get_current().name2val 245 | 246 | 247 | def log(*args, level=INFO): 248 | """ 249 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 250 | """ 251 | get_current().log(*args, level=level) 252 | 253 | 254 | def debug(*args): 255 | log(*args, level=DEBUG) 256 | 257 | 258 | def info(*args): 259 | log(*args, level=INFO) 260 | 261 | 262 | def warn(*args): 263 | log(*args, level=WARN) 264 | 265 | 266 | def error(*args): 267 | log(*args, level=ERROR) 268 | 269 | 270 | def set_level(level): 271 | """ 272 | Set logging threshold on current logger. 273 | """ 274 | get_current().set_level(level) 275 | 276 | 277 | def set_comm(comm): 278 | get_current().set_comm(comm) 279 | 280 | 281 | def get_dir(): 282 | """ 283 | Get directory that log files are being written to. 284 | will be None if there is no output directory (i.e., if you didn't call start) 285 | """ 286 | return get_current().get_dir() 287 | 288 | 289 | record_tabular = logkv 290 | dump_tabular = dumpkvs 291 | 292 | 293 | @contextmanager 294 | def profile_kv(scopename): 295 | logkey = "wait_" + scopename 296 | tstart = time.time() 297 | try: 298 | yield 299 | finally: 300 | get_current().name2val[logkey] += time.time() - tstart 301 | 302 | 303 | def profile(n): 304 | """ 305 | Usage: 306 | @profile("my_func") 307 | def my_func(): code 308 | """ 309 | 310 | def decorator_with_name(func): 311 | def func_wrapper(*args, **kwargs): 312 | with profile_kv(n): 313 | return func(*args, **kwargs) 314 | 315 | return func_wrapper 316 | 317 | return decorator_with_name 318 | 319 | 320 | # ================================================================ 321 | # Backend 322 | # ================================================================ 323 | 324 | 325 | def get_current(): 326 | if Logger.CURRENT is None: 327 | _configure_default_logger() 328 | 329 | return Logger.CURRENT 330 | 331 | 332 | class Logger(object): 333 | DEFAULT = None # A logger with no output files. (See right below class definition) 334 | # So that you can still log to the terminal without setting up any output files 335 | CURRENT = None # Current logger being used by the free functions above 336 | 337 | def __init__(self, dir, output_formats, comm=None): 338 | self.name2val = defaultdict(float) # values this iteration 339 | self.name2cnt = defaultdict(int) 340 | self.level = INFO 341 | self.dir = dir 342 | self.output_formats = output_formats 343 | self.comm = comm 344 | 345 | # Logging API, forwarded 346 | # ---------------------------------------- 347 | def logkv(self, key, val): 348 | self.name2val[key] = val 349 | 350 | def logkv_mean(self, key, val): 351 | oldval, cnt = self.name2val[key], self.name2cnt[key] 352 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) 353 | self.name2cnt[key] = cnt + 1 354 | 355 | def dumpkvs(self): 356 | if self.comm is None: 357 | d = self.name2val 358 | else: 359 | d = mpi_weighted_mean( 360 | self.comm, 361 | { 362 | name: (val, self.name2cnt.get(name, 1)) 363 | for (name, val) in self.name2val.items() 364 | }, 365 | ) 366 | if self.comm.rank != 0: 367 | d["dummy"] = 1 # so we don't get a warning about empty dict 368 | out = d.copy() # Return the dict for unit testing purposes 369 | for fmt in self.output_formats: 370 | if isinstance(fmt, KVWriter): 371 | fmt.writekvs(d) 372 | self.name2val.clear() 373 | self.name2cnt.clear() 374 | return out 375 | 376 | def log(self, *args, level=INFO): 377 | if self.level <= level: 378 | self._do_log(args) 379 | 380 | # Configuration 381 | # ---------------------------------------- 382 | def set_level(self, level): 383 | self.level = level 384 | 385 | def set_comm(self, comm): 386 | self.comm = comm 387 | 388 | def get_dir(self): 389 | return self.dir 390 | 391 | def close(self): 392 | for fmt in self.output_formats: 393 | fmt.close() 394 | 395 | # Misc 396 | # ---------------------------------------- 397 | def _do_log(self, args): 398 | for fmt in self.output_formats: 399 | if isinstance(fmt, SeqWriter): 400 | fmt.writeseq(map(str, args)) 401 | 402 | 403 | def get_rank_without_mpi_import(): 404 | # check environment variables here instead of importing mpi4py 405 | # to avoid calling MPI_Init() when this module is imported 406 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]: 407 | if varname in os.environ: 408 | return int(os.environ[varname]) 409 | return 0 410 | 411 | 412 | def mpi_weighted_mean(comm, local_name2valcount): 413 | """ 414 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110 415 | Perform a weighted average over dicts that are each on a different node 416 | Input: local_name2valcount: dict mapping key -> (value, count) 417 | Returns: key -> mean 418 | """ 419 | all_name2valcount = comm.gather(local_name2valcount) 420 | if comm.rank == 0: 421 | name2sum = defaultdict(float) 422 | name2count = defaultdict(float) 423 | for n2vc in all_name2valcount: 424 | for (name, (val, count)) in n2vc.items(): 425 | try: 426 | val = float(val) 427 | except ValueError: 428 | if comm.rank == 0: 429 | warnings.warn( 430 | "WARNING: tried to compute mean on non-float {}={}".format( 431 | name, val 432 | ) 433 | ) 434 | else: 435 | name2sum[name] += val * count 436 | name2count[name] += count 437 | return {name: name2sum[name] / name2count[name] for name in name2sum} 438 | else: 439 | return {} 440 | 441 | 442 | def configure(dir='./results', format_strs=None, comm=None, log_suffix=""): 443 | """ 444 | If comm is provided, average all numerical stats across that comm 445 | """ 446 | if dir is None: 447 | dir = os.getenv("OPENAI_LOGDIR") 448 | if dir is None: 449 | dir = osp.join( 450 | tempfile.gettempdir(), 451 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"), 452 | ) 453 | assert isinstance(dir, str) 454 | dir = os.path.expanduser(dir) 455 | os.makedirs(os.path.expanduser(dir), exist_ok=True) 456 | 457 | rank = get_rank_without_mpi_import() 458 | if rank > 0: 459 | log_suffix = log_suffix + "-rank%03i" % rank 460 | 461 | if format_strs is None: 462 | if rank == 0: 463 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",") 464 | else: 465 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",") 466 | format_strs = filter(None, format_strs) 467 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] 468 | 469 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) 470 | if output_formats: 471 | log("Logging to %s" % dir) 472 | 473 | 474 | def _configure_default_logger(): 475 | configure() 476 | Logger.DEFAULT = Logger.CURRENT 477 | 478 | 479 | def reset(): 480 | if Logger.CURRENT is not Logger.DEFAULT: 481 | Logger.CURRENT.close() 482 | Logger.CURRENT = Logger.DEFAULT 483 | log("Reset logger") 484 | 485 | 486 | @contextmanager 487 | def scoped_configure(dir=None, format_strs=None, comm=None): 488 | prevlogger = Logger.CURRENT 489 | configure(dir=dir, format_strs=format_strs, comm=comm) 490 | try: 491 | yield 492 | finally: 493 | Logger.CURRENT.close() 494 | Logger.CURRENT = prevlogger 495 | 496 | -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for various likelihood-based losses. These are ported from the original 3 | Ho et al. diffusion models codebase: 4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 5 | """ 6 | 7 | import numpy as np 8 | 9 | import torch as th 10 | 11 | 12 | def normal_kl(mean1, logvar1, mean2, logvar2): 13 | """ 14 | Compute the KL divergence between two gaussians. 15 | 16 | Shapes are automatically broadcasted, so batches can be compared to 17 | scalars, among other use cases. 18 | """ 19 | tensor = None 20 | for obj in (mean1, logvar1, mean2, logvar2): 21 | if isinstance(obj, th.Tensor): 22 | tensor = obj 23 | break 24 | assert tensor is not None, "at least one argument must be a Tensor" 25 | 26 | # Force variances to be Tensors. Broadcasting helps convert scalars to 27 | # Tensors, but it does not work for th.exp(). 28 | logvar1, logvar2 = [ 29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 30 | for x in (logvar1, logvar2) 31 | ] 32 | 33 | return 0.5 * ( 34 | -1.0 35 | + logvar2 36 | - logvar1 37 | + th.exp(logvar1 - logvar2) 38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 39 | ) 40 | 41 | 42 | def approx_standard_normal_cdf(x): 43 | """ 44 | A fast approximation of the cumulative distribution function of the 45 | standard normal. 46 | """ 47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 48 | 49 | 50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 51 | """ 52 | Compute the log-likelihood of a Gaussian distribution discretizing to a 53 | given image. 54 | 55 | :param x: the target images. It is assumed that this was uint8 values, 56 | rescaled to the range [-1, 1]. 57 | :param means: the Gaussian mean Tensor. 58 | :param log_scales: the Gaussian log stddev Tensor. 59 | :return: a tensor like x of log probabilities (in nats). 60 | """ 61 | assert x.shape == means.shape == log_scales.shape 62 | centered_x = x - means 63 | inv_stdv = th.exp(-log_scales) 64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 65 | cdf_plus = approx_standard_normal_cdf(plus_in) 66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 67 | cdf_min = approx_standard_normal_cdf(min_in) 68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 70 | cdf_delta = cdf_plus - cdf_min 71 | log_probs = th.where( 72 | x < -0.999, 73 | log_cdf_plus, 74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 75 | ) 76 | assert log_probs.shape == x.shape 77 | return log_probs 78 | -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | 11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 12 | class SiLU(nn.Module): 13 | def forward(self, x): 14 | return x * th.sigmoid(x) 15 | 16 | 17 | class GroupNorm32(nn.GroupNorm): 18 | def forward(self, x): 19 | return super().forward(x.float()).type(x.dtype) 20 | 21 | 22 | def conv_nd(dims, *args, **kwargs): 23 | """ 24 | Create a 1D, 2D, or 3D convolution module. 25 | """ 26 | if dims == 1: 27 | return nn.Conv1d(*args, **kwargs) 28 | elif dims == 2: 29 | return nn.Conv2d(*args, **kwargs) 30 | elif dims == 3: 31 | return nn.Conv3d(*args, **kwargs) 32 | raise ValueError(f"unsupported dimensions: {dims}") 33 | 34 | def layer_norm(shape, *args, **kwargs): 35 | 36 | return nn.LayerNorm(shape, *args, **kwargs) 37 | 38 | def linear(*args, **kwargs): 39 | """ 40 | Create a linear module. 41 | """ 42 | return nn.Linear(*args, **kwargs) 43 | 44 | 45 | def avg_pool_nd(dims, *args, **kwargs): 46 | """ 47 | Create a 1D, 2D, or 3D average pooling module. 48 | """ 49 | if dims == 1: 50 | return nn.AvgPool1d(*args, **kwargs) 51 | elif dims == 2: 52 | return nn.AvgPool2d(*args, **kwargs) 53 | elif dims == 3: 54 | return nn.AvgPool3d(*args, **kwargs) 55 | raise ValueError(f"unsupported dimensions: {dims}") 56 | 57 | 58 | def update_ema(target_params, source_params, rate=0.99): 59 | """ 60 | Update target parameters to be closer to those of source parameters using 61 | an exponential moving average. 62 | 63 | :param target_params: the target parameter sequence. 64 | :param source_params: the source parameter sequence. 65 | :param rate: the EMA rate (closer to 1 means slower). 66 | """ 67 | for targ, src in zip(target_params, source_params): 68 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 69 | 70 | 71 | def zero_module(module): 72 | """ 73 | Zero out the parameters of a module and return it. 74 | """ 75 | for p in module.parameters(): 76 | p.detach().zero_() 77 | return module 78 | 79 | 80 | def scale_module(module, scale): 81 | """ 82 | Scale the parameters of a module and return it. 83 | """ 84 | for p in module.parameters(): 85 | p.detach().mul_(scale) 86 | return module 87 | 88 | 89 | def mean_flat(tensor): 90 | """ 91 | Take the mean over all non-batch dimensions. 92 | """ 93 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 94 | 95 | 96 | def normalization(channels): 97 | """ 98 | Make a standard normalization layer. 99 | 100 | :param channels: number of input channels. 101 | :return: an nn.Module for normalization. 102 | """ 103 | return GroupNorm32(32, channels) 104 | 105 | 106 | def timestep_embedding(timesteps, dim, max_period=10000): 107 | """ 108 | Create sinusoidal timestep embeddings. 109 | 110 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 111 | These may be fractional. 112 | :param dim: the dimension of the output. 113 | :param max_period: controls the minimum frequency of the embeddings. 114 | :return: an [N x dim] Tensor of positional embeddings. 115 | """ 116 | half = dim // 2 117 | freqs = th.exp( 118 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 119 | ).to(device=timesteps.device) 120 | args = timesteps[:, None].float() * freqs[None] 121 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 122 | if dim % 2: 123 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 124 | return embedding 125 | 126 | 127 | def checkpoint(func, inputs, params, flag): 128 | """ 129 | Evaluate a function without caching intermediate activations, allowing for 130 | reduced memory at the expense of extra compute in the backward pass. 131 | 132 | :param func: the function to evaluate. 133 | :param inputs: the argument sequence to pass to `func`. 134 | :param params: a sequence of parameters `func` depends on but does not 135 | explicitly take as arguments. 136 | :param flag: if False, disable gradient checkpointing. 137 | """ 138 | if flag: 139 | args = tuple(inputs) + tuple(params) 140 | return CheckpointFunction.apply(func, len(inputs), *args) 141 | else: 142 | return func(*inputs) 143 | 144 | 145 | class CheckpointFunction(th.autograd.Function): 146 | @staticmethod 147 | def forward(ctx, run_function, length, *args): 148 | ctx.run_function = run_function 149 | ctx.input_tensors = list(args[:length]) 150 | ctx.input_params = list(args[length:]) 151 | with th.no_grad(): 152 | output_tensors = ctx.run_function(*ctx.input_tensors) 153 | return output_tensors 154 | 155 | @staticmethod 156 | def backward(ctx, *output_grads): 157 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 158 | with th.enable_grad(): 159 | # Fixes a bug where the first op in run_function modifies the 160 | # Tensor storage in place, which is not allowed for detach()'d 161 | # Tensors. 162 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 163 | output_tensors = ctx.run_function(*shallow_copies) 164 | input_grads = th.autograd.grad( 165 | output_tensors, 166 | ctx.input_tensors + ctx.input_params, 167 | output_grads, 168 | allow_unused=True, 169 | ) 170 | del ctx.input_tensors 171 | del ctx.input_params 172 | del output_tensors 173 | return (None, None) + input_grads 174 | -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/resample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | import torch.distributed as dist 6 | 7 | 8 | def create_named_schedule_sampler(name, diffusion, maxt): 9 | """ 10 | Create a ScheduleSampler from a library of pre-defined samplers. 11 | 12 | :param name: the name of the sampler. 13 | :param diffusion: the diffusion object to sample for. 14 | """ 15 | if name == "uniform": 16 | return UniformSampler(diffusion, maxt) 17 | elif name == "loss-second-moment": 18 | return LossSecondMomentResampler(diffusion) 19 | else: 20 | raise NotImplementedError(f"unknown schedule sampler: {name}") 21 | 22 | 23 | class ScheduleSampler(ABC): 24 | """ 25 | A distribution over timesteps in the diffusion process, intended to reduce 26 | variance of the objective. 27 | 28 | By default, samplers perform unbiased importance sampling, in which the 29 | objective's mean is unchanged. 30 | However, subclasses may override sample() to change how the resampled 31 | terms are reweighted, allowing for actual changes in the objective. 32 | """ 33 | 34 | @abstractmethod 35 | def weights(self): 36 | """ 37 | Get a numpy array of weights, one per diffusion step. 38 | 39 | The weights needn't be normalized, but must be positive. 40 | """ 41 | 42 | def sample(self, batch_size, device): 43 | """ 44 | Importance-sample timesteps for a batch. 45 | 46 | :param batch_size: the number of timesteps. 47 | :param device: the torch device to save to. 48 | :return: a tuple (timesteps, weights): 49 | - timesteps: a tensor of timestep indices. 50 | - weights: a tensor of weights to scale the resulting losses. 51 | """ 52 | w = self.weights() 53 | p = w / np.sum(w) 54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 55 | indices = th.from_numpy(indices_np).long().to(device) 56 | weights_np = 1 / (len(p) * p[indices_np]) 57 | weights = th.from_numpy(weights_np).float().to(device) 58 | return indices, weights 59 | 60 | 61 | class UniformSampler(ScheduleSampler): 62 | def __init__(self, diffusion, maxt): 63 | self.diffusion = diffusion 64 | self._weights = np.ones([maxt]) 65 | 66 | def weights(self): 67 | return self._weights 68 | 69 | 70 | class LossAwareSampler(ScheduleSampler): 71 | def update_with_local_losses(self, local_ts, local_losses): 72 | """ 73 | Update the reweighting using losses from a model. 74 | 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | 80 | :param local_ts: an integer Tensor of timesteps. 81 | :param local_losses: a 1D Tensor of losses. 82 | """ 83 | batch_sizes = [ 84 | th.tensor([0], dtype=th.int32, device=local_ts.device) 85 | for _ in range(dist.get_world_size()) 86 | ] 87 | dist.all_gather( 88 | batch_sizes, 89 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 90 | ) 91 | 92 | # Pad all_gather batches to be the maximum batch size. 93 | batch_sizes = [x.item() for x in batch_sizes] 94 | max_bs = max(batch_sizes) 95 | 96 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 97 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 98 | dist.all_gather(timestep_batches, local_ts) 99 | dist.all_gather(loss_batches, local_losses) 100 | timesteps = [ 101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 102 | ] 103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 104 | self.update_with_all_losses(timesteps, losses) 105 | 106 | @abstractmethod 107 | def update_with_all_losses(self, ts, losses): 108 | """ 109 | Update the reweighting using losses from a model. 110 | 111 | Sub-classes should override this method to update the reweighting 112 | using losses from the model. 113 | 114 | This method directly updates the reweighting without synchronizing 115 | between workers. It is called by update_with_local_losses from all 116 | ranks with identical arguments. Thus, it should have deterministic 117 | behavior to maintain state across workers. 118 | 119 | :param ts: a list of int timesteps. 120 | :param losses: a list of float losses, one per timestep. 121 | """ 122 | 123 | 124 | class LossSecondMomentResampler(LossAwareSampler): 125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 126 | self.diffusion = diffusion 127 | self.history_per_term = history_per_term 128 | self.uniform_prob = uniform_prob 129 | self._loss_history = np.zeros( 130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 131 | ) 132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 133 | 134 | def weights(self): 135 | if not self._warmed_up(): 136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 138 | weights /= np.sum(weights) 139 | weights *= 1 - self.uniform_prob 140 | weights += self.uniform_prob / len(weights) 141 | return weights 142 | 143 | def update_with_all_losses(self, ts, losses): 144 | for t, loss in zip(ts, losses): 145 | if self._loss_counts[t] == self.history_per_term: 146 | # Shift out the oldest loss term. 147 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 148 | self._loss_history[t, -1] = loss 149 | else: 150 | self._loss_history[t, self._loss_counts[t]] = loss 151 | self._loss_counts[t] += 1 152 | 153 | def _warmed_up(self): 154 | return (self._loss_counts == self.history_per_term).all() 155 | -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/respace.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | 4 | from .gaussian_diffusion import GaussianDiffusion 5 | 6 | 7 | def space_timesteps(num_timesteps, section_counts): 8 | """ 9 | Create a list of timesteps to use from an original diffusion process, 10 | given the number of timesteps we want to take from equally-sized portions 11 | of the original process. 12 | 13 | For example, if there's 300 timesteps and the section counts are [10,15,20] 14 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 15 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 16 | 17 | If the stride is a string starting with "ddim", then the fixed striding 18 | from the DDIM paper is used, and only one section is allowed. 19 | 20 | :param num_timesteps: the number of diffusion steps in the original 21 | process to divide up. 22 | :param section_counts: either a list of numbers, or a string containing 23 | comma-separated numbers, indicating the step count 24 | per section. As a special case, use "ddimN" where N 25 | is a number of steps to use the striding from the 26 | DDIM paper. 27 | :return: a set of diffusion steps from the original process to use. 28 | """ 29 | if isinstance(section_counts, str): 30 | if section_counts.startswith("ddim"): 31 | desired_count = int(section_counts[len("ddim") :]) 32 | for i in range(1, num_timesteps): 33 | if len(range(0, num_timesteps, i)) == desired_count: 34 | return set(range(0, num_timesteps, i)) 35 | raise ValueError( 36 | f"cannot create exactly {num_timesteps} steps with an integer stride" 37 | ) 38 | section_counts = [int(x) for x in section_counts.split(",")] 39 | size_per = num_timesteps // len(section_counts) 40 | extra = num_timesteps % len(section_counts) 41 | start_idx = 0 42 | all_steps = [] 43 | for i, section_count in enumerate(section_counts): 44 | size = size_per + (1 if i < extra else 0) 45 | if size < section_count: 46 | raise ValueError( 47 | f"cannot divide section of {size} steps into {section_count}" 48 | ) 49 | if section_count <= 1: 50 | frac_stride = 1 51 | else: 52 | frac_stride = (size - 1) / (section_count - 1) 53 | cur_idx = 0.0 54 | taken_steps = [] 55 | for _ in range(section_count): 56 | taken_steps.append(start_idx + round(cur_idx)) 57 | cur_idx += frac_stride 58 | all_steps += taken_steps 59 | start_idx += size 60 | return set(all_steps) 61 | 62 | 63 | class SpacedDiffusion(GaussianDiffusion): 64 | """ 65 | A diffusion process which can skip steps in a base diffusion process. 66 | 67 | :param use_timesteps: a collection (sequence or set) of timesteps from the 68 | original diffusion process to retain. 69 | :param kwargs: the kwargs to create the base diffusion process. 70 | """ 71 | 72 | def __init__(self, use_timesteps, **kwargs): 73 | self.use_timesteps = set(use_timesteps) 74 | self.timestep_map = [] 75 | self.original_num_steps = len(kwargs["betas"]) 76 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 77 | last_alpha_cumprod = 1.0 78 | new_betas = [] 79 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 80 | if i in self.use_timesteps: 81 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 82 | last_alpha_cumprod = alpha_cumprod 83 | self.timestep_map.append(i) 84 | kwargs["betas"] = np.array(new_betas) 85 | super().__init__(**kwargs) 86 | 87 | def p_mean_variance( 88 | self, model, *args, **kwargs 89 | ): # pylint: disable=signature-differs 90 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 91 | 92 | def training_losses( 93 | self, model, *args, **kwargs 94 | ): # pylint: disable=signature-differs 95 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 96 | 97 | def condition_mean(self, cond_fn, *args, **kwargs): 98 | return super().condition_mean(self._wrap_model2(cond_fn), *args, **kwargs) 99 | 100 | def condition_score(self, cond_fn, *args, **kwargs): 101 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 102 | 103 | def _wrap_model(self, model): 104 | if isinstance(model, _WrappedModel): 105 | return model 106 | return _WrappedModel( 107 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps 108 | ) 109 | def _wrap_model2(self, model): 110 | if isinstance(model, _WrappedModel2): 111 | return model 112 | return _WrappedModel2( 113 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps 114 | ) 115 | 116 | def _scale_timesteps(self, t): 117 | # Scaling is done by the wrapped model. 118 | return t 119 | 120 | 121 | class _WrappedModel: 122 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): 123 | self.model = model 124 | self.timestep_map = timestep_map 125 | self.rescale_timesteps = rescale_timesteps 126 | self.original_num_steps = original_num_steps 127 | 128 | 129 | def __call__(self, x, ts, **kwargs): 130 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 131 | new_ts = map_tensor[ts] 132 | 133 | if self.rescale_timesteps: 134 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 135 | return self.model(x, new_ts, **kwargs) 136 | 137 | 138 | 139 | class _WrappedModel2: 140 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): 141 | self.model = model 142 | self.timestep_map = timestep_map 143 | self.rescale_timesteps = rescale_timesteps 144 | self.original_num_steps = original_num_steps 145 | 146 | def __call__(self, x, ts, org, **kwargs): 147 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 148 | new_ts = map_tensor[ts] 149 | if self.rescale_timesteps: 150 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 151 | return self.model(x, new_ts,org, **kwargs) 152 | -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/script_util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import inspect 3 | 4 | from . import gaussian_diffusion as gd 5 | from .respace import SpacedDiffusion, space_timesteps 6 | from .unet import SuperResModel, UNetModel_newpreview, UNetModel_v1preview, EncoderUNetModel 7 | from . import dpm_solver 8 | NUM_CLASSES = 2 9 | 10 | 11 | def diffusion_defaults(): 12 | """ 13 | Defaults for image and classifier training. 14 | """ 15 | return dict( 16 | learn_sigma=False, 17 | diffusion_steps=1000, 18 | noise_schedule="linear", 19 | timestep_respacing="", 20 | use_kl=False, 21 | predict_xstart=False, 22 | rescale_timesteps=False, 23 | rescale_learned_sigmas=False, 24 | ) 25 | 26 | 27 | def classifier_defaults(): 28 | """ 29 | Defaults for classifier models. 30 | """ 31 | return dict( 32 | image_size=64, 33 | classifier_use_fp16=False, 34 | classifier_width=128, 35 | classifier_depth=2, 36 | classifier_attention_resolutions="32,16,8", # 16 37 | classifier_use_scale_shift_norm=True, # False 38 | classifier_resblock_updown=True, # False 39 | classifier_pool="spatial", 40 | ) 41 | 42 | 43 | def model_and_diffusion_defaults(): 44 | """ 45 | Defaults for image training. 46 | """ 47 | res = dict( 48 | image_size=256, 49 | num_channels=128, 50 | num_res_blocks=2, 51 | num_heads=4, 52 | in_ch = 5, 53 | num_heads_upsample=-1, 54 | num_head_channels=-1, 55 | attention_resolutions="16,8", 56 | channel_mult="", 57 | dropout=0.0, 58 | class_cond=False, 59 | use_checkpoint=False, 60 | use_scale_shift_norm=True, 61 | resblock_updown=False, 62 | use_fp16=False, 63 | use_new_attention_order=False, 64 | dpm_solver = False, 65 | version = 'new', 66 | ) 67 | res.update(diffusion_defaults()) 68 | return res 69 | 70 | 71 | def classifier_and_diffusion_defaults(): 72 | res = classifier_defaults() 73 | res.update(diffusion_defaults()) 74 | return res 75 | 76 | 77 | def create_model_and_diffusion( 78 | image_size, 79 | class_cond, 80 | learn_sigma, 81 | num_channels, 82 | num_res_blocks, 83 | channel_mult, 84 | in_ch, 85 | num_heads, 86 | num_head_channels, 87 | num_heads_upsample, 88 | attention_resolutions, 89 | dropout, 90 | diffusion_steps, 91 | noise_schedule, 92 | timestep_respacing, 93 | use_kl, 94 | predict_xstart, 95 | rescale_timesteps, 96 | rescale_learned_sigmas, 97 | use_checkpoint, 98 | use_scale_shift_norm, 99 | resblock_updown, 100 | use_fp16, 101 | use_new_attention_order, 102 | dpm_solver, 103 | version, 104 | ): 105 | model = create_model( 106 | image_size, 107 | num_channels, 108 | num_res_blocks, 109 | channel_mult=channel_mult, 110 | learn_sigma=learn_sigma, 111 | class_cond=class_cond, 112 | use_checkpoint=use_checkpoint, 113 | attention_resolutions=attention_resolutions, 114 | in_ch = in_ch, 115 | num_heads=num_heads, 116 | num_head_channels=num_head_channels, 117 | num_heads_upsample=num_heads_upsample, 118 | use_scale_shift_norm=use_scale_shift_norm, 119 | dropout=dropout, 120 | resblock_updown=resblock_updown, 121 | use_fp16=use_fp16, 122 | use_new_attention_order=use_new_attention_order, 123 | version = version, 124 | ) 125 | diffusion = create_gaussian_diffusion( 126 | steps=diffusion_steps, 127 | learn_sigma=learn_sigma, 128 | noise_schedule=noise_schedule, 129 | use_kl=use_kl, 130 | predict_xstart=predict_xstart, 131 | rescale_timesteps=rescale_timesteps, 132 | rescale_learned_sigmas=rescale_learned_sigmas, 133 | dpm_solver=dpm_solver, 134 | timestep_respacing=timestep_respacing, 135 | ) 136 | return model, diffusion 137 | 138 | 139 | def create_model( 140 | image_size, 141 | num_channels, 142 | num_res_blocks, 143 | channel_mult="", 144 | learn_sigma=False, 145 | class_cond=False, 146 | use_checkpoint=False, 147 | attention_resolutions="16", 148 | in_ch=4, #4 149 | num_heads=1, 150 | num_head_channels=-1, 151 | num_heads_upsample=-1, 152 | use_scale_shift_norm=False, 153 | dropout=0, 154 | resblock_updown=False, 155 | use_fp16=False, 156 | use_new_attention_order=False, 157 | version = 'new', # 'new' 158 | ): 159 | if channel_mult == "": 160 | if image_size == 512: 161 | channel_mult = (1, 1, 2, 2, 4, 4) 162 | elif image_size == 256: 163 | channel_mult = (1, 1, 2, 2, 4, 4) 164 | elif image_size == 128: 165 | channel_mult = (1, 1, 2, 3, 4) 166 | elif image_size == 64: 167 | channel_mult = (1, 2, 3, 4) 168 | else: 169 | raise ValueError(f"unsupported image size: {image_size}") 170 | else: 171 | channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(",")) 172 | 173 | attention_ds = [] 174 | for res in attention_resolutions.split(","): 175 | attention_ds.append(image_size // int(res)) 176 | 177 | return UNetModel_newpreview( 178 | image_size=image_size, 179 | in_channels=in_ch, 180 | model_channels=num_channels, 181 | out_channels=2,#(3 if not learn_sigma else 6), 182 | num_res_blocks=num_res_blocks, 183 | attention_resolutions=tuple(attention_ds), 184 | dropout=dropout, 185 | channel_mult=channel_mult, 186 | num_classes=(NUM_CLASSES if class_cond else None), 187 | use_checkpoint=use_checkpoint, 188 | use_fp16=use_fp16, 189 | num_heads=num_heads, 190 | num_head_channels=num_head_channels, 191 | num_heads_upsample=num_heads_upsample, 192 | use_scale_shift_norm=use_scale_shift_norm, 193 | resblock_updown=resblock_updown, 194 | use_new_attention_order=use_new_attention_order, 195 | ) if version == 'new' else UNetModel_v1preview( 196 | image_size=image_size, 197 | in_channels=in_ch, 198 | model_channels=num_channels, 199 | out_channels=2,#(3 if not learn_sigma else 6), 200 | num_res_blocks=num_res_blocks, 201 | attention_resolutions=tuple(attention_ds), 202 | dropout=dropout, 203 | channel_mult=channel_mult, 204 | num_classes=(NUM_CLASSES if class_cond else None), 205 | use_checkpoint=use_checkpoint, 206 | use_fp16=use_fp16, 207 | num_heads=num_heads, 208 | num_head_channels=num_head_channels, 209 | num_heads_upsample=num_heads_upsample, 210 | use_scale_shift_norm=use_scale_shift_norm, 211 | resblock_updown=resblock_updown, 212 | use_new_attention_order=use_new_attention_order, 213 | ) 214 | 215 | 216 | def create_classifier_and_diffusion( 217 | image_size, 218 | classifier_use_fp16, 219 | classifier_width, 220 | classifier_depth, 221 | classifier_attention_resolutions, 222 | classifier_use_scale_shift_norm, 223 | classifier_resblock_updown, 224 | classifier_pool, 225 | learn_sigma, 226 | diffusion_steps, 227 | noise_schedule, 228 | timestep_respacing, 229 | use_kl, 230 | predict_xstart, 231 | rescale_timesteps, 232 | rescale_learned_sigmas, 233 | ): 234 | classifier = create_classifier( 235 | image_size, 236 | classifier_use_fp16, 237 | classifier_width, 238 | classifier_depth, 239 | classifier_attention_resolutions, 240 | classifier_use_scale_shift_norm, 241 | classifier_resblock_updown, 242 | classifier_pool, 243 | ) 244 | diffusion = create_gaussian_diffusion( 245 | steps=diffusion_steps, 246 | learn_sigma=learn_sigma, 247 | noise_schedule=noise_schedule, 248 | use_kl=use_kl, 249 | predict_xstart=predict_xstart, 250 | rescale_timesteps=rescale_timesteps, 251 | rescale_learned_sigmas=rescale_learned_sigmas, 252 | timestep_respacing=timestep_respacing, 253 | ) 254 | return classifier, diffusion 255 | 256 | 257 | def create_classifier( 258 | image_size, 259 | classifier_use_fp16, 260 | classifier_width, 261 | classifier_depth, 262 | classifier_attention_resolutions, 263 | classifier_use_scale_shift_norm, 264 | classifier_resblock_updown, 265 | classifier_pool, 266 | ): 267 | if image_size == 256: 268 | channel_mult = (1, 1, 2, 2, 4, 4) 269 | elif image_size == 128: 270 | channel_mult = (1, 1, 2, 3, 4) 271 | elif image_size == 64: 272 | channel_mult = (1, 2, 3, 4) 273 | else: 274 | raise ValueError(f"unsupported image size: {image_size}") 275 | 276 | attention_ds = [] 277 | for res in classifier_attention_resolutions.split(","): 278 | attention_ds.append(image_size // int(res)) 279 | 280 | return EncoderUNetModel( 281 | image_size=image_size, 282 | in_channels=3, 283 | model_channels=classifier_width, 284 | out_channels=2,#1000, 285 | num_res_blocks=classifier_depth, 286 | attention_resolutions=tuple(attention_ds), 287 | channel_mult=channel_mult, 288 | use_fp16=classifier_use_fp16, 289 | num_head_channels=64, 290 | use_scale_shift_norm=classifier_use_scale_shift_norm, 291 | resblock_updown=classifier_resblock_updown, 292 | pool=classifier_pool, 293 | ) 294 | 295 | 296 | def sr_model_and_diffusion_defaults(): 297 | res = model_and_diffusion_defaults() 298 | res["large_size"] = 256 299 | res["small_size"] = 64 300 | arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0] 301 | for k in res.copy().keys(): 302 | if k not in arg_names: 303 | del res[k] 304 | return res 305 | 306 | 307 | def sr_create_model_and_diffusion( 308 | large_size, 309 | small_size, 310 | class_cond, 311 | learn_sigma, 312 | num_channels, 313 | num_res_blocks, 314 | num_heads, 315 | num_head_channels, 316 | num_heads_upsample, 317 | attention_resolutions, 318 | dropout, 319 | diffusion_steps, 320 | noise_schedule, 321 | timestep_respacing, 322 | use_kl, 323 | predict_xstart, 324 | rescale_timesteps, 325 | rescale_learned_sigmas, 326 | use_checkpoint, 327 | use_scale_shift_norm, 328 | resblock_updown, 329 | use_fp16, 330 | ): 331 | model = sr_create_model( 332 | large_size, 333 | small_size, 334 | num_channels, 335 | num_res_blocks, 336 | learn_sigma=learn_sigma, 337 | class_cond=class_cond, 338 | use_checkpoint=use_checkpoint, 339 | attention_resolutions=attention_resolutions, 340 | num_heads=num_heads, 341 | num_head_channels=num_head_channels, 342 | num_heads_upsample=num_heads_upsample, 343 | use_scale_shift_norm=use_scale_shift_norm, 344 | dropout=dropout, 345 | resblock_updown=resblock_updown, 346 | use_fp16=use_fp16, 347 | ) 348 | diffusion = create_gaussian_diffusion( 349 | steps=diffusion_steps, 350 | learn_sigma=learn_sigma, 351 | noise_schedule=noise_schedule, 352 | use_kl=use_kl, 353 | predict_xstart=predict_xstart, 354 | dpm_solver = dpm_solver, 355 | rescale_timesteps=rescale_timesteps, 356 | rescale_learned_sigmas=rescale_learned_sigmas, 357 | timestep_respacing=timestep_respacing, 358 | ) 359 | return model, diffusion 360 | 361 | 362 | def sr_create_model( 363 | large_size, 364 | small_size, 365 | num_channels, 366 | num_res_blocks, 367 | learn_sigma, 368 | class_cond, 369 | use_checkpoint, 370 | attention_resolutions, 371 | num_heads, 372 | num_head_channels, 373 | num_heads_upsample, 374 | use_scale_shift_norm, 375 | dropout, 376 | resblock_updown, 377 | use_fp16, 378 | ): 379 | _ = small_size # hack to prevent unused variable 380 | 381 | if large_size == 512: 382 | channel_mult = (1, 1, 2, 2, 4, 4) 383 | elif large_size == 256: 384 | channel_mult = (1, 1, 2, 2, 4, 4) 385 | elif large_size == 64: 386 | channel_mult = (1, 2, 3, 4) 387 | else: 388 | raise ValueError(f"unsupported large size: {large_size}") 389 | 390 | attention_ds = [] 391 | for res in attention_resolutions.split(","): 392 | attention_ds.append(large_size // int(res)) 393 | 394 | return SuperResModel( 395 | image_size=large_size, 396 | in_channels=3, 397 | model_channels=num_channels, 398 | out_channels=(3 if not learn_sigma else 6), 399 | num_res_blocks=num_res_blocks, 400 | attention_resolutions=tuple(attention_ds), 401 | dropout=dropout, 402 | channel_mult=channel_mult, 403 | num_classes=(NUM_CLASSES if class_cond else None), 404 | use_checkpoint=use_checkpoint, 405 | num_heads=num_heads, 406 | num_head_channels=num_head_channels, 407 | num_heads_upsample=num_heads_upsample, 408 | use_scale_shift_norm=use_scale_shift_norm, 409 | resblock_updown=resblock_updown, 410 | use_fp16=use_fp16, 411 | ) 412 | 413 | 414 | def create_gaussian_diffusion( 415 | *, 416 | steps=1000, 417 | learn_sigma=False, 418 | sigma_small=False, 419 | noise_schedule="linear", 420 | use_kl=False, 421 | predict_xstart=False, 422 | dpm_solver = False, 423 | rescale_timesteps=False, 424 | rescale_learned_sigmas=False, 425 | timestep_respacing="", 426 | ): 427 | betas = gd.get_named_beta_schedule(noise_schedule, steps) 428 | if use_kl: 429 | loss_type = gd.LossType.RESCALED_KL 430 | elif rescale_learned_sigmas: 431 | loss_type = gd.LossType.RESCALED_MSE 432 | else: 433 | loss_type = gd.LossType.MSE 434 | if not timestep_respacing: 435 | timestep_respacing = [steps] 436 | return SpacedDiffusion( 437 | use_timesteps=space_timesteps(steps, timestep_respacing), 438 | betas=betas, 439 | model_mean_type=( 440 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 441 | ), 442 | model_var_type=( 443 | ( 444 | gd.ModelVarType.FIXED_LARGE 445 | if not sigma_small 446 | else gd.ModelVarType.FIXED_SMALL 447 | ) 448 | if not learn_sigma 449 | else gd.ModelVarType.LEARNED_RANGE 450 | ), 451 | loss_type=loss_type, 452 | dpm_solver=dpm_solver, 453 | rescale_timesteps=rescale_timesteps, 454 | ) 455 | 456 | 457 | def add_dict_to_argparser(parser, default_dict): 458 | for k, v in default_dict.items(): 459 | v_type = type(v) 460 | if v is None: 461 | v_type = str 462 | elif isinstance(v, bool): 463 | v_type = str2bool 464 | parser.add_argument(f"--{k}", default=v, type=v_type) 465 | 466 | 467 | def args_to_dict(args, keys): 468 | return {k: getattr(args, k) for k in keys} 469 | 470 | 471 | def str2bool(v): 472 | """ 473 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 474 | """ 475 | if isinstance(v, bool): 476 | return v 477 | if v.lower() in ("yes", "true", "t", "y", "1"): 478 | return True 479 | elif v.lower() in ("no", "false", "f", "n", "0"): 480 | return False 481 | else: 482 | raise argparse.ArgumentTypeError("boolean value expected") 483 | -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/train_util.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import functools 3 | import os 4 | 5 | import blobfile as bf 6 | import torch as th 7 | import torch.distributed as dist 8 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP 9 | from torch.optim import AdamW 10 | 11 | from . import dist_util, logger 12 | from .fp16_util import MixedPrecisionTrainer 13 | from .nn import update_ema 14 | from .resample import LossAwareSampler, UniformSampler 15 | # from visdom import Visdom 16 | # viz = Visdom(port=8850) 17 | # loss_window = viz.line( Y=th.zeros((1)).cpu(), X=th.zeros((1)).cpu(), opts=dict(xlabel='epoch', ylabel='Loss', title='loss')) 18 | # grad_window = viz.line(Y=th.zeros((1)).cpu(), X=th.zeros((1)).cpu(), 19 | # opts=dict(xlabel='step', ylabel='amplitude', title='gradient')) 20 | 21 | 22 | # For ImageNet experiments, this was a good default value. 23 | # We found that the lg_loss_scale quickly climbed to 24 | # 20-21 within the first ~1K steps of training. 25 | INITIAL_LOG_LOSS_SCALE = 20.0 26 | 27 | def visualize(img): 28 | _min = img.min() 29 | _max = img.max() 30 | normalized_img = (img - _min)/ (_max - _min) 31 | return normalized_img 32 | 33 | class TrainLoop: 34 | def __init__( 35 | self, 36 | *, 37 | model, 38 | classifier, 39 | diffusion, 40 | data, 41 | dataloader, 42 | batch_size, 43 | microbatch, 44 | lr, 45 | ema_rate, 46 | log_interval, 47 | save_interval, 48 | resume_checkpoint, 49 | use_fp16=False, 50 | fp16_scale_growth=1e-3, 51 | schedule_sampler=None, 52 | weight_decay=0.0, 53 | lr_anneal_steps=0, 54 | ): 55 | self.model = model 56 | self.dataloader=dataloader 57 | self.classifier = classifier 58 | self.diffusion = diffusion 59 | self.data = data 60 | self.batch_size = batch_size 61 | self.microbatch = microbatch if microbatch > 0 else batch_size 62 | self.lr = lr 63 | self.ema_rate = ( 64 | [ema_rate] 65 | if isinstance(ema_rate, float) 66 | else [float(x) for x in ema_rate.split(",")] 67 | ) 68 | self.log_interval = log_interval 69 | self.save_interval = save_interval 70 | self.resume_checkpoint = resume_checkpoint 71 | self.use_fp16 = use_fp16 72 | self.fp16_scale_growth = fp16_scale_growth 73 | self.schedule_sampler = schedule_sampler or UniformSampler(diffusion) 74 | self.weight_decay = weight_decay 75 | self.lr_anneal_steps = lr_anneal_steps 76 | 77 | self.step = 0 78 | self.resume_step = 0 79 | self.global_batch = self.batch_size * dist.get_world_size() 80 | 81 | self.sync_cuda = th.cuda.is_available() 82 | 83 | self._load_and_sync_parameters() 84 | self.mp_trainer = MixedPrecisionTrainer( 85 | model=self.model, 86 | use_fp16=self.use_fp16, 87 | fp16_scale_growth=fp16_scale_growth, 88 | ) 89 | 90 | self.opt = AdamW( 91 | self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay 92 | ) 93 | if self.resume_step: 94 | self._load_optimizer_state() 95 | # Model was resumed, either due to a restart or a checkpoint 96 | # being specified at the command line. 97 | self.ema_params = [ 98 | self._load_ema_parameters(rate) for rate in self.ema_rate 99 | ] 100 | else: 101 | self.ema_params = [ 102 | copy.deepcopy(self.mp_trainer.master_params) 103 | for _ in range(len(self.ema_rate)) 104 | ] 105 | 106 | if th.cuda.is_available(): 107 | self.use_ddp = True 108 | self.ddp_model = DDP( 109 | self.model, 110 | device_ids=[dist_util.dev()], 111 | output_device=dist_util.dev(), 112 | broadcast_buffers=False, 113 | bucket_cap_mb=128, 114 | find_unused_parameters=False, 115 | ) 116 | else: 117 | if dist.get_world_size() > 1: 118 | logger.warn( 119 | "Distributed training requires CUDA. " 120 | "Gradients will not be synchronized properly!" 121 | ) 122 | self.use_ddp = False 123 | self.ddp_model = self.model 124 | 125 | def _load_and_sync_parameters(self): 126 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 127 | 128 | if resume_checkpoint: 129 | print('resume model') 130 | self.resume_step = parse_resume_step_from_filename(resume_checkpoint) 131 | if dist.get_rank() == 0: 132 | logger.log(f"loading model from checkpoint: {resume_checkpoint}...") 133 | self.model.load_part_state_dict( 134 | dist_util.load_state_dict( 135 | resume_checkpoint, map_location=dist_util.dev() 136 | ) 137 | ) 138 | 139 | dist_util.sync_params(self.model.parameters()) 140 | 141 | def _load_ema_parameters(self, rate): 142 | ema_params = copy.deepcopy(self.mp_trainer.master_params) 143 | 144 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 145 | ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate) 146 | if ema_checkpoint: 147 | if dist.get_rank() == 0: 148 | logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") 149 | state_dict = dist_util.load_state_dict( 150 | ema_checkpoint, map_location=dist_util.dev() 151 | ) 152 | ema_params = self.mp_trainer.state_dict_to_master_params(state_dict) 153 | 154 | dist_util.sync_params(ema_params) 155 | return ema_params 156 | 157 | def _load_optimizer_state(self): 158 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 159 | opt_checkpoint = bf.join( 160 | bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt" 161 | ) 162 | if bf.exists(opt_checkpoint): 163 | logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") 164 | state_dict = dist_util.load_state_dict( 165 | opt_checkpoint, map_location=dist_util.dev() 166 | ) 167 | self.opt.load_state_dict(state_dict) 168 | 169 | def run_loop(self): 170 | i = 0 171 | data_iter = iter(self.dataloader) 172 | while ( 173 | not self.lr_anneal_steps 174 | or self.step + self.resume_step < self.lr_anneal_steps 175 | ): 176 | 177 | 178 | try: 179 | batch, cond, name = next(data_iter) 180 | except StopIteration: 181 | # StopIteration is thrown if dataset ends 182 | # reinitialize data loader 183 | data_iter = iter(self.dataloader) 184 | batch, cond, name = next(data_iter) 185 | 186 | self.run_step(batch, cond) 187 | 188 | 189 | i += 1 190 | 191 | if self.step % self.log_interval == 0: 192 | logger.dumpkvs() 193 | # if self.step % self.save_interval == 0: 194 | if self.step == 90000 or self.step == 95000 or self.step == 100000: 195 | self.save() 196 | # Run for a finite amount of time in integration tests. 197 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: 198 | return 199 | self.step += 1 200 | # Save the last checkpoint if it wasn't already saved. 201 | if (self.step - 1) % self.save_interval != 0: 202 | self.save() 203 | 204 | def run_step(self, batch, cond): 205 | batch=th.cat((batch, cond), dim=1) 206 | 207 | cond={} 208 | sample = self.forward_backward(batch, cond) 209 | took_step = self.mp_trainer.optimize(self.opt) 210 | if took_step: 211 | self._update_ema() 212 | self._anneal_lr() 213 | self.log_step() 214 | return sample 215 | 216 | def forward_backward(self, batch, cond): 217 | 218 | self.mp_trainer.zero_grad() 219 | for i in range(0, batch.shape[0], self.microbatch): 220 | micro = batch[i : i + self.microbatch].to(dist_util.dev()) 221 | micro_cond = { 222 | k: v[i : i + self.microbatch].to(dist_util.dev()) 223 | for k, v in cond.items() 224 | } 225 | 226 | last_batch = (i + self.microbatch) >= batch.shape[0] 227 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) 228 | 229 | compute_losses = functools.partial( 230 | self.diffusion.training_losses_segmentation, 231 | self.ddp_model, 232 | self.classifier, 233 | micro, 234 | t, 235 | model_kwargs=micro_cond, 236 | ) 237 | 238 | if last_batch or not self.use_ddp: 239 | losses1 = compute_losses() 240 | 241 | else: 242 | with self.ddp_model.no_sync(): 243 | losses1 = compute_losses() 244 | 245 | if isinstance(self.schedule_sampler, LossAwareSampler): 246 | self.schedule_sampler.update_with_local_losses( 247 | t, losses1[0]["loss"].detach() 248 | ) 249 | losses = losses1[0] 250 | sample = losses1[1] 251 | 252 | loss = (losses["loss"] * weights + losses['loss_cal'] * 10).mean() 253 | 254 | log_loss_dict( 255 | self.diffusion, t, {k: v * weights for k, v in losses.items()} 256 | ) 257 | self.mp_trainer.backward(loss) 258 | for name, param in self.ddp_model.named_parameters(): 259 | if param.grad is None: 260 | print(name) 261 | return sample 262 | 263 | def _update_ema(self): 264 | for rate, params in zip(self.ema_rate, self.ema_params): 265 | update_ema(params, self.mp_trainer.master_params, rate=rate) 266 | 267 | def _anneal_lr(self): 268 | if not self.lr_anneal_steps: 269 | return 270 | frac_done = (self.step + self.resume_step) / self.lr_anneal_steps 271 | lr = self.lr * (1 - frac_done) 272 | for param_group in self.opt.param_groups: 273 | param_group["lr"] = lr 274 | 275 | def log_step(self): 276 | logger.logkv("step", self.step + self.resume_step) 277 | logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) 278 | 279 | def save(self): 280 | def save_checkpoint(rate, params): 281 | state_dict = self.mp_trainer.master_params_to_state_dict(params) 282 | if dist.get_rank() == 0: 283 | logger.log(f"saving model {rate}...") 284 | if not rate: 285 | filename = f"savedmodel{(self.step+self.resume_step):06d}.pt" 286 | else: 287 | filename = f"emasavedmodel_{rate}_{(self.step+self.resume_step):06d}.pt" 288 | with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: 289 | th.save(state_dict, f) 290 | 291 | save_checkpoint(0, self.mp_trainer.master_params) 292 | for rate, params in zip(self.ema_rate, self.ema_params): 293 | save_checkpoint(rate, params) 294 | 295 | if dist.get_rank() == 0: 296 | with bf.BlobFile( 297 | bf.join(get_blob_logdir(), f"optsavedmodel{(self.step+self.resume_step):06d}.pt"), 298 | "wb", 299 | ) as f: 300 | th.save(self.opt.state_dict(), f) 301 | 302 | dist.barrier() 303 | 304 | 305 | def parse_resume_step_from_filename(filename): 306 | """ 307 | Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the 308 | checkpoint's number of steps. 309 | """ 310 | split = filename.split("model") 311 | if len(split) < 2: 312 | return 0 313 | split1 = split[-1].split(".")[0] 314 | try: 315 | return int(split1) 316 | except ValueError: 317 | return 0 318 | 319 | 320 | def get_blob_logdir(): 321 | # You can change this to be a separate path to save checkpoints to 322 | # a blobstore or some external drive. 323 | return logger.get_dir() 324 | 325 | 326 | def find_resume_checkpoint(): 327 | # On your infrastructure, you may want to override this to automatically 328 | # discover the latest checkpoint on your blob storage, etc. 329 | return None 330 | 331 | 332 | def find_ema_checkpoint(main_checkpoint, step, rate): 333 | if main_checkpoint is None: 334 | return None 335 | filename = f"ema_{rate}_{(step):06d}.pt" 336 | path = bf.join(bf.dirname(main_checkpoint), filename) 337 | if bf.exists(path): 338 | return path 339 | return None 340 | 341 | 342 | def log_loss_dict(diffusion, ts, losses): 343 | for key, values in losses.items(): 344 | logger.logkv_mean(key, values.mean().item()) 345 | # Log the quantiles (four quartiles, in particular). 346 | for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): 347 | quartile = int(4 * sub_t / diffusion.num_timesteps) 348 | logger.logkv_mean(f"{key}_q{quartile}", sub_loss) 349 | 350 | -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision.utils as vutils 7 | from PIL import Image 8 | softmax_helper = lambda x: F.softmax(x, 1) 9 | sigmoid_helper = lambda x: F.sigmoid(x) 10 | 11 | 12 | class InitWeights_He(object): 13 | def __init__(self, neg_slope=1e-2): 14 | self.neg_slope = neg_slope 15 | 16 | def __call__(self, module): 17 | if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d): 18 | module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope) 19 | if module.bias is not None: 20 | module.bias = nn.init.constant_(module.bias, 0) 21 | 22 | def maybe_to_torch(d): 23 | if isinstance(d, list): 24 | d = [maybe_to_torch(i) if not isinstance(i, torch.Tensor) else i for i in d] 25 | elif not isinstance(d, torch.Tensor): 26 | d = torch.from_numpy(d).float() 27 | return d 28 | 29 | 30 | def to_cuda(data, non_blocking=True, gpu_id=0): 31 | if isinstance(data, list): 32 | data = [i.cuda(gpu_id, non_blocking=non_blocking) for i in data] 33 | else: 34 | data = data.cuda(gpu_id, non_blocking=non_blocking) 35 | return data 36 | 37 | 38 | class no_op(object): 39 | def __enter__(self): 40 | pass 41 | 42 | def __exit__(self, *args): 43 | pass 44 | 45 | def staple(a): 46 | # a: n,c,h,w detach tensor 47 | mvres = mv(a) 48 | gap = 0.4 49 | if gap > 0.02: 50 | for i, s in enumerate(a): 51 | r = s * mvres 52 | res = r if i == 0 else torch.cat((res,r),0) 53 | nres = mv(res) 54 | gap = torch.mean(torch.abs(mvres - nres)) 55 | mvres = nres 56 | a = res 57 | return mvres 58 | 59 | def allone(disc,cup): 60 | disc = np.array(disc) / 255 61 | cup = np.array(cup) / 255 62 | res = np.clip(disc * 0.5 + cup,0,1) * 255 63 | res = 255 - res 64 | res = Image.fromarray(np.uint8(res)) 65 | return res 66 | 67 | def dice_score(pred, targs): 68 | pred = (pred>0).float() 69 | return 2. * (pred*targs).sum() / (pred+targs).sum() 70 | 71 | def mv(a): 72 | # res = Image.fromarray(np.uint8(img_list[0] / 2 + img_list[1] / 2 )) 73 | # res.show() 74 | b = a.size(0) 75 | return torch.sum(a, 0, keepdim=True) / b 76 | 77 | def tensor_to_img_array(tensor): 78 | image = tensor.cpu().detach().numpy() 79 | image = np.transpose(image, [0, 2, 3, 1]) 80 | return image 81 | 82 | def export(tar, img_path=None): 83 | # image_name = image_name or "image.jpg" 84 | c = tar.size(1) 85 | if c == 3: 86 | vutils.save_image(tar, fp = img_path) 87 | else: 88 | s = torch.tensor(tar)[:,-1,:,:].unsqueeze(1) 89 | s = torch.cat((s,s,s),1) 90 | vutils.save_image(s, fp = img_path) 91 | 92 | def norm(t): 93 | m, s, v = torch.mean(t), torch.std(t), torch.var(t) 94 | return (t - m) / s 95 | -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/vmamba.py: -------------------------------------------------------------------------------- 1 | import time 2 | import math 3 | from functools import partial 4 | from typing import Optional, Callable 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.utils.checkpoint as checkpoint 9 | from einops import rearrange, repeat 10 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 11 | try: 12 | from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref 13 | except: 14 | pass 15 | 16 | # an alternative for mamba_ssm (in which causal_conv1d is needed) 17 | try: 18 | from selective_scan import selective_scan_fn as selective_scan_fn_v1 19 | from selective_scan import selective_scan_ref as selective_scan_ref_v1 20 | except: 21 | pass 22 | 23 | DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})" 24 | 25 | 26 | def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False): 27 | """ 28 | u: r(B D L) 29 | delta: r(B D L) 30 | A: r(D N) 31 | B: r(B N L) 32 | C: r(B N L) 33 | D: r(D) 34 | z: r(B D L) 35 | delta_bias: r(D), fp32 36 | 37 | ignores: 38 | [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu] 39 | """ 40 | import numpy as np 41 | 42 | # fvcore.nn.jit_handles 43 | def get_flops_einsum(input_shapes, equation): 44 | np_arrs = [np.zeros(s) for s in input_shapes] 45 | optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1] 46 | for line in optim.split("\n"): 47 | if "optimized flop" in line.lower(): 48 | # divided by 2 because we count MAC (multiply-add counted as one flop) 49 | flop = float(np.floor(float(line.split(":")[-1]) / 2)) 50 | return flop 51 | 52 | 53 | assert not with_complex 54 | 55 | flops = 0 # below code flops = 0 56 | if False: 57 | ... 58 | """ 59 | dtype_in = u.dtype 60 | u = u.float() 61 | delta = delta.float() 62 | if delta_bias is not None: 63 | delta = delta + delta_bias[..., None].float() 64 | if delta_softplus: 65 | delta = F.softplus(delta) 66 | batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] 67 | is_variable_B = B.dim() >= 3 68 | is_variable_C = C.dim() >= 3 69 | if A.is_complex(): 70 | if is_variable_B: 71 | B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) 72 | if is_variable_C: 73 | C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) 74 | else: 75 | B = B.float() 76 | C = C.float() 77 | x = A.new_zeros((batch, dim, dstate)) 78 | ys = [] 79 | """ 80 | 81 | flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln") 82 | if with_Group: 83 | flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln") 84 | else: 85 | flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln") 86 | if False: 87 | ... 88 | """ 89 | deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) 90 | if not is_variable_B: 91 | deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) 92 | else: 93 | if B.dim() == 3: 94 | deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) 95 | else: 96 | B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) 97 | deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) 98 | if is_variable_C and C.dim() == 4: 99 | C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) 100 | last_state = None 101 | """ 102 | 103 | in_for_flops = B * D * N 104 | if with_Group: 105 | in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd") 106 | else: 107 | in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd") 108 | flops += L * in_for_flops 109 | if False: 110 | ... 111 | """ 112 | for i in range(u.shape[2]): 113 | x = deltaA[:, :, i] * x + deltaB_u[:, :, i] 114 | if not is_variable_C: 115 | y = torch.einsum('bdn,dn->bd', x, C) 116 | else: 117 | if C.dim() == 3: 118 | y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) 119 | else: 120 | y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) 121 | if i == u.shape[2] - 1: 122 | last_state = x 123 | if y.is_complex(): 124 | y = y.real * 2 125 | ys.append(y) 126 | y = torch.stack(ys, dim=2) # (batch dim L) 127 | """ 128 | 129 | if with_D: 130 | flops += B * D * L 131 | if with_Z: 132 | flops += B * D * L 133 | if False: 134 | ... 135 | """ 136 | out = y if D is None else y + u * rearrange(D, "d -> d 1") 137 | if z is not None: 138 | out = out * F.silu(z) 139 | out = out.to(dtype=dtype_in) 140 | """ 141 | 142 | return flops 143 | 144 | 145 | class PatchEmbed2D(nn.Module): 146 | r""" Image to Patch Embedding 147 | Args: 148 | patch_size (int): Patch token size. Default: 4. 149 | in_chans (int): Number of input image channels. Default: 3. 150 | embed_dim (int): Number of linear projection output channels. Default: 96. 151 | norm_layer (nn.Module, optional): Normalization layer. Default: None 152 | """ 153 | def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, **kwargs): 154 | super().__init__() 155 | if isinstance(patch_size, int): 156 | patch_size = (patch_size, patch_size) 157 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 158 | if norm_layer is not None: 159 | self.norm = norm_layer(embed_dim) 160 | else: 161 | self.norm = None 162 | 163 | def forward(self, x): 164 | x = self.proj(x).permute(0, 2, 3, 1) 165 | if self.norm is not None: 166 | x = self.norm(x) 167 | return x 168 | 169 | 170 | class PatchMerging2D(nn.Module): 171 | r""" Patch Merging Layer. 172 | Args: 173 | input_resolution (tuple[int]): Resolution of input feature. 174 | dim (int): Number of input channels. 175 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 176 | """ 177 | 178 | def __init__(self, dim, norm_layer=nn.LayerNorm): 179 | super().__init__() 180 | self.dim = dim 181 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 182 | self.norm = norm_layer(4 * dim) 183 | 184 | def forward(self, x): 185 | B, H, W, C = x.shape 186 | 187 | SHAPE_FIX = [-1, -1] 188 | if (W % 2 != 0) or (H % 2 != 0): 189 | print(f"Warning, x.shape {x.shape} is not match even ===========", flush=True) 190 | SHAPE_FIX[0] = H // 2 191 | SHAPE_FIX[1] = W // 2 192 | 193 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 194 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 195 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 196 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 197 | 198 | if SHAPE_FIX[0] > 0: 199 | x0 = x0[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] 200 | x1 = x1[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] 201 | x2 = x2[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] 202 | x3 = x3[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] 203 | 204 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 205 | x = x.view(B, H//2, W//2, 4 * C) # B H/2*W/2 4*C 206 | 207 | x = self.norm(x) 208 | x = self.reduction(x) 209 | 210 | return x 211 | 212 | 213 | class PatchExpand2D(nn.Module): 214 | def __init__(self, dim, dim_scale=2, norm_layer=nn.LayerNorm): 215 | super().__init__() 216 | self.dim = dim*2 217 | self.dim_scale = dim_scale 218 | self.expand = nn.Linear(self.dim, dim_scale*self.dim, bias=False) 219 | self.norm = norm_layer(self.dim // dim_scale) 220 | 221 | def forward(self, x): 222 | B, H, W, C = x.shape 223 | x = self.expand(x) 224 | 225 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//self.dim_scale) 226 | x= self.norm(x) 227 | 228 | return x 229 | 230 | 231 | class Final_PatchExpand2D(nn.Module): 232 | def __init__(self, dim, dim_scale=4, norm_layer=nn.LayerNorm): 233 | super().__init__() 234 | self.dim = dim 235 | self.dim_scale = dim_scale 236 | self.expand = nn.Linear(self.dim, dim_scale*self.dim, bias=False) 237 | self.norm = norm_layer(self.dim // dim_scale) 238 | 239 | def forward(self, x): 240 | B, H, W, C = x.shape 241 | x = self.expand(x) 242 | 243 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//self.dim_scale) 244 | x= self.norm(x) 245 | 246 | return x 247 | 248 | 249 | class SS2D(nn.Module): 250 | def __init__( 251 | self, 252 | d_model, 253 | d_state=16, 254 | # d_state="auto", # 20240109 255 | d_conv=3, 256 | expand=2, 257 | dt_rank="auto", 258 | dt_min=0.001, 259 | dt_max=0.1, 260 | dt_init="random", 261 | dt_scale=1.0, 262 | dt_init_floor=1e-4, 263 | dropout=0., 264 | conv_bias=True, 265 | bias=False, 266 | device=None, 267 | dtype=None, 268 | **kwargs, 269 | ): 270 | factory_kwargs = {"device": device, "dtype": dtype} 271 | super().__init__() 272 | self.d_model = d_model 273 | self.d_state = d_state 274 | # self.d_state = math.ceil(self.d_model / 6) if d_state == "auto" else d_model # 20240109 275 | self.d_conv = d_conv 276 | self.expand = expand 277 | self.d_inner = int(self.expand * self.d_model) 278 | self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank 279 | 280 | self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) 281 | self.conv2d = nn.Conv2d( 282 | in_channels=self.d_inner, 283 | out_channels=self.d_inner, 284 | groups=self.d_inner, 285 | bias=conv_bias, 286 | kernel_size=d_conv, 287 | padding=(d_conv - 1) // 2, 288 | **factory_kwargs, 289 | ) 290 | self.act = nn.SiLU() 291 | 292 | self.x_proj = ( 293 | nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 294 | nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 295 | nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 296 | nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 297 | ) 298 | self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K=4, N, inner) 299 | del self.x_proj 300 | 301 | self.dt_projs = ( 302 | self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), 303 | self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), 304 | self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), 305 | self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), 306 | ) 307 | self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K=4, inner, rank) 308 | self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K=4, inner) 309 | del self.dt_projs 310 | 311 | self.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=4, merge=True) # (K=4, D, N) 312 | self.Ds = self.D_init(self.d_inner, copies=4, merge=True) # (K=4, D, N) 313 | 314 | # self.selective_scan = selective_scan_fn 315 | self.forward_core = self.forward_corev0 316 | 317 | self.out_norm = nn.LayerNorm(self.d_inner) 318 | self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) 319 | self.dropout = nn.Dropout(dropout) if dropout > 0. else None 320 | 321 | @staticmethod 322 | def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs): 323 | dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs) 324 | 325 | # Initialize special dt projection to preserve variance at initialization 326 | dt_init_std = dt_rank**-0.5 * dt_scale 327 | if dt_init == "constant": 328 | nn.init.constant_(dt_proj.weight, dt_init_std) 329 | elif dt_init == "random": 330 | nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) 331 | else: 332 | raise NotImplementedError 333 | 334 | # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max 335 | dt = torch.exp( 336 | torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) 337 | + math.log(dt_min) 338 | ).clamp(min=dt_init_floor) 339 | # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 340 | inv_dt = dt + torch.log(-torch.expm1(-dt)) 341 | with torch.no_grad(): 342 | dt_proj.bias.copy_(inv_dt) 343 | # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit 344 | dt_proj.bias._no_reinit = True 345 | 346 | return dt_proj 347 | 348 | @staticmethod 349 | def A_log_init(d_state, d_inner, copies=1, device=None, merge=True): 350 | # S4D real initialization 351 | A = repeat( 352 | torch.arange(1, d_state + 1, dtype=torch.float32, device=device), 353 | "n -> d n", 354 | d=d_inner, 355 | ).contiguous() 356 | A_log = torch.log(A) # Keep A_log in fp32 357 | if copies > 1: 358 | A_log = repeat(A_log, "d n -> r d n", r=copies) 359 | if merge: 360 | A_log = A_log.flatten(0, 1) 361 | A_log = nn.Parameter(A_log) 362 | A_log._no_weight_decay = True 363 | return A_log 364 | 365 | @staticmethod 366 | def D_init(d_inner, copies=1, device=None, merge=True): 367 | # D "skip" parameter 368 | D = torch.ones(d_inner, device=device) 369 | if copies > 1: 370 | D = repeat(D, "n1 -> r n1", r=copies) 371 | if merge: 372 | D = D.flatten(0, 1) 373 | D = nn.Parameter(D) # Keep in fp32 374 | D._no_weight_decay = True 375 | return D 376 | 377 | def forward_corev0(self, x: torch.Tensor): 378 | self.selective_scan = selective_scan_fn 379 | 380 | B, C, H, W = x.shape 381 | L = H * W 382 | K = 4 383 | 384 | x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L) 385 | xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l) 386 | 387 | x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight) 388 | # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1) 389 | dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2) 390 | dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight) 391 | # dts = dts + self.dt_projs_bias.view(1, K, -1, 1) 392 | 393 | xs = xs.float().view(B, -1, L) # (b, k * d, l) 394 | dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l) 395 | Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l) 396 | Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l) 397 | Ds = self.Ds.float().view(-1) # (k * d) 398 | As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state) 399 | dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d) 400 | 401 | out_y = self.selective_scan( 402 | xs, dts, 403 | As, Bs, Cs, Ds, z=None, 404 | delta_bias=dt_projs_bias, 405 | delta_softplus=True, 406 | return_last_state=False, 407 | ).view(B, K, -1, L) 408 | assert out_y.dtype == torch.float 409 | 410 | inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) 411 | wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) 412 | invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) 413 | 414 | return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y 415 | 416 | # an alternative to forward_corev1 417 | def forward_corev1(self, x: torch.Tensor): 418 | self.selective_scan = selective_scan_fn_v1 419 | 420 | B, C, H, W = x.shape 421 | L = H * W 422 | K = 4 423 | 424 | x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L) 425 | xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l) 426 | 427 | x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight) 428 | # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1) 429 | dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2) 430 | dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight) 431 | # dts = dts + self.dt_projs_bias.view(1, K, -1, 1) 432 | 433 | xs = xs.float().view(B, -1, L) # (b, k * d, l) 434 | dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l) 435 | Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l) 436 | Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l) 437 | Ds = self.Ds.float().view(-1) # (k * d) 438 | As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state) 439 | dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d) 440 | 441 | out_y = self.selective_scan( 442 | xs, dts, 443 | As, Bs, Cs, Ds, 444 | delta_bias=dt_projs_bias, 445 | delta_softplus=True, 446 | ).view(B, K, -1, L) 447 | assert out_y.dtype == torch.float 448 | 449 | inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) 450 | wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) 451 | invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) 452 | 453 | return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y 454 | 455 | def forward(self, x: torch.Tensor, **kwargs): 456 | B, H, W, C = x.shape 457 | 458 | xz = self.in_proj(x) 459 | x, z = xz.chunk(2, dim=-1) # (b, h, w, d) 460 | 461 | x = x.permute(0, 3, 1, 2).contiguous() 462 | x = self.act(self.conv2d(x)) # (b, d, h, w) 463 | y1, y2, y3, y4 = self.forward_core(x) 464 | assert y1.dtype == torch.float32 465 | y = y1 + y2 + y3 + y4 466 | y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1) 467 | y = self.out_norm(y) 468 | y = y * F.silu(z) 469 | out = self.out_proj(y) 470 | if self.dropout is not None: 471 | out = self.dropout(out) 472 | return out 473 | 474 | 475 | class VSSBlock(nn.Module): 476 | def __init__( 477 | self, 478 | hidden_dim: int = 0, 479 | drop_path: float = 0, 480 | norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), 481 | attn_drop_rate: float = 0, 482 | d_state: int = 16, 483 | **kwargs, 484 | ): 485 | super().__init__() 486 | self.ln_1 = norm_layer(hidden_dim) 487 | self.self_attention = SS2D(d_model=hidden_dim, dropout=attn_drop_rate, d_state=d_state, **kwargs) 488 | self.drop_path = DropPath(drop_path) 489 | 490 | def forward(self, input: torch.Tensor): 491 | x = input + self.drop_path(self.self_attention(self.ln_1(input))) 492 | return x 493 | 494 | 495 | class VSSLayer(nn.Module): 496 | """ A basic Swin Transformer layer for one stage. 497 | Args: 498 | dim (int): Number of input channels. 499 | depth (int): Number of blocks. 500 | drop (float, optional): Dropout rate. Default: 0.0 501 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 502 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 503 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 504 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 505 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 506 | """ 507 | 508 | def __init__( 509 | self, 510 | dim, 511 | depth, 512 | attn_drop=0., 513 | drop_path=0., 514 | norm_layer=nn.LayerNorm, 515 | downsample=None, 516 | use_checkpoint=False, 517 | d_state=16, 518 | **kwargs, 519 | ): 520 | super().__init__() 521 | self.dim = dim 522 | self.use_checkpoint = use_checkpoint 523 | 524 | self.blocks = nn.ModuleList([ 525 | VSSBlock( 526 | hidden_dim=dim, 527 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 528 | norm_layer=norm_layer, 529 | attn_drop_rate=attn_drop, 530 | d_state=d_state, 531 | ) 532 | for i in range(depth)]) 533 | 534 | if True: # is this really applied? Yes, but been overriden later in VSSM! 535 | def _init_weights(module: nn.Module): 536 | for name, p in module.named_parameters(): 537 | if name in ["out_proj.weight"]: 538 | p = p.clone().detach_() # fake init, just to keep the seed .... 539 | nn.init.kaiming_uniform_(p, a=math.sqrt(5)) 540 | self.apply(_init_weights) 541 | 542 | if downsample is not None: 543 | self.downsample = downsample(dim=dim, norm_layer=norm_layer) 544 | else: 545 | self.downsample = None 546 | 547 | 548 | def forward(self, x): 549 | for blk in self.blocks: 550 | if self.use_checkpoint: 551 | x = checkpoint.checkpoint(blk, x) 552 | else: 553 | x = blk(x) 554 | 555 | if self.downsample is not None: 556 | x = self.downsample(x) 557 | 558 | return x 559 | 560 | 561 | 562 | class VSSLayer_up(nn.Module): 563 | """ A basic Swin Transformer layer for one stage. 564 | Args: 565 | dim (int): Number of input channels. 566 | depth (int): Number of blocks. 567 | drop (float, optional): Dropout rate. Default: 0.0 568 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 569 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 570 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 571 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 572 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 573 | """ 574 | 575 | def __init__( 576 | self, 577 | dim, 578 | depth, 579 | attn_drop=0., 580 | drop_path=0., 581 | norm_layer=nn.LayerNorm, 582 | upsample=None, 583 | use_checkpoint=False, 584 | d_state=16, 585 | **kwargs, 586 | ): 587 | super().__init__() 588 | self.dim = dim 589 | self.use_checkpoint = use_checkpoint 590 | 591 | self.blocks = nn.ModuleList([ 592 | VSSBlock( 593 | hidden_dim=dim, 594 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 595 | norm_layer=norm_layer, 596 | attn_drop_rate=attn_drop, 597 | d_state=d_state, 598 | ) 599 | for i in range(depth)]) 600 | 601 | if True: # is this really applied? Yes, but been overriden later in VSSM! 602 | def _init_weights(module: nn.Module): 603 | for name, p in module.named_parameters(): 604 | if name in ["out_proj.weight"]: 605 | p = p.clone().detach_() # fake init, just to keep the seed .... 606 | nn.init.kaiming_uniform_(p, a=math.sqrt(5)) 607 | self.apply(_init_weights) 608 | 609 | if upsample is not None: 610 | self.upsample = upsample(dim=dim, norm_layer=norm_layer) 611 | else: 612 | self.upsample = None 613 | 614 | 615 | def forward(self, x): 616 | if self.upsample is not None: 617 | x = self.upsample(x) 618 | for blk in self.blocks: 619 | if self.use_checkpoint: 620 | x = checkpoint.checkpoint(blk, x) 621 | else: 622 | x = blk(x) 623 | return x 624 | 625 | 626 | 627 | class VSSM(nn.Module): 628 | def __init__(self, patch_size=4, in_chans=3, num_classes=1000, depths=[2, 2, 9, 2], depths_decoder=[2, 9, 2, 2], 629 | dims=[64,128,256,512], dims_decoder=[512,256,128,64], d_state=16, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 630 | norm_layer=nn.LayerNorm, patch_norm=True, 631 | use_checkpoint=False, **kwargs): 632 | super().__init__() 633 | self.num_classes = num_classes 634 | self.num_layers = len(depths) 635 | if isinstance(dims, int): 636 | dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)] 637 | self.embed_dim = dims[0] 638 | self.num_features = dims[-1] 639 | self.dims = dims 640 | 641 | self.patch_embed = PatchEmbed2D(patch_size=patch_size, in_chans=in_chans, embed_dim=self.embed_dim, 642 | norm_layer=norm_layer if patch_norm else None) 643 | # WASTED absolute position embedding ====================== 644 | self.ape = False 645 | # self.ape = False 646 | # drop_rate = 0.0 647 | if self.ape: 648 | self.patches_resolution = self.patch_embed.patches_resolution 649 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, *self.patches_resolution, self.embed_dim)) 650 | trunc_normal_(self.absolute_pos_embed, std=.02) 651 | self.pos_drop = nn.Dropout(p=drop_rate) 652 | 653 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 654 | dpr_decoder = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths_decoder))][::-1] 655 | 656 | self.layers = nn.ModuleList() 657 | for i_layer in range(self.num_layers): 658 | # if i_layer != 3 and i_layer != 2: 659 | # layer = FFTBlock(dims[i_layer]) 660 | # self.layers.append(layer) 661 | layer = VSSLayer( 662 | dim=dims[i_layer], 663 | depth=depths[i_layer], 664 | d_state=math.ceil(dims[0] / 6) if d_state is None else d_state, # 20240109 665 | drop=drop_rate, 666 | attn_drop=attn_drop_rate, 667 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 668 | norm_layer=norm_layer, 669 | downsample=PatchMerging2D if (i_layer < self.num_layers - 1) else None, 670 | use_checkpoint=use_checkpoint, 671 | ) 672 | self.layers.append(layer) 673 | 674 | 675 | self.layers_up = nn.ModuleList() 676 | for i_layer in range(self.num_layers): 677 | layer = VSSLayer_up( 678 | dim=dims_decoder[i_layer], 679 | depth=depths_decoder[i_layer], 680 | d_state=math.ceil(dims[0] / 6) if d_state is None else d_state, # 20240109 681 | drop=drop_rate, 682 | attn_drop=attn_drop_rate, 683 | drop_path=dpr_decoder[sum(depths_decoder[:i_layer]):sum(depths_decoder[:i_layer + 1])], 684 | norm_layer=norm_layer, 685 | upsample=PatchExpand2D if (i_layer != 0) else None, 686 | use_checkpoint=use_checkpoint, 687 | ) 688 | self.layers_up.append(layer) 689 | self.final_up = Final_PatchExpand2D(dim=dims_decoder[-1], dim_scale=4, norm_layer=norm_layer) 690 | self.final_conv = nn.Conv2d(dims_decoder[-1]//4, num_classes, 1) 691 | # self.norm = norm_layer(self.num_features) 692 | # self.avgpool = nn.AdaptiveAvgPool1d(1) 693 | # self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 694 | 695 | self.apply(self._init_weights) 696 | 697 | def _init_weights(self, m: nn.Module): 698 | """ 699 | out_proj.weight which is previously initilized in VSSBlock, would be cleared in nn.Linear 700 | no fc.weight found in the any of the model parameters 701 | no nn.Embedding found in the any of the model parameters 702 | so the thing is, VSSBlock initialization is useless 703 | 704 | Conv2D is not intialized !!! 705 | """ 706 | if isinstance(m, nn.Linear): 707 | trunc_normal_(m.weight, std=.02) 708 | if isinstance(m, nn.Linear) and m.bias is not None: 709 | nn.init.constant_(m.bias, 0) 710 | elif isinstance(m, nn.LayerNorm): 711 | nn.init.constant_(m.bias, 0) 712 | nn.init.constant_(m.weight, 1.0) 713 | 714 | @torch.jit.ignore 715 | def no_weight_decay(self): 716 | return {'absolute_pos_embed'} 717 | 718 | @torch.jit.ignore 719 | def no_weight_decay_keywords(self): 720 | return {'relative_position_bias_table'} 721 | 722 | def forward_features(self, x): 723 | skip_list = [] 724 | x = self.patch_embed(x) 725 | if self.ape: 726 | x = x + self.absolute_pos_embed 727 | x = self.pos_drop(x) 728 | # count = 0 729 | # skip_list.append(x) 730 | for layer in self.layers: 731 | skip_list.append(x) 732 | x = layer(x) 733 | # count += 1 734 | # if count == 1 or count == 3 or count == 4 or count == 6: 735 | # skip_list.append(x) 736 | return x, skip_list 737 | 738 | def forward_features_up(self, x, skip_list): 739 | for inx, layer_up in enumerate(self.layers_up): 740 | if inx == 0: 741 | x = layer_up(x) 742 | else: 743 | x = layer_up(x+skip_list[-inx]) 744 | 745 | return x 746 | 747 | def forward_final(self, x): 748 | x = self.final_up(x) 749 | x = x.permute(0, 3, 1, 2).contiguous() 750 | x = self.final_conv(x) 751 | return x 752 | 753 | def forward_backbone(self, x): 754 | x = self.patch_embed(x) 755 | if self.ape: 756 | x = x + self.absolute_pos_embed 757 | x = self.pos_drop(x) 758 | 759 | for layer in self.layers: 760 | x = layer(x) 761 | return x 762 | 763 | def forward(self, x): 764 | x, skip_list = self.forward_features(x) 765 | x = self.forward_features_up(x, skip_list) 766 | x = self.forward_final(x) 767 | 768 | return x, skip_list 769 | 770 | 771 | class SpectralGatingNetwork(nn.Module): 772 | def __init__(self, dim): 773 | super().__init__() 774 | # this weights are valid for h=14 and w=8 775 | if dim == 64: #96 for large model, 64 for small and base model 776 | self.h = 64 #H 777 | self.w = 33 #(W/2)+1 778 | self.complex_weight = nn.Parameter(torch.randn(self.h, self.w, dim, 2, dtype=torch.float32) * 0.02) 779 | if dim == 128: #96 for large model, 64 for small and base model 780 | self.h = 32 #H 781 | self.w = 17 #(W/2)+1 782 | self.complex_weight = nn.Parameter(torch.randn(self.h, self.w, dim, 2, dtype=torch.float32) * 0.02) 783 | 784 | def forward(self, x): 785 | B, H, W, C = x.shape 786 | x = x.to(torch.float32) 787 | y_fft = torch.fft.rfft2(x, dim=(1, 2), norm='ortho') 788 | weight = torch.view_as_complex(self.complex_weight) 789 | y_fft = y_fft * weight 790 | y_fft = torch.fft.irfft2(y_fft, s=(H, W), dim=(1, 2), norm='ortho') 791 | return y_fft 792 | 793 | -------------------------------------------------------------------------------- /CrackSegDiff/guided_diffusion/vmunet.py: -------------------------------------------------------------------------------- 1 | from .vmamba import VSSM 2 | import torch 3 | from torch import nn 4 | 5 | 6 | class VMUNet(nn.Module): 7 | def __init__(self, 8 | input_channels=3, 9 | num_classes=2, 10 | depths=[2, 2, 9, 2], 11 | depths_decoder=[2, 9, 2, 2], 12 | drop_path_rate=0.2, 13 | load_ckpt_path=None, 14 | ): 15 | super().__init__() 16 | 17 | self.load_ckpt_path = load_ckpt_path 18 | self.num_classes = num_classes 19 | 20 | self.vmunet = VSSM(in_chans=input_channels, 21 | num_classes=num_classes, 22 | depths=depths, 23 | depths_decoder=depths_decoder, 24 | drop_path_rate=drop_path_rate, 25 | ) 26 | 27 | def forward(self, x): 28 | if x.size()[1] == 1: 29 | x = x.repeat(1,3,1,1) 30 | logits, skip_list = self.vmunet(x) 31 | if self.num_classes == 1: return torch.sigmoid(logits), skip_list 32 | else: return logits 33 | 34 | def load_from(self): 35 | if self.load_ckpt_path is not None: 36 | model_dict = self.vmunet.state_dict() 37 | modelCheckpoint = torch.load(self.load_ckpt_path) 38 | pretrained_dict = modelCheckpoint['model'] 39 | # 过滤操作 40 | new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()} 41 | model_dict.update(new_dict) 42 | # 打印出来,更新了多少的参数 43 | print('Total model_dict: {}, Total pretrained_dict: {}, update: {}'.format(len(model_dict), len(pretrained_dict), len(new_dict))) 44 | self.vmunet.load_state_dict(model_dict) 45 | 46 | not_loaded_keys = [k for k in pretrained_dict.keys() if k not in new_dict.keys()] 47 | print('Not loaded keys:', not_loaded_keys) 48 | print("encoder loaded finished!") 49 | 50 | model_dict = self.vmunet.state_dict() 51 | modelCheckpoint = torch.load(self.load_ckpt_path) 52 | pretrained_odict = modelCheckpoint['model'] 53 | pretrained_dict = {} 54 | for k, v in pretrained_odict.items(): 55 | if 'layers.0' in k: 56 | new_k = k.replace('layers.0', 'layers_up.3') 57 | pretrained_dict[new_k] = v 58 | elif 'layers.1' in k: 59 | new_k = k.replace('layers.1', 'layers_up.2') 60 | pretrained_dict[new_k] = v 61 | elif 'layers.2' in k: 62 | new_k = k.replace('layers.2', 'layers_up.1') 63 | pretrained_dict[new_k] = v 64 | elif 'layers.3' in k: 65 | new_k = k.replace('layers.3', 'layers_up.0') 66 | pretrained_dict[new_k] = v 67 | # 过滤操作 68 | new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()} 69 | model_dict.update(new_dict) 70 | # 打印出来,更新了多少的参数 71 | print('Total model_dict: {}, Total pretrained_dict: {}, update: {}'.format(len(model_dict), len(pretrained_dict), len(new_dict))) 72 | self.vmunet.load_state_dict(model_dict) 73 | 74 | # 找到没有加载的键(keys) 75 | not_loaded_keys = [k for k in pretrained_dict.keys() if k not in new_dict.keys()] 76 | print('Not loaded keys:', not_loaded_keys) 77 | print("decoder loaded finished!") -------------------------------------------------------------------------------- /CrackSegDiff/segmentation_env.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append(".") 3 | import numpy as np 4 | import math 5 | import torch 6 | import torchvision 7 | from PIL import Image 8 | import argparse 9 | import os 10 | import cv2 11 | major = cv2.__version__.split('.')[0] # Get opencv version 12 | from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score, jaccard_score 13 | def calculate_metrics(pred, gt, pred_, gt_): 14 | pred_ = np.array(pred_) 15 | gt_ = np.array(gt_) 16 | return (f1_score(gt.reshape(-1), pred.reshape(-1), zero_division=0), 17 | precision_score(gt.reshape(-1), pred.reshape(-1), zero_division=0), 18 | recall_score(gt.reshape(-1), pred.reshape(-1), zero_division=0), 19 | accuracy_score(gt.reshape(-1), pred.reshape(-1)), 20 | jaccard_score(gt.reshape(-1), pred.reshape(-1), zero_division=0), 21 | bfscore(pred_, gt_, 3)) 22 | def bfscore(prfile, gtfile, threshold): 23 | 24 | gt_ = gtfile # Convert color space 25 | pr_ = prfile # Convert color space 26 | classes_gt = np.unique(gt_) # Get GT classes 27 | classes_pr = np.unique(pr_) # Get predicted classes 28 | # Check classes from GT and prediction 29 | if not np.array_equiv(classes_gt, classes_pr): 30 | classes = np.concatenate((classes_gt, classes_pr)) 31 | classes = np.unique(classes) 32 | classes = np.sort(classes) 33 | else: 34 | classes = classes_gt # Get matched classes 35 | m = int(np.max(classes)) # Get max of classes (number of classes) 36 | # Define bfscore variable (initialized with zeros) 37 | bfscores = np.zeros((m+1), dtype=float) 38 | areas_gt = np.zeros((m+1), dtype=float) 39 | for i in range(m+1): 40 | bfscores[i] = np.nan 41 | areas_gt[i] = np.nan 42 | for target_class in classes: # Iterate over classes 43 | if target_class == 0: # Skip background 44 | continue 45 | gt = gt_.copy() 46 | gt[gt != target_class] = 0 47 | # contours는 point의 list형태. 48 | if major == '3': # For opencv version 3.x 49 | _, contours, _ = cv2.findContours( 50 | gt, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE) # Find contours of the shape 51 | else: # For other opencv versions 52 | contours, _ = cv2.findContours( 53 | gt, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE) # Find contours of the shape 54 | 55 | # contours 는 list of numpy arrays 56 | contours_gt = [] 57 | for i in range(len(contours)): 58 | for j in range(len(contours[i])): 59 | contours_gt.append(contours[i][j][0].tolist()) 60 | # Get contour area of GT 61 | if contours_gt: 62 | area = cv2.contourArea(np.array(contours_gt)) 63 | areas_gt[target_class] = area 64 | pr = pr_.copy() 65 | pr[pr != target_class] = 0 66 | # contours는 point의 list형태. 67 | if major == '3': # For opencv version 3.x 68 | _, contours, _ = cv2.findContours( 69 | pr, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE) 70 | else: # For other opencv versions 71 | contours, _ = cv2.findContours( 72 | pr, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE) 73 | # contours 는 list of numpy arrays 74 | contours_pr = [] 75 | for i in range(len(contours)): 76 | for j in range(len(contours[i])): 77 | contours_pr.append(contours[i][j][0].tolist()) 78 | # 3. calculate 79 | if len(contours) == 0: 80 | return 0 81 | else: 82 | precision, numerator, denominator = calc_precision_recall( 83 | contours_gt, contours_pr, threshold) # Precision 84 | 85 | recall, numerator, denominator = calc_precision_recall( 86 | contours_pr, contours_gt, threshold) # Recall 87 | 88 | 89 | if (recall + precision) == 0: 90 | f1 = 0 91 | else: 92 | f1 = 2*recall*precision/(recall + precision) # F1 score 93 | bfscores[target_class] = f1 94 | 95 | 96 | # return bfscores[1:], np.sum(bfscores[1:])/len(classes[1:]) # Return bfscores, except for background, and per image score 97 | return bfscores[-1] # Return bfscores, except for background 98 | def calc_precision_recall(contours_a, contours_b, threshold): 99 | x = contours_a 100 | y = contours_b 101 | 102 | xx = np.array(x) 103 | hits = [] 104 | 105 | for yrec in y: 106 | d = np.square(xx[:,0] - yrec[0]) + np.square(xx[:,1] - yrec[1]) 107 | 108 | hits.append(np.any(d < threshold*threshold)) 109 | top_count = np.sum(hits) 110 | 111 | try: 112 | precision_recall = top_count / len(y) 113 | except ZeroDivisionError: 114 | precision_recall = 0 115 | 116 | return precision_recall, top_count, len(y) 117 | 118 | def main(): 119 | argParser = argparse.ArgumentParser() 120 | argParser.add_argument("--inp_pth") 121 | argParser.add_argument("--out_pth") 122 | args = argParser.parse_args() 123 | mix_res = (0., 0., 0., 0., 0., 0.) 124 | num = 0 125 | pred_path = args.inp_pth 126 | gt_path = args.out_pth 127 | for root, dirs, files in os.walk(pred_path, topdown=False): 128 | for name in files: 129 | if 'ens' in name: 130 | num += 1 131 | ind = name.split('_')[0] 132 | pred_ = Image.open(os.path.join(root, name)).convert('L') 133 | gt_name = ind + ".bmp" 134 | gt_ = Image.open(os.path.join(gt_path, gt_name)).convert('L') 135 | pred = torchvision.transforms.PILToTensor()(pred_) 136 | pred = torch.unsqueeze(pred,0).float() 137 | if pred.max() == 0: 138 | pred = pred 139 | else: 140 | pred = pred / pred.max() 141 | gt = torchvision.transforms.PILToTensor()(gt_) 142 | gt = torch.unsqueeze(gt, 0).float() / 255.0 143 | temp = calculate_metrics(pred, gt, pred_, gt_) 144 | printed = False 145 | for x in temp : 146 | if x < 0.7 and not printed: 147 | print(ind,temp[0],temp[4],temp[5]) 148 | printed = True 149 | break 150 | mix_res = tuple([sum(a) for a in zip(mix_res, temp)]) 151 | F1, Precision, Recall, Accuracy, IoU, BFScore= tuple([a / num for a in mix_res]) 152 | print('IoU is', IoU) 153 | print('F1 is', F1) 154 | print('Precision is', Precision) 155 | print('Recall is', Recall) 156 | print('Accuracy is', Accuracy) 157 | print('BFScore is', BFScore) 158 | if __name__ == "__main__": 159 | main() 160 | -------------------------------------------------------------------------------- /CrackSegDiff/segmentation_sample.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import random 5 | sys.path.append(".") 6 | import numpy as np 7 | from tqdm import tqdm 8 | import torch as th 9 | from PIL import Image 10 | from guided_diffusion import dist_util, logger 11 | from guided_diffusion.custom_dataset_loader import CustomDataset 12 | from guided_diffusion.utils import staple 13 | from guided_diffusion.script_util import ( 14 | NUM_CLASSES, 15 | model_and_diffusion_defaults, 16 | create_model_and_diffusion, 17 | add_dict_to_argparser, 18 | args_to_dict, 19 | ) 20 | import torchvision.transforms as transforms 21 | seed=10 22 | th.manual_seed(seed) 23 | th.cuda.manual_seed_all(seed) 24 | np.random.seed(seed) 25 | random.seed(seed) 26 | 27 | def visualize(img): 28 | _min = img.min() 29 | _max = img.max() 30 | normalized_img = (img - _min)/ (_max - _min) 31 | return normalized_img 32 | 33 | 34 | def main(): 35 | args = create_argparser().parse_args() 36 | dist_util.setup_dist(args) 37 | logger.configure(dir = args.out_dir) 38 | tran_list = [transforms.ToTensor()] 39 | transform_test = transforms.Compose(tran_list) 40 | print("Your current directory : ", args.data_dir) 41 | ds = CustomDataset(args, args.data_dir, transform_test, mode='Test') 42 | args.in_ch = 7 43 | datal = th.utils.data.DataLoader( 44 | ds, 45 | batch_size=args.batch_size, 46 | shuffle=True, 47 | num_workers=8, 48 | pin_memory=True 49 | ) 50 | data = iter(datal) 51 | 52 | logger.log("creating model and diffusion...") 53 | 54 | model, diffusion = create_model_and_diffusion( 55 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 56 | ) 57 | all_images = [] 58 | state_dict = dist_util.load_state_dict(args.model_path, map_location="cpu") 59 | from collections import OrderedDict 60 | new_state_dict = OrderedDict() 61 | for k, v in state_dict.items(): 62 | # name = k[7:] # remove `module.` 63 | if 'module.' in k: 64 | new_state_dict[k[7:]] = v 65 | # load params 66 | else: 67 | new_state_dict = state_dict 68 | 69 | model.load_state_dict(new_state_dict) 70 | 71 | model.to(dist_util.dev()) 72 | if args.use_fp16: 73 | model.convert_to_fp16() 74 | model.eval() 75 | num_tqdm = range(len(data)) 76 | # i_sample = 0 77 | for _ in tqdm(num_tqdm, desc='Processing'): 78 | b, m, path = next(data) # should return an image from the dataloader "data" 79 | c = th.randn_like(b[:, :1, ...]) 80 | # i_sample += 1 81 | # if i_sample < 400: 82 | # continue 83 | 84 | img = th.cat((b, c), dim=1) # add a noise channel$ 85 | slice_ID = path[0].split("/")[-1].split('.')[0] 86 | print(slice_ID) 87 | logger.log("sampling...") 88 | start = th.cuda.Event(enable_timing=True) 89 | end = th.cuda.Event(enable_timing=True) 90 | enslist = [] 91 | for i in range(args.num_ensemble): # this is for the generation of an ensemble of 5 masks. 92 | model_kwargs = {} 93 | start.record() 94 | sample_fn = ( 95 | diffusion.p_sample_loop_known if not args.use_ddim else diffusion.ddim_sample_loop_known 96 | ) 97 | sample, x_noisy, org, cal, cal_out = sample_fn( 98 | model, 99 | (args.batch_size, 3, args.image_size, args.image_size), img, 100 | step = args.diffusion_steps, 101 | clip_denoised=args.clip_denoised, 102 | model_kwargs=model_kwargs, 103 | ) 104 | end.record() 105 | th.cuda.synchronize() 106 | print('time for 1 sample', start.elapsed_time(end)) # time measurement for the generation of 1 sample 107 | # co = th.tensor(cal_out) 108 | co = cal_out.clone().detach() 109 | if args.version == 'new': 110 | enslist.append(sample[:,-1,:,:]) 111 | # enslist.append(co[:,-1,:,:]) 112 | # enslist.append(cal[:,-1,:,:]) 113 | else: 114 | # enslist.append(co) 115 | enslist.append(sample[:, -1, :, :]) 116 | x = staple(th.stack(enslist, dim=0)).squeeze(0) 117 | x = th.clamp(x, 0.0, 1.0) 118 | ensres = (x.mean(dim=0, keepdim=True).round())*255 119 | out_img = Image.fromarray((ensres[0].detach().cpu().numpy()).astype(np.uint8)) 120 | out_img.save(os.path.join(args.out_dir, str(slice_ID)+'_output_ens'+".png")) 121 | def create_argparser(): 122 | defaults = dict( 123 | # data_name='BRATS', 124 | data_dir="/home/dell/jlc/data2500/Test", 125 | clip_denoised=True, 126 | num_samples=1, 127 | batch_size=1, 128 | use_ddim=False, 129 | model_path="/home/dell/jlc/segdiff/model-5df/savedmodel100000.pt", #path to pretrain model 130 | num_ensemble=1, #number of samples in the ensemble 131 | gpu_dev="1", 132 | out_dir='/home/dell/jlc/segdiff/result/', 133 | multi_gpu=None, #"0,1,2" 134 | debug=False 135 | ) 136 | defaults.update(model_and_diffusion_defaults()) 137 | parser = argparse.ArgumentParser() 138 | add_dict_to_argparser(parser, defaults) 139 | return parser 140 | 141 | 142 | if __name__ == "__main__": 143 | 144 | main() 145 | -------------------------------------------------------------------------------- /CrackSegDiff/segmentation_train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | sys.path.append("../") 4 | sys.path.append("./") 5 | from guided_diffusion import dist_util, logger 6 | from guided_diffusion.resample import create_named_schedule_sampler 7 | from guided_diffusion.custom_dataset_loader import CustomDataset 8 | from guided_diffusion.script_util import ( 9 | model_and_diffusion_defaults, 10 | create_model_and_diffusion, 11 | args_to_dict, 12 | add_dict_to_argparser, 13 | ) 14 | import torch as th 15 | from guided_diffusion.train_util import TrainLoop 16 | import torchvision.transforms as transforms 17 | find_unused_parameters=True 18 | def main(): 19 | args = create_argparser().parse_args() 20 | dist_util.setup_dist(args) 21 | logger.configure(dir = args.out_dir) 22 | logger.log("creating data loader...") 23 | # tran_list = [transforms.Resize((args.image_size,args.image_size)), transforms.ToTensor(),] 24 | tran_list = [transforms.ToTensor(), ] 25 | transform_train = transforms.Compose(tran_list) 26 | print("Your current directory : ", args.data_dir) 27 | ds = CustomDataset(args, args.data_dir, transform_train) 28 | args.in_ch = 7 #4 29 | datal= th.utils.data.DataLoader( 30 | ds, 31 | batch_size=args.batch_size, 32 | shuffle=True, 33 | num_workers=16, 34 | pin_memory=True 35 | ) 36 | data = iter(datal) 37 | logger.log("creating model and diffusion...") 38 | model, diffusion = create_model_and_diffusion( 39 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 40 | ) 41 | if args.multi_gpu: 42 | model = th.nn.DataParallel(model,device_ids=[int(id) for id in args.multi_gpu.split(',')]) 43 | model.to(device = th.device('cuda', int(args.gpu_dev))) 44 | else: 45 | model.to(dist_util.dev()) 46 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion, maxt=args.diffusion_steps) 47 | logger.log("training...") 48 | TrainLoop( 49 | model=model, 50 | diffusion=diffusion, 51 | classifier=None, 52 | data=data, 53 | dataloader=datal, 54 | batch_size=args.batch_size, 55 | microbatch=args.microbatch, 56 | lr=args.lr, 57 | ema_rate=args.ema_rate, 58 | log_interval=args.log_interval, 59 | save_interval=args.save_interval, 60 | resume_checkpoint=args.resume_checkpoint, 61 | use_fp16=args.use_fp16, 62 | fp16_scale_growth=args.fp16_scale_growth, 63 | schedule_sampler=schedule_sampler, 64 | weight_decay=args.weight_decay, 65 | lr_anneal_steps=args.lr_anneal_steps, 66 | ).run_loop() 67 | def create_argparser(): 68 | defaults = dict( 69 | # data_name = 'BRATS', 70 | data_dir="/home/dell/jlc/data2500/Train2000", 71 | schedule_sampler="uniform", 72 | lr=1e-4, 73 | weight_decay=0.0, 74 | lr_anneal_steps=0, 75 | batch_size=8, 76 | microbatch=-1, # -1 disables microbatches 77 | ema_rate="0.9999", # comma-separated list of EMA values 78 | log_interval=100, 79 | save_interval=100000, 80 | resume_checkpoint=None, #"/results/pretrainedmodel.pt" 81 | use_fp16=False, 82 | fp16_scale_growth=1e-3, 83 | gpu_dev = "1", 84 | multi_gpu = None, #"0,1,2" 85 | out_dir='/home/dell/jlc/segdiff/model-3df1/' 86 | ) 87 | defaults.update(model_and_diffusion_defaults()) 88 | parser = argparse.ArgumentParser() 89 | add_dict_to_argparser(parser, defaults) 90 | return parser 91 | 92 | 93 | if __name__ == "__main__": 94 | main() 95 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CrackSegDiff: Diffusion Probability Model-based Multi-modal Crack Segmentation 2 | 3 | ## Abstract 4 | 5 | Integrating grayscale and depth data in road inspection robots could enhance the accuracy, reliability, and comprehensiveness of road condition assessments, leading to improved maintenance strategies and safer infrastructure. However, these data sources are often compromised by significant background noise from the pavement. Recent advancements in Diffusion Probabilistic Models (DPM) have demonstrated remarkable success in image segmentation tasks, showcasing potent denoising capabilities, as evidenced in studies like SegDiff. Despite these advancements, current DPM-based segmentors do not fully capitalize on the potential of original image data. In this paper, we propose a novel DPM-based approach for crack segmentation, named CrackSegDiff, which uniquely fuses grayscale and range/depth images. This method enhances the reverse diffusion process by intensifying the interaction between local feature extraction via DPM and global feature extraction. Unlike traditional methods that utilize Transformers for global features, our approach employs Vm-unet to efficiently capture long-range information of the original data. The integration of features is further refined through two innovative modules: the Channel Fusion Module (CFM) and the Shallow Feature Compensation Module (SFCM). Our experimental evaluation on the three-class crack image segmentation tasks within the FIND dataset demonstrates that CrackSegDiff outperforms state-of-the-art methods, particularly excelling in the detection of shallow cracks. 6 | 7 | Paper: [arxiv](https://arxiv.org/abs/2410.08100) 8 | 9 | ## A Quick Overview 10 | 11 | ## Quantization Results of CrackSegDiff 12 | 13 |
The overall architecture of CrackSegDiff.
17 | 18 |模型 | 22 |Raw intensity | 23 |Raw range | 24 |Fused raw image | 25 |||||||
---|---|---|---|---|---|---|---|---|---|
F1 score | 28 |IoU | 29 |BF score | 30 |F1 score | 31 |IoU | 32 |BF score | 33 |F1 score | 34 |IoU | 35 |BF score | 36 ||
DenseCrack | 41 |68.2% | 42 |56.5% | 43 |- | 44 |78.4% | 45 |65.3% | 46 |- | 47 |81.5% | 48 |69.7% | 49 |- | 50 |
SegNet-FCN | 53 |75.0% | 54 |63.4% | 55 |- | 56 |81.1% | 57 |68.6% | 58 |- | 59 |84.0% | 60 |72.9% | 61 |- | 62 |
CrackFusionNet | 65 |77.8% | 66 |66.5% | 67 |- | 68 |82.6% | 69 |71.3% | 70 |- | 71 |86.8% | 72 |77.3% | 73 |- | 74 |
Unet-fcn | 77 |80.57% | 78 |71.25% | 79 |84.44% | 80 |84.86% | 81 |74.69% | 82 |87.44% | 83 |89.84% | 84 |82.53% | 85 |91.56% | 86 |
HRNet-OCR | 89 |78.55% | 90 |67.73% | 91 |85.13% | 92 |84.89% | 93 |74.18% | 94 |89.47% | 95 |85.07% | 96 |75.55% | 97 |90.05% | 98 |
Crackmer | 101 |76.54% | 102 |64.92% | 103 |81.48% | 104 |81.78% | 105 |69.72% | 106 |84.79% | 107 |87.32% | 108 |78.25% | 109 |89.93% | 110 |
CT-CrackSeg | 113 |83.55% | 114 |74.39% | 115 |88.61% | 116 |88.51% | 117 |80.17% | 118 |91.85% | 119 |92.75% | 120 |87.06% | 121 |95.03% | 122 |
MedSegDiff | 125 |83.05% | 126 |74.61% | 127 |88.21% | 128 |90.87% | 129 |83.70% | 130 |92.98% | 131 |95.03% | 132 |90.77% | 133 |96.50% | 134 |
CrackSegDiff (Ours) | 137 |84.59% | 138 |77.31% | 139 |89.23% | 140 |92.18% | 141 |86.11% | 142 |93.71% | 143 |95.58% | 144 |91.90% | 145 |96.63% | 146 |
Comparison of CrackSegDiff with state-of-the-art grayscale and depth fused segmentors on the FIND Dataset.
150 | 151 |Qualitative comparison of CrackSegDiff with state-of-the-art segmentation methods. From left to right, the metrics used are F1-Score, IoU, and BF-Score.
153 |