├── environment.yml ├── cm ├── __init__.py ├── __pycache__ │ ├── nn.cpython-38.pyc │ ├── unet.cpython-38.pyc │ ├── logger.cpython-38.pyc │ ├── __init__.cpython-38.pyc │ ├── dist_util.cpython-38.pyc │ ├── fp16_util.cpython-38.pyc │ ├── resample.cpython-38.pyc │ ├── train_util.cpython-38.pyc │ ├── random_util.cpython-38.pyc │ ├── script_util.cpython-38.pyc │ ├── image_datasets.cpython-38.pyc │ └── karras_diffusion.cpython-38.pyc ├── dist_util.py ├── losses.py ├── nn.py ├── random_util.py ├── resample.py ├── script_util.py ├── fp16_util.py ├── image_datasets.py ├── logger.py ├── train_util.py └── unet.py ├── consistency_models ├── __init__.py ├── utils.py └── unet.py ├── __pycache__ ├── SwinUnetr.cpython-38.pyc └── Diffusion_model_transformer.cpython-38.pyc ├── Network ├── __pycache__ │ ├── util_nn.cpython-38.pyc │ ├── SwinUnetr.cpython-38.pyc │ ├── Diffusion_model_Unet_2d.cpython-38.pyc │ └── Diffusion_model_transformer.cpython-38.pyc ├── util_nn.py ├── Diffusion_model_Unet_2d.py └── Diffusion_model_transformer.py ├── data └── pet_38_aligned │ ├── imagesTr_full_2d │ ├── Patient_slice1.mat │ └── Patient_slice2.mat │ └── imagesTs_full_2d │ ├── Patient_slice1.mat │ └── Patient_slice2.mat ├── edm_train.py ├── cm_train.py └── README.md /environment.yml: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cm/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Codebase for "Improved Denoising Diffusion Probabilistic Models". 3 | """ 4 | -------------------------------------------------------------------------------- /consistency_models/__init__.py: -------------------------------------------------------------------------------- 1 | from .unet import ConsistencyModel 2 | from .utils import kerras_boundaries 3 | -------------------------------------------------------------------------------- /cm/__pycache__/nn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaoyanpan/Full-dose-Whole-body-PET-Synthesis-from-Low-dose-PET-Using-Consistency-Model/HEAD/cm/__pycache__/nn.cpython-38.pyc -------------------------------------------------------------------------------- /cm/__pycache__/unet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaoyanpan/Full-dose-Whole-body-PET-Synthesis-from-Low-dose-PET-Using-Consistency-Model/HEAD/cm/__pycache__/unet.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/SwinUnetr.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaoyanpan/Full-dose-Whole-body-PET-Synthesis-from-Low-dose-PET-Using-Consistency-Model/HEAD/__pycache__/SwinUnetr.cpython-38.pyc -------------------------------------------------------------------------------- /cm/__pycache__/logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaoyanpan/Full-dose-Whole-body-PET-Synthesis-from-Low-dose-PET-Using-Consistency-Model/HEAD/cm/__pycache__/logger.cpython-38.pyc -------------------------------------------------------------------------------- /cm/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaoyanpan/Full-dose-Whole-body-PET-Synthesis-from-Low-dose-PET-Using-Consistency-Model/HEAD/cm/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /cm/__pycache__/dist_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaoyanpan/Full-dose-Whole-body-PET-Synthesis-from-Low-dose-PET-Using-Consistency-Model/HEAD/cm/__pycache__/dist_util.cpython-38.pyc -------------------------------------------------------------------------------- /cm/__pycache__/fp16_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaoyanpan/Full-dose-Whole-body-PET-Synthesis-from-Low-dose-PET-Using-Consistency-Model/HEAD/cm/__pycache__/fp16_util.cpython-38.pyc -------------------------------------------------------------------------------- /cm/__pycache__/resample.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaoyanpan/Full-dose-Whole-body-PET-Synthesis-from-Low-dose-PET-Using-Consistency-Model/HEAD/cm/__pycache__/resample.cpython-38.pyc -------------------------------------------------------------------------------- /cm/__pycache__/train_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaoyanpan/Full-dose-Whole-body-PET-Synthesis-from-Low-dose-PET-Using-Consistency-Model/HEAD/cm/__pycache__/train_util.cpython-38.pyc -------------------------------------------------------------------------------- /Network/__pycache__/util_nn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaoyanpan/Full-dose-Whole-body-PET-Synthesis-from-Low-dose-PET-Using-Consistency-Model/HEAD/Network/__pycache__/util_nn.cpython-38.pyc -------------------------------------------------------------------------------- /cm/__pycache__/random_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaoyanpan/Full-dose-Whole-body-PET-Synthesis-from-Low-dose-PET-Using-Consistency-Model/HEAD/cm/__pycache__/random_util.cpython-38.pyc -------------------------------------------------------------------------------- /cm/__pycache__/script_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaoyanpan/Full-dose-Whole-body-PET-Synthesis-from-Low-dose-PET-Using-Consistency-Model/HEAD/cm/__pycache__/script_util.cpython-38.pyc -------------------------------------------------------------------------------- /Network/__pycache__/SwinUnetr.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaoyanpan/Full-dose-Whole-body-PET-Synthesis-from-Low-dose-PET-Using-Consistency-Model/HEAD/Network/__pycache__/SwinUnetr.cpython-38.pyc -------------------------------------------------------------------------------- /cm/__pycache__/image_datasets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaoyanpan/Full-dose-Whole-body-PET-Synthesis-from-Low-dose-PET-Using-Consistency-Model/HEAD/cm/__pycache__/image_datasets.cpython-38.pyc -------------------------------------------------------------------------------- /cm/__pycache__/karras_diffusion.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaoyanpan/Full-dose-Whole-body-PET-Synthesis-from-Low-dose-PET-Using-Consistency-Model/HEAD/cm/__pycache__/karras_diffusion.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/Diffusion_model_transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaoyanpan/Full-dose-Whole-body-PET-Synthesis-from-Low-dose-PET-Using-Consistency-Model/HEAD/__pycache__/Diffusion_model_transformer.cpython-38.pyc -------------------------------------------------------------------------------- /data/pet_38_aligned/imagesTr_full_2d/Patient_slice1.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaoyanpan/Full-dose-Whole-body-PET-Synthesis-from-Low-dose-PET-Using-Consistency-Model/HEAD/data/pet_38_aligned/imagesTr_full_2d/Patient_slice1.mat -------------------------------------------------------------------------------- /data/pet_38_aligned/imagesTr_full_2d/Patient_slice2.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaoyanpan/Full-dose-Whole-body-PET-Synthesis-from-Low-dose-PET-Using-Consistency-Model/HEAD/data/pet_38_aligned/imagesTr_full_2d/Patient_slice2.mat -------------------------------------------------------------------------------- /data/pet_38_aligned/imagesTs_full_2d/Patient_slice1.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaoyanpan/Full-dose-Whole-body-PET-Synthesis-from-Low-dose-PET-Using-Consistency-Model/HEAD/data/pet_38_aligned/imagesTs_full_2d/Patient_slice1.mat -------------------------------------------------------------------------------- /data/pet_38_aligned/imagesTs_full_2d/Patient_slice2.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaoyanpan/Full-dose-Whole-body-PET-Synthesis-from-Low-dose-PET-Using-Consistency-Model/HEAD/data/pet_38_aligned/imagesTs_full_2d/Patient_slice2.mat -------------------------------------------------------------------------------- /Network/__pycache__/Diffusion_model_Unet_2d.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaoyanpan/Full-dose-Whole-body-PET-Synthesis-from-Low-dose-PET-Using-Consistency-Model/HEAD/Network/__pycache__/Diffusion_model_Unet_2d.cpython-38.pyc -------------------------------------------------------------------------------- /Network/__pycache__/Diffusion_model_transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaoyanpan/Full-dose-Whole-body-PET-Synthesis-from-Low-dose-PET-Using-Consistency-Model/HEAD/Network/__pycache__/Diffusion_model_transformer.cpython-38.pyc -------------------------------------------------------------------------------- /consistency_models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def kerras_boundaries(sigma, eps, N, T): 5 | # This will be used to generate the boundaries for the time discretization 6 | 7 | return torch.tensor( 8 | [ 9 | (eps ** (1 / sigma) + i / (N - 1) * (T ** (1 / sigma) - eps ** (1 / sigma))) 10 | ** sigma 11 | for i in range(N) 12 | ] 13 | ) 14 | -------------------------------------------------------------------------------- /cm/dist_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for distributed training. 3 | """ 4 | 5 | import io 6 | import os 7 | import socket 8 | 9 | import blobfile as bf 10 | # from mpi4py import MPI 11 | import torch as th 12 | import torch.distributed as dist 13 | 14 | # Change this to reflect your cluster layout. 15 | # The GPU for a given rank is (rank % GPUS_PER_NODE). 16 | GPUS_PER_NODE = 8 17 | 18 | SETUP_RETRY_COUNT = 3 19 | 20 | 21 | def setup_dist(): 22 | """ 23 | Setup a distributed process group. 24 | """ 25 | if dist.is_initialized(): 26 | return 27 | os.environ["CUDA_VISIBLE_DEVICES"] = f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}" 28 | 29 | comm = MPI.COMM_WORLD 30 | backend = "gloo" if not th.cuda.is_available() else "nccl" 31 | 32 | if backend == "gloo": 33 | hostname = "localhost" 34 | else: 35 | hostname = socket.gethostbyname(socket.getfqdn()) 36 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0) 37 | os.environ["RANK"] = str(comm.rank) 38 | os.environ["WORLD_SIZE"] = str(comm.size) 39 | 40 | port = comm.bcast(_find_free_port(), root=0) 41 | os.environ["MASTER_PORT"] = str(port) 42 | dist.init_process_group(backend=backend, init_method="env://") 43 | 44 | 45 | def dev(): 46 | """ 47 | Get the device to use for torch.distributed. 48 | """ 49 | if th.cuda.is_available(): 50 | return th.device("cuda") 51 | return th.device("cpu") 52 | 53 | 54 | def load_state_dict(path, **kwargs): 55 | """ 56 | Load a PyTorch file without redundant fetches across MPI ranks. 57 | """ 58 | chunk_size = 2**30 # MPI has a relatively small size limit 59 | if MPI.COMM_WORLD.Get_rank() == 0: 60 | with bf.BlobFile(path, "rb") as f: 61 | data = f.read() 62 | num_chunks = len(data) // chunk_size 63 | if len(data) % chunk_size: 64 | num_chunks += 1 65 | MPI.COMM_WORLD.bcast(num_chunks) 66 | for i in range(0, len(data), chunk_size): 67 | MPI.COMM_WORLD.bcast(data[i : i + chunk_size]) 68 | else: 69 | num_chunks = MPI.COMM_WORLD.bcast(None) 70 | data = bytes() 71 | for _ in range(num_chunks): 72 | data += MPI.COMM_WORLD.bcast(None) 73 | 74 | return th.load(io.BytesIO(data), **kwargs) 75 | 76 | 77 | def sync_params(params): 78 | """ 79 | Synchronize a sequence of Tensors across ranks from rank 0. 80 | """ 81 | for p in params: 82 | with th.no_grad(): 83 | dist.broadcast(p, 0) 84 | 85 | 86 | def _find_free_port(): 87 | try: 88 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 89 | s.bind(("", 0)) 90 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 91 | return s.getsockname()[1] 92 | finally: 93 | s.close() 94 | -------------------------------------------------------------------------------- /cm/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for various likelihood-based losses. These are ported from the original 3 | Ho et al. diffusion models codebase: 4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 5 | """ 6 | 7 | import numpy as np 8 | 9 | import torch as th 10 | 11 | 12 | def normal_kl(mean1, logvar1, mean2, logvar2): 13 | """ 14 | Compute the KL divergence between two gaussians. 15 | 16 | Shapes are automatically broadcasted, so batches can be compared to 17 | scalars, among other use cases. 18 | """ 19 | tensor = None 20 | for obj in (mean1, logvar1, mean2, logvar2): 21 | if isinstance(obj, th.Tensor): 22 | tensor = obj 23 | break 24 | assert tensor is not None, "at least one argument must be a Tensor" 25 | 26 | # Force variances to be Tensors. Broadcasting helps convert scalars to 27 | # Tensors, but it does not work for th.exp(). 28 | logvar1, logvar2 = [ 29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 30 | for x in (logvar1, logvar2) 31 | ] 32 | 33 | return 0.5 * ( 34 | -1.0 35 | + logvar2 36 | - logvar1 37 | + th.exp(logvar1 - logvar2) 38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 39 | ) 40 | 41 | 42 | def approx_standard_normal_cdf(x): 43 | """ 44 | A fast approximation of the cumulative distribution function of the 45 | standard normal. 46 | """ 47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 48 | 49 | 50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 51 | """ 52 | Compute the log-likelihood of a Gaussian distribution discretizing to a 53 | given image. 54 | 55 | :param x: the target images. It is assumed that this was uint8 values, 56 | rescaled to the range [-1, 1]. 57 | :param means: the Gaussian mean Tensor. 58 | :param log_scales: the Gaussian log stddev Tensor. 59 | :return: a tensor like x of log probabilities (in nats). 60 | """ 61 | assert x.shape == means.shape == log_scales.shape 62 | centered_x = x - means 63 | inv_stdv = th.exp(-log_scales) 64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 65 | cdf_plus = approx_standard_normal_cdf(plus_in) 66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 67 | cdf_min = approx_standard_normal_cdf(min_in) 68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 70 | cdf_delta = cdf_plus - cdf_min 71 | log_probs = th.where( 72 | x < -0.999, 73 | log_cdf_plus, 74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 75 | ) 76 | assert log_probs.shape == x.shape 77 | return log_probs 78 | -------------------------------------------------------------------------------- /edm_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a diffusion model on images. 3 | """ 4 | 5 | import argparse 6 | 7 | from cm import dist_util, logger 8 | from cm.image_datasets import load_data 9 | from cm.resample import create_named_schedule_sampler 10 | from cm.script_util import ( 11 | model_and_diffusion_defaults, 12 | create_model_and_diffusion, 13 | args_to_dict, 14 | add_dict_to_argparser, 15 | ) 16 | from cm.train_util import TrainLoop 17 | import torch.distributed as dist 18 | 19 | 20 | def main(): 21 | args = create_argparser().parse_args() 22 | 23 | dist_util.setup_dist() 24 | logger.configure() 25 | 26 | logger.log("creating model and diffusion...") 27 | model, diffusion = create_model_and_diffusion( 28 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 29 | ) 30 | model.to(dist_util.dev()) 31 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) 32 | 33 | logger.log("creating data loader...") 34 | if args.batch_size == -1: 35 | batch_size = args.global_batch_size // dist.get_world_size() 36 | if args.global_batch_size % dist.get_world_size() != 0: 37 | logger.log( 38 | f"warning, using smaller global_batch_size of {dist.get_world_size()*batch_size} instead of {args.global_batch_size}" 39 | ) 40 | else: 41 | batch_size = args.batch_size 42 | 43 | data = load_data( 44 | data_dir=args.data_dir, 45 | batch_size=batch_size, 46 | image_size=args.image_size, 47 | class_cond=args.class_cond, 48 | ) 49 | 50 | logger.log("creating data loader...") 51 | 52 | logger.log("training...") 53 | TrainLoop( 54 | model=model, 55 | diffusion=diffusion, 56 | data=data, 57 | batch_size=batch_size, 58 | microbatch=args.microbatch, 59 | lr=args.lr, 60 | ema_rate=args.ema_rate, 61 | log_interval=args.log_interval, 62 | save_interval=args.save_interval, 63 | resume_checkpoint=args.resume_checkpoint, 64 | use_fp16=args.use_fp16, 65 | fp16_scale_growth=args.fp16_scale_growth, 66 | schedule_sampler=schedule_sampler, 67 | weight_decay=args.weight_decay, 68 | lr_anneal_steps=args.lr_anneal_steps, 69 | ).run_loop() 70 | 71 | 72 | def create_argparser(): 73 | defaults = dict( 74 | data_dir="", 75 | schedule_sampler="uniform", 76 | lr=1e-4, 77 | weight_decay=0.0, 78 | lr_anneal_steps=0, 79 | global_batch_size=2048, 80 | batch_size=-1, 81 | microbatch=-1, # -1 disables microbatches 82 | ema_rate="0.9999", # comma-separated list of EMA values 83 | log_interval=10, 84 | save_interval=10000, 85 | resume_checkpoint="", 86 | use_fp16=False, 87 | fp16_scale_growth=1e-3, 88 | ) 89 | defaults.update(model_and_diffusion_defaults()) 90 | parser = argparse.ArgumentParser() 91 | add_dict_to_argparser(parser, defaults) 92 | return parser 93 | 94 | 95 | if __name__ == "__main__": 96 | main() 97 | -------------------------------------------------------------------------------- /consistency_models/unet.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List 3 | 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | blk = lambda ic, oc: nn.Sequential( 11 | nn.GroupNorm(32, num_channels=ic), 12 | nn.SiLU(), 13 | nn.Conv2d(ic, oc, 3, padding=1), 14 | nn.GroupNorm(32, num_channels=oc), 15 | nn.SiLU(), 16 | nn.Conv2d(oc, oc, 3, padding=1), 17 | ) 18 | 19 | 20 | class ConsistencyModel(nn.Module): 21 | """ 22 | This is ridiculous Unet structure, hey but it works! 23 | """ 24 | 25 | def __init__(self, n_channel: int, eps: float = 0.002, D: int = 128) -> None: 26 | super(ConsistencyModel, self).__init__() 27 | 28 | self.eps = eps 29 | 30 | self.freqs = torch.exp( 31 | -math.log(10000) * torch.arange(start=0, end=D, dtype=torch.float32) / D 32 | ) 33 | 34 | self.down = nn.Sequential( 35 | *[ 36 | nn.Conv2d(n_channel, D, 3, padding=1), 37 | blk(D, D), 38 | blk(D, 2 * D), 39 | blk(2 * D, 2 * D), 40 | ] 41 | ) 42 | 43 | self.time_downs = nn.Sequential( 44 | nn.Linear(2 * D, D), 45 | nn.Linear(2 * D, D), 46 | nn.Linear(2 * D, 2 * D), 47 | nn.Linear(2 * D, 2 * D), 48 | ) 49 | 50 | self.mid = blk(2 * D, 2 * D) 51 | 52 | self.up = nn.Sequential( 53 | *[ 54 | blk(2 * D, 2 * D), 55 | blk(2 * 2 * D, D), 56 | blk(D, D), 57 | nn.Conv2d(2 * D, 2 * D, 3, padding=1), 58 | ] 59 | ) 60 | self.last = nn.Conv2d(2 * D + n_channel, n_channel, 3, padding=1) 61 | 62 | def forward(self, x, t) -> torch.Tensor: 63 | if isinstance(t, float): 64 | t = ( 65 | torch.tensor([t] * x.shape[0], dtype=torch.float32) 66 | .to(x.device) 67 | .unsqueeze(1) 68 | ) 69 | # time embedding 70 | args = t.float() * self.freqs[None].to(t.device) 71 | t_emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1).to(x.device) 72 | 73 | x_ori = x 74 | 75 | # perform F(x, t) 76 | hs = [] 77 | for idx, layer in enumerate(self.down): 78 | if idx % 2 == 1: 79 | x = layer(x) + x 80 | else: 81 | x = layer(x) 82 | x = F.interpolate(x, scale_factor=0.5) 83 | hs.append(x) 84 | 85 | x = x + self.time_downs[idx](t_emb)[:, :, None, None] 86 | 87 | x = self.mid(x) 88 | 89 | for idx, layer in enumerate(self.up): 90 | if idx % 2 == 0: 91 | x = layer(x) + x 92 | else: 93 | x = torch.cat([x, hs.pop()], dim=1) 94 | x = F.interpolate(x, scale_factor=2, mode="nearest") 95 | x = layer(x) 96 | 97 | x = self.last(torch.cat([x, x_ori], dim=1)) 98 | 99 | t = t - self.eps 100 | c_skip_t = 0.25 / (t.pow(2) + 0.25) 101 | c_out_t = 0.25 * t / ((t + self.eps).pow(2) + 0.25).pow(0.5) 102 | 103 | return c_skip_t[:, :, None, None] * x_ori + c_out_t[:, :, None, None] * x 104 | 105 | def loss(self, x, z, t1, t2, ema_model): 106 | x2 = x + z * t2[:, :, None, None] 107 | x2 = self(x2, t2) 108 | 109 | with torch.no_grad(): 110 | x1 = x + z * t1[:, :, None, None] 111 | x1 = ema_model(x1, t1) 112 | 113 | return F.mse_loss(x1, x2) 114 | 115 | @torch.no_grad() 116 | def sample(self, x, ts: List[float]): 117 | x = self(x, ts[0]) 118 | 119 | for t in ts[1:]: 120 | z = torch.randn_like(x) 121 | x = x + math.sqrt(t**2 - self.eps**2) * z 122 | x = self(x, t) 123 | 124 | return x 125 | -------------------------------------------------------------------------------- /Network/util_nn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Oct 14 17:44:01 2022 4 | 5 | @author: mhu58 6 | """ 7 | 8 | import math 9 | 10 | import torch as th 11 | import torch.nn as nn 12 | 13 | 14 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 15 | class SiLU(nn.Module): 16 | def forward(self, x): 17 | return x * th.sigmoid(x) 18 | 19 | 20 | class GroupNorm32(nn.GroupNorm): 21 | def forward(self, x): 22 | return super().forward(x.float()).type(x.dtype) 23 | 24 | 25 | def conv_nd(dims, *args, **kwargs): 26 | """ 27 | Create a 1D, 2D, or 3D convolution module. 28 | """ 29 | if dims == 1: 30 | return nn.Conv1d(*args, **kwargs) 31 | elif dims == 2: 32 | return nn.Conv2d(*args, **kwargs) 33 | elif dims == 3: 34 | return nn.Conv3d(*args, **kwargs) 35 | raise ValueError(f"unsupported dimensions: {dims}") 36 | 37 | 38 | def linear(*args, **kwargs): 39 | """ 40 | Create a linear module. 41 | """ 42 | return nn.Linear(*args, **kwargs) 43 | 44 | 45 | def avg_pool_nd(dims, *args, **kwargs): 46 | """ 47 | Create a 1D, 2D, or 3D average pooling module. 48 | """ 49 | if dims == 1: 50 | return nn.AvgPool1d(*args, **kwargs) 51 | elif dims == 2: 52 | return nn.AvgPool2d(*args, **kwargs) 53 | elif dims == 3: 54 | return nn.AvgPool3d(*args, **kwargs) 55 | raise ValueError(f"unsupported dimensions: {dims}") 56 | 57 | 58 | def update_ema(target_params, source_params, rate=0.99): 59 | """ 60 | Update target parameters to be closer to those of source parameters using 61 | an exponential moving average. 62 | :param target_params: the target parameter sequence. 63 | :param source_params: the source parameter sequence. 64 | :param rate: the EMA rate (closer to 1 means slower). 65 | """ 66 | for targ, src in zip(target_params, source_params): 67 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 68 | 69 | 70 | def zero_module(module): 71 | """ 72 | Zero out the parameters of a module and return it. 73 | """ 74 | for p in module.parameters(): 75 | p.detach().zero_() 76 | return module 77 | 78 | 79 | def scale_module(module, scale): 80 | """ 81 | Scale the parameters of a module and return it. 82 | """ 83 | for p in module.parameters(): 84 | p.detach().mul_(scale) 85 | return module 86 | 87 | 88 | def mean_flat(tensor): 89 | """ 90 | Take the mean over all non-batch dimensions. 91 | """ 92 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 93 | 94 | 95 | def normalization(channels): 96 | """ 97 | Make a standard normalization layer. 98 | :param channels: number of input channels. 99 | :return: an nn.Module for normalization. 100 | """ 101 | return GroupNorm32(32,channels) 102 | 103 | 104 | def timestep_embedding(timesteps, dim, max_period=10000): 105 | """ 106 | Create sinusoidal timestep embeddings. 107 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 108 | These may be fractional. 109 | :param dim: the dimension of the output. 110 | :param max_period: controls the minimum frequency of the embeddings. 111 | :return: an [N x dim] Tensor of positional embeddings. 112 | """ 113 | half = dim // 2 114 | freqs = th.exp( 115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 116 | ).to(device=timesteps.device) 117 | args = timesteps[:, None].float() * freqs[None] 118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 119 | if dim % 2: 120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 121 | return embedding 122 | 123 | 124 | def checkpoint(func, inputs, params, flag): 125 | """ 126 | Evaluate a function without caching intermediate activations, allowing for 127 | reduced memory at the expense of extra compute in the backward pass. 128 | :param func: the function to evaluate. 129 | :param inputs: the argument sequence to pass to `func`. 130 | :param params: a sequence of parameters `func` depends on but does not 131 | explicitly take as arguments. 132 | :param flag: if False, disable gradient checkpointing. 133 | """ 134 | if flag: 135 | args = tuple(inputs) + tuple(params) 136 | return CheckpointFunction.apply(func, len(inputs), *args) 137 | else: 138 | return func(*inputs) 139 | 140 | 141 | class CheckpointFunction(th.autograd.Function): 142 | @staticmethod 143 | def forward(ctx, run_function, length, *args): 144 | ctx.run_function = run_function 145 | ctx.input_tensors = list(args[:length]) 146 | ctx.input_params = list(args[length:]) 147 | with th.no_grad(): 148 | output_tensors = ctx.run_function(*ctx.input_tensors) 149 | return output_tensors 150 | 151 | @staticmethod 152 | def backward(ctx, *output_grads): 153 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 154 | with th.enable_grad(): 155 | # Fixes a bug where the first op in run_function modifies the 156 | # Tensor storage in place, which is not allowed for detach()'d 157 | # Tensors. 158 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 159 | output_tensors = ctx.run_function(*shallow_copies) 160 | input_grads = th.autograd.grad( 161 | output_tensors, 162 | ctx.input_tensors + ctx.input_params, 163 | output_grads, 164 | allow_unused=True, 165 | ) 166 | del ctx.input_tensors 167 | del ctx.input_params 168 | del output_tensors 169 | return (None, None) + input_grads -------------------------------------------------------------------------------- /cm_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a diffusion model on images. 3 | """ 4 | 5 | import argparse 6 | 7 | # from cm import dist_util, logger 8 | from cm.image_datasets import load_data 9 | from cm.resample import create_named_schedule_sampler 10 | from cm.script_util import ( 11 | model_and_diffusion_defaults, 12 | create_model_and_diffusion, 13 | cm_train_defaults, 14 | args_to_dict, 15 | add_dict_to_argparser, 16 | create_ema_and_scales_fn, 17 | ) 18 | from cm.train_util import CMTrainLoop 19 | import torch.distributed as dist 20 | import copy 21 | 22 | 23 | def main(): 24 | args = create_argparser().parse_args() 25 | 26 | # dist_util.setup_dist() 27 | # logger.configure() 28 | 29 | # logger.log("creating model and diffusion...") 30 | ema_scale_fn = create_ema_and_scales_fn( 31 | target_ema_mode=args.target_ema_mode, 32 | start_ema=args.start_ema, 33 | scale_mode=args.scale_mode, 34 | start_scales=args.start_scales, 35 | end_scales=args.end_scales, 36 | total_steps=args.total_training_steps, 37 | distill_steps_per_iter=args.distill_steps_per_iter, 38 | ) 39 | if args.training_mode == "progdist": 40 | distillation = False 41 | elif "consistency" in args.training_mode: 42 | distillation = True 43 | else: 44 | raise ValueError(f"unknown training mode {args.training_mode}") 45 | 46 | model_and_diffusion_kwargs = args_to_dict( 47 | args, model_and_diffusion_defaults().keys() 48 | ) 49 | model_and_diffusion_kwargs["distillation"] = distillation 50 | model, diffusion = create_model_and_diffusion(**model_and_diffusion_kwargs) 51 | # model.to(dist_util.dev()) 52 | model.train() 53 | if args.use_fp16: 54 | model.convert_to_fp16() 55 | 56 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) 57 | 58 | # logger.log("creating data loader...") 59 | # if args.batch_size == -1: 60 | # batch_size = args.global_batch_size // dist.get_world_size() 61 | # if args.global_batch_size % dist.get_world_size() != 0: 62 | # logger.log( 63 | # f"warning, using smaller global_batch_size of {dist.get_world_size()*batch_size} instead of {args.global_batch_size}" 64 | # ) 65 | # else: 66 | # batch_size = args.batch_size 67 | 68 | # data = load_data( 69 | # data_dir=args.data_dir, 70 | # batch_size=batch_size, 71 | # image_size=args.image_size, 72 | # class_cond=args.class_cond, 73 | # ) 74 | 75 | if len(args.teacher_model_path) > 0: # path to the teacher score model. 76 | # logger.log(f"loading the teacher model from {args.teacher_model_path}") 77 | teacher_model_and_diffusion_kwargs = copy.deepcopy(model_and_diffusion_kwargs) 78 | teacher_model_and_diffusion_kwargs["dropout"] = args.teacher_dropout 79 | teacher_model_and_diffusion_kwargs["distillation"] = False 80 | teacher_model, teacher_diffusion = create_model_and_diffusion( 81 | **teacher_model_and_diffusion_kwargs, 82 | ) 83 | 84 | # teacher_model.load_state_dict( 85 | # dist_util.load_state_dict(args.teacher_model_path, map_location="cpu"), 86 | # ) 87 | 88 | # teacher_model.to(dist_util.dev()) 89 | teacher_model.eval() 90 | 91 | for dst, src in zip(model.parameters(), teacher_model.parameters()): 92 | dst.data.copy_(src.data) 93 | 94 | if args.use_fp16: 95 | teacher_model.convert_to_fp16() 96 | 97 | else: 98 | teacher_model = None 99 | teacher_diffusion = None 100 | 101 | # load the target model for distillation, if path specified. 102 | 103 | # logger.log("creating the target model") 104 | target_model, _ = create_model_and_diffusion( 105 | **model_and_diffusion_kwargs, 106 | ) 107 | 108 | # target_model.to(dist_util.dev()) 109 | target_model.train() 110 | 111 | # dist_util.sync_params(target_model.parameters()) 112 | # dist_util.sync_params(target_model.buffers()) 113 | 114 | for dst, src in zip(target_model.parameters(), model.parameters()): 115 | dst.data.copy_(src.data) 116 | 117 | if args.use_fp16: 118 | target_model.convert_to_fp16() 119 | 120 | # logger.log("training...") 121 | CMTrainLoop( 122 | model=model, 123 | target_model=target_model, 124 | teacher_model=teacher_model, 125 | teacher_diffusion=teacher_diffusion, 126 | training_mode=args.training_mode, 127 | ema_scale_fn=ema_scale_fn, 128 | total_training_steps=args.total_training_steps, 129 | diffusion=diffusion, 130 | data=None, 131 | batch_size=None, 132 | microbatch=args.microbatch, 133 | lr=args.lr, 134 | ema_rate=args.ema_rate, 135 | log_interval=args.log_interval, 136 | save_interval=args.save_interval, 137 | resume_checkpoint=args.resume_checkpoint, 138 | use_fp16=args.use_fp16, 139 | fp16_scale_growth=args.fp16_scale_growth, 140 | schedule_sampler=schedule_sampler, 141 | weight_decay=args.weight_decay, 142 | lr_anneal_steps=args.lr_anneal_steps, 143 | ).run_loop() 144 | 145 | 146 | def create_argparser(): 147 | defaults = dict( 148 | data_dir="", 149 | schedule_sampler="uniform", 150 | lr=1e-4, 151 | weight_decay=0.0, 152 | lr_anneal_steps=0, 153 | global_batch_size=2048, 154 | batch_size=-1, 155 | microbatch=-1, # -1 disables microbatches 156 | ema_rate="0.9999", # comma-separated list of EMA values 157 | log_interval=10, 158 | save_interval=10000, 159 | resume_checkpoint="", 160 | use_fp16=False, 161 | fp16_scale_growth=1e-3, 162 | ) 163 | defaults.update(model_and_diffusion_defaults()) 164 | defaults.update(cm_train_defaults()) 165 | parser = argparse.ArgumentParser() 166 | add_dict_to_argparser(parser, defaults) 167 | return parser 168 | 169 | 170 | if __name__ == "__main__": 171 | main() 172 | -------------------------------------------------------------------------------- /cm/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | import numpy as np 10 | import torch.nn.functional as F 11 | 12 | 13 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 14 | class SiLU(nn.Module): 15 | def forward(self, x): 16 | return x * th.sigmoid(x) 17 | 18 | 19 | class GroupNorm32(nn.GroupNorm): 20 | def forward(self, x): 21 | return super().forward(x.float()).type(x.dtype) 22 | 23 | 24 | def conv_nd(dims, *args, **kwargs): 25 | """ 26 | Create a 1D, 2D, or 3D convolution module. 27 | """ 28 | if dims == 1: 29 | return nn.Conv1d(*args, **kwargs) 30 | elif dims == 2: 31 | return nn.Conv2d(*args, **kwargs) 32 | elif dims == 3: 33 | return nn.Conv3d(*args, **kwargs) 34 | raise ValueError(f"unsupported dimensions: {dims}") 35 | 36 | 37 | def linear(*args, **kwargs): 38 | """ 39 | Create a linear module. 40 | """ 41 | return nn.Linear(*args, **kwargs) 42 | 43 | 44 | def avg_pool_nd(dims, *args, **kwargs): 45 | """ 46 | Create a 1D, 2D, or 3D average pooling module. 47 | """ 48 | if dims == 1: 49 | return nn.AvgPool1d(*args, **kwargs) 50 | elif dims == 2: 51 | return nn.AvgPool2d(*args, **kwargs) 52 | elif dims == 3: 53 | return nn.AvgPool3d(*args, **kwargs) 54 | raise ValueError(f"unsupported dimensions: {dims}") 55 | 56 | 57 | def update_ema(target_params, source_params, rate=0.99): 58 | """ 59 | Update target parameters to be closer to those of source parameters using 60 | an exponential moving average. 61 | 62 | :param target_params: the target parameter sequence. 63 | :param source_params: the source parameter sequence. 64 | :param rate: the EMA rate (closer to 1 means slower). 65 | """ 66 | for targ, src in zip(target_params, source_params): 67 | targ[1].detach().mul_(rate).add_(src[1], alpha=1 - rate) 68 | 69 | 70 | def zero_module(module): 71 | """ 72 | Zero out the parameters of a module and return it. 73 | """ 74 | for p in module.parameters(): 75 | p.detach().zero_() 76 | return module 77 | 78 | 79 | def scale_module(module, scale): 80 | """ 81 | Scale the parameters of a module and return it. 82 | """ 83 | for p in module.parameters(): 84 | p.detach().mul_(scale) 85 | return module 86 | 87 | 88 | def mean_flat(tensor): 89 | """ 90 | Take the mean over all non-batch dimensions. 91 | """ 92 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 93 | # return th.nn.L1Loss()(tensor,tensor2) 94 | 95 | 96 | def append_dims(x, target_dims): 97 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" 98 | dims_to_append = target_dims - x.ndim 99 | if dims_to_append < 0: 100 | raise ValueError( 101 | f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" 102 | ) 103 | return x[(...,) + (None,) * dims_to_append] 104 | 105 | 106 | def append_zero(x): 107 | return th.cat([x, x.new_zeros([1])]) 108 | 109 | 110 | def normalization(channels): 111 | """ 112 | Make a standard normalization layer. 113 | 114 | :param channels: number of input channels. 115 | :return: an nn.Module for normalization. 116 | """ 117 | return GroupNorm32(32, channels) 118 | 119 | 120 | def timestep_embedding(timesteps, dim, max_period=10000): 121 | """ 122 | Create sinusoidal timestep embeddings. 123 | 124 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 125 | These may be fractional. 126 | :param dim: the dimension of the output. 127 | :param max_period: controls the minimum frequency of the embeddings. 128 | :return: an [N x dim] Tensor of positional embeddings. 129 | """ 130 | half = dim // 2 131 | freqs = th.exp( 132 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 133 | ).to(device=timesteps.device) 134 | args = timesteps[:, None].float() * freqs[None] 135 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 136 | if dim % 2: 137 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 138 | return embedding 139 | 140 | 141 | def checkpoint(func, inputs, params, flag): 142 | """ 143 | Evaluate a function without caching intermediate activations, allowing for 144 | reduced memory at the expense of extra compute in the backward pass. 145 | 146 | :param func: the function to evaluate. 147 | :param inputs: the argument sequence to pass to `func`. 148 | :param params: a sequence of parameters `func` depends on but does not 149 | explicitly take as arguments. 150 | :param flag: if False, disable gradient checkpointing. 151 | """ 152 | if flag: 153 | args = tuple(inputs) + tuple(params) 154 | return CheckpointFunction.apply(func, len(inputs), *args) 155 | else: 156 | return func(*inputs) 157 | 158 | 159 | class CheckpointFunction(th.autograd.Function): 160 | @staticmethod 161 | def forward(ctx, run_function, length, *args): 162 | ctx.run_function = run_function 163 | ctx.input_tensors = list(args[:length]) 164 | ctx.input_params = list(args[length:]) 165 | with th.no_grad(): 166 | output_tensors = ctx.run_function(*ctx.input_tensors) 167 | return output_tensors 168 | 169 | @staticmethod 170 | def backward(ctx, *output_grads): 171 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 172 | with th.enable_grad(): 173 | # Fixes a bug where the first op in run_function modifies the 174 | # Tensor storage in place, which is not allowed for detach()'d 175 | # Tensors. 176 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 177 | output_tensors = ctx.run_function(*shallow_copies) 178 | input_grads = th.autograd.grad( 179 | output_tensors, 180 | ctx.input_tensors + ctx.input_params, 181 | output_grads, 182 | allow_unused=True, 183 | ) 184 | del ctx.input_tensors 185 | del ctx.input_params 186 | del output_tensors 187 | return (None, None) + input_grads 188 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # 2D-Medical-Consistency-Model 3 | **This is the repository for the paper published in Medical Physics: "[Full-dose Whole-body PET Synthesis from Low-dose PET Using High-efficiency Denoising Diffusion Probabilistic Model: PET Consistency Model](https://aapm.onlinelibrary.wiley.com/doi/10.1002/mp.17068)".** 4 | 5 | Consistency Model is one of the super fast Denoising Diffusion Probability Models (DDPMs), which only use 2-timestep to generate the target image, while the DDPMs usually require 50- to 1000-timesteps. This is particular useful for: 1) Three-dimensional Medical image synthesis, 2) Image translation instead image creation like traditional DDPMs do. 6 | 7 | The codes were created based on [image-guided diffusion](https://github.com/openai/guided-diffusion), [SwinUnet](https://github.com/HuCaoFighting/Swin-Unet), and [Monai](https://monai.io/) 8 | 9 | Notice: Due to the data restriction, we can only provide MATLAB file (so no patient information) with over-smoothed PET images. The data we show just to demonstrate how the user should organize their data. The dicom or nii file processing are also included in the Jupyter notebook. 10 | 11 | # Required packages 12 | 13 | The requires packages are in test_env.yaml. 14 | 15 | Create an environment using Anaconda: 16 | ``` 17 | conda env create -f \your directory\test_env.yaml 18 | ``` 19 | 20 | # How to organize your data 21 | The data organization example is shown in folder "data/pet_38_aligned". Or you can see the below screenshots: 22 | ![image](https://github.com/shaoyanpan/Full-dose-Whole-body-PET-Synthesis-from-Low-dose-PET-Using-Consistency-Model/assets/89927506/a2fdf7af-25be-47d7-8b49-7bc7c2c2468f) 23 | MATLAB files: every matlab file can contain a dict has image and label together. So you see you only need two folders: imagesTr_full_2d for training, imagesTs_full_2d for testing. You can change the name but please make sure also change the reading dir in the jupyter notebook. 24 | 25 | ![image](https://github.com/shaoyanpan/Full-dose-Whole-body-PET-Synthesis-from-Low-dose-PET-Using-Consistency-Model/assets/89927506/a7bf529f-3e4e-4e58-b0fe-3a87fb5ecbe9) 26 | Nii files: one nii file can only contain either image or label. So in this case, you need imagesTr and labelsTr for training, imagesTs and labelsTs for testing, and imagesVal and labelsVal for validation 27 | 28 | # Usage 29 | 30 | The usage is in the jupyter notebook Consistency_Low_Dose_Denoising_main.ipynb. Including how to build the consistency-diffusion forward process, how to build a network, and how to call the whole Consistency process to train, and sample new synthetic images. However, we give simple example below: 31 | 32 | **Create Consistency-diffusion** 33 | ``` 34 | from cm.resample import UniformSampler 35 | from cm.karras_diffusion import KarrasDenoiser,karras_sample 36 | consistency = KarrasDenoiser( 37 | sigma_data=0.5, 38 | sigma_max=80.0, 39 | sigma_min=0.002, 40 | rho=7.0, 41 | weight_schedule="karras", 42 | distillation=False, 43 | loss_norm="l1") 44 | 45 | schedule_sampler = UniformSampler(consistency) 46 | ``` 47 | 48 | **Create network for input image with size of 64x64 (Notice this is because we apply the 64x64 patch-based training and inference for our 96x196 low-dose PET images** 49 | ``` 50 | from Diffusion_model_transformer import * 51 | 52 | num_channels=128 53 | attention_resolutions="16,8" 54 | channel_mult = (1, 2, 3, 4) 55 | num_heads=[4,4,8,16] 56 | window_size = [[4,4],[4,4],[4,4],[4,4]] 57 | num_res_blocks = [2,2,2,2] 58 | sample_kernel=([2,2],[2,2],[2,2]), 59 | 60 | attention_ds = [] 61 | for res in attention_resolutions.split(","): 62 | # Careful for the image_size//int(res), only use for CNN 63 | attention_ds.append(image_size//int(res)) 64 | class_cond = False 65 | use_scale_shift_norm = True 66 | 67 | Consistency_network = SwinVITModel( 68 | image_size=img_size, 69 | in_channels=2, 70 | model_channels=num_channels, 71 | out_channels=1, 72 | dims=2, 73 | sample_kernel = sample_kernel, 74 | num_res_blocks=num_res_blocks, 75 | attention_resolutions=tuple(attention_ds), 76 | dropout=0, 77 | channel_mult=channel_mult, 78 | num_classes=None, 79 | use_checkpoint=False, 80 | use_fp16=False, 81 | num_heads=num_heads, 82 | window_size = window_size, 83 | num_head_channels=64, 84 | num_heads_upsample=-1, 85 | use_scale_shift_norm=use_scale_shift_norm, 86 | resblock_updown=False, 87 | use_new_attention_order=False, 88 | ).to(device) 89 | 90 | # Don't forget the ema model. You must have this to run the code no matter you use ema or not. 91 | Consistency_network_ema = copy.deepcopy(Consistency_network) 92 | ``` 93 | 94 | **Train the consistency model (you don't have to use the ema as in our .ipynb** 95 | ``` 96 | # Create fake examples, just for you to run the code 97 | img_size = (96,192) # Adjust this for the size of your image input 98 | condition = torch.randn([1,1,96,192]) #batch, channel, height, width 99 | target = torch.randn([1,1,96,192]) #batch, channel, height, width 100 | 101 | all_loss = consistency.consistency_losses(Consistency_network, 102 | target, 103 | condition, 104 | num_scales, 105 | target_model=Consistency_network_ema) 106 | loss = (all_loss["loss"] * weights).mean() 107 | ``` 108 | 109 | **generate new synthetic images** 110 | ``` 111 | # Create fake examples 112 | Low_dose = torch.randn([1,1,96,192]) #batch, channel, height, width 113 | img_size = (96,192) # Adjust this for the size of your image input 114 | 115 | # Set up the step# for your inference 116 | consistency_num = 3 117 | steps = np.round(np.linspace(1.0, 150.0, num=consistency_num)) 118 | def diffusion_sampling(Low_dose,A_to_B_model): 119 | sampled_images = karras_sample( 120 | consistency, 121 | A_to_B_model, 122 | shape=Low_dose.shape, 123 | condition=Low_dose, 124 | sampler="multistep", 125 | steps = 151, 126 | ts = steps, 127 | device = device) 128 | return sampled_images 129 | 130 | # Patch-based inference parameter 131 | overlap = 0.75 132 | mode ='constant' 133 | back_ground_intensity = -1 134 | Inference_patch_number_each_time = 40 135 | from monai.inferers import SlidingWindowInferer 136 | inferer = SlidingWindowInferer(img_size, Inference_patch_number_each_time, overlap=overlap, 137 | mode =mode ,cval = back_ground_intensity, sw_device=device,device = device) 138 | 139 | # 140 | High_dose_samples = inferer(Low_dose,diffusion_sampling,Consistency_network) 141 | ``` 142 | 143 | 144 | # Visual examples 145 | ![Picture1](https://github.com/shaoyanpan/Full-dose-Whole-body-PET-Synthesis-from-Low-dose-PET-Using-Consistency-Model/assets/89927506/15e56941-d7c6-4eab-994a-04e2d1d4d1df) 146 | 147 | -------------------------------------------------------------------------------- /cm/random_util.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.distributed as dist 3 | from . import dist_util 4 | 5 | 6 | def get_generator(generator, num_samples=0, seed=0): 7 | if generator == "dummy": 8 | return DummyGenerator() 9 | elif generator == "determ": 10 | return DeterministicGenerator(num_samples, seed) 11 | elif generator == "determ-indiv": 12 | return DeterministicIndividualGenerator(num_samples, seed) 13 | else: 14 | raise NotImplementedError 15 | 16 | 17 | class DummyGenerator: 18 | def randn(self, *args, **kwargs): 19 | return th.randn(*args, **kwargs) 20 | 21 | def randint(self, *args, **kwargs): 22 | return th.randint(*args, **kwargs) 23 | 24 | def randn_like(self, *args, **kwargs): 25 | return th.randn_like(*args, **kwargs) 26 | 27 | 28 | class DeterministicGenerator: 29 | """ 30 | RNG to deterministically sample num_samples samples that does not depend on batch_size or mpi_machines 31 | Uses a single rng and samples num_samples sized randomness and subsamples the current indices 32 | """ 33 | 34 | def __init__(self, num_samples, seed=0): 35 | if dist.is_initialized(): 36 | self.rank = dist.get_rank() 37 | self.world_size = dist.get_world_size() 38 | else: 39 | print("Warning: Distributed not initialised, using single rank") 40 | self.rank = 0 41 | self.world_size = 1 42 | self.num_samples = num_samples 43 | self.done_samples = 0 44 | self.seed = seed 45 | self.rng_cpu = th.Generator() 46 | if th.cuda.is_available(): 47 | self.rng_cuda = th.Generator(dist_util.dev()) 48 | self.set_seed(seed) 49 | 50 | def get_global_size_and_indices(self, size): 51 | global_size = (self.num_samples, *size[1:]) 52 | indices = th.arange( 53 | self.done_samples + self.rank, 54 | self.done_samples + self.world_size * int(size[0]), 55 | self.world_size, 56 | ) 57 | indices = th.clamp(indices, 0, self.num_samples - 1) 58 | assert ( 59 | len(indices) == size[0] 60 | ), f"rank={self.rank}, ws={self.world_size}, l={len(indices)}, bs={size[0]}" 61 | return global_size, indices 62 | 63 | def get_generator(self, device): 64 | return self.rng_cpu if th.device(device).type == "cpu" else self.rng_cuda 65 | 66 | def randn(self, *size, dtype=th.float, device="cpu"): 67 | global_size, indices = self.get_global_size_and_indices(size) 68 | generator = self.get_generator(device) 69 | return th.randn(*global_size, generator=generator, dtype=dtype, device=device)[ 70 | indices 71 | ] 72 | 73 | def randint(self, low, high, size, dtype=th.long, device="cpu"): 74 | global_size, indices = self.get_global_size_and_indices(size) 75 | generator = self.get_generator(device) 76 | return th.randint( 77 | low, high, generator=generator, size=global_size, dtype=dtype, device=device 78 | )[indices] 79 | 80 | def randn_like(self, tensor): 81 | size, dtype, device = tensor.size(), tensor.dtype, tensor.device 82 | return self.randn(*size, dtype=dtype, device=device) 83 | 84 | def set_done_samples(self, done_samples): 85 | self.done_samples = done_samples 86 | self.set_seed(self.seed) 87 | 88 | def get_seed(self): 89 | return self.seed 90 | 91 | def set_seed(self, seed): 92 | self.rng_cpu.manual_seed(seed) 93 | if th.cuda.is_available(): 94 | self.rng_cuda.manual_seed(seed) 95 | 96 | 97 | class DeterministicIndividualGenerator: 98 | """ 99 | RNG to deterministically sample num_samples samples that does not depend on batch_size or mpi_machines 100 | Uses a separate rng for each sample to reduce memoery usage 101 | """ 102 | 103 | def __init__(self, num_samples, seed=0): 104 | if dist.is_initialized(): 105 | self.rank = dist.get_rank() 106 | self.world_size = dist.get_world_size() 107 | else: 108 | print("Warning: Distributed not initialised, using single rank") 109 | self.rank = 0 110 | self.world_size = 1 111 | self.num_samples = num_samples 112 | self.done_samples = 0 113 | self.seed = seed 114 | self.rng_cpu = [th.Generator() for _ in range(num_samples)] 115 | if th.cuda.is_available(): 116 | self.rng_cuda = [th.Generator(dist_util.dev()) for _ in range(num_samples)] 117 | self.set_seed(seed) 118 | 119 | def get_size_and_indices(self, size): 120 | indices = th.arange( 121 | self.done_samples + self.rank, 122 | self.done_samples + self.world_size * int(size[0]), 123 | self.world_size, 124 | ) 125 | indices = th.clamp(indices, 0, self.num_samples - 1) 126 | assert ( 127 | len(indices) == size[0] 128 | ), f"rank={self.rank}, ws={self.world_size}, l={len(indices)}, bs={size[0]}" 129 | return (1, *size[1:]), indices 130 | 131 | def get_generator(self, device): 132 | return self.rng_cpu if th.device(device).type == "cpu" else self.rng_cuda 133 | 134 | def randn(self, *size, dtype=th.float, device="cpu"): 135 | size, indices = self.get_size_and_indices(size) 136 | generator = self.get_generator(device) 137 | return th.cat( 138 | [ 139 | th.randn(*size, generator=generator[i], dtype=dtype, device=device) 140 | for i in indices 141 | ], 142 | dim=0, 143 | ) 144 | 145 | def randint(self, low, high, size, dtype=th.long, device="cpu"): 146 | size, indices = self.get_size_and_indices(size) 147 | generator = self.get_generator(device) 148 | return th.cat( 149 | [ 150 | th.randint( 151 | low, 152 | high, 153 | generator=generator[i], 154 | size=size, 155 | dtype=dtype, 156 | device=device, 157 | ) 158 | for i in indices 159 | ], 160 | dim=0, 161 | ) 162 | 163 | def randn_like(self, tensor): 164 | size, dtype, device = tensor.size(), tensor.dtype, tensor.device 165 | return self.randn(*size, dtype=dtype, device=device) 166 | 167 | def set_done_samples(self, done_samples): 168 | self.done_samples = done_samples 169 | 170 | def get_seed(self): 171 | return self.seed 172 | 173 | def set_seed(self, seed): 174 | [ 175 | rng_cpu.manual_seed(i + self.num_samples * seed) 176 | for i, rng_cpu in enumerate(self.rng_cpu) 177 | ] 178 | if th.cuda.is_available(): 179 | [ 180 | rng_cuda.manual_seed(i + self.num_samples * seed) 181 | for i, rng_cuda in enumerate(self.rng_cuda) 182 | ] 183 | -------------------------------------------------------------------------------- /cm/resample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | from scipy.stats import norm 6 | import torch.distributed as dist 7 | 8 | 9 | def create_named_schedule_sampler(name, diffusion): 10 | """ 11 | Create a ScheduleSampler from a library of pre-defined samplers. 12 | 13 | :param name: the name of the sampler. 14 | :param diffusion: the diffusion object to sample for. 15 | """ 16 | if name == "uniform": 17 | return UniformSampler(diffusion) 18 | elif name == "loss-second-moment": 19 | return LossSecondMomentResampler(diffusion) 20 | elif name == "lognormal": 21 | return LogNormalSampler() 22 | else: 23 | raise NotImplementedError(f"unknown schedule sampler: {name}") 24 | 25 | 26 | class ScheduleSampler(ABC): 27 | """ 28 | A distribution over timesteps in the diffusion process, intended to reduce 29 | variance of the objective. 30 | 31 | By default, samplers perform unbiased importance sampling, in which the 32 | objective's mean is unchanged. 33 | However, subclasses may override sample() to change how the resampled 34 | terms are reweighted, allowing for actual changes in the objective. 35 | """ 36 | 37 | @abstractmethod 38 | def weights(self): 39 | """ 40 | Get a numpy array of weights, one per diffusion step. 41 | 42 | The weights needn't be normalized, but must be positive. 43 | """ 44 | 45 | def sample(self, batch_size, device): 46 | """ 47 | Importance-sample timesteps for a batch. 48 | 49 | :param batch_size: the number of timesteps. 50 | :param device: the torch device to save to. 51 | :return: a tuple (timesteps, weights): 52 | - timesteps: a tensor of timestep indices. 53 | - weights: a tensor of weights to scale the resulting losses. 54 | """ 55 | w = self.weights() 56 | p = w / np.sum(w) 57 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 58 | indices = th.from_numpy(indices_np).long().to(device) 59 | weights_np = 1 / (len(p) * p[indices_np]) 60 | weights = th.from_numpy(weights_np).float().to(device) 61 | return indices, weights 62 | 63 | 64 | class UniformSampler(ScheduleSampler): 65 | def __init__(self, diffusion): 66 | self.diffusion = diffusion 67 | self._weights = np.ones([diffusion.num_timesteps]) 68 | 69 | def weights(self): 70 | return self._weights 71 | 72 | 73 | class LossAwareSampler(ScheduleSampler): 74 | def update_with_local_losses(self, local_ts, local_losses): 75 | """ 76 | Update the reweighting using losses from a model. 77 | 78 | Call this method from each rank with a batch of timesteps and the 79 | corresponding losses for each of those timesteps. 80 | This method will perform synchronization to make sure all of the ranks 81 | maintain the exact same reweighting. 82 | 83 | :param local_ts: an integer Tensor of timesteps. 84 | :param local_losses: a 1D Tensor of losses. 85 | """ 86 | batch_sizes = [ 87 | th.tensor([0], dtype=th.int32, device=local_ts.device) 88 | for _ in range(dist.get_world_size()) 89 | ] 90 | dist.all_gather( 91 | batch_sizes, 92 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 93 | ) 94 | 95 | # Pad all_gather batches to be the maximum batch size. 96 | batch_sizes = [x.item() for x in batch_sizes] 97 | max_bs = max(batch_sizes) 98 | 99 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 100 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 101 | dist.all_gather(timestep_batches, local_ts) 102 | dist.all_gather(loss_batches, local_losses) 103 | timesteps = [ 104 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 105 | ] 106 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 107 | self.update_with_all_losses(timesteps, losses) 108 | 109 | @abstractmethod 110 | def update_with_all_losses(self, ts, losses): 111 | """ 112 | Update the reweighting using losses from a model. 113 | 114 | Sub-classes should override this method to update the reweighting 115 | using losses from the model. 116 | 117 | This method directly updates the reweighting without synchronizing 118 | between workers. It is called by update_with_local_losses from all 119 | ranks with identical arguments. Thus, it should have deterministic 120 | behavior to maintain state across workers. 121 | 122 | :param ts: a list of int timesteps. 123 | :param losses: a list of float losses, one per timestep. 124 | """ 125 | 126 | 127 | class LossSecondMomentResampler(LossAwareSampler): 128 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 129 | self.diffusion = diffusion 130 | self.history_per_term = history_per_term 131 | self.uniform_prob = uniform_prob 132 | self._loss_history = np.zeros( 133 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 134 | ) 135 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 136 | 137 | def weights(self): 138 | if not self._warmed_up(): 139 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 140 | weights = np.sqrt(np.mean(self._loss_history**2, axis=-1)) 141 | weights /= np.sum(weights) 142 | weights *= 1 - self.uniform_prob 143 | weights += self.uniform_prob / len(weights) 144 | return weights 145 | 146 | def update_with_all_losses(self, ts, losses): 147 | for t, loss in zip(ts, losses): 148 | if self._loss_counts[t] == self.history_per_term: 149 | # Shift out the oldest loss term. 150 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 151 | self._loss_history[t, -1] = loss 152 | else: 153 | self._loss_history[t, self._loss_counts[t]] = loss 154 | self._loss_counts[t] += 1 155 | 156 | def _warmed_up(self): 157 | return (self._loss_counts == self.history_per_term).all() 158 | 159 | 160 | class LogNormalSampler: 161 | def __init__(self, p_mean=-1.2, p_std=1.2, even=False): 162 | self.p_mean = p_mean 163 | self.p_std = p_std 164 | self.even = even 165 | if self.even: 166 | self.inv_cdf = lambda x: norm.ppf(x, loc=p_mean, scale=p_std) 167 | self.rank, self.size = dist.get_rank(), dist.get_world_size() 168 | 169 | def sample(self, bs, device): 170 | if self.even: 171 | # buckets = [1/G] 172 | start_i, end_i = self.rank * bs, (self.rank + 1) * bs 173 | global_batch_size = self.size * bs 174 | locs = (th.arange(start_i, end_i) + th.rand(bs)) / global_batch_size 175 | log_sigmas = th.tensor(self.inv_cdf(locs), dtype=th.float32, device=device) 176 | else: 177 | log_sigmas = self.p_mean + self.p_std * th.randn(bs, device=device) 178 | sigmas = th.exp(log_sigmas) 179 | weights = th.ones_like(sigmas) 180 | return sigmas, weights 181 | -------------------------------------------------------------------------------- /cm/script_util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from .karras_diffusion import KarrasDenoiser 4 | from .unet import UNetModel 5 | import numpy as np 6 | 7 | NUM_CLASSES = 1000 8 | 9 | 10 | def cm_train_defaults(): 11 | return dict( 12 | teacher_model_path="", 13 | teacher_dropout=0.1, 14 | training_mode="consistency_distillation", 15 | target_ema_mode="fixed", 16 | scale_mode="fixed", 17 | total_training_steps=600000, 18 | start_ema=0.0, 19 | start_scales=40, 20 | end_scales=40, 21 | distill_steps_per_iter=50000, 22 | loss_norm="lpips", 23 | ) 24 | 25 | 26 | def model_and_diffusion_defaults(): 27 | """ 28 | Defaults for image training. 29 | """ 30 | res = dict( 31 | sigma_min=0.002, 32 | sigma_max=80.0, 33 | image_size=64, 34 | num_channels=128, 35 | num_res_blocks=2, 36 | num_heads=4, 37 | num_heads_upsample=-1, 38 | num_head_channels=-1, 39 | attention_resolutions="32,16,8", 40 | channel_mult="", 41 | dropout=0.0, 42 | class_cond=False, 43 | use_checkpoint=False, 44 | use_scale_shift_norm=True, 45 | resblock_updown=False, 46 | use_fp16=False, 47 | use_new_attention_order=False, 48 | learn_sigma=False, 49 | weight_schedule="karras", 50 | ) 51 | return res 52 | 53 | 54 | def create_model_and_diffusion( 55 | image_size, 56 | class_cond, 57 | learn_sigma, 58 | num_channels, 59 | num_res_blocks, 60 | channel_mult, 61 | num_heads, 62 | num_head_channels, 63 | num_heads_upsample, 64 | attention_resolutions, 65 | dropout, 66 | use_checkpoint, 67 | use_scale_shift_norm, 68 | resblock_updown, 69 | use_fp16, 70 | use_new_attention_order, 71 | weight_schedule, 72 | sigma_min=0.002, 73 | sigma_max=80.0, 74 | distillation=False, 75 | ): 76 | model = create_model( 77 | image_size, 78 | num_channels, 79 | num_res_blocks, 80 | channel_mult=channel_mult, 81 | learn_sigma=learn_sigma, 82 | class_cond=class_cond, 83 | use_checkpoint=use_checkpoint, 84 | attention_resolutions=attention_resolutions, 85 | num_heads=num_heads, 86 | num_head_channels=num_head_channels, 87 | num_heads_upsample=num_heads_upsample, 88 | use_scale_shift_norm=use_scale_shift_norm, 89 | dropout=dropout, 90 | resblock_updown=resblock_updown, 91 | use_fp16=use_fp16, 92 | use_new_attention_order=use_new_attention_order, 93 | ) 94 | diffusion = KarrasDenoiser( 95 | sigma_data=0.5, 96 | sigma_max=sigma_max, 97 | sigma_min=sigma_min, 98 | distillation=distillation, 99 | weight_schedule=weight_schedule, 100 | ) 101 | return model, diffusion 102 | 103 | 104 | def create_model( 105 | image_size, 106 | num_channels, 107 | num_res_blocks, 108 | channel_mult="", 109 | learn_sigma=False, 110 | class_cond=False, 111 | use_checkpoint=False, 112 | attention_resolutions="16", 113 | num_heads=1, 114 | num_head_channels=-1, 115 | num_heads_upsample=-1, 116 | use_scale_shift_norm=False, 117 | dropout=0, 118 | resblock_updown=False, 119 | use_fp16=False, 120 | use_new_attention_order=False, 121 | ): 122 | if channel_mult == "": 123 | if image_size == 512: 124 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4) 125 | elif image_size == 256: 126 | channel_mult = (1, 1, 2, 2, 4, 4) 127 | elif image_size == 128: 128 | channel_mult = (1, 1, 2, 3, 4) 129 | elif image_size == 64: 130 | channel_mult = (1, 2, 3, 4) 131 | else: 132 | raise ValueError(f"unsupported image size: {image_size}") 133 | else: 134 | channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(",")) 135 | 136 | attention_ds = [] 137 | for res in attention_resolutions.split(","): 138 | attention_ds.append(image_size // int(res)) 139 | 140 | return UNetModel( 141 | image_size=image_size, 142 | in_channels=3, 143 | model_channels=num_channels, 144 | out_channels=(3 if not learn_sigma else 6), 145 | num_res_blocks=num_res_blocks, 146 | attention_resolutions=tuple(attention_ds), 147 | dropout=dropout, 148 | channel_mult=channel_mult, 149 | num_classes=(NUM_CLASSES if class_cond else None), 150 | use_checkpoint=use_checkpoint, 151 | use_fp16=use_fp16, 152 | num_heads=num_heads, 153 | num_head_channels=num_head_channels, 154 | num_heads_upsample=num_heads_upsample, 155 | use_scale_shift_norm=use_scale_shift_norm, 156 | resblock_updown=resblock_updown, 157 | use_new_attention_order=use_new_attention_order, 158 | ) 159 | 160 | 161 | def create_ema_and_scales_fn( 162 | target_ema_mode, 163 | start_ema, 164 | scale_mode, 165 | start_scales, 166 | end_scales, 167 | total_steps, 168 | distill_steps_per_iter, 169 | ): 170 | def ema_and_scales_fn(step): 171 | if target_ema_mode == "fixed" and scale_mode == "fixed": 172 | target_ema = start_ema 173 | scales = start_scales 174 | elif target_ema_mode == "fixed" and scale_mode == "progressive": 175 | target_ema = start_ema 176 | scales = np.ceil( 177 | np.sqrt( 178 | (step / total_steps) * ((end_scales + 1) ** 2 - start_scales**2) 179 | + start_scales**2 180 | ) 181 | - 1 182 | ).astype(np.int32) 183 | scales = np.maximum(scales, 1) 184 | scales = scales + 1 185 | 186 | elif target_ema_mode == "adaptive" and scale_mode == "progressive": 187 | scales = np.ceil( 188 | np.sqrt( 189 | (step / total_steps) * ((end_scales + 1) ** 2 - start_scales**2) 190 | + start_scales**2 191 | ) 192 | - 1 193 | ).astype(np.int32) 194 | scales = np.maximum(scales, 1) 195 | c = -np.log(start_ema) * start_scales 196 | target_ema = np.exp(-c / scales) 197 | scales = scales + 1 198 | elif target_ema_mode == "fixed" and scale_mode == "progdist": 199 | distill_stage = step // distill_steps_per_iter 200 | scales = start_scales // (2**distill_stage) 201 | scales = np.maximum(scales, 2) 202 | 203 | sub_stage = np.maximum( 204 | step - distill_steps_per_iter * (np.log2(start_scales) - 1), 205 | 0, 206 | ) 207 | sub_stage = sub_stage // (distill_steps_per_iter * 2) 208 | sub_scales = 2 // (2**sub_stage) 209 | sub_scales = np.maximum(sub_scales, 1) 210 | 211 | scales = np.where(scales == 2, sub_scales, scales) 212 | 213 | target_ema = 1.0 214 | else: 215 | raise NotImplementedError 216 | 217 | return float(target_ema), int(scales) 218 | 219 | return ema_and_scales_fn 220 | 221 | 222 | def add_dict_to_argparser(parser, default_dict): 223 | for k, v in default_dict.items(): 224 | v_type = type(v) 225 | if v is None: 226 | v_type = str 227 | elif isinstance(v, bool): 228 | v_type = str2bool 229 | parser.add_argument(f"--{k}", default=v, type=v_type) 230 | 231 | 232 | def args_to_dict(args, keys): 233 | return {k: getattr(args, k) for k in keys} 234 | 235 | 236 | def str2bool(v): 237 | """ 238 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 239 | """ 240 | if isinstance(v, bool): 241 | return v 242 | if v.lower() in ("yes", "true", "t", "y", "1"): 243 | return True 244 | elif v.lower() in ("no", "false", "f", "n", "0"): 245 | return False 246 | else: 247 | raise argparse.ArgumentTypeError("boolean value expected") 248 | -------------------------------------------------------------------------------- /cm/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 9 | 10 | from . import logger 11 | 12 | INITIAL_LOG_LOSS_SCALE = 20.0 13 | 14 | 15 | def convert_module_to_f16(l): 16 | """ 17 | Convert primitive modules to float16. 18 | """ 19 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 20 | l.weight.data = l.weight.data.half() 21 | if l.bias is not None: 22 | l.bias.data = l.bias.data.half() 23 | 24 | 25 | def convert_module_to_f32(l): 26 | """ 27 | Convert primitive modules to float32, undoing convert_module_to_f16(). 28 | """ 29 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 30 | l.weight.data = l.weight.data.float() 31 | if l.bias is not None: 32 | l.bias.data = l.bias.data.float() 33 | 34 | 35 | def make_master_params(param_groups_and_shapes): 36 | """ 37 | Copy model parameters into a (differently-shaped) list of full-precision 38 | parameters. 39 | """ 40 | master_params = [] 41 | for param_group, shape in param_groups_and_shapes: 42 | master_param = nn.Parameter( 43 | _flatten_dense_tensors( 44 | [param.detach().float() for (_, param) in param_group] 45 | ).view(shape) 46 | ) 47 | master_param.requires_grad = True 48 | master_params.append(master_param) 49 | return master_params 50 | 51 | 52 | def model_grads_to_master_grads(param_groups_and_shapes, master_params): 53 | """ 54 | Copy the gradients from the model parameters into the master parameters 55 | from make_master_params(). 56 | """ 57 | for master_param, (param_group, shape) in zip( 58 | master_params, param_groups_and_shapes 59 | ): 60 | master_param.grad = _flatten_dense_tensors( 61 | [param_grad_or_zeros(param) for (_, param) in param_group] 62 | ).view(shape) 63 | 64 | 65 | def master_params_to_model_params(param_groups_and_shapes, master_params): 66 | """ 67 | Copy the master parameter data back into the model parameters. 68 | """ 69 | # Without copying to a list, if a generator is passed, this will 70 | # silently not copy any parameters. 71 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): 72 | for (_, param), unflat_master_param in zip( 73 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 74 | ): 75 | param.detach().copy_(unflat_master_param) 76 | 77 | 78 | def unflatten_master_params(param_group, master_param): 79 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group]) 80 | 81 | 82 | def get_param_groups_and_shapes(named_model_params): 83 | named_model_params = list(named_model_params) 84 | scalar_vector_named_params = ( 85 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1], 86 | (-1), 87 | ) 88 | matrix_named_params = ( 89 | [(n, p) for (n, p) in named_model_params if p.ndim > 1], 90 | (1, -1), 91 | ) 92 | return [scalar_vector_named_params, matrix_named_params] 93 | 94 | 95 | def master_params_to_state_dict( 96 | model, param_groups_and_shapes, master_params, use_fp16 97 | ): 98 | if use_fp16: 99 | state_dict = model.state_dict() 100 | for master_param, (param_group, _) in zip( 101 | master_params, param_groups_and_shapes 102 | ): 103 | for (name, _), unflat_master_param in zip( 104 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 105 | ): 106 | assert name in state_dict 107 | state_dict[name] = unflat_master_param 108 | else: 109 | state_dict = model.state_dict() 110 | for i, (name, _value) in enumerate(model.named_parameters()): 111 | assert name in state_dict 112 | state_dict[name] = master_params[i] 113 | return state_dict 114 | 115 | 116 | def state_dict_to_master_params(model, state_dict, use_fp16): 117 | if use_fp16: 118 | named_model_params = [ 119 | (name, state_dict[name]) for name, _ in model.named_parameters() 120 | ] 121 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params) 122 | master_params = make_master_params(param_groups_and_shapes) 123 | else: 124 | master_params = [state_dict[name] for name, _ in model.named_parameters()] 125 | return master_params 126 | 127 | 128 | def zero_master_grads(master_params): 129 | for param in master_params: 130 | param.grad = None 131 | 132 | 133 | def zero_grad(model_params): 134 | for param in model_params: 135 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 136 | if param.grad is not None: 137 | param.grad.detach_() 138 | param.grad.zero_() 139 | 140 | 141 | def param_grad_or_zeros(param): 142 | if param.grad is not None: 143 | return param.grad.data.detach() 144 | else: 145 | return th.zeros_like(param) 146 | 147 | 148 | class MixedPrecisionTrainer: 149 | def __init__( 150 | self, 151 | *, 152 | model, 153 | use_fp16=False, 154 | fp16_scale_growth=1e-3, 155 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, 156 | ): 157 | self.model = model 158 | self.use_fp16 = use_fp16 159 | self.fp16_scale_growth = fp16_scale_growth 160 | 161 | self.model_params = list(self.model.parameters()) 162 | self.master_params = self.model_params 163 | self.param_groups_and_shapes = None 164 | self.lg_loss_scale = initial_lg_loss_scale 165 | 166 | if self.use_fp16: 167 | self.param_groups_and_shapes = get_param_groups_and_shapes( 168 | self.model.named_parameters() 169 | ) 170 | self.master_params = make_master_params(self.param_groups_and_shapes) 171 | # self.model.convert_to_fp16() 172 | 173 | def zero_grad(self): 174 | zero_grad(self.model_params) 175 | 176 | def backward(self, loss: th.Tensor): 177 | if self.use_fp16: 178 | loss_scale = 2**self.lg_loss_scale 179 | (loss * loss_scale).backward() 180 | else: 181 | loss.backward() 182 | 183 | def optimize(self, opt: th.optim.Optimizer): 184 | if self.use_fp16: 185 | return self._optimize_fp16(opt) 186 | else: 187 | return self._optimize_normal(opt) 188 | 189 | def _optimize_fp16(self, opt: th.optim.Optimizer): 190 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale) 191 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) 192 | grad_norm, param_norm = self._compute_norms(grad_scale=2**self.lg_loss_scale) 193 | if check_overflow(grad_norm): 194 | self.lg_loss_scale -= 1 195 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") 196 | zero_master_grads(self.master_params) 197 | return False 198 | 199 | logger.logkv_mean("grad_norm", grad_norm) 200 | logger.logkv_mean("param_norm", param_norm) 201 | 202 | for p in self.master_params: 203 | p.grad.mul_(1.0 / (2**self.lg_loss_scale)) 204 | opt.step() 205 | zero_master_grads(self.master_params) 206 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params) 207 | self.lg_loss_scale += self.fp16_scale_growth 208 | return True 209 | 210 | def _optimize_normal(self, opt: th.optim.Optimizer): 211 | grad_norm, param_norm = self._compute_norms() 212 | logger.logkv_mean("grad_norm", grad_norm) 213 | logger.logkv_mean("param_norm", param_norm) 214 | opt.step() 215 | return True 216 | 217 | def _compute_norms(self, grad_scale=1.0): 218 | grad_norm = 0.0 219 | param_norm = 0.0 220 | for p in self.master_params: 221 | with th.no_grad(): 222 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2 223 | if p.grad is not None: 224 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2 225 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) 226 | 227 | def master_params_to_state_dict(self, master_params): 228 | return master_params_to_state_dict( 229 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16 230 | ) 231 | 232 | def state_dict_to_master_params(self, state_dict): 233 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16) 234 | 235 | 236 | def check_overflow(value): 237 | return (value == float("inf")) or (value == -float("inf")) or (value != value) 238 | -------------------------------------------------------------------------------- /cm/image_datasets.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | from PIL import Image 5 | import blobfile as bf 6 | # from mpi4py import MPI 7 | import numpy as np 8 | from torch.utils.data import DataLoader, Dataset 9 | import scipy.io 10 | 11 | from monai.transforms import ( 12 | AsDiscrete, 13 | AddChanneld, 14 | Compose, 15 | CropForegroundd, 16 | LoadImaged, 17 | Orientationd, 18 | RandFlipd, 19 | RandCropByPosNegLabeld, 20 | RandShiftIntensityd, 21 | ScaleIntensityRanged, 22 | Spacingd, 23 | RandRotate90d, 24 | ToTensord, 25 | RandAffined, 26 | RandCropByLabelClassesd, 27 | SpatialPadd, 28 | RandAdjustContrastd, 29 | RandShiftIntensityd, 30 | ScaleIntensityd, 31 | NormalizeIntensityd, 32 | RandScaleIntensityd, 33 | RandGaussianNoised, 34 | RandGaussianSmoothd, 35 | ScaleIntensityRangePercentilesd, 36 | Resized, 37 | Transposed, 38 | RandSpatialCropd, 39 | RandSpatialCropSamplesd 40 | ) 41 | from monai.transforms import (CastToTyped, 42 | Compose, CropForegroundd, EnsureChannelFirstd, LoadImaged, 43 | NormalizeIntensity, RandCropByPosNegLabeld, 44 | RandFlipd, RandGaussianNoised, 45 | RandGaussianSmoothd, RandScaleIntensityd, 46 | RandZoomd, SpatialCrop, SpatialPadd, EnsureTyped) 47 | from natsort import natsorted 48 | import glob 49 | import torch 50 | 51 | def load_data( 52 | *, 53 | data_dir, 54 | batch_size, 55 | image_size, 56 | class_cond=False, 57 | deterministic=False, 58 | random_crop=False, 59 | random_flip=True, 60 | ): 61 | """ 62 | For a dataset, create a generator over (images, kwargs) pairs. 63 | 64 | Each images is an NCHW float tensor, and the kwargs dict contains zero or 65 | more keys, each of which map to a batched Tensor of their own. 66 | The kwargs dict can be used for class labels, in which case the key is "y" 67 | and the values are integer tensors of class labels. 68 | 69 | :param data_dir: a dataset directory. 70 | :param batch_size: the batch size of each returned pair. 71 | :param image_size: the size to which images are resized. 72 | :param class_cond: if True, include a "y" key in returned dicts for class 73 | label. If classes are not available and this is true, an 74 | exception will be raised. 75 | :param deterministic: if True, yield results in a deterministic order. 76 | :param random_crop: if True, randomly crop the images for augmentation. 77 | :param random_flip: if True, randomly flip the images for augmentation. 78 | """ 79 | if not data_dir: 80 | raise ValueError("unspecified data directory") 81 | all_files = _list_image_files_recursively(data_dir) 82 | classes = None 83 | if class_cond: 84 | # Assume classes are the first part of the filename, 85 | # before an underscore. 86 | class_names = [bf.basename(path).split("_")[0] for path in all_files] 87 | sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))} 88 | classes = [sorted_classes[x] for x in class_names] 89 | dataset = ImageDataset( 90 | image_size, 91 | all_files, 92 | classes=classes, 93 | shard=MPI.COMM_WORLD.Get_rank(), 94 | num_shards=MPI.COMM_WORLD.Get_size(), 95 | random_crop=random_crop, 96 | random_flip=random_flip, 97 | ) 98 | if deterministic: 99 | loader = DataLoader( 100 | dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True 101 | ) 102 | else: 103 | loader = DataLoader( 104 | dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True 105 | ) 106 | while True: 107 | yield from loader 108 | 109 | 110 | def _list_image_files_recursively(data_dir): 111 | results = [] 112 | for entry in sorted(bf.listdir(data_dir)): 113 | full_path = bf.join(data_dir, entry) 114 | ext = entry.split(".")[-1] 115 | if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]: 116 | results.append(full_path) 117 | elif bf.isdir(full_path): 118 | results.extend(_list_image_files_recursively(full_path)) 119 | return results 120 | 121 | 122 | # class ImageDataset(Dataset): 123 | # def __init__( 124 | # self, 125 | # resolution, 126 | # image_paths, 127 | # classes=None, 128 | # shard=0, 129 | # num_shards=1, 130 | # random_crop=False, 131 | # random_flip=True, 132 | # ): 133 | # super().__init__() 134 | # self.resolution = resolution 135 | # self.local_images = image_paths[shard:][::num_shards] 136 | # self.local_classes = None if classes is None else classes[shard:][::num_shards] 137 | # self.random_crop = random_crop 138 | # self.random_flip = random_flip 139 | 140 | # def __len__(self): 141 | # return len(self.local_images) 142 | 143 | # def __getitem__(self, idx): 144 | # path = self.local_images[idx] 145 | # with bf.BlobFile(path, "rb") as f: 146 | # pil_image = Image.open(f) 147 | # pil_image.load() 148 | # pil_image = pil_image.convert("RGB") 149 | 150 | # if self.random_crop: 151 | # arr = random_crop_arr(pil_image, self.resolution) 152 | # else: 153 | # arr = center_crop_arr(pil_image, self.resolution) 154 | 155 | # if self.random_flip and random.random() < 0.5: 156 | # arr = arr[:, ::-1] 157 | 158 | # arr = arr.astype(np.float32) / 127.5 - 1 159 | 160 | # out_dict = {} 161 | # if self.local_classes is not None: 162 | # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) 163 | # return np.transpose(arr, [2, 0, 1]), out_dict 164 | 165 | class CustomDataset(Dataset): 166 | def __init__(self,imgs_path,labels_path, train_flag = True): 167 | self.imgs_path = imgs_path 168 | self.labels_path = labels_path 169 | self.train_flag = train_flag 170 | file_list = natsorted(glob.glob(self.imgs_path + "*"), key=lambda y: y.lower()) 171 | label_list = natsorted(glob.glob(self.labels_path + "*"), key=lambda y: y.lower()) 172 | # print(file_list) 173 | self.data = [] 174 | self.label = [] 175 | # self.loader = LoadImaged(keys= ['image','label'],reader='nibabelreader') 176 | # self.loader = LoadImaged(keys= ['image','label'],reader='PILReader') 177 | for img_path in file_list: 178 | class_name = img_path.split("/")[-1] 179 | self.data.append([img_path, class_name]) 180 | for label_path in label_list: 181 | class_name = label_path.split("/")[-1] 182 | self.label.append([label_path, class_name]) 183 | self.train_transforms = Compose( 184 | [ 185 | # LoadImaged(keys=["image","label"],reader='nibabelreader'), 186 | AddChanneld(keys=["image","label"]), 187 | ToTensord(keys=["image","label"]), 188 | ] 189 | ) 190 | self.test_transforms = Compose( 191 | [ 192 | AddChanneld(keys=["image","label"]), 193 | ToTensord(keys=["image","label"]), 194 | ] 195 | ) 196 | def __len__(self): 197 | return len(self.data) 198 | 199 | def __getitem__(self, idx): 200 | 201 | img_path, class_name = self.data[idx] 202 | # label_path, class_name = self.label[idx] 203 | # image = scipy.io.loadmat(img_path) 204 | cao = scipy.io.loadmat(img_path) 205 | # cao = {"image":img_path,'label':label_path} 206 | # cao = {"image":image['image'],'label':image['label']} 207 | 208 | 209 | if not self.train_flag: 210 | affined_data_dict = self.test_transforms(cao) 211 | img_tensor = affined_data_dict['image'].to(torch.float) 212 | label_tensor = affined_data_dict['label'].to(torch.float) 213 | img_tensor = torch.unsqueeze(img_tensor, 1) 214 | label_tensor = torch.unsqueeze(label_tensor, 1) 215 | 216 | else: 217 | affined_data_dict = self.train_transforms(cao) 218 | img = affined_data_dict['image'] 219 | label = affined_data_dict['label'] 220 | img_tensor = torch.unsqueeze(img, 1).to(torch.float) 221 | label_tensor = torch.unsqueeze(label, 1).to(torch.float) 222 | 223 | return img_tensor,label_tensor 224 | 225 | def center_crop_arr(pil_image, image_size): 226 | # We are not on a new enough PIL to support the `reducing_gap` 227 | # argument, which uses BOX downsampling at powers of two first. 228 | # Thus, we do it by hand to improve downsample quality. 229 | while min(*pil_image.size) >= 2 * image_size: 230 | pil_image = pil_image.resize( 231 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 232 | ) 233 | 234 | scale = image_size / min(*pil_image.size) 235 | pil_image = pil_image.resize( 236 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 237 | ) 238 | 239 | arr = np.array(pil_image) 240 | crop_y = (arr.shape[0] - image_size) // 2 241 | crop_x = (arr.shape[1] - image_size) // 2 242 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] 243 | 244 | 245 | def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0): 246 | min_smaller_dim_size = math.ceil(image_size / max_crop_frac) 247 | max_smaller_dim_size = math.ceil(image_size / min_crop_frac) 248 | smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1) 249 | 250 | # We are not on a new enough PIL to support the `reducing_gap` 251 | # argument, which uses BOX downsampling at powers of two first. 252 | # Thus, we do it by hand to improve downsample quality. 253 | while min(*pil_image.size) >= 2 * smaller_dim_size: 254 | pil_image = pil_image.resize( 255 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 256 | ) 257 | 258 | scale = smaller_dim_size / min(*pil_image.size) 259 | pil_image = pil_image.resize( 260 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 261 | ) 262 | 263 | arr = np.array(pil_image) 264 | crop_y = random.randrange(arr.shape[0] - image_size + 1) 265 | crop_x = random.randrange(arr.shape[1] - image_size + 1) 266 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] 267 | -------------------------------------------------------------------------------- /cm/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies: 3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py 4 | """ 5 | 6 | import os 7 | import sys 8 | import shutil 9 | import os.path as osp 10 | import json 11 | import time 12 | import datetime 13 | import tempfile 14 | import warnings 15 | from collections import defaultdict 16 | from contextlib import contextmanager 17 | 18 | DEBUG = 10 19 | INFO = 20 20 | WARN = 30 21 | ERROR = 40 22 | 23 | DISABLED = 50 24 | 25 | 26 | class KVWriter(object): 27 | def writekvs(self, kvs): 28 | raise NotImplementedError 29 | 30 | 31 | class SeqWriter(object): 32 | def writeseq(self, seq): 33 | raise NotImplementedError 34 | 35 | 36 | class HumanOutputFormat(KVWriter, SeqWriter): 37 | def __init__(self, filename_or_file): 38 | if isinstance(filename_or_file, str): 39 | self.file = open(filename_or_file, "wt") 40 | self.own_file = True 41 | else: 42 | assert hasattr(filename_or_file, "read"), ( 43 | "expected file or str, got %s" % filename_or_file 44 | ) 45 | self.file = filename_or_file 46 | self.own_file = False 47 | 48 | def writekvs(self, kvs): 49 | # Create strings for printing 50 | key2str = {} 51 | for (key, val) in sorted(kvs.items()): 52 | if hasattr(val, "__float__"): 53 | valstr = "%-8.3g" % val 54 | else: 55 | valstr = str(val) 56 | key2str[self._truncate(key)] = self._truncate(valstr) 57 | 58 | # Find max widths 59 | if len(key2str) == 0: 60 | print("WARNING: tried to write empty key-value dict") 61 | return 62 | else: 63 | keywidth = max(map(len, key2str.keys())) 64 | valwidth = max(map(len, key2str.values())) 65 | 66 | # Write out the data 67 | dashes = "-" * (keywidth + valwidth + 7) 68 | lines = [dashes] 69 | for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()): 70 | lines.append( 71 | "| %s%s | %s%s |" 72 | % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val))) 73 | ) 74 | lines.append(dashes) 75 | self.file.write("\n".join(lines) + "\n") 76 | 77 | # Flush the output to the file 78 | self.file.flush() 79 | 80 | def _truncate(self, s): 81 | maxlen = 30 82 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s 83 | 84 | def writeseq(self, seq): 85 | seq = list(seq) 86 | for (i, elem) in enumerate(seq): 87 | self.file.write(elem) 88 | if i < len(seq) - 1: # add space unless this is the last one 89 | self.file.write(" ") 90 | self.file.write("\n") 91 | self.file.flush() 92 | 93 | def close(self): 94 | if self.own_file: 95 | self.file.close() 96 | 97 | 98 | class JSONOutputFormat(KVWriter): 99 | def __init__(self, filename): 100 | self.file = open(filename, "wt") 101 | 102 | def writekvs(self, kvs): 103 | for k, v in sorted(kvs.items()): 104 | if hasattr(v, "dtype"): 105 | kvs[k] = float(v) 106 | self.file.write(json.dumps(kvs) + "\n") 107 | self.file.flush() 108 | 109 | def close(self): 110 | self.file.close() 111 | 112 | 113 | class CSVOutputFormat(KVWriter): 114 | def __init__(self, filename): 115 | self.file = open(filename, "w+t") 116 | self.keys = [] 117 | self.sep = "," 118 | 119 | def writekvs(self, kvs): 120 | # Add our current row to the history 121 | extra_keys = list(kvs.keys() - self.keys) 122 | extra_keys.sort() 123 | if extra_keys: 124 | self.keys.extend(extra_keys) 125 | self.file.seek(0) 126 | lines = self.file.readlines() 127 | self.file.seek(0) 128 | for (i, k) in enumerate(self.keys): 129 | if i > 0: 130 | self.file.write(",") 131 | self.file.write(k) 132 | self.file.write("\n") 133 | for line in lines[1:]: 134 | self.file.write(line[:-1]) 135 | self.file.write(self.sep * len(extra_keys)) 136 | self.file.write("\n") 137 | for (i, k) in enumerate(self.keys): 138 | if i > 0: 139 | self.file.write(",") 140 | v = kvs.get(k) 141 | if v is not None: 142 | self.file.write(str(v)) 143 | self.file.write("\n") 144 | self.file.flush() 145 | 146 | def close(self): 147 | self.file.close() 148 | 149 | 150 | class TensorBoardOutputFormat(KVWriter): 151 | """ 152 | Dumps key/value pairs into TensorBoard's numeric format. 153 | """ 154 | 155 | def __init__(self, dir): 156 | os.makedirs(dir, exist_ok=True) 157 | self.dir = dir 158 | self.step = 1 159 | prefix = "events" 160 | path = osp.join(osp.abspath(dir), prefix) 161 | import tensorflow as tf 162 | from tensorflow.python import pywrap_tensorflow 163 | from tensorflow.core.util import event_pb2 164 | from tensorflow.python.util import compat 165 | 166 | self.tf = tf 167 | self.event_pb2 = event_pb2 168 | self.pywrap_tensorflow = pywrap_tensorflow 169 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) 170 | 171 | def writekvs(self, kvs): 172 | def summary_val(k, v): 173 | kwargs = {"tag": k, "simple_value": float(v)} 174 | return self.tf.Summary.Value(**kwargs) 175 | 176 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) 177 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary) 178 | event.step = ( 179 | self.step 180 | ) # is there any reason why you'd want to specify the step? 181 | self.writer.WriteEvent(event) 182 | self.writer.Flush() 183 | self.step += 1 184 | 185 | def close(self): 186 | if self.writer: 187 | self.writer.Close() 188 | self.writer = None 189 | 190 | 191 | def make_output_format(format, ev_dir, log_suffix=""): 192 | os.makedirs(ev_dir, exist_ok=True) 193 | if format == "stdout": 194 | return HumanOutputFormat(sys.stdout) 195 | elif format == "log": 196 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix)) 197 | elif format == "json": 198 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix)) 199 | elif format == "csv": 200 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix)) 201 | elif format == "tensorboard": 202 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix)) 203 | else: 204 | raise ValueError("Unknown format specified: %s" % (format,)) 205 | 206 | 207 | # ================================================================ 208 | # API 209 | # ================================================================ 210 | 211 | 212 | def logkv(key, val): 213 | """ 214 | Log a value of some diagnostic 215 | Call this once for each diagnostic quantity, each iteration 216 | If called many times, last value will be used. 217 | """ 218 | get_current().logkv(key, val) 219 | 220 | 221 | def logkv_mean(key, val): 222 | """ 223 | The same as logkv(), but if called many times, values averaged. 224 | """ 225 | get_current().logkv_mean(key, val) 226 | 227 | 228 | def logkvs(d): 229 | """ 230 | Log a dictionary of key-value pairs 231 | """ 232 | for (k, v) in d.items(): 233 | logkv(k, v) 234 | 235 | 236 | def dumpkvs(): 237 | """ 238 | Write all of the diagnostics from the current iteration 239 | """ 240 | return get_current().dumpkvs() 241 | 242 | 243 | def getkvs(): 244 | return get_current().name2val 245 | 246 | 247 | def log(*args, level=INFO): 248 | """ 249 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 250 | """ 251 | get_current().log(*args, level=level) 252 | 253 | 254 | def debug(*args): 255 | log(*args, level=DEBUG) 256 | 257 | 258 | def info(*args): 259 | log(*args, level=INFO) 260 | 261 | 262 | def warn(*args): 263 | log(*args, level=WARN) 264 | 265 | 266 | def error(*args): 267 | log(*args, level=ERROR) 268 | 269 | 270 | def set_level(level): 271 | """ 272 | Set logging threshold on current logger. 273 | """ 274 | get_current().set_level(level) 275 | 276 | 277 | def set_comm(comm): 278 | get_current().set_comm(comm) 279 | 280 | 281 | def get_dir(): 282 | """ 283 | Get directory that log files are being written to. 284 | will be None if there is no output directory (i.e., if you didn't call start) 285 | """ 286 | return get_current().get_dir() 287 | 288 | 289 | record_tabular = logkv 290 | dump_tabular = dumpkvs 291 | 292 | 293 | @contextmanager 294 | def profile_kv(scopename): 295 | logkey = "wait_" + scopename 296 | tstart = time.time() 297 | try: 298 | yield 299 | finally: 300 | get_current().name2val[logkey] += time.time() - tstart 301 | 302 | 303 | def profile(n): 304 | """ 305 | Usage: 306 | @profile("my_func") 307 | def my_func(): code 308 | """ 309 | 310 | def decorator_with_name(func): 311 | def func_wrapper(*args, **kwargs): 312 | with profile_kv(n): 313 | return func(*args, **kwargs) 314 | 315 | return func_wrapper 316 | 317 | return decorator_with_name 318 | 319 | 320 | # ================================================================ 321 | # Backend 322 | # ================================================================ 323 | 324 | 325 | def get_current(): 326 | if Logger.CURRENT is None: 327 | _configure_default_logger() 328 | 329 | return Logger.CURRENT 330 | 331 | 332 | class Logger(object): 333 | DEFAULT = None # A logger with no output files. (See right below class definition) 334 | # So that you can still log to the terminal without setting up any output files 335 | CURRENT = None # Current logger being used by the free functions above 336 | 337 | def __init__(self, dir, output_formats, comm=None): 338 | self.name2val = defaultdict(float) # values this iteration 339 | self.name2cnt = defaultdict(int) 340 | self.level = INFO 341 | self.dir = dir 342 | self.output_formats = output_formats 343 | self.comm = comm 344 | 345 | # Logging API, forwarded 346 | # ---------------------------------------- 347 | def logkv(self, key, val): 348 | self.name2val[key] = val 349 | 350 | def logkv_mean(self, key, val): 351 | oldval, cnt = self.name2val[key], self.name2cnt[key] 352 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) 353 | self.name2cnt[key] = cnt + 1 354 | 355 | def dumpkvs(self): 356 | if self.comm is None: 357 | d = self.name2val 358 | else: 359 | d = mpi_weighted_mean( 360 | self.comm, 361 | { 362 | name: (val, self.name2cnt.get(name, 1)) 363 | for (name, val) in self.name2val.items() 364 | }, 365 | ) 366 | if self.comm.rank != 0: 367 | d["dummy"] = 1 # so we don't get a warning about empty dict 368 | out = d.copy() # Return the dict for unit testing purposes 369 | for fmt in self.output_formats: 370 | if isinstance(fmt, KVWriter): 371 | fmt.writekvs(d) 372 | self.name2val.clear() 373 | self.name2cnt.clear() 374 | return out 375 | 376 | def log(self, *args, level=INFO): 377 | if self.level <= level: 378 | self._do_log(args) 379 | 380 | # Configuration 381 | # ---------------------------------------- 382 | def set_level(self, level): 383 | self.level = level 384 | 385 | def set_comm(self, comm): 386 | self.comm = comm 387 | 388 | def get_dir(self): 389 | return self.dir 390 | 391 | def close(self): 392 | for fmt in self.output_formats: 393 | fmt.close() 394 | 395 | # Misc 396 | # ---------------------------------------- 397 | def _do_log(self, args): 398 | for fmt in self.output_formats: 399 | if isinstance(fmt, SeqWriter): 400 | fmt.writeseq(map(str, args)) 401 | 402 | 403 | def get_rank_without_mpi_import(): 404 | # check environment variables here instead of importing mpi4py 405 | # to avoid calling MPI_Init() when this module is imported 406 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]: 407 | if varname in os.environ: 408 | return int(os.environ[varname]) 409 | return 0 410 | 411 | 412 | def mpi_weighted_mean(comm, local_name2valcount): 413 | """ 414 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110 415 | Perform a weighted average over dicts that are each on a different node 416 | Input: local_name2valcount: dict mapping key -> (value, count) 417 | Returns: key -> mean 418 | """ 419 | all_name2valcount = comm.gather(local_name2valcount) 420 | if comm.rank == 0: 421 | name2sum = defaultdict(float) 422 | name2count = defaultdict(float) 423 | for n2vc in all_name2valcount: 424 | for (name, (val, count)) in n2vc.items(): 425 | try: 426 | val = float(val) 427 | except ValueError: 428 | if comm.rank == 0: 429 | warnings.warn( 430 | "WARNING: tried to compute mean on non-float {}={}".format( 431 | name, val 432 | ) 433 | ) 434 | else: 435 | name2sum[name] += val * count 436 | name2count[name] += count 437 | return {name: name2sum[name] / name2count[name] for name in name2sum} 438 | else: 439 | return {} 440 | 441 | 442 | def configure(dir=None, format_strs=None, comm=None, log_suffix=""): 443 | """ 444 | If comm is provided, average all numerical stats across that comm 445 | """ 446 | if dir is None: 447 | dir = os.getenv("OPENAI_LOGDIR") 448 | if dir is None: 449 | dir = osp.join( 450 | tempfile.gettempdir(), 451 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"), 452 | ) 453 | assert isinstance(dir, str) 454 | dir = os.path.expanduser(dir) 455 | os.makedirs(os.path.expanduser(dir), exist_ok=True) 456 | 457 | rank = get_rank_without_mpi_import() 458 | if rank > 0: 459 | log_suffix = log_suffix + "-rank%03i" % rank 460 | 461 | if format_strs is None: 462 | if rank == 0: 463 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",") 464 | else: 465 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",") 466 | format_strs = filter(None, format_strs) 467 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] 468 | 469 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) 470 | if output_formats: 471 | log("Logging to %s" % dir) 472 | 473 | 474 | def _configure_default_logger(): 475 | configure() 476 | Logger.DEFAULT = Logger.CURRENT 477 | 478 | 479 | def reset(): 480 | if Logger.CURRENT is not Logger.DEFAULT: 481 | Logger.CURRENT.close() 482 | Logger.CURRENT = Logger.DEFAULT 483 | log("Reset logger") 484 | 485 | 486 | @contextmanager 487 | def scoped_configure(dir=None, format_strs=None, comm=None): 488 | prevlogger = Logger.CURRENT 489 | configure(dir=dir, format_strs=format_strs, comm=comm) 490 | try: 491 | yield 492 | finally: 493 | Logger.CURRENT.close() 494 | Logger.CURRENT = prevlogger 495 | 496 | -------------------------------------------------------------------------------- /cm/train_util.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import functools 3 | import os 4 | 5 | import blobfile as bf 6 | import torch as th 7 | import torch.distributed as dist 8 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP 9 | from torch.optim import RAdam 10 | 11 | from . import dist_util, logger 12 | from .fp16_util import MixedPrecisionTrainer 13 | from .nn import update_ema 14 | from .resample import LossAwareSampler, UniformSampler 15 | 16 | from .fp16_util import ( 17 | get_param_groups_and_shapes, 18 | make_master_params, 19 | master_params_to_model_params, 20 | ) 21 | import numpy as np 22 | 23 | # For ImageNet experiments, this was a good default value. 24 | # We found that the lg_loss_scale quickly climbed to 25 | # 20-21 within the first ~1K steps of training. 26 | INITIAL_LOG_LOSS_SCALE = 20.0 27 | 28 | 29 | class TrainLoop: 30 | def __init__( 31 | self, 32 | *, 33 | model, 34 | diffusion, 35 | data, 36 | batch_size, 37 | microbatch, 38 | lr, 39 | ema_rate, 40 | log_interval, 41 | save_interval, 42 | resume_checkpoint, 43 | use_fp16=False, 44 | fp16_scale_growth=1e-3, 45 | schedule_sampler=None, 46 | weight_decay=0.0, 47 | lr_anneal_steps=0, 48 | ): 49 | self.model = model 50 | self.diffusion = diffusion 51 | self.data = data 52 | self.batch_size = batch_size 53 | self.microbatch = microbatch if microbatch > 0 else batch_size 54 | self.lr = lr 55 | self.ema_rate = ( 56 | [ema_rate] 57 | if isinstance(ema_rate, float) 58 | else [float(x) for x in ema_rate.split(",")] 59 | ) 60 | self.log_interval = log_interval 61 | self.save_interval = save_interval 62 | self.resume_checkpoint = resume_checkpoint 63 | self.use_fp16 = use_fp16 64 | self.fp16_scale_growth = fp16_scale_growth 65 | self.schedule_sampler = schedule_sampler or UniformSampler(diffusion) 66 | self.weight_decay = weight_decay 67 | self.lr_anneal_steps = lr_anneal_steps 68 | 69 | self.step = 0 70 | self.resume_step = 0 71 | # self.global_batch = self.batch_size * dist.get_world_size() 72 | 73 | self.sync_cuda = th.cuda.is_available() 74 | 75 | # self._load_and_sync_parameters() 76 | self.mp_trainer = MixedPrecisionTrainer( 77 | model=self.model, 78 | use_fp16=self.use_fp16, 79 | fp16_scale_growth=fp16_scale_growth, 80 | ) 81 | 82 | self.opt = RAdam( 83 | self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay 84 | ) 85 | if self.resume_step: 86 | self._load_optimizer_state() 87 | # Model was resumed, either due to a restart or a checkpoint 88 | # being specified at the command line. 89 | self.ema_params = [ 90 | self._load_ema_parameters(rate) for rate in self.ema_rate 91 | ] 92 | else: 93 | self.ema_params = [ 94 | copy.deepcopy(self.mp_trainer.master_params) 95 | for _ in range(len(self.ema_rate)) 96 | ] 97 | 98 | # if th.cuda.is_available(): 99 | # self.use_ddp = True 100 | # self.ddp_model = DDP( 101 | # self.model, 102 | # device_ids=[dist_util.dev()], 103 | # output_device=dist_util.dev(), 104 | # broadcast_buffers=False, 105 | # bucket_cap_mb=128, 106 | # find_unused_parameters=False, 107 | # ) 108 | # else: 109 | # if dist.get_world_size() > 1: 110 | # logger.warn( 111 | # "Distributed training requires CUDA. " 112 | # "Gradients will not be synchronized properly!" 113 | # ) 114 | # self.use_ddp = False 115 | # self.ddp_model = self.model 116 | 117 | self.step = self.resume_step 118 | 119 | def _load_and_sync_parameters(self): 120 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 121 | 122 | if resume_checkpoint: 123 | self.resume_step = parse_resume_step_from_filename(resume_checkpoint) 124 | if dist.get_rank() == 0: 125 | logger.log(f"loading model from checkpoint: {resume_checkpoint}...") 126 | self.model.load_state_dict( 127 | dist_util.load_state_dict( 128 | resume_checkpoint, map_location=dist_util.dev() 129 | ), 130 | ) 131 | 132 | dist_util.sync_params(self.model.parameters()) 133 | dist_util.sync_params(self.model.buffers()) 134 | 135 | def _load_ema_parameters(self, rate): 136 | ema_params = copy.deepcopy(self.mp_trainer.master_params) 137 | 138 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 139 | ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate) 140 | if ema_checkpoint: 141 | if dist.get_rank() == 0: 142 | logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") 143 | state_dict = dist_util.load_state_dict( 144 | ema_checkpoint, map_location=dist_util.dev() 145 | ) 146 | ema_params = self.mp_trainer.state_dict_to_master_params(state_dict) 147 | 148 | dist_util.sync_params(ema_params) 149 | return ema_params 150 | 151 | def _load_optimizer_state(self): 152 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 153 | opt_checkpoint = bf.join( 154 | bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt" 155 | ) 156 | if bf.exists(opt_checkpoint): 157 | logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") 158 | state_dict = dist_util.load_state_dict( 159 | opt_checkpoint, map_location=dist_util.dev() 160 | ) 161 | self.opt.load_state_dict(state_dict) 162 | 163 | def run_loop(self): 164 | while not self.lr_anneal_steps or self.step < self.lr_anneal_steps: 165 | batch, cond = next(self.data) 166 | self.run_step(batch, cond) 167 | if self.step % self.log_interval == 0: 168 | logger.dumpkvs() 169 | if self.step % self.save_interval == 0: 170 | self.save() 171 | # Run for a finite amount of time in integration tests. 172 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: 173 | return 174 | # Save the last checkpoint if it wasn't already saved. 175 | if (self.step - 1) % self.save_interval != 0: 176 | self.save() 177 | 178 | def run_step(self, batch, cond): 179 | self.forward_backward(batch, cond) 180 | took_step = self.mp_trainer.optimize(self.opt) 181 | if took_step: 182 | self.step += 1 183 | self._update_ema() 184 | self._anneal_lr() 185 | self.log_step() 186 | 187 | def forward_backward(self, batch, cond): 188 | self.mp_trainer.zero_grad() 189 | for i in range(0, batch.shape[0], self.microbatch): 190 | micro = batch[i : i + self.microbatch].to(dist_util.dev()) 191 | micro_cond = { 192 | k: v[i : i + self.microbatch].to(dist_util.dev()) 193 | for k, v in cond.items() 194 | } 195 | last_batch = (i + self.microbatch) >= batch.shape[0] 196 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) 197 | 198 | compute_losses = functools.partial( 199 | self.diffusion.training_losses, 200 | self.ddp_model, 201 | micro, 202 | t, 203 | model_kwargs=micro_cond, 204 | ) 205 | 206 | if last_batch or not self.use_ddp: 207 | losses = compute_losses() 208 | else: 209 | with self.ddp_model.no_sync(): 210 | losses = compute_losses() 211 | 212 | if isinstance(self.schedule_sampler, LossAwareSampler): 213 | self.schedule_sampler.update_with_local_losses( 214 | t, losses["loss"].detach() 215 | ) 216 | 217 | loss = (losses["loss"] * weights).mean() 218 | log_loss_dict( 219 | self.diffusion, t, {k: v * weights for k, v in losses.items()} 220 | ) 221 | self.mp_trainer.backward(loss) 222 | 223 | def _update_ema(self): 224 | for rate, params in zip(self.ema_rate, self.ema_params): 225 | update_ema(params, self.mp_trainer.master_params, rate=rate) 226 | 227 | def _anneal_lr(self): 228 | if not self.lr_anneal_steps: 229 | return 230 | frac_done = (self.step + self.resume_step) / self.lr_anneal_steps 231 | lr = self.lr * (1 - frac_done) 232 | for param_group in self.opt.param_groups: 233 | param_group["lr"] = lr 234 | 235 | def log_step(self): 236 | logger.logkv("step", self.step + self.resume_step) 237 | logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) 238 | 239 | def save(self): 240 | def save_checkpoint(rate, params): 241 | state_dict = self.mp_trainer.master_params_to_state_dict(params) 242 | if dist.get_rank() == 0: 243 | logger.log(f"saving model {rate}...") 244 | if not rate: 245 | filename = f"model{(self.step+self.resume_step):06d}.pt" 246 | else: 247 | filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt" 248 | with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: 249 | th.save(state_dict, f) 250 | 251 | for rate, params in zip(self.ema_rate, self.ema_params): 252 | save_checkpoint(rate, params) 253 | 254 | if dist.get_rank() == 0: 255 | with bf.BlobFile( 256 | bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"), 257 | "wb", 258 | ) as f: 259 | th.save(self.opt.state_dict(), f) 260 | 261 | # Save model parameters last to prevent race conditions where a restart 262 | # loads model at step N, but opt/ema state isn't saved for step N. 263 | save_checkpoint(0, self.mp_trainer.master_params) 264 | dist.barrier() 265 | 266 | 267 | class CMTrainLoop(TrainLoop): 268 | def __init__( 269 | self, 270 | *, 271 | target_model, 272 | teacher_model, 273 | teacher_diffusion, 274 | training_mode, 275 | ema_scale_fn, 276 | total_training_steps, 277 | **kwargs, 278 | ): 279 | super().__init__(**kwargs) 280 | self.training_mode = training_mode 281 | self.ema_scale_fn = ema_scale_fn 282 | self.target_model = target_model 283 | self.teacher_model = teacher_model 284 | self.teacher_diffusion = teacher_diffusion 285 | self.total_training_steps = total_training_steps 286 | 287 | if target_model: 288 | # self._load_and_sync_target_parameters() 289 | self.target_model.requires_grad_(False) 290 | self.target_model.train() 291 | 292 | self.target_model_param_groups_and_shapes = get_param_groups_and_shapes( 293 | self.target_model.named_parameters() 294 | ) 295 | self.target_model_master_params = make_master_params( 296 | self.target_model_param_groups_and_shapes 297 | ) 298 | 299 | if teacher_model: 300 | self._load_and_sync_teacher_parameters() 301 | self.teacher_model.requires_grad_(False) 302 | self.teacher_model.eval() 303 | 304 | self.global_step = self.step 305 | if training_mode == "progdist": 306 | self.target_model.eval() 307 | _, scale = ema_scale_fn(self.global_step) 308 | if scale == 1 or scale == 2: 309 | _, start_scale = ema_scale_fn(0) 310 | n_normal_steps = int(np.log2(start_scale // 2)) * self.lr_anneal_steps 311 | step = self.global_step - n_normal_steps 312 | if step != 0: 313 | self.lr_anneal_steps *= 2 314 | self.step = step % self.lr_anneal_steps 315 | else: 316 | self.step = 0 317 | else: 318 | self.step = self.global_step % self.lr_anneal_steps 319 | 320 | def _load_and_sync_target_parameters(self): 321 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 322 | if resume_checkpoint: 323 | path, name = os.path.split(resume_checkpoint) 324 | target_name = name.replace("model", "target_model") 325 | resume_target_checkpoint = os.path.join(path, target_name) 326 | if bf.exists(resume_target_checkpoint) and dist.get_rank() == 0: 327 | logger.log( 328 | "loading model from checkpoint: {resume_target_checkpoint}..." 329 | ) 330 | self.target_model.load_state_dict( 331 | dist_util.load_state_dict( 332 | resume_target_checkpoint, map_location=dist_util.dev() 333 | ), 334 | ) 335 | 336 | dist_util.sync_params(self.target_model.parameters()) 337 | dist_util.sync_params(self.target_model.buffers()) 338 | 339 | def _load_and_sync_teacher_parameters(self): 340 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 341 | if resume_checkpoint: 342 | path, name = os.path.split(resume_checkpoint) 343 | teacher_name = name.replace("model", "teacher_model") 344 | resume_teacher_checkpoint = os.path.join(path, teacher_name) 345 | 346 | if bf.exists(resume_teacher_checkpoint) and dist.get_rank() == 0: 347 | logger.log( 348 | "loading model from checkpoint: {resume_teacher_checkpoint}..." 349 | ) 350 | self.teacher_model.load_state_dict( 351 | dist_util.load_state_dict( 352 | resume_teacher_checkpoint, map_location=dist_util.dev() 353 | ), 354 | ) 355 | 356 | dist_util.sync_params(self.teacher_model.parameters()) 357 | dist_util.sync_params(self.teacher_model.buffers()) 358 | 359 | def run_loop(self): 360 | saved = False 361 | while ( 362 | not self.lr_anneal_steps 363 | or self.step < self.lr_anneal_steps 364 | or self.global_step < self.total_training_steps 365 | ): 366 | batch, cond = next(self.data) 367 | self.run_step(batch, cond) 368 | saved = False 369 | if ( 370 | self.global_step 371 | and self.save_interval != -1 372 | and self.global_step % self.save_interval == 0 373 | ): 374 | self.save() 375 | saved = True 376 | th.cuda.empty_cache() 377 | # Run for a finite amount of time in integration tests. 378 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: 379 | return 380 | 381 | if self.global_step % self.log_interval == 0: 382 | logger.dumpkvs() 383 | 384 | # Save the last checkpoint if it wasn't already saved. 385 | if not saved: 386 | self.save() 387 | 388 | def run_step(self, batch, cond): 389 | self.forward_backward(batch, cond) 390 | took_step = self.mp_trainer.optimize(self.opt) 391 | if took_step: 392 | self._update_ema() 393 | if self.target_model: 394 | self._update_target_ema() 395 | if self.training_mode == "progdist": 396 | self.reset_training_for_progdist() 397 | self.step += 1 398 | self.global_step += 1 399 | 400 | self._anneal_lr() 401 | self.log_step() 402 | 403 | def _update_target_ema(self): 404 | target_ema, scales = self.ema_scale_fn(self.global_step) 405 | with th.no_grad(): 406 | update_ema( 407 | self.target_model_master_params, 408 | self.mp_trainer.master_params, 409 | rate=target_ema, 410 | ) 411 | master_params_to_model_params( 412 | self.target_model_param_groups_and_shapes, 413 | self.target_model_master_params, 414 | ) 415 | 416 | def reset_training_for_progdist(self): 417 | assert self.training_mode == "progdist", "Training mode must be progdist" 418 | if self.global_step > 0: 419 | scales = self.ema_scale_fn(self.global_step)[1] 420 | scales2 = self.ema_scale_fn(self.global_step - 1)[1] 421 | if scales != scales2: 422 | with th.no_grad(): 423 | update_ema( 424 | self.teacher_model.parameters(), 425 | self.model.parameters(), 426 | 0.0, 427 | ) 428 | # reset optimizer 429 | self.opt = RAdam( 430 | self.mp_trainer.master_params, 431 | lr=self.lr, 432 | weight_decay=self.weight_decay, 433 | ) 434 | 435 | self.ema_params = [ 436 | copy.deepcopy(self.mp_trainer.master_params) 437 | for _ in range(len(self.ema_rate)) 438 | ] 439 | if scales == 2: 440 | self.lr_anneal_steps *= 2 441 | self.teacher_model.eval() 442 | self.step = 0 443 | 444 | def forward_backward(self, batch, cond): 445 | self.mp_trainer.zero_grad() 446 | for i in range(0, batch.shape[0], self.microbatch): 447 | micro = batch[i : i + self.microbatch].to(dist_util.dev()) 448 | micro_cond = { 449 | k: v[i : i + self.microbatch].to(dist_util.dev()) 450 | for k, v in cond.items() 451 | } 452 | last_batch = (i + self.microbatch) >= batch.shape[0] 453 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) 454 | 455 | ema, num_scales = self.ema_scale_fn(self.global_step) 456 | if self.training_mode == "progdist": 457 | if num_scales == self.ema_scale_fn(0)[1]: 458 | compute_losses = functools.partial( 459 | self.diffusion.progdist_losses, 460 | self.ddp_model, 461 | micro, 462 | num_scales, 463 | target_model=self.teacher_model, 464 | target_diffusion=self.teacher_diffusion, 465 | model_kwargs=micro_cond, 466 | ) 467 | else: 468 | compute_losses = functools.partial( 469 | self.diffusion.progdist_losses, 470 | self.ddp_model, 471 | micro, 472 | num_scales, 473 | target_model=self.target_model, 474 | target_diffusion=self.diffusion, 475 | model_kwargs=micro_cond, 476 | ) 477 | elif self.training_mode == "consistency_distillation": 478 | compute_losses = functools.partial( 479 | self.diffusion.consistency_losses, 480 | self.ddp_model, 481 | micro, 482 | num_scales, 483 | target_model=self.target_model, 484 | teacher_model=self.teacher_model, 485 | teacher_diffusion=self.teacher_diffusion, 486 | model_kwargs=micro_cond, 487 | ) 488 | elif self.training_mode == "consistency_training": 489 | compute_losses = functools.partial( 490 | self.diffusion.consistency_losses, 491 | self.ddp_model, 492 | micro, 493 | num_scales, 494 | target_model=self.target_model, 495 | model_kwargs=micro_cond, 496 | ) 497 | else: 498 | raise ValueError(f"Unknown training mode {self.training_mode}") 499 | 500 | if last_batch or not self.use_ddp: 501 | losses = compute_losses() 502 | else: 503 | with self.ddp_model.no_sync(): 504 | losses = compute_losses() 505 | 506 | if isinstance(self.schedule_sampler, LossAwareSampler): 507 | self.schedule_sampler.update_with_local_losses( 508 | t, losses["loss"].detach() 509 | ) 510 | 511 | loss = (losses["loss"] * weights).mean() 512 | 513 | log_loss_dict( 514 | self.diffusion, t, {k: v * weights for k, v in losses.items()} 515 | ) 516 | self.mp_trainer.backward(loss) 517 | 518 | def save(self): 519 | import blobfile as bf 520 | 521 | step = self.global_step 522 | 523 | def save_checkpoint(rate, params): 524 | state_dict = self.mp_trainer.master_params_to_state_dict(params) 525 | if dist.get_rank() == 0: 526 | logger.log(f"saving model {rate}...") 527 | if not rate: 528 | filename = f"model{step:06d}.pt" 529 | else: 530 | filename = f"ema_{rate}_{step:06d}.pt" 531 | with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: 532 | th.save(state_dict, f) 533 | 534 | for rate, params in zip(self.ema_rate, self.ema_params): 535 | save_checkpoint(rate, params) 536 | 537 | logger.log("saving optimizer state...") 538 | if dist.get_rank() == 0: 539 | with bf.BlobFile( 540 | bf.join(get_blob_logdir(), f"opt{step:06d}.pt"), 541 | "wb", 542 | ) as f: 543 | th.save(self.opt.state_dict(), f) 544 | 545 | if dist.get_rank() == 0: 546 | if self.target_model: 547 | logger.log("saving target model state") 548 | filename = f"target_model{step:06d}.pt" 549 | with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: 550 | th.save(self.target_model.state_dict(), f) 551 | if self.teacher_model and self.training_mode == "progdist": 552 | logger.log("saving teacher model state") 553 | filename = f"teacher_model{step:06d}.pt" 554 | with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: 555 | th.save(self.teacher_model.state_dict(), f) 556 | 557 | # Save model parameters last to prevent race conditions where a restart 558 | # loads model at step N, but opt/ema state isn't saved for step N. 559 | save_checkpoint(0, self.mp_trainer.master_params) 560 | dist.barrier() 561 | 562 | def log_step(self): 563 | step = self.global_step 564 | logger.logkv("step", step) 565 | logger.logkv("samples", (step + 1) * self.global_batch) 566 | 567 | 568 | def parse_resume_step_from_filename(filename): 569 | """ 570 | Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the 571 | checkpoint's number of steps. 572 | """ 573 | split = filename.split("model") 574 | if len(split) < 2: 575 | return 0 576 | split1 = split[-1].split(".")[0] 577 | try: 578 | return int(split1) 579 | except ValueError: 580 | return 0 581 | 582 | 583 | def get_blob_logdir(): 584 | # You can change this to be a separate path to save checkpoints to 585 | # a blobstore or some external drive. 586 | return logger.get_dir() 587 | 588 | 589 | def find_resume_checkpoint(): 590 | # On your infrastructure, you may want to override this to automatically 591 | # discover the latest checkpoint on your blob storage, etc. 592 | return None 593 | 594 | 595 | def find_ema_checkpoint(main_checkpoint, step, rate): 596 | if main_checkpoint is None: 597 | return None 598 | filename = f"ema_{rate}_{(step):06d}.pt" 599 | path = bf.join(bf.dirname(main_checkpoint), filename) 600 | if bf.exists(path): 601 | return path 602 | return None 603 | 604 | 605 | def log_loss_dict(diffusion, ts, losses): 606 | for key, values in losses.items(): 607 | logger.logkv_mean(key, values.mean().item()) 608 | # Log the quantiles (four quartiles, in particular). 609 | for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): 610 | quartile = int(4 * sub_t / diffusion.num_timesteps) 611 | logger.logkv_mean(f"{key}_q{quartile}", sub_loss) 612 | -------------------------------------------------------------------------------- /Network/Diffusion_model_Unet_2d.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | import math 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | # from .fp16_util import convert_module_to_f16, convert_module_to_f32 11 | from Network.util_nn import ( 12 | checkpoint, 13 | conv_nd, 14 | linear, 15 | avg_pool_nd, 16 | zero_module, 17 | normalization, 18 | timestep_embedding, 19 | ) 20 | 21 | class AttentionPool2d(nn.Module): 22 | """ 23 | Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py 24 | """ 25 | 26 | def __init__( 27 | self, 28 | spacial_dim: int, 29 | embed_dim: int, 30 | num_heads_channels: int, 31 | output_dim: int = None, 32 | ): 33 | super().__init__() 34 | self.positional_embedding = nn.Parameter( 35 | th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5 36 | ) 37 | self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) 38 | self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) 39 | self.num_heads = embed_dim // num_heads_channels 40 | self.attention = QKVAttention(self.num_heads) 41 | 42 | def forward(self, x): 43 | b, c, *_spatial = x.shape 44 | x = x.reshape(b, c, -1) # NC(HW) 45 | x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) 46 | x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) 47 | x = self.qkv_proj(x) 48 | x = self.attention(x) 49 | x = self.c_proj(x) 50 | return x[:, :, 0] 51 | 52 | 53 | class TimestepBlock(nn.Module): 54 | """ 55 | Any module where forward() takes timestep embeddings as a second argument. 56 | """ 57 | 58 | @abstractmethod 59 | def forward(self, x, emb): 60 | """ 61 | Apply the module to `x` given `emb` timestep embeddings. 62 | """ 63 | 64 | 65 | class TimestepEmbedSequential(nn.Sequential, TimestepBlock): 66 | """ 67 | A sequential module that passes timestep embeddings to the children that 68 | support it as an extra input. 69 | """ 70 | 71 | def forward(self, x, emb): 72 | for layer in self: 73 | if isinstance(layer, TimestepBlock): 74 | x = layer(x, emb) 75 | else: 76 | x = layer(x) 77 | return x 78 | 79 | 80 | class Upsample(nn.Module): 81 | """ 82 | An upsampling layer with an optional convolution. 83 | :param channels: channels in the inputs and outputs. 84 | :param use_conv: a bool determining if a convolution is applied. 85 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 86 | upsampling occurs in the inner-two dimensions. 87 | """ 88 | 89 | def __init__(self, channels, use_conv, sample_kernel, dims=2, out_channels=None): 90 | super().__init__() 91 | self.channels = channels 92 | if dims == 3: 93 | self.sample_kernel=(sample_kernel[0],sample_kernel[1],sample_kernel[2]) 94 | else: 95 | self.sample_kernel=(sample_kernel[0],sample_kernel[1]) 96 | self.dims = dims 97 | if use_conv: 98 | self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) 99 | else: 100 | self.conv = th.nn.Upsample(scale_factor=self.sample_kernel,mode='nearest') 101 | 102 | def forward(self, x): 103 | assert x.shape[1] == self.channels 104 | x = self.conv(x) 105 | 106 | # if self.dims == 3: 107 | # x = F.interpolate( 108 | # x, scale_factor=self.sample_kernel, mode="nearest" 109 | # ) 110 | # else: 111 | # x = F.interpolate(x, scale_factor=self.sample_kernel, mode="nearest") 112 | # if self.use_conv: 113 | # x = self.conv(x) 114 | return x 115 | 116 | 117 | class Downsample(nn.Module): 118 | """ 119 | A downsampling layer with an optional convolution. 120 | :param channels: channels in the inputs and outputs. 121 | :param use_conv: a bool determining if a convolution is applied. 122 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 123 | downsampling occurs in the inner-two dimensions. 124 | """ 125 | 126 | def __init__(self, channels, use_conv,sample_kernel, dims=2, out_channels=None): 127 | super().__init__() 128 | self.channels = channels 129 | self.out_channels = out_channels or channels 130 | self.use_conv = use_conv 131 | self.dims = dims 132 | if self.dims == 3: 133 | self.sample_kernel = (1/sample_kernel[0],1/sample_kernel[1],1/sample_kernel[2]) 134 | else: 135 | self.sample_kernel = (1/sample_kernel[0],1/sample_kernel[1]) 136 | # stride = 2 if dims != 3 else (2, 2, 2) 137 | # stride = 2 138 | if use_conv: 139 | self.op = th.nn.Upsample(scale_factor=self.sample_kernel,mode='nearest') 140 | else: 141 | assert self.channels == self.out_channels 142 | self.op = th.nn.Upsample(scale_factor=self.sample_kernel,mode='nearest') 143 | 144 | def forward(self, x): 145 | assert x.shape[1] == self.channels 146 | # x = F.interpolate( 147 | # x, scale_factor=self.sample_kernel, mode="nearest" 148 | # ) 149 | return self.op(x) 150 | 151 | 152 | class ResBlock(TimestepBlock): 153 | """ 154 | A residual block that can optionally change the number of channels. 155 | :param channels: the number of input channels. 156 | :param emb_channels: the number of timestep embedding channels. 157 | :param dropout: the rate of dropout. 158 | :param out_channels: if specified, the number of out channels. 159 | :param use_conv: if True and out_channels is specified, use a spatial 160 | convolution instead of a smaller 1x1 convolution to change the 161 | channels in the skip connection. 162 | :param dims: determines if the signal is 1D, 2D, or 3D. 163 | :param use_checkpoint: if True, use gradient checkpointing on this module. 164 | :param up: if True, use this block for upsampling. 165 | :param down: if True, use this block for downsampling. 166 | """ 167 | 168 | def __init__( 169 | self, 170 | channels, 171 | emb_channels, 172 | dropout, 173 | out_channels=None, 174 | use_conv=False, 175 | use_scale_shift_norm=False, 176 | dims=2, 177 | sample_kernel = None, 178 | use_checkpoint=False, 179 | up=False, 180 | down=False, 181 | ): 182 | super().__init__() 183 | self.channels = channels 184 | self.emb_channels = emb_channels 185 | self.dropout = dropout 186 | self.out_channels = out_channels or channels 187 | self.use_conv = use_conv 188 | self.use_checkpoint = use_checkpoint 189 | self.use_scale_shift_norm = use_scale_shift_norm 190 | 191 | self.in_layers = nn.Sequential( 192 | normalization(channels), 193 | nn.SiLU(), 194 | conv_nd(dims, channels, self.out_channels, 3, padding=1), 195 | ) 196 | 197 | self.updown = up or down 198 | 199 | if up: 200 | self.h_upd = Upsample(channels, False,sample_kernel, dims) 201 | self.x_upd = Upsample(channels, False,sample_kernel, dims) 202 | elif down: 203 | self.h_upd = Downsample(channels, False,sample_kernel, dims) 204 | self.x_upd = Downsample(channels, False,sample_kernel, dims) 205 | else: 206 | self.h_upd = self.x_upd = nn.Identity() 207 | 208 | self.emb_layers = nn.Sequential( 209 | nn.SiLU(), 210 | linear( 211 | emb_channels, 212 | 2 * self.out_channels if use_scale_shift_norm else self.out_channels, 213 | ), 214 | ) 215 | self.out_layers = nn.Sequential( 216 | normalization(self.out_channels), 217 | nn.SiLU(), 218 | nn.Dropout(p=dropout), 219 | zero_module( 220 | conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) 221 | ), 222 | ) 223 | 224 | if self.out_channels == channels: 225 | self.skip_connection = nn.Identity() 226 | elif use_conv: 227 | self.skip_connection = conv_nd( 228 | dims, channels, self.out_channels, 3, padding=1 229 | ) 230 | else: 231 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) 232 | 233 | def forward(self, x, emb): 234 | """ 235 | Apply the block to a Tensor, conditioned on a timestep embedding. 236 | :param x: an [N x C x ...] Tensor of features. 237 | :param emb: an [N x emb_channels] Tensor of timestep embeddings. 238 | :return: an [N x C x ...] Tensor of outputs. 239 | """ 240 | return checkpoint( 241 | self._forward, (x, emb), self.parameters(), self.use_checkpoint 242 | ) 243 | 244 | def _forward(self, x, emb): 245 | if self.updown: 246 | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] 247 | h = in_rest(x) 248 | h = self.h_upd(h) 249 | x = self.x_upd(x) 250 | h = in_conv(h) 251 | else: 252 | h = self.in_layers(x) 253 | emb_out = self.emb_layers(emb).type(h.dtype) 254 | while len(emb_out.shape) < len(h.shape): 255 | emb_out = emb_out[..., None] 256 | if self.use_scale_shift_norm: 257 | out_norm, out_rest = self.out_layers[0], self.out_layers[1:] 258 | scale, shift = th.chunk(emb_out, 2, dim=1) 259 | h = out_norm(h) * (1 + scale) + shift 260 | h = out_rest(h) 261 | else: 262 | h = h + emb_out 263 | h = self.out_layers(h) 264 | return self.skip_connection(x) + h 265 | 266 | 267 | class AttentionBlock(nn.Module): 268 | """ 269 | An attention block that allows spatial positions to attend to each other. 270 | Originally ported from here, but adapted to the N-d case. 271 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. 272 | """ 273 | 274 | def __init__( 275 | self, 276 | channels, 277 | num_heads=1, 278 | num_head_channels=-1, 279 | use_checkpoint=False, 280 | use_new_attention_order=False, 281 | ): 282 | super().__init__() 283 | self.channels = channels 284 | if num_head_channels == -1: 285 | self.num_heads = num_heads 286 | else: 287 | assert ( 288 | channels % num_head_channels == 0 289 | ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" 290 | self.num_heads = channels // num_head_channels 291 | self.use_checkpoint = use_checkpoint 292 | self.norm = normalization(channels) 293 | self.qkv = conv_nd(1, channels, channels * 3, 1) 294 | if use_new_attention_order: 295 | # split qkv before split heads 296 | self.attention = QKVAttention(self.num_heads) 297 | else: 298 | # split heads before split qkv 299 | self.attention = QKVAttentionLegacy(self.num_heads) 300 | 301 | self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) 302 | 303 | def forward(self, x): 304 | return checkpoint(self._forward, (x,), self.parameters(), True) 305 | 306 | def _forward(self, x): 307 | b, c, *spatial = x.shape 308 | x = x.reshape(b, c, -1) 309 | qkv = self.qkv(self.norm(x).float()).type(x.dtype) 310 | h = self.attention(qkv) 311 | h = self.proj_out(h.float()).type(x.dtype) 312 | return (x + h).reshape(b, c, *spatial) 313 | 314 | 315 | def count_flops_attn(model, _x, y): 316 | """ 317 | A counter for the `thop` package to count the operations in an 318 | attention operation. 319 | Meant to be used like: 320 | macs, params = thop.profile( 321 | model, 322 | inputs=(inputs, timestamps), 323 | custom_ops={QKVAttention: QKVAttention.count_flops}, 324 | ) 325 | """ 326 | b, c, *spatial = y[0].shape 327 | num_spatial = int(np.prod(spatial)) 328 | # We perform two matmuls with the same number of ops. 329 | # The first computes the weight matrix, the second computes 330 | # the combination of the value vectors. 331 | matmul_ops = 2 * b * (num_spatial ** 2) * c 332 | model.total_ops += th.DoubleTensor([matmul_ops]) 333 | 334 | 335 | class QKVAttentionLegacy(nn.Module): 336 | """ 337 | A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping 338 | """ 339 | 340 | def __init__(self, n_heads): 341 | super().__init__() 342 | self.n_heads = n_heads 343 | 344 | def forward(self, qkv): 345 | """ 346 | Apply QKV attention. 347 | :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. 348 | :return: an [N x (H * C) x T] tensor after attention. 349 | """ 350 | bs, width, length = qkv.shape 351 | assert width % (3 * self.n_heads) == 0 352 | ch = width // (3 * self.n_heads) 353 | q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) 354 | scale = 1 / math.sqrt(math.sqrt(ch)) 355 | weight = th.einsum( 356 | "bct,bcs->bts", q * scale, k * scale 357 | ) # More stable with f16 than dividing afterwards 358 | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) 359 | a = th.einsum("bts,bcs->bct", weight, v) 360 | return a.reshape(bs, -1, length) 361 | 362 | @staticmethod 363 | def count_flops(model, _x, y): 364 | return count_flops_attn(model, _x, y) 365 | 366 | 367 | class QKVAttention(nn.Module): 368 | """ 369 | A module which performs QKV attention and splits in a different order. 370 | """ 371 | 372 | def __init__(self, n_heads): 373 | super().__init__() 374 | self.n_heads = n_heads 375 | 376 | def forward(self, qkv): 377 | """ 378 | Apply QKV attention. 379 | :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. 380 | :return: an [N x (H * C) x T] tensor after attention. 381 | """ 382 | bs, width, length = qkv.shape 383 | assert width % (3 * self.n_heads) == 0 384 | ch = width // (3 * self.n_heads) 385 | q, k, v = qkv.chunk(3, dim=1) 386 | scale = 1 / math.sqrt(math.sqrt(ch)) 387 | weight = th.einsum( 388 | "bct,bcs->bts", 389 | (q * scale).view(bs * self.n_heads, ch, length), 390 | (k * scale).view(bs * self.n_heads, ch, length), 391 | ) # More stable with f16 than dividing afterwards 392 | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) 393 | a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) 394 | return a.reshape(bs, -1, length) 395 | 396 | @staticmethod 397 | def count_flops(model, _x, y): 398 | return count_flops_attn(model, _x, y) 399 | 400 | 401 | class UNetModel(nn.Module): 402 | """ 403 | The full UNet model with attention and timestep embedding. 404 | :param in_channels: channels in the input Tensor. 405 | :param model_channels: base channel count for the model. 406 | :param out_channels: channels in the output Tensor. 407 | :param num_res_blocks: number of residual blocks per downsample. 408 | :param attention_resolutions: a collection of downsample rates at which 409 | attention will take place. May be a set, list, or tuple. 410 | For example, if this contains 4, then at 4x downsampling, attention 411 | will be used. 412 | :param dropout: the dropout probability. 413 | :param channel_mult: channel multiplier for each level of the UNet. 414 | :param conv_resample: if True, use learned convolutions for upsampling and 415 | downsampling. 416 | :param dims: determines if the signal is 1D, 2D, or 3D. 417 | :param num_classes: if specified (as an int), then this model will be 418 | class-conditional with `num_classes` classes. 419 | :param use_checkpoint: use gradient checkpointing to reduce memory usage. 420 | :param num_heads: the number of attention heads in each attention layer. 421 | :param num_heads_channels: if specified, ignore num_heads and instead use 422 | a fixed channel width per attention head. 423 | :param num_heads_upsample: works with num_heads to set a different number 424 | of heads for upsampling. Deprecated. 425 | :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. 426 | :param resblock_updown: use residual blocks for up/downsampling. 427 | :param use_new_attention_order: use a different attention pattern for potentially 428 | increased efficiency. 429 | """ 430 | 431 | def __init__( 432 | self, 433 | img_size, 434 | image_size, 435 | in_channels, 436 | model_channels, 437 | out_channels, 438 | num_res_blocks, 439 | attention_resolutions, 440 | dropout=0, 441 | channel_mult=(1, 2, 4, 8), 442 | conv_resample=False, 443 | dims=2, 444 | sample_kernel = None, 445 | num_classes=None, 446 | use_checkpoint=False, 447 | use_fp16=False, 448 | num_heads=1, 449 | num_head_channels=-1, 450 | num_heads_upsample=-1, 451 | use_scale_shift_norm=False, 452 | resblock_updown=False, 453 | use_new_attention_order=False, 454 | ): 455 | super().__init__() 456 | 457 | if num_heads_upsample == -1: 458 | num_heads_upsample = num_heads 459 | self.img_size=img_size 460 | self.image_size = image_size 461 | self.in_channels = in_channels 462 | self.model_channels = model_channels 463 | self.out_channels = out_channels 464 | self.num_res_blocks = num_res_blocks 465 | self.attention_resolutions = attention_resolutions 466 | self.dropout = dropout 467 | self.channel_mult = channel_mult 468 | self.conv_resample = conv_resample 469 | self.num_classes = num_classes 470 | self.use_checkpoint = use_checkpoint 471 | self.dtype = th.float16 if use_fp16 else th.float32 472 | self.num_heads = num_heads 473 | self.num_head_channels = num_head_channels 474 | self.num_heads_upsample = num_heads_upsample 475 | self.sample_kernel =sample_kernel[0] 476 | 477 | time_embed_dim = model_channels * 4 478 | self.time_embed = nn.Sequential( 479 | linear(model_channels, time_embed_dim), 480 | nn.SiLU(), 481 | linear(time_embed_dim, time_embed_dim), 482 | ) 483 | 484 | if self.num_classes is not None: 485 | self.label_emb = nn.Embedding(num_classes, time_embed_dim) 486 | 487 | ch = input_ch = int(channel_mult[0] * model_channels) 488 | self.input_blocks = nn.ModuleList( 489 | [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))] 490 | ) 491 | self._feature_size = ch 492 | input_block_chans = [ch] 493 | ds = image_size 494 | for level, mult in enumerate(channel_mult): 495 | for _ in range(num_res_blocks): 496 | layers = [ 497 | ResBlock( 498 | ch, 499 | time_embed_dim, 500 | dropout, 501 | out_channels=int(mult * model_channels), 502 | dims=dims, 503 | use_checkpoint=use_checkpoint, 504 | use_scale_shift_norm=use_scale_shift_norm, 505 | ) 506 | ] 507 | ch = int(mult * model_channels) 508 | if ds in attention_resolutions: 509 | layers.append( 510 | AttentionBlock( 511 | ch, 512 | use_checkpoint=use_checkpoint, 513 | num_heads=num_heads, 514 | num_head_channels=num_head_channels, 515 | use_new_attention_order=use_new_attention_order, 516 | ) 517 | ) 518 | self.input_blocks.append(TimestepEmbedSequential(*layers)) 519 | self._feature_size += ch 520 | input_block_chans.append(ch) 521 | if level != len(channel_mult) - 1: 522 | out_ch = ch 523 | self.input_blocks.append( 524 | TimestepEmbedSequential( 525 | ResBlock( 526 | ch, 527 | time_embed_dim, 528 | dropout, 529 | out_channels=out_ch, 530 | dims=dims, 531 | sample_kernel = self.sample_kernel[level], 532 | use_checkpoint=use_checkpoint, 533 | use_scale_shift_norm=use_scale_shift_norm, 534 | down=True, 535 | ) 536 | if resblock_updown 537 | else Downsample( 538 | ch, conv_resample,self.sample_kernel[level], dims=dims, out_channels=out_ch 539 | ) 540 | ) 541 | ) 542 | ch = out_ch 543 | input_block_chans.append(ch) 544 | ds //= 2 545 | self._feature_size += ch 546 | 547 | self.middle_block = TimestepEmbedSequential( 548 | ResBlock( 549 | ch, 550 | time_embed_dim, 551 | dropout, 552 | dims=dims, 553 | use_checkpoint=use_checkpoint, 554 | use_scale_shift_norm=use_scale_shift_norm, 555 | ), 556 | AttentionBlock( 557 | ch, 558 | use_checkpoint=use_checkpoint, 559 | num_heads=num_heads, 560 | num_head_channels=num_head_channels, 561 | use_new_attention_order=use_new_attention_order, 562 | ), 563 | ResBlock( 564 | ch, 565 | time_embed_dim, 566 | dropout, 567 | dims=dims, 568 | use_checkpoint=use_checkpoint, 569 | use_scale_shift_norm=use_scale_shift_norm, 570 | ), 571 | ) 572 | self._feature_size += ch 573 | 574 | self.output_blocks = nn.ModuleList([]) 575 | for level, mult in list(enumerate(channel_mult))[::-1]: 576 | for i in range(num_res_blocks+1): 577 | ich = input_block_chans.pop() 578 | layers = [ 579 | ResBlock( 580 | ch + ich, 581 | time_embed_dim, 582 | dropout, 583 | out_channels=int(model_channels * mult), 584 | dims=dims, 585 | use_checkpoint=use_checkpoint, 586 | use_scale_shift_norm=use_scale_shift_norm, 587 | ) 588 | ] 589 | ch = int(model_channels * mult) 590 | if ds in attention_resolutions: 591 | layers.append( 592 | AttentionBlock( 593 | ch, 594 | use_checkpoint=use_checkpoint, 595 | num_heads=num_heads_upsample, 596 | num_head_channels=num_head_channels, 597 | use_new_attention_order=use_new_attention_order, 598 | ) 599 | ) 600 | if level and i == num_res_blocks: 601 | out_ch = ch 602 | layers.append( 603 | ResBlock( 604 | ch, 605 | time_embed_dim, 606 | dropout, 607 | out_channels=out_ch, 608 | dims=dims, 609 | sample_kernel = self.sample_kernel[level-1], 610 | use_checkpoint=use_checkpoint, 611 | use_scale_shift_norm=use_scale_shift_norm, 612 | up=True, 613 | ) 614 | if resblock_updown 615 | else Upsample(ch, conv_resample,self.sample_kernel[level-1], dims=dims, out_channels=out_ch) 616 | ) 617 | ds *= 2 618 | self.output_blocks.append(TimestepEmbedSequential(*layers)) 619 | self._feature_size += ch 620 | 621 | self.out = nn.Sequential( 622 | normalization(ch), 623 | nn.SiLU(), 624 | zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)), 625 | ) 626 | 627 | def forward(self, x, timesteps, y=None): 628 | """ 629 | Apply the model to an input batch. 630 | :param x: an [N x C x ...] Tensor of inputs. 631 | :param timesteps: a 1-D batch of timesteps. 632 | :param y: an [N] Tensor of labels, if class-conditional. 633 | :return: an [N x C x ...] Tensor of outputs. 634 | """ 635 | assert (y is not None) == ( 636 | self.num_classes is not None 637 | ), "must specify y if and only if the model is class-conditional" 638 | 639 | hs = [] 640 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) 641 | 642 | if self.num_classes is not None: 643 | assert y.shape == (x.shape[0],) 644 | emb = emb + self.label_emb(y) 645 | 646 | h = x.type(self.dtype) 647 | for module in self.input_blocks: 648 | h = module(h, emb) 649 | hs.append(h) 650 | h = self.middle_block(h, emb) 651 | for module in self.output_blocks: 652 | h = th.cat([h, hs.pop()], dim=1) 653 | h = module(h, emb) 654 | h = h.type(x.dtype) 655 | return self.out(h) -------------------------------------------------------------------------------- /cm/unet.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | import math 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from .fp16_util import convert_module_to_f16, convert_module_to_f32 11 | from .nn import ( 12 | checkpoint, 13 | conv_nd, 14 | linear, 15 | avg_pool_nd, 16 | zero_module, 17 | normalization, 18 | timestep_embedding, 19 | ) 20 | 21 | 22 | class AttentionPool2d(nn.Module): 23 | """ 24 | Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py 25 | """ 26 | 27 | def __init__( 28 | self, 29 | spacial_dim: int, 30 | embed_dim: int, 31 | num_heads_channels: int, 32 | output_dim: int = None, 33 | ): 34 | super().__init__() 35 | self.positional_embedding = nn.Parameter( 36 | th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5 37 | ) 38 | self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) 39 | self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) 40 | self.num_heads = embed_dim // num_heads_channels 41 | self.attention = QKVAttention(self.num_heads) 42 | 43 | def forward(self, x): 44 | b, c, *_spatial = x.shape 45 | x = x.reshape(b, c, -1) # NC(HW) 46 | x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) 47 | x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) 48 | x = self.qkv_proj(x) 49 | x = self.attention(x) 50 | x = self.c_proj(x) 51 | return x[:, :, 0] 52 | 53 | 54 | class TimestepBlock(nn.Module): 55 | """ 56 | Any module where forward() takes timestep embeddings as a second argument. 57 | """ 58 | 59 | @abstractmethod 60 | def forward(self, x, emb): 61 | """ 62 | Apply the module to `x` given `emb` timestep embeddings. 63 | """ 64 | 65 | 66 | class TimestepEmbedSequential(nn.Sequential, TimestepBlock): 67 | """ 68 | A sequential module that passes timestep embeddings to the children that 69 | support it as an extra input. 70 | """ 71 | 72 | def forward(self, x, emb): 73 | for layer in self: 74 | if isinstance(layer, TimestepBlock): 75 | x = layer(x, emb) 76 | else: 77 | x = layer(x) 78 | return x 79 | 80 | 81 | class Upsample(nn.Module): 82 | """ 83 | An upsampling layer with an optional convolution. 84 | 85 | :param channels: channels in the inputs and outputs. 86 | :param use_conv: a bool determining if a convolution is applied. 87 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 88 | upsampling occurs in the inner-two dimensions. 89 | """ 90 | 91 | def __init__(self, channels, use_conv, dims=2, out_channels=None): 92 | super().__init__() 93 | self.channels = channels 94 | self.out_channels = out_channels or channels 95 | self.use_conv = use_conv 96 | self.dims = dims 97 | if use_conv: 98 | self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) 99 | 100 | def forward(self, x): 101 | assert x.shape[1] == self.channels 102 | if self.dims == 3: 103 | x = F.interpolate( 104 | x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" 105 | ) 106 | else: 107 | x = F.interpolate(x, scale_factor=2, mode="nearest") 108 | if self.use_conv: 109 | x = self.conv(x) 110 | return x 111 | 112 | 113 | class Downsample(nn.Module): 114 | """ 115 | A downsampling layer with an optional convolution. 116 | 117 | :param channels: channels in the inputs and outputs. 118 | :param use_conv: a bool determining if a convolution is applied. 119 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 120 | downsampling occurs in the inner-two dimensions. 121 | """ 122 | 123 | def __init__(self, channels, use_conv, dims=2, out_channels=None): 124 | super().__init__() 125 | self.channels = channels 126 | self.out_channels = out_channels or channels 127 | self.use_conv = use_conv 128 | self.dims = dims 129 | stride = 2 if dims != 3 else (1, 2, 2) 130 | if use_conv: 131 | self.op = conv_nd( 132 | dims, self.channels, self.out_channels, 3, stride=stride, padding=1 133 | ) 134 | else: 135 | assert self.channels == self.out_channels 136 | self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) 137 | 138 | def forward(self, x): 139 | assert x.shape[1] == self.channels 140 | return self.op(x) 141 | 142 | 143 | class ResBlock(TimestepBlock): 144 | """ 145 | A residual block that can optionally change the number of channels. 146 | 147 | :param channels: the number of input channels. 148 | :param emb_channels: the number of timestep embedding channels. 149 | :param dropout: the rate of dropout. 150 | :param out_channels: if specified, the number of out channels. 151 | :param use_conv: if True and out_channels is specified, use a spatial 152 | convolution instead of a smaller 1x1 convolution to change the 153 | channels in the skip connection. 154 | :param dims: determines if the signal is 1D, 2D, or 3D. 155 | :param use_checkpoint: if True, use gradient checkpointing on this module. 156 | :param up: if True, use this block for upsampling. 157 | :param down: if True, use this block for downsampling. 158 | """ 159 | 160 | def __init__( 161 | self, 162 | channels, 163 | emb_channels, 164 | dropout, 165 | out_channels=None, 166 | use_conv=False, 167 | use_scale_shift_norm=False, 168 | dims=2, 169 | use_checkpoint=False, 170 | up=False, 171 | down=False, 172 | ): 173 | super().__init__() 174 | self.channels = channels 175 | self.emb_channels = emb_channels 176 | self.dropout = dropout 177 | self.out_channels = out_channels or channels 178 | self.use_conv = use_conv 179 | self.use_checkpoint = use_checkpoint 180 | self.use_scale_shift_norm = use_scale_shift_norm 181 | 182 | self.in_layers = nn.Sequential( 183 | normalization(channels), 184 | nn.SiLU(), 185 | conv_nd(dims, channels, self.out_channels, 3, padding=1), 186 | ) 187 | 188 | self.updown = up or down 189 | 190 | if up: 191 | self.h_upd = Upsample(channels, False, dims) 192 | self.x_upd = Upsample(channels, False, dims) 193 | elif down: 194 | self.h_upd = Downsample(channels, False, dims) 195 | self.x_upd = Downsample(channels, False, dims) 196 | else: 197 | self.h_upd = self.x_upd = nn.Identity() 198 | 199 | self.emb_layers = nn.Sequential( 200 | nn.SiLU(), 201 | linear( 202 | emb_channels, 203 | 2 * self.out_channels if use_scale_shift_norm else self.out_channels, 204 | ), 205 | ) 206 | self.out_layers = nn.Sequential( 207 | normalization(self.out_channels), 208 | nn.SiLU(), 209 | nn.Dropout(p=dropout), 210 | zero_module( 211 | conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) 212 | ), 213 | ) 214 | 215 | if self.out_channels == channels: 216 | self.skip_connection = nn.Identity() 217 | elif use_conv: 218 | self.skip_connection = conv_nd( 219 | dims, channels, self.out_channels, 3, padding=1 220 | ) 221 | else: 222 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) 223 | 224 | def forward(self, x, emb): 225 | """ 226 | Apply the block to a Tensor, conditioned on a timestep embedding. 227 | 228 | :param x: an [N x C x ...] Tensor of features. 229 | :param emb: an [N x emb_channels] Tensor of timestep embeddings. 230 | :return: an [N x C x ...] Tensor of outputs. 231 | """ 232 | return checkpoint( 233 | self._forward, (x, emb), self.parameters(), self.use_checkpoint 234 | ) 235 | 236 | def _forward(self, x, emb): 237 | if self.updown: 238 | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] 239 | h = in_rest(x) 240 | h = self.h_upd(h) 241 | x = self.x_upd(x) 242 | h = in_conv(h) 243 | else: 244 | h = self.in_layers(x) 245 | emb_out = self.emb_layers(emb).type(h.dtype) 246 | while len(emb_out.shape) < len(h.shape): 247 | emb_out = emb_out[..., None] 248 | if self.use_scale_shift_norm: 249 | out_norm, out_rest = self.out_layers[0], self.out_layers[1:] 250 | scale, shift = th.chunk(emb_out, 2, dim=1) 251 | h = out_norm(h) * (1 + scale) + shift 252 | h = out_rest(h) 253 | else: 254 | h = h + emb_out 255 | h = self.out_layers(h) 256 | return self.skip_connection(x) + h 257 | 258 | 259 | class AttentionBlock(nn.Module): 260 | """ 261 | An attention block that allows spatial positions to attend to each other. 262 | 263 | Originally ported from here, but adapted to the N-d case. 264 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. 265 | """ 266 | 267 | def __init__( 268 | self, 269 | channels, 270 | num_heads=1, 271 | num_head_channels=-1, 272 | use_checkpoint=False, 273 | attention_type="flash", 274 | encoder_channels=None, 275 | dims=2, 276 | channels_last=False, 277 | use_new_attention_order=False, 278 | ): 279 | super().__init__() 280 | self.channels = channels 281 | if num_head_channels == -1: 282 | self.num_heads = num_heads 283 | else: 284 | assert ( 285 | channels % num_head_channels == 0 286 | ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" 287 | self.num_heads = channels // num_head_channels 288 | self.use_checkpoint = use_checkpoint 289 | self.norm = normalization(channels) 290 | self.qkv = conv_nd(dims, channels, channels * 3, 1) 291 | self.attention_type = attention_type 292 | if attention_type == "flash": 293 | self.attention = QKVAttentionLegacy(self.num_heads) 294 | else: 295 | # split heads before split qkv 296 | self.attention = QKVAttentionLegacy(self.num_heads) 297 | 298 | self.use_attention_checkpoint = not ( 299 | self.use_checkpoint or self.attention_type == "flash" 300 | ) 301 | if encoder_channels is not None: 302 | assert attention_type != "flash" 303 | self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1) 304 | self.proj_out = zero_module(conv_nd(dims, channels, channels, 1)) 305 | 306 | def forward(self, x, encoder_out=None): 307 | if encoder_out is None: 308 | return checkpoint( 309 | self._forward, (x,), self.parameters(), self.use_checkpoint 310 | ) 311 | else: 312 | return checkpoint( 313 | self._forward, (x, encoder_out), self.parameters(), self.use_checkpoint 314 | ) 315 | 316 | def _forward(self, x, encoder_out=None): 317 | b, _, *spatial = x.shape 318 | qkv = self.qkv(self.norm(x)).view(b, -1, np.prod(spatial)) 319 | if encoder_out is not None: 320 | encoder_out = self.encoder_kv(encoder_out) 321 | h = checkpoint( 322 | self.attention, (qkv, encoder_out), (), self.use_attention_checkpoint 323 | ) 324 | else: 325 | h = checkpoint(self.attention, (qkv,), (), self.use_attention_checkpoint) 326 | h = h.view(b, -1, *spatial) 327 | h = self.proj_out(h) 328 | return x + h 329 | 330 | 331 | # class QKVFlashAttention(nn.Module): 332 | # def __init__( 333 | # self, 334 | # embed_dim, 335 | # num_heads, 336 | # batch_first=True, 337 | # attention_dropout=0.0, 338 | # causal=False, 339 | # device=None, 340 | # dtype=None, 341 | # **kwargs, 342 | # ) -> None: 343 | # from einops import rearrange 344 | # from flash_attn.flash_attention import FlashAttention 345 | 346 | # assert batch_first 347 | # factory_kwargs = {"device": device, "dtype": dtype} 348 | # super().__init__() 349 | # self.embed_dim = embed_dim 350 | # self.num_heads = num_heads 351 | # self.causal = causal 352 | 353 | # assert ( 354 | # self.embed_dim % num_heads == 0 355 | # ), "self.kdim must be divisible by num_heads" 356 | # self.head_dim = self.embed_dim // num_heads 357 | # assert self.head_dim in [16, 32, 64], "Only support head_dim == 16, 32, or 64" 358 | 359 | # self.inner_attn = FlashAttention( 360 | # attention_dropout=attention_dropout, **factory_kwargs 361 | # ) 362 | # self.rearrange = rearrange 363 | 364 | # def forward(self, qkv, attn_mask=None, key_padding_mask=None, need_weights=False): 365 | # qkv = self.rearrange( 366 | # qkv, "b (three h d) s -> b s three h d", three=3, h=self.num_heads 367 | # ) 368 | # qkv, _ = self.inner_attn( 369 | # qkv, 370 | # key_padding_mask=key_padding_mask, 371 | # need_weights=need_weights, 372 | # causal=self.causal, 373 | # ) 374 | # return self.rearrange(qkv, "b s h d -> b (h d) s") 375 | 376 | 377 | def count_flops_attn(model, _x, y): 378 | """ 379 | A counter for the `thop` package to count the operations in an 380 | attention operation. 381 | Meant to be used like: 382 | macs, params = thop.profile( 383 | model, 384 | inputs=(inputs, timestamps), 385 | custom_ops={QKVAttention: QKVAttention.count_flops}, 386 | ) 387 | """ 388 | b, c, *spatial = y[0].shape 389 | num_spatial = int(np.prod(spatial)) 390 | # We perform two matmuls with the same number of ops. 391 | # The first computes the weight matrix, the second computes 392 | # the combination of the value vectors. 393 | matmul_ops = 2 * b * (num_spatial**2) * c 394 | model.total_ops += th.DoubleTensor([matmul_ops]) 395 | 396 | 397 | class QKVAttentionLegacy(nn.Module): 398 | """ 399 | A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping 400 | """ 401 | 402 | def __init__(self, n_heads): 403 | super().__init__() 404 | self.n_heads = n_heads 405 | 406 | def forward(self, qkv): 407 | """ 408 | Apply QKV attention. 409 | 410 | :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. 411 | :return: an [N x (H * C) x T] tensor after attention. 412 | """ 413 | bs, width, length = qkv.shape 414 | assert width % (3 * self.n_heads) == 0 415 | ch = width // (3 * self.n_heads) 416 | q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) 417 | scale = 1 / math.sqrt(math.sqrt(ch)) 418 | weight = th.einsum( 419 | "bct,bcs->bts", q * scale, k * scale 420 | ) # More stable with f16 than dividing afterwards 421 | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) 422 | a = th.einsum("bts,bcs->bct", weight, v) 423 | return a.reshape(bs, -1, length) 424 | 425 | @staticmethod 426 | def count_flops(model, _x, y): 427 | return count_flops_attn(model, _x, y) 428 | 429 | 430 | # class QKVAttention(nn.Module): 431 | # """ 432 | # A module which performs QKV attention and splits in a different order. 433 | # """ 434 | 435 | # def __init__(self, n_heads): 436 | # super().__init__() 437 | # self.n_heads = n_heads 438 | 439 | # def forward(self, qkv): 440 | # """ 441 | # Apply QKV attention. 442 | 443 | # :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. 444 | # :return: an [N x (H * C) x T] tensor after attention. 445 | # """ 446 | # bs, width, length = qkv.shape 447 | # assert width % (3 * self.n_heads) == 0 448 | # ch = width // (3 * self.n_heads) 449 | # q, k, v = qkv.chunk(3, dim=1) 450 | # scale = 1 / math.sqrt(math.sqrt(ch)) 451 | # weight = th.einsum( 452 | # "bct,bcs->bts", 453 | # (q * scale).view(bs * self.n_heads, ch, length), 454 | # (k * scale).view(bs * self.n_heads, ch, length), 455 | # ) # More stable with f16 than dividing afterwards 456 | # weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) 457 | # a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) 458 | # return a.reshape(bs, -1, length) 459 | 460 | # @staticmethod 461 | # def count_flops(model, _x, y): 462 | # return count_flops_attn(model, _x, y) 463 | 464 | 465 | class QKVAttention(nn.Module): 466 | """ 467 | A module which performs QKV attention. Fallback from Blocksparse if use_fp16=False 468 | """ 469 | 470 | def __init__(self, n_heads): 471 | super().__init__() 472 | self.n_heads = n_heads 473 | 474 | def forward(self, qkv, encoder_kv=None): 475 | """ 476 | Apply QKV attention. 477 | 478 | :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. 479 | :return: an [N x (H * C) x T] tensor after attention. 480 | """ 481 | bs, width, length = qkv.shape 482 | assert width % (3 * self.n_heads) == 0 483 | ch = width // (3 * self.n_heads) 484 | q, k, v = qkv.chunk(3, dim=1) 485 | if encoder_kv is not None: 486 | assert encoder_kv.shape[1] == 2 * ch * self.n_heads 487 | ek, ev = encoder_kv.chunk(2, dim=1) 488 | k = th.cat([ek, k], dim=-1) 489 | v = th.cat([ev, v], dim=-1) 490 | scale = 1 / math.sqrt(math.sqrt(ch)) 491 | weight = th.einsum( 492 | "bct,bcs->bts", 493 | (q * scale).view(bs * self.n_heads, ch, length), 494 | (k * scale).view(bs * self.n_heads, ch, -1), 495 | ) # More stable with f16 than dividing afterwards 496 | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) 497 | a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, -1)) 498 | return a.reshape(bs, -1, length) 499 | 500 | @staticmethod 501 | def count_flops(model, _x, y): 502 | return count_flops_attn(model, _x, y) 503 | 504 | 505 | class UNetModel(nn.Module): 506 | """ 507 | The full UNet model with attention and timestep embedding. 508 | 509 | :param in_channels: channels in the input Tensor. 510 | :param model_channels: base channel count for the model. 511 | :param out_channels: channels in the output Tensor. 512 | :param num_res_blocks: number of residual blocks per downsample. 513 | :param attention_resolutions: a collection of downsample rates at which 514 | attention will take place. May be a set, list, or tuple. 515 | For example, if this contains 4, then at 4x downsampling, attention 516 | will be used. 517 | :param dropout: the dropout probability. 518 | :param channel_mult: channel multiplier for each level of the UNet. 519 | :param conv_resample: if True, use learned convolutions for upsampling and 520 | downsampling. 521 | :param dims: determines if the signal is 1D, 2D, or 3D. 522 | :param num_classes: if specified (as an int), then this model will be 523 | class-conditional with `num_classes` classes. 524 | :param use_checkpoint: use gradient checkpointing to reduce memory usage. 525 | :param num_heads: the number of attention heads in each attention layer. 526 | :param num_heads_channels: if specified, ignore num_heads and instead use 527 | a fixed channel width per attention head. 528 | :param num_heads_upsample: works with num_heads to set a different number 529 | of heads for upsampling. Deprecated. 530 | :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. 531 | :param resblock_updown: use residual blocks for up/downsampling. 532 | :param use_new_attention_order: use a different attention pattern for potentially 533 | increased efficiency. 534 | """ 535 | 536 | def __init__( 537 | self, 538 | image_size, 539 | in_channels, 540 | model_channels, 541 | out_channels, 542 | num_res_blocks, 543 | attention_resolutions, 544 | dropout=0, 545 | channel_mult=(1, 2, 4, 8), 546 | conv_resample=True, 547 | dims=2, 548 | num_classes=None, 549 | use_checkpoint=False, 550 | use_fp16=False, 551 | num_heads=1, 552 | num_head_channels=-1, 553 | num_heads_upsample=-1, 554 | use_scale_shift_norm=False, 555 | resblock_updown=False, 556 | use_new_attention_order=False, 557 | ): 558 | super().__init__() 559 | 560 | if num_heads_upsample == -1: 561 | num_heads_upsample = num_heads 562 | 563 | self.image_size = image_size 564 | self.in_channels = in_channels 565 | self.model_channels = model_channels 566 | self.out_channels = out_channels 567 | self.num_res_blocks = num_res_blocks 568 | self.attention_resolutions = attention_resolutions 569 | self.dropout = dropout 570 | self.channel_mult = channel_mult 571 | self.conv_resample = conv_resample 572 | self.num_classes = num_classes 573 | self.use_checkpoint = use_checkpoint 574 | self.dtype = th.float16 if use_fp16 else th.float32 575 | self.num_heads = num_heads 576 | self.num_head_channels = num_head_channels 577 | self.num_heads_upsample = num_heads_upsample 578 | 579 | time_embed_dim = model_channels * 4 580 | self.time_embed = nn.Sequential( 581 | linear(model_channels, time_embed_dim), 582 | nn.SiLU(), 583 | linear(time_embed_dim, time_embed_dim), 584 | ) 585 | 586 | if self.num_classes is not None: 587 | self.label_emb = nn.Embedding(num_classes, time_embed_dim) 588 | 589 | ch = input_ch = int(channel_mult[0] * model_channels) 590 | self.input_blocks = nn.ModuleList( 591 | [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))] 592 | ) 593 | self._feature_size = ch 594 | input_block_chans = [ch] 595 | ds = 1 596 | for level, mult in enumerate(channel_mult): 597 | for _ in range(num_res_blocks): 598 | layers = [ 599 | ResBlock( 600 | ch, 601 | time_embed_dim, 602 | dropout, 603 | out_channels=int(mult * model_channels), 604 | dims=dims, 605 | use_checkpoint=use_checkpoint, 606 | use_scale_shift_norm=use_scale_shift_norm, 607 | ) 608 | ] 609 | ch = int(mult * model_channels) 610 | if ds in attention_resolutions: 611 | layers.append( 612 | AttentionBlock( 613 | ch, 614 | use_checkpoint=use_checkpoint, 615 | num_heads=num_heads, 616 | num_head_channels=num_head_channels, 617 | use_new_attention_order=use_new_attention_order, 618 | ) 619 | ) 620 | self.input_blocks.append(TimestepEmbedSequential(*layers)) 621 | self._feature_size += ch 622 | input_block_chans.append(ch) 623 | if level != len(channel_mult) - 1: 624 | out_ch = ch 625 | self.input_blocks.append( 626 | TimestepEmbedSequential( 627 | ResBlock( 628 | ch, 629 | time_embed_dim, 630 | dropout, 631 | out_channels=out_ch, 632 | dims=dims, 633 | use_checkpoint=use_checkpoint, 634 | use_scale_shift_norm=use_scale_shift_norm, 635 | down=True, 636 | ) 637 | if resblock_updown 638 | else Downsample( 639 | ch, conv_resample, dims=dims, out_channels=out_ch 640 | ) 641 | ) 642 | ) 643 | ch = out_ch 644 | input_block_chans.append(ch) 645 | ds *= 2 646 | self._feature_size += ch 647 | 648 | self.middle_block = TimestepEmbedSequential( 649 | ResBlock( 650 | ch, 651 | time_embed_dim, 652 | dropout, 653 | dims=dims, 654 | use_checkpoint=use_checkpoint, 655 | use_scale_shift_norm=use_scale_shift_norm, 656 | ), 657 | AttentionBlock( 658 | ch, 659 | use_checkpoint=use_checkpoint, 660 | num_heads=num_heads, 661 | num_head_channels=num_head_channels, 662 | use_new_attention_order=use_new_attention_order, 663 | ), 664 | ResBlock( 665 | ch, 666 | time_embed_dim, 667 | dropout, 668 | dims=dims, 669 | use_checkpoint=use_checkpoint, 670 | use_scale_shift_norm=use_scale_shift_norm, 671 | ), 672 | ) 673 | self._feature_size += ch 674 | 675 | self.output_blocks = nn.ModuleList([]) 676 | for level, mult in list(enumerate(channel_mult))[::-1]: 677 | for i in range(num_res_blocks + 1): 678 | ich = input_block_chans.pop() 679 | layers = [ 680 | ResBlock( 681 | ch + ich, 682 | time_embed_dim, 683 | dropout, 684 | out_channels=int(model_channels * mult), 685 | dims=dims, 686 | use_checkpoint=use_checkpoint, 687 | use_scale_shift_norm=use_scale_shift_norm, 688 | ) 689 | ] 690 | ch = int(model_channels * mult) 691 | if ds in attention_resolutions: 692 | layers.append( 693 | AttentionBlock( 694 | ch, 695 | use_checkpoint=use_checkpoint, 696 | num_heads=num_heads_upsample, 697 | num_head_channels=num_head_channels, 698 | use_new_attention_order=use_new_attention_order, 699 | ) 700 | ) 701 | if level and i == num_res_blocks: 702 | out_ch = ch 703 | layers.append( 704 | ResBlock( 705 | ch, 706 | time_embed_dim, 707 | dropout, 708 | out_channels=out_ch, 709 | dims=dims, 710 | use_checkpoint=use_checkpoint, 711 | use_scale_shift_norm=use_scale_shift_norm, 712 | up=True, 713 | ) 714 | if resblock_updown 715 | else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) 716 | ) 717 | ds //= 2 718 | self.output_blocks.append(TimestepEmbedSequential(*layers)) 719 | self._feature_size += ch 720 | 721 | self.out = nn.Sequential( 722 | normalization(ch), 723 | nn.SiLU(), 724 | zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)), 725 | ) 726 | 727 | def convert_to_fp16(self): 728 | """ 729 | Convert the torso of the model to float16. 730 | """ 731 | self.input_blocks.apply(convert_module_to_f16) 732 | self.middle_block.apply(convert_module_to_f16) 733 | self.output_blocks.apply(convert_module_to_f16) 734 | 735 | def convert_to_fp32(self): 736 | """ 737 | Convert the torso of the model to float32. 738 | """ 739 | self.input_blocks.apply(convert_module_to_f32) 740 | self.middle_block.apply(convert_module_to_f32) 741 | self.output_blocks.apply(convert_module_to_f32) 742 | 743 | def forward(self, x, timesteps, y=None): 744 | """ 745 | Apply the model to an input batch. 746 | 747 | :param x: an [N x C x ...] Tensor of inputs. 748 | :param timesteps: a 1-D batch of timesteps. 749 | :param y: an [N] Tensor of labels, if class-conditional. 750 | :return: an [N x C x ...] Tensor of outputs. 751 | """ 752 | assert (y is not None) == ( 753 | self.num_classes is not None 754 | ), "must specify y if and only if the model is class-conditional" 755 | 756 | hs = [] 757 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) 758 | 759 | if self.num_classes is not None: 760 | assert y.shape == (x.shape[0],) 761 | emb = emb + self.label_emb(y) 762 | 763 | h = x.type(self.dtype) 764 | for module in self.input_blocks: 765 | h = module(h, emb) 766 | hs.append(h) 767 | h = self.middle_block(h, emb) 768 | for module in self.output_blocks: 769 | h = th.cat([h, hs.pop()], dim=1) 770 | h = module(h, emb) 771 | h = h.type(x.dtype) 772 | return self.out(h) 773 | -------------------------------------------------------------------------------- /Network/Diffusion_model_transformer.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | import math 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | # from nnFormer import * 10 | from Network.SwinUnetr import * 11 | # from .fp16_util import convert_module_to_f16, convert_module_to_f32 12 | from Network.util_nn import ( 13 | checkpoint, 14 | conv_nd, 15 | linear, 16 | avg_pool_nd, 17 | zero_module, 18 | normalization, 19 | timestep_embedding, 20 | ) 21 | from monai.utils import ensure_tuple_rep 22 | 23 | class TimestepBlock(nn.Module): 24 | """ 25 | Any module where forward() takes timestep embeddings as a second argument. 26 | """ 27 | 28 | @abstractmethod 29 | def forward(self, x, emb): 30 | """ 31 | Apply the module to `x` given `emb` timestep embeddings. 32 | """ 33 | 34 | class TimestepEmbedSequential(nn.Sequential, TimestepBlock): 35 | """ 36 | A sequential module that passes timestep embeddings to the children that 37 | support it as an extra input. 38 | """ 39 | 40 | def forward(self, x, emb): 41 | for layer in self: 42 | if isinstance(layer, TimestepBlock): 43 | x = layer(x, emb) 44 | else: 45 | x = layer(x) 46 | return x 47 | 48 | class Upsample(nn.Module): 49 | """ 50 | An upsampling layer with an optional convolution. 51 | :param channels: channels in the inputs and outputs. 52 | :param use_conv: a bool determining if a convolution is applied. 53 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 54 | upsampling occurs in the inner-two dimensions. 55 | """ 56 | 57 | def __init__(self, channels, use_conv, sample_kernel, dims=2, out_channels=None): 58 | super().__init__() 59 | self.channels = channels 60 | self.out_channels = out_channels or channels 61 | self.use_conv = use_conv 62 | if dims == 3: 63 | self.sample_kernel=(sample_kernel[0],sample_kernel[1],sample_kernel[2]) 64 | else: 65 | self.sample_kernel=(sample_kernel[0],sample_kernel[1]) 66 | self.dims = dims 67 | if use_conv: 68 | self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) 69 | else: 70 | self.up = th.nn.Upsample(scale_factor=self.sample_kernel,mode='nearest') 71 | self.conv = conv_nd(dims, self.channels, self.channels, 3, padding=1) 72 | 73 | def forward(self, x): 74 | assert x.shape[1] == self.channels 75 | x = self.up(x) 76 | x = self.conv(x) 77 | 78 | # if self.dims == 3: 79 | # x = F.interpolate( 80 | # x, scale_factor=self.sample_kernel, mode="nearest" 81 | # ) 82 | # else: 83 | # x = F.interpolate(x, scale_factor=self.sample_kernel, mode="nearest") 84 | # if self.use_conv: 85 | # x = self.conv(x) 86 | return x 87 | 88 | class Downsample(nn.Module): 89 | """ 90 | A downsampling layer with an optional convolution. 91 | :param channels: channels in the inputs and outputs. 92 | :param use_conv: a bool determining if a convolution is applied. 93 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 94 | downsampling occurs in the inner-two dimensions. 95 | """ 96 | 97 | def __init__(self, channels, use_conv,sample_kernel, dims=2, out_channels=None): 98 | super().__init__() 99 | self.channels = channels 100 | self.out_channels = out_channels or channels 101 | self.use_conv = use_conv 102 | self.dims = dims 103 | if self.dims == 3: 104 | self.sample_kernel = (1/sample_kernel[0],1/sample_kernel[1],1/sample_kernel[2]) 105 | else: 106 | self.sample_kernel = (1/sample_kernel[0],1/sample_kernel[1]) 107 | # stride = 2 if dims != 3 else (2, 2, 2) 108 | # stride = 2 109 | if use_conv: 110 | self.op = th.nn.Upsample(scale_factor=self.sample_kernel,mode='nearest') 111 | else: 112 | assert self.channels == self.out_channels 113 | self.op = th.nn.Upsample(scale_factor=self.sample_kernel,mode='nearest') 114 | self.conv = conv_nd(dims, self.channels, self.channels, 3, padding=1) 115 | 116 | def forward(self, x): 117 | assert x.shape[1] == self.channels 118 | # x = F.interpolate( 119 | # x, scale_factor=self.sample_kernel, mode="nearest" 120 | # ) 121 | return self.conv(self.op(x)) 122 | 123 | 124 | class ResBlock(TimestepBlock): 125 | """ 126 | A residual block that can optionally change the number of channels. 127 | :param channels: the number of input channels. 128 | :param emb_channels: the number of timestep embedding channels. 129 | :param dropout: the rate of dropout. 130 | :param out_channels: if specified, the number of out channels. 131 | :param use_conv: if True and out_channels is specified, use a spatial 132 | convolution instead of a smaller 1x1 convolution to change the 133 | channels in the skip connection. 134 | :param dims: determines if the signal is 1D, 2D, or 3D. 135 | :param use_checkpoint: if True, use gradient checkpointing on this module. 136 | :param up: if True, use this block for upsampling. 137 | :param down: if True, use this block for downsampling. 138 | """ 139 | 140 | def __init__( 141 | self, 142 | channels, 143 | emb_channels, 144 | dropout, 145 | out_channels=None, 146 | use_conv=False, 147 | use_scale_shift_norm=False, 148 | dims=2, 149 | use_checkpoint=False, 150 | up=False, 151 | down=False, 152 | sample_kernel = None, 153 | use_swin = False, 154 | num_heads = 4, 155 | window_size = [4,4,4], 156 | input_resolution = [1,1,1], 157 | drop_path = 0.1 158 | ): 159 | super().__init__() 160 | self.channels = channels 161 | self.emb_channels = emb_channels 162 | self.dropout = dropout 163 | self.out_channels = out_channels or channels 164 | self.use_conv = use_conv 165 | self.use_checkpoint = use_checkpoint 166 | self.use_scale_shift_norm = use_scale_shift_norm 167 | self.input_resolution=input_resolution 168 | self.use_swin=use_swin 169 | self.dims = dims 170 | self.window_size=window_size 171 | self.use_swin=use_swin 172 | self.updown = up or down 173 | 174 | if up: 175 | self.h_upd = Upsample(channels, False,sample_kernel, dims) 176 | self.x_upd = Upsample(channels, False,sample_kernel, dims) 177 | elif down: 178 | self.h_upd = Downsample(channels, False,sample_kernel, dims) 179 | self.x_upd = Downsample(channels, False,sample_kernel, dims) 180 | else: 181 | self.h_upd = self.x_upd = nn.Identity() 182 | 183 | if use_swin: 184 | self.in_layers = nn.Sequential( 185 | normalization(channels), 186 | nn.SiLU(), 187 | conv_nd(dims, channels, self.out_channels, 3, padding=1), 188 | ) 189 | 190 | self.shift_size = tuple(i // 2 for i in window_size) 191 | self.no_shift = tuple(0 for i in window_size) 192 | # self.swin_layer = nn.ModuleList([ 193 | # SwinTransformerBlock( 194 | # dim=self.out_channels, 195 | # input_resolution=self.input_resolution, 196 | # num_heads=num_heads, 197 | # window_size=window_size, 198 | # shift_size=self.no_shift if (i % 2 == 0) else self.shift_size, 199 | # mlp_ratio=4, 200 | # qkv_bias=True, 201 | # qk_scale=None, 202 | # drop=0, 203 | # attn_drop=0, 204 | # drop_path=drop_path, 205 | # norm_layer = nn.LayerNorm) 206 | # for i in range(2)]) 207 | self.swin_layer = nn.ModuleList([SwinTransformerBlock( 208 | dim=self.out_channels, 209 | num_heads=num_heads, 210 | window_size=window_size, 211 | shift_size=self.no_shift if (i % 2 == 0) else self.shift_size, 212 | mlp_ratio=4, 213 | qkv_bias=True, 214 | drop=0, 215 | attn_drop=0, 216 | drop_path=drop_path, 217 | norm_layer=nn.LayerNorm, 218 | use_checkpoint=None) 219 | for i in range(2)]) 220 | self.out_layers = nn.Sequential( 221 | normalization(self.out_channels), 222 | nn.Identity()) 223 | else: 224 | self.in_layers = nn.Sequential( 225 | normalization(channels), 226 | nn.SiLU(), 227 | conv_nd(dims, channels, self.out_channels, 3, padding=1), 228 | ) 229 | self.swin_layer = nn.ModuleList([nn.Identity()]) 230 | self.out_layers = nn.Sequential( 231 | normalization(self.out_channels), 232 | nn.SiLU(), 233 | nn.Dropout(p=0), 234 | zero_module( 235 | conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) 236 | ), 237 | ) 238 | 239 | self.emb_layers = nn.Sequential( 240 | nn.SiLU(), 241 | linear( 242 | emb_channels, 243 | 2 * self.out_channels if use_scale_shift_norm else self.out_channels, 244 | ), 245 | ) 246 | 247 | 248 | if self.out_channels == channels: 249 | self.skip_connection = nn.Identity() 250 | elif use_conv: 251 | self.skip_connection = conv_nd( 252 | dims, channels, self.out_channels, 3, padding=1 253 | ) 254 | else: 255 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) 256 | 257 | def forward(self, x, emb): 258 | """ 259 | Apply the block to a Tensor, conditioned on a timestep embedding. 260 | :param x: an [N x C x ...] Tensor of features. 261 | :param emb: an [N x emb_channels] Tensor of timestep embeddings. 262 | :return: an [N x C x ...] Tensor of outputs. 263 | """ 264 | return checkpoint( 265 | self._forward, (x, emb), self.parameters(), self.use_checkpoint 266 | ) 267 | 268 | def _forward(self, x_in, emb): 269 | if self.updown: 270 | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] 271 | h = in_rest(x_in) 272 | h = self.h_upd(h) 273 | x_in = self.x_upd(x_in) 274 | h = in_conv(h) 275 | else: 276 | h_in = self.in_layers(x_in) 277 | 278 | emb_out = self.emb_layers(emb).type(h_in.dtype) 279 | while len(emb_out.shape) < len(h_in.shape): 280 | emb_out = emb_out[..., None] 281 | if self.use_scale_shift_norm: 282 | out_norm, out_rest = self.out_layers[0], self.out_layers[1:] 283 | scale, shift = th.chunk(emb_out, 2, dim=1) 284 | h_in = out_norm(h_in) * (1 + scale) + shift 285 | 286 | if self.use_swin: 287 | if self.dims == 3: 288 | b, c, d, h, w = x_in.shape 289 | window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size) 290 | h_in = rearrange(h_in, "b c d h w -> b d h w c") 291 | dp = int(np.ceil(d / window_size[0])) * window_size[0] 292 | hp = int(np.ceil(h / window_size[1])) * window_size[1] 293 | wp = int(np.ceil(w / window_size[2])) * window_size[2] 294 | attn_mask = compute_mask([dp, hp, wp], window_size, shift_size, x.device) 295 | for blk in self.swin_layer: 296 | h_in = blk(h_in, attn_mask) 297 | h_in = h_in.view(b, d, h, w, -1) 298 | h_in = rearrange(h_in, "b d h w c -> b c d h w") 299 | 300 | elif self.dims == 2: 301 | b, c, h, w = h_in.shape 302 | window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size) 303 | h_in = rearrange(h_in, "b c h w -> b h w c") 304 | hp = int(np.ceil(h / window_size[0])) * window_size[0] 305 | wp = int(np.ceil(w / window_size[1])) * window_size[1] 306 | attn_mask = compute_mask([hp, wp], window_size, shift_size, h_in.device) 307 | for blk in self.swin_layer: 308 | h_in = blk(h_in, attn_mask) 309 | h_in = h_in.view(b, h, w, -1) 310 | h_in = rearrange(h_in, "b h w c -> b c h w") 311 | else: 312 | for blk in self.swin_layer: 313 | h_in = blk(h_in) 314 | 315 | 316 | 317 | # S, H, W = h.size(2), h.size(3), h.size(4) 318 | # h = h.flatten(2).transpose(1, 2).contiguous() 319 | # for blk in self.swin_layer: 320 | # h = blk(h) 321 | # h = h.transpose(1, 2).contiguous().view(-1, self.out_channels, S, H, W) 322 | 323 | h_in = out_rest(h_in) 324 | else: 325 | h_in = h_in + emb_out 326 | 327 | if self.use_swin: 328 | if self.dims == 3: 329 | b, c, d, h, w = x_in.shape 330 | window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size) 331 | h_in = rearrange(h_in, "b c d h w -> b d h w c") 332 | dp = int(np.ceil(d / window_size[0])) * window_size[0] 333 | hp = int(np.ceil(h / window_size[1])) * window_size[1] 334 | wp = int(np.ceil(w / window_size[2])) * window_size[2] 335 | attn_mask = compute_mask([dp, hp, wp], window_size, shift_size, x.device) 336 | for blk in self.swin_layer: 337 | h_in = blk(h_in, attn_mask) 338 | h_in = h_in.view(b, d, h, w, -1) 339 | h_in = rearrange(h_in, "b d h w c -> b c d h w") 340 | 341 | elif self.dims == 2: 342 | b, c, h, w = h_in.shape 343 | window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size) 344 | h_in = rearrange(h_in, "b c h w -> b h w c") 345 | hp = int(np.ceil(h / window_size[0])) * window_size[0] 346 | wp = int(np.ceil(w / window_size[1])) * window_size[1] 347 | attn_mask = compute_mask([hp, wp], window_size, shift_size, h_in.device) 348 | for blk in self.swin_layer: 349 | h_in = blk(h_in, attn_mask) 350 | h_in = h_in.view(b, h, w, -1) 351 | h_in = rearrange(h_in, "b h w c -> b c h w") 352 | else: 353 | for blk in self.swin_layer: 354 | h_in = blk(h_in) 355 | 356 | h_in = self.out_layers(h_in) 357 | return self.skip_connection(x_in) + h_in 358 | 359 | class SwinVITModel(nn.Module): 360 | """ 361 | The full UNet model with attention and timestep embedding. 362 | :param in_channels: channels in the input Tensor. 363 | :param model_channels: base channel count for the model. 364 | :param out_channels: channels in the output Tensor. 365 | :param num_res_blocks: number of residual blocks per downsample. 366 | :param attention_resolutions: a collection of downsample rates at which 367 | attention will take place. May be a set, list, or tuple. 368 | For example, if this contains 4, then at 4x downsampling, attention 369 | will be used. 370 | :param dropout: the dropout probability. 371 | :param channel_mult: channel multiplier for each level of the UNet. 372 | :param conv_resample: if True, use learned convolutions for upsampling and 373 | downsampling. 374 | :param dims: determines if the signal is 1D, 2D, or 3D. 375 | :param num_classes: if specified (as an int), then this model will be 376 | class-conditional with `num_classes` classes. 377 | :param use_checkpoint: use gradient checkpointing to reduce memory usage. 378 | :param num_heads: the number of attention heads in each attention layer. 379 | :param num_heads_channels: if specified, ignore num_heads and instead use 380 | a fixed channel width per attention head. 381 | :param num_heads_upsample: works with num_heads to set a different number 382 | of heads for upsampling. Deprecated. 383 | :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. 384 | :param resblock_updown: use residual blocks for up/downsampling. 385 | :param use_new_attention_order: use a different attention pattern for potentially 386 | increased efficiency. 387 | """ 388 | 389 | def __init__( 390 | self, 391 | image_size, 392 | in_channels, 393 | model_channels, 394 | out_channels, 395 | num_res_blocks, 396 | attention_resolutions, 397 | dropout=0, 398 | channel_mult=(1, 2, 4, 8), 399 | conv_resample=False, 400 | dims=2, 401 | sample_kernel=None, 402 | num_classes=None, 403 | use_checkpoint=False, 404 | use_fp16=False, 405 | num_heads=1, 406 | window_size = 4, 407 | num_head_channels=-1, 408 | num_heads_upsample=-1, 409 | use_scale_shift_norm=False, 410 | resblock_updown=False, 411 | use_new_attention_order=False, 412 | ): 413 | super().__init__() 414 | 415 | if num_heads_upsample == -1: 416 | num_heads_upsample = num_heads 417 | 418 | self.image_size = image_size 419 | self.in_channels = in_channels 420 | self.model_channels = model_channels 421 | self.out_channels = out_channels 422 | self.num_res_blocks = num_res_blocks 423 | self.attention_resolutions = attention_resolutions 424 | self.dropout = dropout 425 | self.channel_mult = channel_mult 426 | self.conv_resample = conv_resample 427 | self.num_classes = num_classes 428 | self.use_checkpoint = use_checkpoint 429 | self.dtype = th.float16 if use_fp16 else th.float32 430 | self.num_heads = num_heads 431 | self.num_head_channels = num_head_channels 432 | self.num_heads_upsample = num_heads_upsample 433 | self.sample_kernel = sample_kernel[0] 434 | spatial_dims = dims 435 | drop_path = [x.item() for x in th.linspace(0, dropout, len(channel_mult))] 436 | 437 | 438 | time_embed_dim = model_channels * 4 439 | self.time_embed = nn.Sequential( 440 | linear(model_channels, time_embed_dim), 441 | nn.SiLU(), 442 | linear(time_embed_dim, time_embed_dim), 443 | ) 444 | 445 | if self.num_classes is not None: 446 | self.label_emb = nn.Embedding(num_classes, time_embed_dim) 447 | 448 | ch = input_ch = int(channel_mult[0] * model_channels) 449 | self.input_blocks = nn.ModuleList( 450 | [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))] 451 | ) 452 | self._feature_size = ch 453 | input_block_chans = [ch] 454 | ds = image_size 455 | 456 | for level, mult in enumerate(channel_mult): 457 | for _ in range(num_res_blocks[level]): 458 | if ds[0] in attention_resolutions: 459 | use_swin = True 460 | else: 461 | use_swin = False 462 | layers = [ 463 | ResBlock( 464 | ch, 465 | time_embed_dim, 466 | dropout, 467 | out_channels=int(mult * model_channels), 468 | dims=dims, 469 | use_checkpoint=use_checkpoint, 470 | use_scale_shift_norm=use_scale_shift_norm, 471 | use_swin = use_swin, 472 | num_heads = num_heads[level], 473 | window_size = window_size[level], 474 | input_resolution = ds, 475 | drop_path = drop_path[level] 476 | ) 477 | ] 478 | ch = int(mult * model_channels) 479 | # if ds in attention_resolutions: 480 | # layers.append( 481 | # AttentionBlock( 482 | # ch, 483 | # use_checkpoint=use_checkpoint, 484 | # num_heads=num_heads, 485 | # num_head_channels=num_head_channels, 486 | # use_new_attention_order=use_new_attention_order, 487 | # ) 488 | # ) 489 | self.input_blocks.append(TimestepEmbedSequential(*layers)) 490 | self._feature_size += ch 491 | input_block_chans.append(ch) 492 | if level != len(channel_mult) -1: 493 | out_ch = ch 494 | self.input_blocks.append( 495 | TimestepEmbedSequential( 496 | ResBlock( 497 | ch, 498 | time_embed_dim, 499 | dropout, 500 | out_channels=int(mult * model_channels), 501 | dims=dims, 502 | use_checkpoint=use_checkpoint, 503 | use_scale_shift_norm=use_scale_shift_norm, 504 | use_swin = use_swin, 505 | num_heads = num_heads[level], 506 | window_size = window_size[level], 507 | input_resolution = ds, 508 | drop_path = drop_path[level], 509 | down=True, 510 | sample_kernel=self.sample_kernel[level], 511 | ) 512 | if resblock_updown 513 | else Downsample( 514 | ch, conv_resample,self.sample_kernel[level], dims=dims, out_channels=out_ch 515 | ) 516 | ) 517 | ) 518 | ch = out_ch 519 | input_block_chans.append(ch) 520 | if dims == 3: 521 | ds = [ds[0]//self.sample_kernel[level][0],ds[1]//self.sample_kernel[level][1],ds[2]//self.sample_kernel[level][2]] 522 | else: 523 | ds = [ds[0]//self.sample_kernel[level][0],ds[1]//self.sample_kernel[level][1]] 524 | self._feature_size += ch 525 | 526 | self.middle_block = TimestepEmbedSequential( 527 | ResBlock( 528 | ch, 529 | time_embed_dim, 530 | dropout, 531 | out_channels=int(mult * model_channels), 532 | dims=dims, 533 | use_checkpoint=use_checkpoint, 534 | use_scale_shift_norm=use_scale_shift_norm, 535 | use_swin = use_swin, 536 | num_heads = num_heads[level], 537 | window_size = window_size[level], 538 | input_resolution = ds, 539 | drop_path = drop_path[level] 540 | ), 541 | # AttentionBlock( 542 | # ch, 543 | # use_checkpoint=use_checkpoint, 544 | # num_heads=num_heads, 545 | # num_head_channels=num_head_channels, 546 | # use_new_attention_order=use_new_attention_order, 547 | # ), 548 | ResBlock( 549 | ch, 550 | time_embed_dim, 551 | dropout, 552 | out_channels=int(mult * model_channels), 553 | dims=dims, 554 | use_checkpoint=use_checkpoint, 555 | use_scale_shift_norm=use_scale_shift_norm, 556 | use_swin = use_swin, 557 | num_heads = num_heads[level], 558 | window_size = window_size[level], 559 | input_resolution = ds, 560 | drop_path = drop_path[level] 561 | ), 562 | ) 563 | self._feature_size += ch 564 | 565 | self.output_blocks = nn.ModuleList([]) 566 | for level, mult in list(enumerate(channel_mult))[::-1]: 567 | for i in range(num_res_blocks[level] + 1): 568 | ich = input_block_chans.pop() 569 | if ds[0] in attention_resolutions: 570 | use_swin = True 571 | else: 572 | use_swin = False 573 | layers = [ 574 | ResBlock( 575 | ch + ich, 576 | time_embed_dim, 577 | dropout, 578 | out_channels=int(model_channels * mult), 579 | dims=dims, 580 | use_checkpoint=use_checkpoint, 581 | use_scale_shift_norm=use_scale_shift_norm, 582 | use_swin = use_swin, 583 | num_heads = num_heads[level], 584 | window_size = window_size[level], 585 | input_resolution = ds, 586 | drop_path = drop_path[level] 587 | ) 588 | ] 589 | ch = int(model_channels * mult) 590 | # if ds in attention_resolutions: 591 | # layers.append( 592 | # AttentionBlock( 593 | # ch, 594 | # use_checkpoint=use_checkpoint, 595 | # num_heads=num_heads_upsample, 596 | # num_head_channels=num_head_channels, 597 | # use_new_attention_order=use_new_attention_order, 598 | # ) 599 | # ) 600 | if level and i == num_res_blocks[level]: 601 | out_ch = ch 602 | layers.append( 603 | ResBlock( 604 | ch, 605 | time_embed_dim, 606 | dropout, 607 | out_channels=int(model_channels * mult), 608 | dims=dims, 609 | use_checkpoint=use_checkpoint, 610 | use_scale_shift_norm=use_scale_shift_norm, 611 | use_swin = use_swin, 612 | num_heads = num_heads[level], 613 | window_size = window_size[level], 614 | input_resolution = ds, 615 | drop_path = drop_path[level], 616 | up=True, 617 | sample_kernel=self.sample_kernel[level-1], 618 | ) 619 | if resblock_updown 620 | else Upsample(ch, conv_resample,self.sample_kernel[level-1], dims=dims, out_channels=out_ch) 621 | ) 622 | if dims == 3: 623 | ds = [ds[0]*self.sample_kernel[level-1][0], 624 | ds[1]*self.sample_kernel[level-1][1], 625 | ds[2]*self.sample_kernel[level-1][2]] 626 | else: 627 | ds = [ds[0]*self.sample_kernel[level-1][0], 628 | ds[1]*self.sample_kernel[level-1][1]] 629 | self.output_blocks.append(TimestepEmbedSequential(*layers)) 630 | self._feature_size += ch 631 | 632 | self.out = nn.Sequential( 633 | normalization(ch), 634 | nn.SiLU(), 635 | zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)), 636 | ) 637 | 638 | def forward_with_cond_scale( 639 | self, 640 | *args, 641 | cond_scale = 2., 642 | **kwargs 643 | ): 644 | logits = self.forward(*args, null_cond_prob = 0., **kwargs) 645 | if cond_scale == 1 or not self.has_cond: 646 | return logits 647 | 648 | null_logits = self.forward(*args, null_cond_prob = 1., **kwargs) 649 | return null_logits + (logits - null_logits) * cond_scale 650 | 651 | def forward(self, x, timesteps,cond = None,null_cond_prob = 0., y=None): 652 | """ 653 | Apply the model to an input batch. 654 | :param x: an [N x C x ...] Tensor of inputs. 655 | :param timesteps: a 1-D batch of timesteps. 656 | :param y: an [N] Tensor of labels, if class-conditional. 657 | :return: an [N x C x ...] Tensor of outputs. 658 | """ 659 | assert (y is not None) == ( 660 | self.num_classes is not None 661 | ), "must specify y if and only if the model is class-conditional" 662 | 663 | hs = [] 664 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) 665 | 666 | if self.num_classes is not None: 667 | assert y.shape == (x.shape[0],) 668 | emb = emb + self.label_emb(y) 669 | 670 | h = x.type(self.dtype) 671 | for module in self.input_blocks: 672 | h = module(h, emb) 673 | hs.append(h) 674 | h = self.middle_block(h, emb) 675 | for module in self.output_blocks: 676 | h = th.cat([h, hs.pop()], dim=1) 677 | h = module(h, emb) 678 | h = h.type(x.dtype) 679 | return self.out(h) --------------------------------------------------------------------------------