├── .gitignore ├── LICENSE ├── README.md ├── figures └── visual-results.gif ├── main_train.py ├── models ├── aps.py ├── helpers.py ├── losses.py └── misc_models.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Paul Ang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # APS: A deep learning framework for MR-to-CT image synthesis 2 | 3 | This repository is the official implementation of _An Improved Deep Learning 4 | Framework for MR-to-CT Image Synthesis with a New Hybrid Objective Function_. 5 | 6 | ## Getting started 7 | 8 | - Download or clone this repo to your computer. 9 | - Run `pip install -r requirements.txt` to install the required Python packages. 10 | - The code was developed and tested on _Python 3.6.13_ and _Ubuntu 16.04_. 11 | - Note that this codebase runs on [PyTorch Lightning library](https://www.pytorchlightning.ai). 12 | 13 | ## Training and testing the APS framework on your own dataset 14 | 15 | ### Steps 16 | 1. Create a custom Dataset class for your own data. Follow the official 17 | PyTorch [tutorial](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html#creating-a-custom-dataset-for-your-files) 18 | if you are unfamiliar with this step. 19 | 2. Integrate your custom Dataset class into `main_train.get_training_dataloaders()`. 20 | 3. Run `python main_train.py --gpus 1` to start training and testing the APS model on 1 GPU. Execute `python main_train.py --help` for the available input arguments. 21 | 22 | ### Notes 23 | 24 | - Ensure that the data are normalized using min-max formula and resized to the 25 | expected input image size (default is 288x288). The expected input image size 26 | is set via `--image_size` input argument of `main_train.py`. 27 | - The `__getitem__()` of your custom Dataset class has to return a _dict_ with the following keys: 28 | - _ct_: the CT image tensor. 29 | - _in_phase_: the MR in-phase image tensor. 30 | - _ct_min_: the smallest CT's HU value before the min-max normalization. This is needed for 31 | metric computations. 32 | - _ct_max_. the biggest CT's HU value before the min-max normalization. This is needed for 33 | metric computations. 34 | 35 | 36 | ## Visual results 37 | 38 | ![visual-results](figures/visual-results.gif) 39 | 40 | ## Citation 41 | 42 | If you use this code for your research, please cite our paper. 43 | 44 | ``` 45 | @inproceedings{ang2022, 46 | title={An improved deep learning framework for MR-to-CT image synthesis with a new hybrid objective function}, 47 | author={Ang, Sui Paul and Phung, Son Lam and Field, Matthew and Schira, Mark Matthias}, 48 | booktitle={Proceedings of the IEEE International Symposium on Biomedical Imaging}, 49 | year={2022} 50 | } 51 | ``` 52 | 53 | ## Acknowledgement 54 | 55 | Some parts of our code are inspired by [F-LSeSim](https://github.com/lyndonzheng/F-LSeSim) 56 | and [CycleGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix). 57 | -------------------------------------------------------------------------------- /figures/visual-results.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/paul-ang/aps-deep-learning-framework/a8b7e9c82c188d404f0ed6be8454fe89d90ff36a/figures/visual-results.gif -------------------------------------------------------------------------------- /main_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from argparse import ArgumentParser 4 | 5 | import pytorch_lightning as pl 6 | import torch 7 | from pytorch_lightning import seed_everything 8 | from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping 9 | 10 | from models.aps import APS 11 | 12 | 13 | def main(): 14 | # Args 15 | parser = ArgumentParser() 16 | parser = pl.Trainer.add_argparse_args(parser) 17 | # Training 18 | parser.add_argument('--seed', default=555, type=int, 19 | help='Set the random seed.') 20 | parser.add_argument('--lr', default=1e-4, type=float, 21 | help='Learning rate for the optimizers.') 22 | parser.add_argument('--debug', default=0, type=int, 23 | help='Activate debug mode.') 24 | parser.add_argument('--name', default='Default-name', type=str, 25 | help="Name of the experiment folder.") 26 | parser.add_argument('--monitor_loss', default='val_mae', type=str, 27 | help="For early stopping and save best model.") 28 | parser.add_argument('--image_size', default=[288, 288], type=int, 29 | help="image size. (Default = 288x288). ", nargs="+") 30 | parser.add_argument('--batch_size', default=2, type=int, 31 | help='batch size for the train and val dataloaders. ' 32 | 'Test dataloader always use batch size of 1 for ' 33 | 'visualization compatibility.') 34 | parser.add_argument('--num_workers', default=6, type=int, 35 | help='Num workers for the dataloaders.') 36 | parser.add_argument('--saved_weight', default='', type=str, 37 | help="path to the saved weight.") 38 | # APS model specific 39 | parser.add_argument('--lambda_adv', default=1.0, type=float, 40 | help="The hyperparameter L_adv for APS.") 41 | parser.add_argument('--lambda_pix', default=100.0, type=float, 42 | help="The hyperparameter L_pix for APS.") 43 | parser.add_argument('--lambda_str', default=10.0, type=float, 44 | help="The hyperparameter L_str for APS.") 45 | args = parser.parse_args() 46 | 47 | # Setup experiment dir 48 | if args.debug == 1: 49 | # A debug save_dir 50 | save_dir = 'experiments/test-debug' 51 | create_exp_dir(save_dir, visual_folder=True) 52 | else: 53 | # Create an experiment folder for logging and saving model's weights 54 | save_dir = 'experiments/{}-{}'.format( 55 | args.name, 56 | time.strftime("%Y%m%d-%H%M%S")) 57 | create_exp_dir(save_dir, visual_folder=True) 58 | 59 | # Some args logic 60 | if args.debug == 1: 61 | print("Debug mode on.") 62 | args.fast_dev_run = True 63 | args.num_workers = 0 64 | 65 | seed_everything(args.seed) 66 | 67 | # Setup dataloaders 68 | train_loader, val_loader, test_loader = get_training_dataloaders( 69 | batch_size=args.batch_size, num_workers=args.num_workers, 70 | image_size=args.image_size) 71 | 72 | # Setup model 73 | assert args.batch_size % 2 == 0, "APS model only works with even number batch size" 74 | print("APS model") 75 | model = APS(lr=args.lr, input_size=args.image_size, 76 | lambda_adv=args.lambda_adv, lambda_pix=args.lambda_pix, 77 | lambda_str=args.lambda_str) 78 | 79 | # Setup checkpoint callbacks 80 | save_best_model = ModelCheckpoint(monitor=args.monitor_loss, dirpath=save_dir, 81 | filename='best_model', save_top_k=1, 82 | mode='min', save_last=True, verbose=True) 83 | 84 | early_stop = EarlyStopping( 85 | monitor='val_mae', 86 | patience=50, 87 | mode='min', 88 | verbose=True 89 | ) 90 | 91 | # Trainer 92 | trainer = pl.Trainer.from_argparse_args(args, default_root_dir=save_dir, 93 | callbacks=[save_best_model, early_stop]) 94 | 95 | if len(args.saved_weight) == 0: # train 96 | print("Train and test the model.") 97 | trainer.fit(model, train_loader, val_loader) 98 | 99 | trainer.test(test_dataloaders=test_loader, ckpt_path='best') 100 | else: # test 101 | print("Test the model using the saved weight.") 102 | print(f"Using {args.saved_weight}.") 103 | model = model.load_from_checkpoint(args.saved_weight) 104 | 105 | trainer.test(model=model, test_dataloaders=test_loader) 106 | 107 | 108 | def create_exp_dir(path, visual_folder=False): 109 | if not os.path.exists(path): 110 | os.makedirs(path, exist_ok=True) 111 | if visual_folder is True: 112 | os.mkdir(path + '/visual') # for visual results 113 | else: 114 | print("DIR already exists.") 115 | print('Experiment dir : {}'.format(path)) 116 | 117 | 118 | def get_training_dataloaders(batch_size=32, num_workers:int =6, **kwargs): 119 | # Integrate your custom dataset class here. 120 | train_dataset, val_dataset, test_dataset = your_custom_dataset(**kwargs) 121 | 122 | train_loader = torch.utils.data.DataLoader( 123 | train_dataset, 124 | batch_size=batch_size, 125 | num_workers=num_workers, 126 | pin_memory=True, 127 | shuffle=True 128 | ) 129 | 130 | val_loader = torch.utils.data.DataLoader( 131 | val_dataset, 132 | batch_size=batch_size, 133 | num_workers=num_workers, 134 | pin_memory=True, 135 | shuffle=False 136 | ) 137 | 138 | test_loader = torch.utils.data.DataLoader( 139 | test_dataset, 140 | batch_size=1, 141 | shuffle=False, 142 | num_workers=num_workers, 143 | pin_memory=True 144 | ) 145 | 146 | return train_loader, val_loader, test_loader 147 | 148 | 149 | if __name__ == '__main__': 150 | main() -------------------------------------------------------------------------------- /models/aps.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | 6 | from models import losses 7 | from models.misc_models import define_G, define_D, GANLoss 8 | from models.helpers import compute_metrics 9 | 10 | 11 | class APS(pl.LightningModule): 12 | def __init__(self, lr=0.0002, input_size:list =[288, 288], lambda_adv=1.0, 13 | lambda_str=10.0, lambda_pix=100.0, num_patches=256, layers=[4, 7, 9]): 14 | super().__init__() 15 | self.save_hyperparameters() 16 | 17 | # Define generator (ResNet-based) 18 | self.net_G = define_G(input_nc=1, output_nc=1, ngf=64, 19 | netG="resnet_9blocks", 20 | norm="instance", use_dropout=False, 21 | init_type="xavier", init_gain=0.02, 22 | no_antialias=False, 23 | no_antialias_up=False) 24 | 25 | # Define discriminator (PatchGAN) 26 | self.net_D = define_D(input_nc=1, ndf=64, netD='basic', 27 | n_layers_D=3, norm="instance", 28 | init_type="xavier", init_gain=0.02, 29 | no_antialias=False) 30 | 31 | # Define l_str func 32 | self.l_str_fn = StructuralConsistencyLoss(num_patches=num_patches, 33 | patch_size=64, 34 | input_size=input_size, 35 | layers=layers) 36 | 37 | # Define l_adv func 38 | self.l_adv_fn = GANLoss(gan_mode="lsgan") 39 | 40 | print(self.hparams) 41 | 42 | def forward(self, real_mri): 43 | return self.net_G(real_mri) 44 | 45 | def l_pix_fn(self, fake_ct, real_ct, p_patch, ori_hw): 46 | with torch.no_grad(): 47 | p_pix = torch.nn.functional.interpolate(torch.relu(p_patch), 48 | ori_hw, align_corners=False, 49 | mode='bilinear') 50 | p_pix_min = p_pix.flatten(1).min(1)[0].view(-1, 1, 1, 1) # min values for each item in the batch 51 | p_pix_max = p_pix.flatten(1).max(1)[0].view(-1, 1, 1, 1) # max values for each item in the batch 52 | p_pix = (p_pix - p_pix_min) / (p_pix_max - p_pix_min) # minmax normalize 53 | 54 | loss = nn.functional.l1_loss(fake_ct, real_ct, reduction='none') * p_pix 55 | return loss.mean() 56 | 57 | def _se_step(self, real_ct, real_mri): 58 | fake_ct = self(real_mri) 59 | real_ct_reversed = torch.flip(real_ct, [0]) 60 | 61 | x = torch.cat([real_mri, real_mri], dim= 0) 62 | y_hat = torch.cat([fake_ct, real_ct], dim=0) 63 | y_tilda = torch.cat([real_ct_reversed, real_ct_reversed], dim=0) 64 | 65 | l_c = self.l_str_fn(x, y_hat, y_tilda) # this func also computes l_c 66 | 67 | return l_c 68 | 69 | def _gen_step(self, real_ct, real_mri): 70 | fake_ct = self(real_mri) 71 | 72 | # L_adv loss 73 | p_patch = self.net_D(fake_ct) 74 | l_adv = self.l_adv_fn(p_patch, True).mean() * self.hparams.lambda_adv 75 | 76 | # L_pix loss 77 | l_pix = self.l_pix_fn(fake_ct, real_ct, p_patch, real_mri.shape[2:]) * self.hparams.lambda_pix 78 | 79 | # L_str loss 80 | l_str = self.l_str_fn(real_mri, fake_ct, None) * self.hparams.lambda_str 81 | 82 | return l_adv + l_pix + l_str 83 | 84 | def _disc_step(self, real_ct, real_mri): 85 | # Fake 86 | fake_ct = self(real_mri).detach() 87 | fake_logits = self.net_D(fake_ct) 88 | fake_loss = self.l_adv_fn(fake_logits, False).mean() 89 | 90 | # Real 91 | real_logits = self.net_D(real_ct) 92 | real_loss = self.l_adv_fn(real_logits, True).mean() 93 | 94 | l_disc = (fake_loss + real_loss) * 0.5 95 | 96 | return l_disc 97 | 98 | def training_step(self, batch, batch_idx, optimizer_idx): 99 | real_ct, real_mri = batch['ct'], batch['in_phase'] 100 | if optimizer_idx == 0: # train the structure encoder 101 | loss = self._se_step(real_ct, real_mri) 102 | self.log('train_SE_loss', loss, on_epoch=True) 103 | elif optimizer_idx == 1: # train the discriminator 104 | loss = self._disc_step(real_ct, real_mri) 105 | self.log('train_D_loss', loss, on_epoch=True) 106 | elif optimizer_idx == 2: # train the generator 107 | loss = self._gen_step(real_ct, real_mri) 108 | self.log('train_G_loss', loss, on_epoch=True) 109 | 110 | return loss 111 | 112 | def validation_step(self, batch, batch_idx): 113 | ct, mri = batch['ct'], batch['in_phase'] 114 | pred = self(mri) 115 | 116 | with torch.no_grad(): 117 | # Add dim to make it same as the ct shape so it can be broadcasted 118 | ct_min, ct_max = batch['ct_min'], batch['ct_max'] 119 | ct_min = ct_min.view(-1, 1, 1) 120 | ct_max = ct_max.view(-1, 1, 1) 121 | 122 | mets = compute_metrics(pred.squeeze(1).cpu().numpy(), 123 | ct_min.cpu().numpy(), 124 | ct_max.cpu().numpy(), 125 | ct.squeeze(1).cpu().numpy()) 126 | self.log('val_mae', mets['mae'], on_epoch=True) 127 | self.log('val_psnr', mets['psnr'], on_epoch=True) 128 | 129 | return {'mae': mets['mae']} 130 | 131 | def test_step(self, batch, batch_idx): 132 | ct, mri = batch['ct'], batch['in_phase'] 133 | pred = self(mri) 134 | 135 | with torch.no_grad(): 136 | # Add dim to make it same as the ct shape so it can be broadcasted 137 | ct_min, ct_max = batch['ct_min'], batch['ct_max'] 138 | ct_min = ct_min.view(-1, 1, 1) 139 | ct_max = ct_max.view(-1, 1, 1) 140 | 141 | # Compute MAE, PSNR, and save visual results 142 | save_figurename = f"{self.trainer.default_root_dir}/visual/Result {batch_idx}.png" 143 | mets = compute_metrics(pred.squeeze(1).cpu().numpy(), 144 | ct_min.cpu().numpy(), 145 | ct_max.cpu().numpy(), 146 | ct.squeeze(1).cpu().numpy(), 147 | mri.squeeze(1).cpu().numpy(), 148 | create_figure=True, 149 | save_figurename=save_figurename) 150 | self.log('test_mae', mets['mae'], on_epoch=True) 151 | self.log('test_psnr', mets['psnr'], on_epoch=True) 152 | 153 | return {'mae': mets['mae']} 154 | 155 | def configure_optimizers(self): 156 | assert self.l_str_fn.LSeSim.conv_init 157 | lr = self.hparams.lr 158 | 159 | net_D_opt = torch.optim.Adam(self.net_D.parameters(), lr=lr) 160 | net_G_opt = torch.optim.Adam(self.net_G.parameters(), lr=lr) 161 | net_SE_opt = torch.optim.Adam(self.l_str_fn.parameters(), lr=lr) 162 | 163 | return net_SE_opt, net_D_opt, net_G_opt 164 | 165 | 166 | class StructuralConsistencyLoss(nn.Module): 167 | def __init__(self, num_patches=256, patch_size=64, input_size=[288, 288], 168 | layers=[4, 7, 9]): 169 | super().__init__() 170 | 171 | self.structure_encoder = losses.VGG16() 172 | 173 | # Re-use the code from LSeSim 174 | self.LSeSim = losses.SpatialCorrelativeLoss('cos', num_patches, 175 | patch_size, True, True) 176 | 177 | # Run a dummy data to initialize the 1x1 convolution operations 178 | self.layers = layers # layer id to extract features from 179 | dummy_fea = torch.randn([1, 1] + input_size) 180 | self(dummy_fea, dummy_fea, None) 181 | 182 | def forward(self, src, tgt, other=None): 183 | n_layers = len(self.layers) 184 | feats_src = self.structure_encoder(src, self.layers, encode_only=True) 185 | feats_tgt = self.structure_encoder(tgt.float(), self.layers, encode_only=True) 186 | if other is not None: 187 | feats_oth = self.structure_encoder( 188 | torch.flip(other.float(), [2, 3]), self.layers, 189 | encode_only=True) 190 | else: 191 | feats_oth = [None for _ in range(n_layers)] 192 | 193 | total_loss = 0.0 194 | for i, (feat_src, feat_tgt, feat_oth) in enumerate(zip(feats_src, feats_tgt, feats_oth)): 195 | loss = self.LSeSim.loss(feat_src, feat_tgt, feat_oth, i) 196 | total_loss += loss.mean() 197 | 198 | if not self.LSeSim.conv_init: 199 | self.LSeSim.update_init_() 200 | 201 | return total_loss / n_layers -------------------------------------------------------------------------------- /models/helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from skimage.metrics import structural_similarity as ssim 4 | from sklearn.metrics import jaccard_score 5 | 6 | 7 | def compute_mae(fake_ct, real_ct): 8 | mae = np.sum(np.abs(fake_ct - real_ct)) / real_ct.size 9 | 10 | return mae 11 | 12 | 13 | def unscale_image(scaled_image: np.array, original_range: list, scaled_range:list =[0, 1]): 14 | minmax_form = (scaled_image - scaled_range[0]) / (scaled_range[1] - scaled_range[0]) 15 | original_image = minmax_form * (original_range[1] - original_range[0]) + original_range[0] 16 | 17 | return original_image 18 | 19 | 20 | def compute_metrics(fake_ct, ori_ct_min, ori_ct_max, real_ct, real_mri=None, 21 | create_figure=False, save_figurename: str = None, 22 | scale_data=[0, 1]): 23 | ''' 24 | :param fake_ct: numpy array of fake ct with shape [batch, H, W] 25 | :param ori_ct_min: original max ct HU value with shape [batch, 1, 1] 26 | :param ori_ct_max: orignal min ct HU value with shape [batch, 1, 1] 27 | :param real_ct: numpy array of real ct with shape [batch, H, W] 28 | :param real_mri: numpy array of real input image with shape [batch, H, W] (for visualization). 29 | If there is a channel dim and more than one channel, then this func will split into t2_F and t2_W 30 | :param create_figure: create a figure that visualizes the data 31 | :param save_figurename: file path to save the figure in 32 | :param scale_data: the range where the data has been normalized to. Default minmax [0, 1] 33 | :return: ssim, mae, fig 34 | ''' 35 | 36 | # Calculate SSIM 37 | # The data is normalized in [1, 0] range, so data_range = max - min = 1 38 | ssim_score = ssim(real_ct.transpose([1, 2, 0]), 39 | fake_ct.transpose([1, 2, 0]), multichannel=True, 40 | data_range=1) 41 | 42 | # Reverse minmax scaling to compute the evaluation metric 43 | # Unscale 44 | real_ct = unscale_image(real_ct, [ori_ct_min, ori_ct_max], scale_data) 45 | fake_ct = unscale_image(fake_ct, [ori_ct_min, ori_ct_max], scale_data) 46 | 47 | # Calculate MAE 48 | mae = compute_mae(fake_ct, real_ct) 49 | 50 | # Calculate PSNR 51 | mse = np.sum((fake_ct - real_ct) ** 2) / real_ct.size 52 | data_range = real_ct.max() - real_ct.min() 53 | psnr = 10 * np.log10((data_range ** 2) / mse) 54 | 55 | # Tissue specific metric 56 | # Air 57 | air_mask = (real_ct < -400) 58 | air_mae = compute_mae(fake_ct[air_mask], real_ct[air_mask]) 59 | 60 | # Soft tissue 61 | tissue_mask = np.logical_and(real_ct >= -400, real_ct <= 160) 62 | tissue_mae = compute_mae(fake_ct[tissue_mask], real_ct[tissue_mask]) 63 | 64 | # Bone 65 | bone_mask = (real_ct > 160) 66 | bone_mae = compute_mae(fake_ct[bone_mask], real_ct[bone_mask]) 67 | 68 | # Calculate IOC 69 | pred_seg_map = get_seg_map(fake_ct) 70 | gt_seg_map = get_seg_map(real_ct) 71 | mean_iou = jaccard_score(gt_seg_map.flatten(), pred_seg_map.flatten(), 72 | average="macro") 73 | air_iou = jaccard_score(gt_seg_map.flatten(), pred_seg_map.flatten(), 74 | labels=[0], average="macro") 75 | tissue_iou = jaccard_score(gt_seg_map.flatten(), pred_seg_map.flatten(), 76 | labels=[1], average="macro") 77 | bone_iou = jaccard_score(gt_seg_map.flatten(), pred_seg_map.flatten(), 78 | labels=[2], average="macro") 79 | 80 | # Log figure 81 | fig = None 82 | if create_figure: # only compatible when batch=1 83 | assert real_ct.shape[0] == 1, 'batch size is not 1 for visualizing figure.' 84 | assert real_mri is not None, "Input MRI is needed for visualization." 85 | fig = plt.figure(figsize=(10, 10)) 86 | 87 | # MRI 88 | if len(real_mri.shape) == 4 and real_mri.shape[1] == 2: 89 | real_t2_F = real_mri[:, 0, :, :] 90 | real_t2_W = real_mri[:, 1, :, :] 91 | 92 | plt.subplot(3, 2, 1) 93 | plt.imshow(real_t2_F.squeeze(), cmap='gray') 94 | plt.axis('off') 95 | plt.title('Input - T2_F') 96 | 97 | plt.subplot(3, 2, 2) 98 | plt.imshow(real_t2_W.squeeze(), cmap='gray') 99 | plt.axis('off') 100 | plt.title('Input - T2_W') 101 | else: 102 | plt.subplot(3, 2, 1) 103 | plt.imshow(real_mri.squeeze(), cmap='gray') 104 | plt.axis('off') 105 | plt.title('Input - in-phase MRI') 106 | 107 | plt.subplot(3, 2, 3) 108 | plt.imshow(fake_ct.squeeze(), cmap='gray') 109 | plt.axis('off') 110 | plt.title(f'sCT (SSIM: {ssim_score:.4f}, MAE: {mae:.4f})') 111 | 112 | plt.subplot(3, 2, 4) 113 | plt.imshow(real_ct.squeeze(), cmap='gray') 114 | plt.axis('off') 115 | plt.title(f'Real CT') 116 | 117 | plt.subplot(3, 2, 5) 118 | im = plt.imshow((real_ct.squeeze() - fake_ct.squeeze())) 119 | plt.axis('off') 120 | plt.title('Difference map') 121 | fig.colorbar(im) 122 | 123 | if save_figurename is not None: 124 | plt.savefig(save_figurename, dpi=100) 125 | 126 | plt.close(fig) 127 | 128 | # Return metrics 129 | to_return = { 130 | 'ssim': ssim_score, 131 | 'mae': mae, 132 | 'fig': fig, 133 | 'psnr': psnr, 134 | 'air_mae': air_mae, 135 | 'tissue_mae': tissue_mae, 136 | 'bone_mae': bone_mae, 137 | 'iou': mean_iou, 138 | 'bone_iou': bone_iou, 139 | 'air_iou': air_iou, 140 | 'tissue_iou': tissue_iou 141 | } 142 | return to_return 143 | 144 | 145 | def get_seg_map(ct): 146 | seg_map = np.zeros_like(ct) 147 | seg_map[(ct < -400)] = 0 # air 148 | seg_map[(np.logical_and(ct >= -400, ct <= 160))] = 1 # soft tissue 149 | seg_map[(ct > 160)] = 2 # bone 150 | 151 | return seg_map 152 | -------------------------------------------------------------------------------- /models/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision.models as models 5 | import numpy as np 6 | 7 | from models.misc_models import init_weights 8 | 9 | 10 | class GANLoss(nn.Module): 11 | """Define different GAN objectives. 12 | 13 | The GANLoss class abstracts away the need to create the target label tensor 14 | that has the same size as the input. 15 | """ 16 | 17 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): 18 | """ Initialize the GANLoss class. 19 | 20 | Parameters: 21 | gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. 22 | target_real_label (bool) - - label for a real image 23 | target_fake_label (bool) - - label of a fake image 24 | 25 | Note: Do not use sigmoid as the last layer of Discriminator. 26 | LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. 27 | """ 28 | super(GANLoss, self).__init__() 29 | self.register_buffer('real_label', torch.tensor(target_real_label)) 30 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 31 | self.gan_mode = gan_mode 32 | if gan_mode == 'lsgan': 33 | self.loss = nn.MSELoss() 34 | elif gan_mode == 'vanilla': 35 | self.loss = nn.BCEWithLogitsLoss() 36 | elif gan_mode == 'hinge': 37 | self.loss = nn.ReLU() 38 | elif gan_mode in ['wgangp', 'nonsaturating']: 39 | self.loss = None 40 | else: 41 | raise NotImplementedError('gan mode %s not implemented' % gan_mode) 42 | 43 | def get_target_tensor(self, prediction, target_is_real): 44 | """Create label tensors with the same size as the input. 45 | 46 | Parameters: 47 | prediction (tensor) - - tpyically the prediction from a discriminator 48 | target_is_real (bool) - - if the ground truth label is for real images or fake images 49 | 50 | Returns: 51 | A label tensor filled with ground truth label, and with the size of the input 52 | """ 53 | 54 | if target_is_real: 55 | target_tensor = self.real_label 56 | else: 57 | target_tensor = self.fake_label 58 | return target_tensor.expand_as(prediction) 59 | 60 | def calculate_loss(self, prediction, target_is_real, is_dis=False): 61 | """Calculate loss given Discriminator's output and grount truth labels. 62 | 63 | Parameters: 64 | prediction (tensor) - - tpyically the prediction output from a discriminator 65 | target_is_real (bool) - - if the ground truth label is for real images or fake images 66 | 67 | Returns: 68 | the calculated loss. 69 | """ 70 | if self.gan_mode in ['lsgan', 'vanilla']: 71 | target_tensor = self.get_target_tensor(prediction, target_is_real) 72 | loss = self.loss(prediction, target_tensor) 73 | else: 74 | if is_dis: 75 | if target_is_real: 76 | prediction = -prediction 77 | if self.gan_mode == 'wgangp': 78 | loss = prediction.mean() 79 | elif self.gan_mode == 'nonsaturating': 80 | loss = F.softplus(prediction).mean() 81 | elif self.gan_mode == 'hinge': 82 | loss = self.loss(1+prediction).mean() 83 | else: 84 | if self.gan_mode == 'nonsaturating': 85 | loss = F.softplus(-prediction).mean() 86 | else: 87 | loss = -prediction.mean() 88 | return loss 89 | 90 | def __call__(self, predictions, target_is_real, is_dis=False): 91 | """Calculate loss for multi-scales gan""" 92 | if isinstance(predictions, list): 93 | losses = [] 94 | for prediction in predictions: 95 | losses.append(self.calculate_loss(prediction, target_is_real, is_dis)) 96 | loss = sum(losses) 97 | else: 98 | loss = self.calculate_loss(predictions, target_is_real, is_dis) 99 | 100 | return loss 101 | 102 | 103 | def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0): 104 | """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028 105 | 106 | Arguments: 107 | netD (network) -- discriminator network 108 | real_data (tensor array) -- real images 109 | fake_data (tensor array) -- generated images from the generator 110 | device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') 111 | type (str) -- if we mix real and fake data or not [real | fake | mixed]. 112 | constant (float) -- the constant used in formula ( ||gradient||_2 - constant)^2 113 | lambda_gp (float) -- weight for this loss 114 | 115 | Returns the gradient penalty loss 116 | """ 117 | if lambda_gp > 0.0: 118 | if type == 'real': # either use real images, fake images, or a linear interpolation of two. 119 | interpolatesv = real_data 120 | elif type == 'fake': 121 | interpolatesv = fake_data 122 | elif type == 'mixed': 123 | alpha = torch.rand(real_data.shape[0], 1, device=device) 124 | alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape) 125 | interpolatesv = alpha * real_data + ((1 - alpha) * fake_data) 126 | else: 127 | raise NotImplementedError('{} not implemented'.format(type)) 128 | interpolatesv.requires_grad_(True) 129 | disc_interpolates = netD(interpolatesv) 130 | if isinstance(disc_interpolates, list): 131 | gradients = 0 132 | for disc_interpolate in disc_interpolates: 133 | gradients += torch.autograd.grad(outputs=disc_interpolate, inputs=interpolatesv, 134 | grad_outputs=torch.ones(disc_interpolate.size()).to(device), 135 | create_graph=True, retain_graph=True, only_inputs=True)[0] 136 | else: 137 | gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv, 138 | grad_outputs=torch.ones(disc_interpolates.size()).to(device), 139 | create_graph=True, retain_graph=True, only_inputs=True)[0] 140 | gradients = gradients.view(real_data.size(0), -1) # flat the data 141 | gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps 142 | return gradient_penalty, gradients 143 | else: 144 | return 0.0, None 145 | 146 | 147 | class StyleLoss(nn.Module): 148 | r""" 149 | Perceptual loss, VGG-based 150 | https://arxiv.org/abs/1603.08155 151 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py 152 | """ 153 | 154 | def __init__(self): 155 | super(StyleLoss, self).__init__() 156 | self.add_module('vgg', VGG16()) 157 | self.criterion = nn.L1Loss() 158 | 159 | def compute_gram(self, x): 160 | b, ch, h, w = x.size() 161 | f = x.view(b, ch, w * h) 162 | f_T = f.transpose(1, 2) 163 | G = f.bmm(f_T) / (b * h * w * ch) 164 | 165 | return G 166 | 167 | def __call__(self, x, y): 168 | # Compute features 169 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 170 | 171 | # Compute loss 172 | style_loss = 0.0 173 | style_loss += self.criterion(self.compute_gram(x_vgg['relu1_2']), self.compute_gram(y_vgg['relu1_2'])) 174 | style_loss += self.criterion(self.compute_gram(x_vgg['relu2_2']), self.compute_gram(y_vgg['relu2_2'])) 175 | style_loss += self.criterion(self.compute_gram(x_vgg['relu3_3']), self.compute_gram(y_vgg['relu3_3'])) 176 | style_loss += self.criterion(self.compute_gram(x_vgg['relu4_3']), self.compute_gram(y_vgg['relu4_3'])) 177 | 178 | return style_loss 179 | 180 | 181 | class PerceptualLoss(nn.Module): 182 | r""" 183 | Perceptual loss, VGG-based 184 | https://arxiv.org/abs/1603.08155 185 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py 186 | """ 187 | 188 | def __init__(self, weights=[0.0, 0.0, 1.0, 0.0, 0.0]): 189 | super(PerceptualLoss, self).__init__() 190 | self.add_module('vgg', VGG16()) 191 | self.criterion = nn.L1Loss() 192 | self.weights = weights 193 | 194 | def __call__(self, x, y): 195 | # Compute features 196 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 197 | 198 | content_loss = 0.0 199 | content_loss += self.weights[0] * self.criterion(x_vgg['relu1_2'], y_vgg['relu1_2']) if self.weights[0] > 0 else 0 200 | content_loss += self.weights[1] * self.criterion(x_vgg['relu2_2'], y_vgg['relu2_2']) if self.weights[1] > 0 else 0 201 | content_loss += self.weights[2] * self.criterion(x_vgg['relu3_3'], y_vgg['relu3_3']) if self.weights[2] > 0 else 0 202 | content_loss += self.weights[3] * self.criterion(x_vgg['relu4_3'], y_vgg['relu4_3']) if self.weights[3] > 0 else 0 203 | content_loss += self.weights[4] * self.criterion(x_vgg['relu5_3'], y_vgg['relu5_3']) if self.weights[4] > 0 else 0 204 | 205 | return content_loss 206 | 207 | 208 | class PatchSim(nn.Module): 209 | """Calculate the similarity in selected patches""" 210 | def __init__(self, patch_nums=256, patch_size=None, norm=True): 211 | super(PatchSim, self).__init__() 212 | self.patch_nums = patch_nums 213 | self.patch_size = patch_size 214 | self.use_norm = norm 215 | 216 | def forward(self, feat, patch_ids=None): 217 | """ 218 | Calculate the similarity for selected patches 219 | """ 220 | B, C, W, H = feat.size() 221 | feat = feat - feat.mean(dim=[-2, -1], keepdim=True) 222 | feat = F.normalize(feat, dim=1) if self.use_norm else feat / np.sqrt(C) 223 | query, key, patch_ids = self.select_patch(feat, patch_ids=patch_ids) 224 | patch_sim = query.bmm(key) if self.use_norm else torch.tanh(query.bmm(key)/10) 225 | if patch_ids is not None: 226 | patch_sim = patch_sim.view(B, len(patch_ids), -1) 227 | 228 | return patch_sim, patch_ids 229 | 230 | def select_patch(self, feat, patch_ids=None): 231 | """ 232 | Select the patches 233 | """ 234 | B, C, W, H = feat.size() 235 | pw, ph = self.patch_size, self.patch_size 236 | feat_reshape = feat.permute(0, 2, 3, 1).flatten(1, 2) # B*N*C 237 | if self.patch_nums > 0: 238 | if patch_ids is None: 239 | patch_ids = torch.randperm(feat_reshape.size(1), device=feat.device) 240 | patch_ids = patch_ids[:int(min(self.patch_nums, patch_ids.size(0)))] 241 | feat_query = feat_reshape[:, patch_ids, :] # B*Num*C 242 | feat_key = [] 243 | Num = feat_query.size(1) 244 | if pw < W and ph < H: 245 | pos_x, pos_y = patch_ids // W, patch_ids % W 246 | # patch should in the feature 247 | left, top = pos_x - int(pw / 2), pos_y - int(ph / 2) 248 | left, top = torch.where(left > 0, left, torch.zeros_like(left)), torch.where(top > 0, top, torch.zeros_like(top)) 249 | start_x = torch.where(left > (W - pw), (W - pw) * torch.ones_like(left), left) 250 | start_y = torch.where(top > (H - ph), (H - ph) * torch.ones_like(top), top) 251 | for i in range(Num): 252 | feat_key.append(feat[:, :, start_x[i]:start_x[i]+pw, start_y[i]:start_y[i]+ph]) # B*C*patch_w*patch_h 253 | feat_key = torch.stack(feat_key, dim=0).permute(1, 0, 2, 3, 4) # B*Num*C*patch_w*patch_h 254 | feat_key = feat_key.reshape(B * Num, C, pw * ph) # Num * C * N 255 | feat_query = feat_query.reshape(B * Num, 1, C) # Num * 1 * C 256 | else: # if patch larger than features size, use B * C * N (H * W) 257 | feat_key = feat.reshape(B, C, W*H) 258 | else: 259 | feat_query = feat.reshape(B, C, H*W).permute(0, 2, 1) # B * N (H * W) * C 260 | feat_key = feat.reshape(B, C, H*W) # B * C * N (H * W) 261 | 262 | return feat_query, feat_key, patch_ids 263 | 264 | 265 | class SpatialCorrelativeLoss(nn.Module): 266 | """ 267 | learnable patch-based spatially-correlative loss with contrastive learning 268 | """ 269 | def __init__(self, loss_mode='cos', patch_nums=256, patch_size=32, norm=True, use_conv=True, 270 | init_type='normal', init_gain=0.02, gpu_ids=[], T=0.1): 271 | super(SpatialCorrelativeLoss, self).__init__() 272 | self.patch_sim = PatchSim(patch_nums=patch_nums, patch_size=patch_size, norm=norm) 273 | self.patch_size = patch_size 274 | self.patch_nums = patch_nums 275 | self.norm = norm 276 | self.use_conv = use_conv 277 | self.conv_init = False 278 | self.init_type = init_type 279 | self.init_gain = init_gain 280 | self.gpu_ids = gpu_ids 281 | self.loss_mode = loss_mode 282 | self.T = T 283 | self.criterion = nn.L1Loss() if norm else nn.SmoothL1Loss() 284 | self.cross_entropy_loss = nn.CrossEntropyLoss() 285 | 286 | def update_init_(self): 287 | self.conv_init = True 288 | 289 | def create_conv(self, feat, layer): 290 | """ 291 | create the 1*1 conv filter to select the features for a specific task 292 | :param feat: extracted features from a pretrained VGG or encoder for the similarity and dissimilarity map 293 | :param layer: different layers use different filter 294 | :return: 295 | """ 296 | input_nc = feat.size(1) 297 | output_nc = max(32, input_nc // 4) 298 | conv = nn.Sequential(*[nn.Conv2d(input_nc, output_nc, kernel_size=1), 299 | nn.ReLU(), 300 | nn.Conv2d(output_nc, output_nc, kernel_size=1)]) 301 | conv.to(feat.device) 302 | setattr(self, 'conv_%d' % layer, conv) 303 | self.init_net(conv, self.init_type, self.init_gain, self.gpu_ids) 304 | 305 | def init_net(self, net, init_type='normal', init_gain=0.02, gpu_ids=[], 306 | debug=False, initialize_weights=True): 307 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights 308 | Parameters: 309 | net (network) -- the network to be initialized 310 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 311 | gain (float) -- scaling factor for normal, xavier and orthogonal. 312 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 313 | 314 | Return an initialized network. 315 | """ 316 | # if len(gpu_ids) > 0: 317 | # assert (torch.cuda.is_available()) 318 | # net.to(gpu_ids[0]) 319 | if initialize_weights: 320 | init_weights(net, init_type, init_gain=init_gain, debug=debug) 321 | return net 322 | 323 | def cal_sim(self, f_src, f_tgt, f_other=None, layer=0, patch_ids=None): 324 | """ 325 | calculate the similarity map using the fixed/learned query and key 326 | :param f_src: feature map from source domain 327 | :param f_tgt: feature map from target domain 328 | :param f_other: feature map from other image (only used for contrastive learning for spatial network) 329 | :return: 330 | """ 331 | if self.use_conv: 332 | if not self.conv_init: 333 | self.create_conv(f_src, layer) 334 | conv = getattr(self, 'conv_%d' % layer) 335 | f_src, f_tgt = conv(f_src), conv(f_tgt) 336 | f_other = conv(f_other) if f_other is not None else None 337 | sim_src, patch_ids = self.patch_sim(f_src, patch_ids) 338 | sim_tgt, patch_ids = self.patch_sim(f_tgt, patch_ids) 339 | if f_other is not None: 340 | sim_other, _ = self.patch_sim(f_other, patch_ids) 341 | else: 342 | sim_other = None 343 | 344 | return sim_src, sim_tgt, sim_other 345 | 346 | def compare_sim(self, sim_src, sim_tgt, sim_other): 347 | """ 348 | measure the shape distance between the same shape and different inputs 349 | :param sim_src: the shape similarity map from source input image 350 | :param sim_tgt: the shape similarity map from target output image 351 | :param sim_other: the shape similarity map from other input image 352 | :return: 353 | """ 354 | B, Num, N = sim_src.size() 355 | if self.loss_mode == 'info' or sim_other is not None: 356 | sim_src = F.normalize(sim_src, dim=-1) 357 | sim_tgt = F.normalize(sim_tgt, dim=-1) 358 | sim_other = F.normalize(sim_other, dim=-1) 359 | sam_neg1 = (sim_src.bmm(sim_other.permute(0, 2, 1))).view(-1, Num) / self.T 360 | sam_neg2 = (sim_tgt.bmm(sim_other.permute(0, 2, 1))).view(-1, Num) / self.T 361 | sam_self = (sim_src.bmm(sim_tgt.permute(0, 2, 1))).view(-1, Num) / self.T 362 | sam_self = torch.cat([sam_self, sam_neg1, sam_neg2], dim=-1) 363 | loss = self.cross_entropy_loss(sam_self, torch.arange(0, sam_self.size(0), dtype=torch.long, device=sim_src.device) % (Num)) 364 | else: 365 | tgt_sorted, _ = sim_tgt.sort(dim=-1, descending=True) 366 | num = int(N / 4) 367 | src = torch.where(sim_tgt < tgt_sorted[:, :, num:num + 1], 0 * sim_src, sim_src) 368 | tgt = torch.where(sim_tgt < tgt_sorted[:, :, num:num + 1], 0 * sim_tgt, sim_tgt) 369 | if self.loss_mode == 'l1': 370 | loss = self.criterion((N / num) * src, (N / num) * tgt) 371 | elif self.loss_mode == 'cos': 372 | sim_pos = F.cosine_similarity(src, tgt, dim=-1) 373 | loss = self.criterion(torch.ones_like(sim_pos), sim_pos) 374 | else: 375 | raise NotImplementedError('padding [%s] is not implemented' % self.loss_mode) 376 | 377 | return loss 378 | 379 | def loss(self, f_src, f_tgt, f_other=None, layer=0): 380 | """ 381 | calculate the spatial similarity and dissimilarity loss for given features from source and target domain 382 | :param f_src: source domain features 383 | :param f_tgt: target domain features 384 | :param f_other: other random sampled features 385 | :param layer: 386 | :return: 387 | """ 388 | sim_src, sim_tgt, sim_other = self.cal_sim(f_src, f_tgt, f_other, layer) 389 | # calculate the spatial similarity for source and target domain 390 | loss = self.compare_sim(sim_src, sim_tgt, sim_other) 391 | return loss 392 | 393 | 394 | class Normalization(nn.Module): 395 | def __init__(self, device): 396 | super(Normalization, self).__init__() 397 | # .view the mean and std to make them [C x 1 x 1] so that they can 398 | # directly work with image Tensor of shape [B x C x H x W]. 399 | # B is batch size. C is number of channels. H is height and W is width. 400 | mean = torch.tensor([0.485, 0.456, 0.406]).to(device) 401 | std = torch.tensor([0.229, 0.224, 0.225]).to(device) 402 | self.mean = mean.view(-1, 1, 1) 403 | self.std = std.view(-1, 1, 1) 404 | 405 | def forward(self, img): 406 | # normalize img 407 | return (img - self.mean) / self.std 408 | 409 | 410 | class VGG16(nn.Module): 411 | def __init__(self): 412 | super(VGG16, self).__init__() 413 | features = models.vgg16(pretrained=True).features 414 | 415 | self.relu1_1 = torch.nn.Sequential() 416 | self.relu1_2 = torch.nn.Sequential() 417 | 418 | self.relu2_1 = torch.nn.Sequential() 419 | self.relu2_2 = torch.nn.Sequential() 420 | 421 | self.relu3_1 = torch.nn.Sequential() 422 | self.relu3_2 = torch.nn.Sequential() 423 | self.relu3_3 = torch.nn.Sequential() 424 | 425 | self.relu4_1 = torch.nn.Sequential() 426 | self.relu4_2 = torch.nn.Sequential() 427 | self.relu4_3 = torch.nn.Sequential() 428 | 429 | self.relu5_1 = torch.nn.Sequential() 430 | self.relu5_2 = torch.nn.Sequential() 431 | self.relu5_3 = torch.nn.Sequential() 432 | 433 | for x in range(2): 434 | if x == 0: 435 | tmp = nn.Conv2d(1, 64, 3, 1, 1) 436 | self.relu1_1.add_module(str(x), tmp) 437 | else: 438 | self.relu1_1.add_module(str(x), features[x]) 439 | 440 | for x in range(2, 4): 441 | self.relu1_2.add_module(str(x), features[x]) 442 | 443 | for x in range(4, 7): 444 | self.relu2_1.add_module(str(x), features[x]) 445 | 446 | for x in range(7, 9): 447 | self.relu2_2.add_module(str(x), features[x]) 448 | 449 | for x in range(9, 12): 450 | self.relu3_1.add_module(str(x), features[x]) 451 | 452 | for x in range(12, 14): 453 | self.relu3_2.add_module(str(x), features[x]) 454 | 455 | for x in range(14, 16): 456 | self.relu3_3.add_module(str(x), features[x]) 457 | 458 | for x in range(16, 18): 459 | self.relu4_1.add_module(str(x), features[x]) 460 | 461 | for x in range(18, 21): 462 | self.relu4_2.add_module(str(x), features[x]) 463 | 464 | for x in range(21, 23): 465 | self.relu4_3.add_module(str(x), features[x]) 466 | 467 | for x in range(23, 26): 468 | self.relu5_1.add_module(str(x), features[x]) 469 | 470 | for x in range(26, 28): 471 | self.relu5_2.add_module(str(x), features[x]) 472 | 473 | for x in range(28, 30): 474 | self.relu5_3.add_module(str(x), features[x]) 475 | 476 | # don't need the gradients, just want the features 477 | #for param in self.parameters(): 478 | # param.requires_grad = False 479 | 480 | def forward(self, x, layers=None, encode_only=False, resize=False): 481 | relu1_1 = self.relu1_1(x) 482 | relu1_2 = self.relu1_2(relu1_1) 483 | 484 | relu2_1 = self.relu2_1(relu1_2) 485 | relu2_2 = self.relu2_2(relu2_1) 486 | 487 | relu3_1 = self.relu3_1(relu2_2) 488 | relu3_2 = self.relu3_2(relu3_1) 489 | relu3_3 = self.relu3_3(relu3_2) 490 | 491 | relu4_1 = self.relu4_1(relu3_3) 492 | relu4_2 = self.relu4_2(relu4_1) 493 | relu4_3 = self.relu4_3(relu4_2) 494 | 495 | relu5_1 = self.relu5_1(relu4_3) 496 | relu5_2 = self.relu5_2(relu5_1) 497 | relu5_3 = self.relu5_3(relu5_2) 498 | 499 | out = { 500 | 'relu1_1': relu1_1, 501 | 'relu1_2': relu1_2, 502 | 503 | 'relu2_1': relu2_1, 504 | 'relu2_2': relu2_2, 505 | 506 | 'relu3_1': relu3_1, 507 | 'relu3_2': relu3_2, 508 | 'relu3_3': relu3_3, 509 | 510 | 'relu4_1': relu4_1, 511 | 'relu4_2': relu4_2, 512 | 'relu4_3': relu4_3, 513 | 514 | 'relu5_1': relu5_1, 515 | 'relu5_2': relu5_2, 516 | 'relu5_3': relu5_3, 517 | } 518 | if encode_only: 519 | if len(layers) > 0: 520 | feats = [] 521 | for layer, key in enumerate(out): 522 | if layer in layers: 523 | feats.append(out[key]) 524 | return feats 525 | else: 526 | return out['relu3_1'] 527 | return out -------------------------------------------------------------------------------- /models/misc_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import init 5 | import functools 6 | from torch.optim import lr_scheduler 7 | import numpy as np 8 | 9 | 10 | def get_filter(filt_size=3): 11 | if(filt_size == 1): 12 | a = np.array([1., ]) 13 | elif(filt_size == 2): 14 | a = np.array([1., 1.]) 15 | elif(filt_size == 3): 16 | a = np.array([1., 2., 1.]) 17 | elif(filt_size == 4): 18 | a = np.array([1., 3., 3., 1.]) 19 | elif(filt_size == 5): 20 | a = np.array([1., 4., 6., 4., 1.]) 21 | elif(filt_size == 6): 22 | a = np.array([1., 5., 10., 10., 5., 1.]) 23 | elif(filt_size == 7): 24 | a = np.array([1., 6., 15., 20., 15., 6., 1.]) 25 | 26 | filt = torch.Tensor(a[:, None] * a[None, :]) 27 | filt = filt / torch.sum(filt) 28 | 29 | return filt 30 | 31 | 32 | class Downsample(nn.Module): 33 | def __init__(self, channels, pad_type='reflect', filt_size=3, stride=2, pad_off=0): 34 | super(Downsample, self).__init__() 35 | self.filt_size = filt_size 36 | self.pad_off = pad_off 37 | self.pad_sizes = [int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2)), int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2))] 38 | self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes] 39 | self.stride = stride 40 | self.off = int((self.stride - 1) / 2.) 41 | self.channels = channels 42 | 43 | filt = get_filter(filt_size=self.filt_size) 44 | self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1))) 45 | 46 | self.pad = get_pad_layer(pad_type)(self.pad_sizes) 47 | 48 | def forward(self, inp): 49 | if(self.filt_size == 1): 50 | if(self.pad_off == 0): 51 | return inp[:, :, ::self.stride, ::self.stride] 52 | else: 53 | return self.pad(inp)[:, :, ::self.stride, ::self.stride] 54 | else: 55 | return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]) 56 | 57 | 58 | class Upsample2(nn.Module): 59 | def __init__(self, scale_factor, mode='nearest'): 60 | super().__init__() 61 | self.factor = scale_factor 62 | self.mode = mode 63 | 64 | def forward(self, x): 65 | return torch.nn.functional.interpolate(x, scale_factor=self.factor, mode=self.mode) 66 | 67 | 68 | class Upsample(nn.Module): 69 | def __init__(self, channels, pad_type='repl', filt_size=4, stride=2): 70 | super(Upsample, self).__init__() 71 | self.filt_size = filt_size 72 | self.filt_odd = np.mod(filt_size, 2) == 1 73 | self.pad_size = int((filt_size - 1) / 2) 74 | self.stride = stride 75 | self.off = int((self.stride - 1) / 2.) 76 | self.channels = channels 77 | 78 | filt = get_filter(filt_size=self.filt_size) * (stride**2) 79 | self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1))) 80 | 81 | self.pad = get_pad_layer(pad_type)([1, 1, 1, 1]) 82 | 83 | def forward(self, inp): 84 | ret_val = F.conv_transpose2d(self.pad(inp), self.filt, stride=self.stride, padding=1 + self.pad_size, groups=inp.shape[1])[:, :, 1:, 1:] 85 | if(self.filt_odd): 86 | return ret_val 87 | else: 88 | return ret_val[:, :, :-1, :-1] 89 | 90 | 91 | def get_pad_layer(pad_type): 92 | if(pad_type in ['refl', 'reflect']): 93 | PadLayer = nn.ReflectionPad2d 94 | elif(pad_type in ['repl', 'replicate']): 95 | PadLayer = nn.ReplicationPad2d 96 | elif(pad_type == 'zero'): 97 | PadLayer = nn.ZeroPad2d 98 | else: 99 | print('Pad type [%s] not recognized' % pad_type) 100 | return PadLayer 101 | 102 | 103 | class Identity(nn.Module): 104 | def forward(self, x): 105 | return x 106 | 107 | 108 | def get_norm_layer(norm_type='instance'): 109 | """Return a normalization layer 110 | 111 | Parameters: 112 | norm_type (str) -- the name of the normalization layer: batch | instance | none 113 | 114 | For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). 115 | For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. 116 | """ 117 | if norm_type == 'batch': 118 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) 119 | elif norm_type == 'instance': 120 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) 121 | elif norm_type == 'none': 122 | def norm_layer(x): 123 | return Identity() 124 | else: 125 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 126 | return norm_layer 127 | 128 | 129 | def get_scheduler(optimizer, opt): 130 | """Return a learning rate scheduler 131 | 132 | Parameters: 133 | optimizer -- the optimizer of the network 134 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  135 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine 136 | 137 | For 'linear', we keep the same learning rate for the first epochs 138 | and linearly decay the rate to zero over the next epochs. 139 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. 140 | See https://pytorch.org/docs/stable/optim.html for more details. 141 | """ 142 | if opt.lr_policy == 'linear': 143 | def lambda_rule(epoch): 144 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1) 145 | return lr_l 146 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 147 | elif opt.lr_policy == 'step': 148 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 149 | elif opt.lr_policy == 'plateau': 150 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 151 | elif opt.lr_policy == 'cosine': 152 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) 153 | else: 154 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 155 | return scheduler 156 | 157 | 158 | def init_weights(net, init_type='normal', init_gain=0.02, debug=False): 159 | """Initialize network weights. 160 | 161 | Parameters: 162 | net (network) -- network to be initialized 163 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 164 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 165 | 166 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might 167 | work better for some applications. Feel free to try yourself. 168 | """ 169 | def init_func(m): # define the initialization function 170 | classname = m.__class__.__name__ 171 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 172 | if debug: 173 | print(classname) 174 | if init_type == 'normal': 175 | init.normal_(m.weight.data, 0.0, init_gain) 176 | elif init_type == 'xavier': 177 | init.xavier_normal_(m.weight.data, gain=init_gain) 178 | elif init_type == 'kaiming': 179 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 180 | elif init_type == 'orthogonal': 181 | init.orthogonal_(m.weight.data, gain=init_gain) 182 | else: 183 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 184 | if hasattr(m, 'bias') and m.bias is not None: 185 | init.constant_(m.bias.data, 0.0) 186 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 187 | init.normal_(m.weight.data, 1.0, init_gain) 188 | init.constant_(m.bias.data, 0.0) 189 | 190 | net.apply(init_func) # apply the initialization function 191 | 192 | 193 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], debug=False, initialize_weights=True): 194 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights 195 | Parameters: 196 | net (network) -- the network to be initialized 197 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 198 | gain (float) -- scaling factor for normal, xavier and orthogonal. 199 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 200 | 201 | Return an initialized network. 202 | """ 203 | # if len(gpu_ids) > 0: 204 | # assert(torch.cuda.is_available()) 205 | # net.to(gpu_ids[0]) 206 | # if not amp: 207 | # net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs for non-AMP training 208 | if initialize_weights: 209 | init_weights(net, init_type, init_gain=init_gain, debug=debug) 210 | return net 211 | 212 | 213 | def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', 214 | init_gain=0.02, no_antialias=False, no_antialias_up=False, gpu_ids=[], opt=None): 215 | """Create a generator 216 | 217 | Parameters: 218 | input_nc (int) -- the number of channels in input images 219 | output_nc (int) -- the number of channels in output images 220 | ngf (int) -- the number of filters in the last conv layer 221 | netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128 222 | norm (str) -- the name of normalization layers used in the network: batch | instance | none 223 | use_dropout (bool) -- if use dropout layers. 224 | init_type (str) -- the name of our initialization method. 225 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 226 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 227 | 228 | Returns a generator 229 | 230 | Our current implementation provides two types of generators: 231 | U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images) 232 | The original U-Net paper: https://arxiv.org/abs/1505.04597 233 | 234 | Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks) 235 | Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations. 236 | We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style). 237 | 238 | 239 | The generator has been initialized by . It uses RELU for non-linearity. 240 | """ 241 | net = None 242 | norm_layer = get_norm_layer(norm_type=norm) 243 | 244 | if netG == 'resnet_9blocks': 245 | net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, no_antialias=no_antialias, no_antialias_up=no_antialias_up, n_blocks=9, opt=opt) 246 | elif netG == 'resnet_6blocks': 247 | net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, no_antialias=no_antialias, no_antialias_up=no_antialias_up, n_blocks=6, opt=opt) 248 | elif netG == 'resnet_4blocks': 249 | net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, no_antialias=no_antialias, no_antialias_up=no_antialias_up, n_blocks=4, opt=opt) 250 | elif netG == 'unet_128': 251 | net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout) 252 | elif netG == 'unet_256': 253 | net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout) 254 | elif netG == 'resnet_cat': 255 | n_blocks = 8 256 | net = G_Resnet(input_nc, output_nc, opt.nz, num_downs=2, n_res=n_blocks - 4, ngf=ngf, norm='inst', nl_layer='relu') 257 | else: 258 | raise NotImplementedError('Generator model name [%s] is not recognized' % netG) 259 | return init_net(net, init_type, init_gain, gpu_ids, initialize_weights=('stylegan2' not in netG)) 260 | 261 | 262 | def define_F(netF, init_type='normal', init_gain=0.02, gpu_ids=[], netF_nc=256): 263 | if netF == 'global_pool': 264 | net = PoolingF() 265 | elif netF == 'reshape': 266 | net = ReshapeF() 267 | elif netF == 'sample': 268 | net = PatchSampleF(use_mlp=False, init_type=init_type, init_gain=init_gain, gpu_ids=gpu_ids, nc=netF_nc) 269 | elif netF == 'mlp_sample': 270 | net = PatchSampleF(use_mlp=True, init_type=init_type, init_gain=init_gain, gpu_ids=gpu_ids, nc=netF_nc) 271 | elif netF == 'strided_conv': 272 | net = StridedConvF(init_type=init_type, init_gain=init_gain, gpu_ids=gpu_ids) 273 | else: 274 | raise NotImplementedError('projection model name [%s] is not recognized' % netF) 275 | return init_net(net, init_type, init_gain, gpu_ids) 276 | 277 | 278 | def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, no_antialias=False, gpu_ids=[], opt=None): 279 | """Create a discriminator 280 | 281 | Parameters: 282 | input_nc (int) -- the number of channels in input images 283 | ndf (int) -- the number of filters in the first conv layer 284 | netD (str) -- the architecture's name: basic | n_layers | pixel 285 | n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers' 286 | norm (str) -- the type of normalization layers used in the network. 287 | init_type (str) -- the name of the initialization method. 288 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 289 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 290 | 291 | Returns a discriminator 292 | 293 | Our current implementation provides three types of discriminators: 294 | [basic]: 'PatchGAN' classifier described in the original pix2pix paper. 295 | It can classify whether 70×70 overlapping patches are real or fake. 296 | Such a patch-level discriminator architecture has fewer parameters 297 | than a full-image discriminator and can work on arbitrarily-sized images 298 | in a fully convolutional fashion. 299 | 300 | [n_layers]: With this mode, you cna specify the number of conv layers in the discriminator 301 | with the parameter (default=3 as used in [basic] (PatchGAN).) 302 | 303 | [pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not. 304 | It encourages greater color diversity but has no effect on spatial statistics. 305 | 306 | The discriminator has been initialized by . It uses Leaky RELU for non-linearity. 307 | """ 308 | net = None 309 | norm_layer = get_norm_layer(norm_type=norm) 310 | 311 | if netD == 'basic': # default PatchGAN classifier 312 | net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, no_antialias=no_antialias,) 313 | elif netD == 'n_layers': # more options 314 | net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, no_antialias=no_antialias,) 315 | elif netD == 'pixel': # classify if each pixel is real or fake 316 | net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer) 317 | else: 318 | raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD) 319 | return init_net(net, init_type, init_gain, gpu_ids, 320 | initialize_weights=('stylegan2' not in netD)) 321 | 322 | 323 | ############################################################################## 324 | # Classes 325 | ############################################################################## 326 | class GANLoss(nn.Module): 327 | """Define different GAN objectives. 328 | 329 | The GANLoss class abstracts away the need to create the target label tensor 330 | that has the same size as the input. 331 | """ 332 | 333 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): 334 | """ Initialize the GANLoss class. 335 | 336 | Parameters: 337 | gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. 338 | target_real_label (bool) - - label for a real image 339 | target_fake_label (bool) - - label of a fake image 340 | 341 | Note: Do not use sigmoid as the last layer of Discriminator. 342 | LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. 343 | """ 344 | super(GANLoss, self).__init__() 345 | self.register_buffer('real_label', torch.tensor(target_real_label)) 346 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 347 | self.gan_mode = gan_mode 348 | if gan_mode == 'lsgan': 349 | self.loss = nn.MSELoss() 350 | elif gan_mode == 'vanilla': 351 | self.loss = nn.BCEWithLogitsLoss() 352 | elif gan_mode in ['wgangp', 'nonsaturating']: 353 | self.loss = None 354 | else: 355 | raise NotImplementedError('gan mode %s not implemented' % gan_mode) 356 | 357 | def get_target_tensor(self, prediction, target_is_real): 358 | """Create label tensors with the same size as the input. 359 | 360 | Parameters: 361 | prediction (tensor) - - tpyically the prediction from a discriminator 362 | target_is_real (bool) - - if the ground truth label is for real images or fake images 363 | 364 | Returns: 365 | A label tensor filled with ground truth label, and with the size of the input 366 | """ 367 | 368 | if target_is_real: 369 | target_tensor = self.real_label 370 | else: 371 | target_tensor = self.fake_label 372 | return target_tensor.expand_as(prediction) 373 | 374 | def __call__(self, prediction, target_is_real): 375 | """Calculate loss given Discriminator's output and grount truth labels. 376 | 377 | Parameters: 378 | prediction (tensor) - - tpyically the prediction output from a discriminator 379 | target_is_real (bool) - - if the ground truth label is for real images or fake images 380 | 381 | Returns: 382 | the calculated loss. 383 | """ 384 | if self.gan_mode in ['lsgan', 'vanilla']: 385 | if isinstance(prediction, list): # if a list 386 | loss = 0 387 | for prediction_i in prediction: 388 | target_tensor = self.get_target_tensor(prediction_i, target_is_real) 389 | loss += self.loss(prediction_i, target_tensor) 390 | else: 391 | target_tensor = self.get_target_tensor(prediction, target_is_real) 392 | loss = self.loss(prediction, target_tensor) 393 | elif self.gan_mode == 'wgangp': 394 | if target_is_real: 395 | loss = -prediction.mean() 396 | else: 397 | loss = prediction.mean() 398 | elif self.gan_mode == 'nonsaturating': 399 | bs = prediction.size(0) 400 | if target_is_real: 401 | loss = F.softplus(-prediction).view(bs, -1).mean(dim=1) 402 | else: 403 | loss = F.softplus(prediction).view(bs, -1).mean(dim=1) 404 | return loss 405 | 406 | 407 | def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0): 408 | """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028 409 | 410 | Arguments: 411 | netD (network) -- discriminator network 412 | real_data (tensor array) -- real images 413 | fake_data (tensor array) -- generated images from the generator 414 | device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') 415 | type (str) -- if we mix real and fake data or not [real | fake | mixed]. 416 | constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2 417 | lambda_gp (float) -- weight for this loss 418 | 419 | Returns the gradient penalty loss 420 | """ 421 | if lambda_gp > 0.0: 422 | if type == 'real': # either use real images, fake images, or a linear interpolation of two. 423 | interpolatesv = real_data 424 | elif type == 'fake': 425 | interpolatesv = fake_data 426 | elif type == 'mixed': 427 | alpha = torch.rand(real_data.shape[0], 1, device=device) 428 | alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape) 429 | interpolatesv = alpha * real_data + ((1 - alpha) * fake_data) 430 | else: 431 | raise NotImplementedError('{} not implemented'.format(type)) 432 | interpolatesv.requires_grad_(True) 433 | disc_interpolates = netD(interpolatesv) 434 | gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv, 435 | grad_outputs=torch.ones(disc_interpolates.size()).to(device), 436 | create_graph=True, retain_graph=True, only_inputs=True) 437 | gradients = gradients[0].view(real_data.size(0), -1) # flat the data 438 | gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps 439 | return gradient_penalty, gradients 440 | else: 441 | return 0.0, None 442 | 443 | 444 | class Normalize(nn.Module): 445 | 446 | def __init__(self, power=2): 447 | super(Normalize, self).__init__() 448 | self.power = power 449 | 450 | def forward(self, x): 451 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 452 | out = x.div(norm + 1e-7) 453 | return out 454 | 455 | 456 | class PoolingF(nn.Module): 457 | def __init__(self): 458 | super(PoolingF, self).__init__() 459 | model = [nn.AdaptiveMaxPool2d(1)] 460 | self.model = nn.Sequential(*model) 461 | self.l2norm = Normalize(2) 462 | 463 | def forward(self, x): 464 | return self.l2norm(self.model(x)) 465 | 466 | 467 | class ReshapeF(nn.Module): 468 | def __init__(self): 469 | super(ReshapeF, self).__init__() 470 | model = [nn.AdaptiveAvgPool2d(4)] 471 | self.model = nn.Sequential(*model) 472 | self.l2norm = Normalize(2) 473 | 474 | def forward(self, x): 475 | x = self.model(x) 476 | x_reshape = x.permute(0, 2, 3, 1).flatten(0, 2) 477 | return self.l2norm(x_reshape) 478 | 479 | 480 | class StridedConvF(nn.Module): 481 | def __init__(self, init_type='normal', init_gain=0.02, gpu_ids=[]): 482 | super().__init__() 483 | # self.conv1 = nn.Conv2d(256, 128, 3, stride=2) 484 | # self.conv2 = nn.Conv2d(128, 64, 3, stride=1) 485 | self.l2_norm = Normalize(2) 486 | self.mlps = {} 487 | self.moving_averages = {} 488 | self.init_type = init_type 489 | self.init_gain = init_gain 490 | self.gpu_ids = gpu_ids 491 | 492 | def create_mlp(self, x): 493 | C, H = x.shape[1], x.shape[2] 494 | n_down = int(np.rint(np.log2(H / 32))) 495 | mlp = [] 496 | for i in range(n_down): 497 | mlp.append(nn.Conv2d(C, max(C // 2, 64), 3, stride=2)) 498 | mlp.append(nn.ReLU()) 499 | C = max(C // 2, 64) 500 | mlp.append(nn.Conv2d(C, 64, 3)) 501 | mlp = nn.Sequential(*mlp) 502 | init_net(mlp, self.init_type, self.init_gain, self.gpu_ids) 503 | return mlp 504 | 505 | def update_moving_average(self, key, x): 506 | if key not in self.moving_averages: 507 | self.moving_averages[key] = x.detach() 508 | 509 | self.moving_averages[key] = self.moving_averages[key] * 0.999 + x.detach() * 0.001 510 | 511 | def forward(self, x, use_instance_norm=False): 512 | C, H = x.shape[1], x.shape[2] 513 | key = '%d_%d' % (C, H) 514 | if key not in self.mlps: 515 | self.mlps[key] = self.create_mlp(x) 516 | self.add_module("child_%s" % key, self.mlps[key]) 517 | mlp = self.mlps[key] 518 | x = mlp(x) 519 | self.update_moving_average(key, x) 520 | x = x - self.moving_averages[key] 521 | if use_instance_norm: 522 | x = F.instance_norm(x) 523 | return self.l2_norm(x) 524 | 525 | 526 | class PatchSampleF(nn.Module): 527 | def __init__(self, use_mlp=False, init_type='normal', init_gain=0.02, nc=256, gpu_ids=[]): 528 | # potential issues: currently, we use the same patch_ids for multiple images in the batch 529 | super(PatchSampleF, self).__init__() 530 | self.l2norm = Normalize(2) 531 | self.use_mlp = use_mlp 532 | self.nc = nc # hard-coded 533 | self.mlp_init = False 534 | self.init_type = init_type 535 | self.init_gain = init_gain 536 | self.gpu_ids = gpu_ids 537 | 538 | def create_mlp(self, feats): 539 | for mlp_id, feat in enumerate(feats): 540 | if type(feat) is int: 541 | input_nc = feat 542 | else: 543 | input_nc = feat.shape[1] 544 | mlp = nn.Sequential(*[nn.Linear(input_nc, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)]) 545 | # if len(self.gpu_ids) > 0: 546 | # mlp.cuda() 547 | setattr(self, 'mlp_%d' % mlp_id, mlp) 548 | init_net(self, self.init_type, self.init_gain, self.gpu_ids) 549 | self.mlp_init = True 550 | 551 | def forward(self, feats, num_patches=64, patch_ids=None): 552 | return_ids = [] 553 | return_feats = [] 554 | if self.use_mlp and not self.mlp_init: 555 | self.create_mlp(feats) 556 | for feat_id, feat in enumerate(feats): 557 | B, H, W = feat.shape[0], feat.shape[2], feat.shape[3] 558 | feat_reshape = feat.permute(0, 2, 3, 1).flatten(1, 2) 559 | if num_patches > 0: 560 | if patch_ids is not None: 561 | patch_id = patch_ids[feat_id] 562 | else: 563 | patch_id = torch.randperm(feat_reshape.shape[1], device=feats[0].device) 564 | patch_id = patch_id[:int(min(num_patches, patch_id.shape[0]))] # .to(patch_ids.device) 565 | x_sample = feat_reshape[:, patch_id, :].flatten(0, 1) # reshape(-1, x.shape[1]) 566 | else: 567 | x_sample = feat_reshape 568 | patch_id = [] 569 | if self.use_mlp: 570 | mlp = getattr(self, 'mlp_%d' % feat_id) 571 | x_sample = mlp(x_sample) 572 | return_ids.append(patch_id) 573 | x_sample = self.l2norm(x_sample) 574 | 575 | if num_patches == 0: 576 | x_sample = x_sample.permute(0, 2, 1).reshape([B, x_sample.shape[-1], H, W]) 577 | return_feats.append(x_sample) 578 | return return_feats, return_ids 579 | 580 | 581 | class G_Resnet(nn.Module): 582 | def __init__(self, input_nc, output_nc, nz, num_downs, n_res, ngf=64, 583 | norm=None, nl_layer=None): 584 | super(G_Resnet, self).__init__() 585 | n_downsample = num_downs 586 | pad_type = 'reflect' 587 | self.enc_content = ContentEncoder(n_downsample, n_res, input_nc, ngf, norm, nl_layer, pad_type=pad_type) 588 | if nz == 0: 589 | self.dec = Decoder(n_downsample, n_res, self.enc_content.output_dim, output_nc, norm=norm, activ=nl_layer, pad_type=pad_type, nz=nz) 590 | else: 591 | self.dec = Decoder_all(n_downsample, n_res, self.enc_content.output_dim, output_nc, norm=norm, activ=nl_layer, pad_type=pad_type, nz=nz) 592 | 593 | def decode(self, content, style=None): 594 | return self.dec(content, style) 595 | 596 | def forward(self, image, style=None, nce_layers=[], encode_only=False): 597 | content, feats = self.enc_content(image, nce_layers=nce_layers, encode_only=encode_only) 598 | if encode_only: 599 | return feats 600 | else: 601 | images_recon = self.decode(content, style) 602 | if len(nce_layers) > 0: 603 | return images_recon, feats 604 | else: 605 | return images_recon 606 | 607 | ################################################################################## 608 | # Encoder and Decoders 609 | ################################################################################## 610 | 611 | 612 | class E_adaIN(nn.Module): 613 | def __init__(self, input_nc, output_nc=1, nef=64, n_layers=4, 614 | norm=None, nl_layer=None, vae=False): 615 | # style encoder 616 | super(E_adaIN, self).__init__() 617 | self.enc_style = StyleEncoder(n_layers, input_nc, nef, output_nc, norm='none', activ='relu', vae=vae) 618 | 619 | def forward(self, image): 620 | style = self.enc_style(image) 621 | return style 622 | 623 | 624 | class StyleEncoder(nn.Module): 625 | def __init__(self, n_downsample, input_dim, dim, style_dim, norm, activ, vae=False): 626 | super(StyleEncoder, self).__init__() 627 | self.vae = vae 628 | self.model = [] 629 | self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type='reflect')] 630 | for i in range(2): 631 | self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type='reflect')] 632 | dim *= 2 633 | for i in range(n_downsample - 2): 634 | self.model += [Conv2dBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type='reflect')] 635 | self.model += [nn.AdaptiveAvgPool2d(1)] # global average pooling 636 | if self.vae: 637 | self.fc_mean = nn.Linear(dim, style_dim) # , 1, 1, 0) 638 | self.fc_var = nn.Linear(dim, style_dim) # , 1, 1, 0) 639 | else: 640 | self.model += [nn.Conv2d(dim, style_dim, 1, 1, 0)] 641 | 642 | self.model = nn.Sequential(*self.model) 643 | self.output_dim = dim 644 | 645 | def forward(self, x): 646 | if self.vae: 647 | output = self.model(x) 648 | output = output.view(x.size(0), -1) 649 | output_mean = self.fc_mean(output) 650 | output_var = self.fc_var(output) 651 | return output_mean, output_var 652 | else: 653 | return self.model(x).view(x.size(0), -1) 654 | 655 | 656 | class ContentEncoder(nn.Module): 657 | def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type='zero'): 658 | super(ContentEncoder, self).__init__() 659 | self.model = [] 660 | self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type='reflect')] 661 | # downsampling blocks 662 | for i in range(n_downsample): 663 | self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type='reflect')] 664 | dim *= 2 665 | # residual blocks 666 | self.model += [ResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)] 667 | self.model = nn.Sequential(*self.model) 668 | self.output_dim = dim 669 | 670 | def forward(self, x, nce_layers=[], encode_only=False): 671 | if len(nce_layers) > 0: 672 | feat = x 673 | feats = [] 674 | for layer_id, layer in enumerate(self.model): 675 | feat = layer(feat) 676 | if layer_id in nce_layers: 677 | feats.append(feat) 678 | if layer_id == nce_layers[-1] and encode_only: 679 | return None, feats 680 | return feat, feats 681 | else: 682 | return self.model(x), None 683 | 684 | for layer_id, layer in enumerate(self.model): 685 | print(layer_id, layer) 686 | 687 | 688 | class Decoder_all(nn.Module): 689 | def __init__(self, n_upsample, n_res, dim, output_dim, norm='batch', activ='relu', pad_type='zero', nz=0): 690 | super(Decoder_all, self).__init__() 691 | # AdaIN residual blocks 692 | self.resnet_block = ResBlocks(n_res, dim, norm, activ, pad_type=pad_type, nz=nz) 693 | self.n_blocks = 0 694 | # upsampling blocks 695 | for i in range(n_upsample): 696 | block = [Upsample2(scale_factor=2), Conv2dBlock(dim + nz, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type='reflect')] 697 | setattr(self, 'block_{:d}'.format(self.n_blocks), nn.Sequential(*block)) 698 | self.n_blocks += 1 699 | dim //= 2 700 | # use reflection padding in the last conv layer 701 | setattr(self, 'block_{:d}'.format(self.n_blocks), Conv2dBlock(dim + nz, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type='reflect')) 702 | self.n_blocks += 1 703 | 704 | def forward(self, x, y=None): 705 | if y is not None: 706 | output = self.resnet_block(cat_feature(x, y)) 707 | for n in range(self.n_blocks): 708 | block = getattr(self, 'block_{:d}'.format(n)) 709 | if n > 0: 710 | output = block(cat_feature(output, y)) 711 | else: 712 | output = block(output) 713 | return output 714 | 715 | 716 | class Decoder(nn.Module): 717 | def __init__(self, n_upsample, n_res, dim, output_dim, norm='batch', activ='relu', pad_type='zero', nz=0): 718 | super(Decoder, self).__init__() 719 | 720 | self.model = [] 721 | # AdaIN residual blocks 722 | self.model += [ResBlocks(n_res, dim, norm, activ, pad_type=pad_type, nz=nz)] 723 | # upsampling blocks 724 | for i in range(n_upsample): 725 | if i == 0: 726 | input_dim = dim + nz 727 | else: 728 | input_dim = dim 729 | self.model += [Upsample2(scale_factor=2), Conv2dBlock(input_dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type='reflect')] 730 | dim //= 2 731 | # use reflection padding in the last conv layer 732 | self.model += [Conv2dBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type='reflect')] 733 | self.model = nn.Sequential(*self.model) 734 | 735 | def forward(self, x, y=None): 736 | if y is not None: 737 | return self.model(cat_feature(x, y)) 738 | else: 739 | return self.model(x) 740 | 741 | ################################################################################## 742 | # Sequential Models 743 | ################################################################################## 744 | 745 | 746 | class ResBlocks(nn.Module): 747 | def __init__(self, num_blocks, dim, norm='inst', activation='relu', pad_type='zero', nz=0): 748 | super(ResBlocks, self).__init__() 749 | self.model = [] 750 | for i in range(num_blocks): 751 | self.model += [ResBlock(dim, norm=norm, activation=activation, pad_type=pad_type, nz=nz)] 752 | self.model = nn.Sequential(*self.model) 753 | 754 | def forward(self, x): 755 | return self.model(x) 756 | 757 | 758 | ################################################################################## 759 | # Basic Blocks 760 | ################################################################################## 761 | def cat_feature(x, y): 762 | y_expand = y.view(y.size(0), y.size(1), 1, 1).expand( 763 | y.size(0), y.size(1), x.size(2), x.size(3)) 764 | x_cat = torch.cat([x, y_expand], 1) 765 | return x_cat 766 | 767 | 768 | class ResBlock(nn.Module): 769 | def __init__(self, dim, norm='inst', activation='relu', pad_type='zero', nz=0): 770 | super(ResBlock, self).__init__() 771 | 772 | model = [] 773 | model += [Conv2dBlock(dim + nz, dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] 774 | model += [Conv2dBlock(dim, dim + nz, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] 775 | self.model = nn.Sequential(*model) 776 | 777 | def forward(self, x): 778 | residual = x 779 | out = self.model(x) 780 | out += residual 781 | return out 782 | 783 | 784 | class Conv2dBlock(nn.Module): 785 | def __init__(self, input_dim, output_dim, kernel_size, stride, 786 | padding=0, norm='none', activation='relu', pad_type='zero'): 787 | super(Conv2dBlock, self).__init__() 788 | self.use_bias = True 789 | # initialize padding 790 | if pad_type == 'reflect': 791 | self.pad = nn.ReflectionPad2d(padding) 792 | elif pad_type == 'zero': 793 | self.pad = nn.ZeroPad2d(padding) 794 | else: 795 | assert 0, "Unsupported padding type: {}".format(pad_type) 796 | 797 | # initialize normalization 798 | norm_dim = output_dim 799 | if norm == 'batch': 800 | self.norm = nn.BatchNorm2d(norm_dim) 801 | elif norm == 'inst': 802 | self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=False) 803 | elif norm == 'ln': 804 | self.norm = LayerNorm(norm_dim) 805 | elif norm == 'none': 806 | self.norm = None 807 | else: 808 | assert 0, "Unsupported normalization: {}".format(norm) 809 | 810 | # initialize activation 811 | if activation == 'relu': 812 | self.activation = nn.ReLU(inplace=True) 813 | elif activation == 'lrelu': 814 | self.activation = nn.LeakyReLU(0.2, inplace=True) 815 | elif activation == 'prelu': 816 | self.activation = nn.PReLU() 817 | elif activation == 'selu': 818 | self.activation = nn.SELU(inplace=True) 819 | elif activation == 'tanh': 820 | self.activation = nn.Tanh() 821 | elif activation == 'none': 822 | self.activation = None 823 | else: 824 | assert 0, "Unsupported activation: {}".format(activation) 825 | 826 | # initialize convolution 827 | self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) 828 | 829 | def forward(self, x): 830 | x = self.conv(self.pad(x)) 831 | if self.norm: 832 | x = self.norm(x) 833 | if self.activation: 834 | x = self.activation(x) 835 | return x 836 | 837 | 838 | class LinearBlock(nn.Module): 839 | def __init__(self, input_dim, output_dim, norm='none', activation='relu'): 840 | super(LinearBlock, self).__init__() 841 | use_bias = True 842 | # initialize fully connected layer 843 | self.fc = nn.Linear(input_dim, output_dim, bias=use_bias) 844 | 845 | # initialize normalization 846 | norm_dim = output_dim 847 | if norm == 'batch': 848 | self.norm = nn.BatchNorm1d(norm_dim) 849 | elif norm == 'inst': 850 | self.norm = nn.InstanceNorm1d(norm_dim) 851 | elif norm == 'ln': 852 | self.norm = LayerNorm(norm_dim) 853 | elif norm == 'none': 854 | self.norm = None 855 | else: 856 | assert 0, "Unsupported normalization: {}".format(norm) 857 | 858 | # initialize activation 859 | if activation == 'relu': 860 | self.activation = nn.ReLU(inplace=True) 861 | elif activation == 'lrelu': 862 | self.activation = nn.LeakyReLU(0.2, inplace=True) 863 | elif activation == 'prelu': 864 | self.activation = nn.PReLU() 865 | elif activation == 'selu': 866 | self.activation = nn.SELU(inplace=True) 867 | elif activation == 'tanh': 868 | self.activation = nn.Tanh() 869 | elif activation == 'none': 870 | self.activation = None 871 | else: 872 | assert 0, "Unsupported activation: {}".format(activation) 873 | 874 | def forward(self, x): 875 | out = self.fc(x) 876 | if self.norm: 877 | out = self.norm(out) 878 | if self.activation: 879 | out = self.activation(out) 880 | return out 881 | 882 | ################################################################################## 883 | # Normalization layers 884 | ################################################################################## 885 | 886 | 887 | class LayerNorm(nn.Module): 888 | def __init__(self, num_features, eps=1e-5, affine=True): 889 | super(LayerNorm, self).__init__() 890 | self.num_features = num_features 891 | self.affine = affine 892 | self.eps = eps 893 | 894 | if self.affine: 895 | self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_()) 896 | self.beta = nn.Parameter(torch.zeros(num_features)) 897 | 898 | def forward(self, x): 899 | shape = [-1] + [1] * (x.dim() - 1) 900 | mean = x.view(x.size(0), -1).mean(1).view(*shape) 901 | std = x.view(x.size(0), -1).std(1).view(*shape) 902 | x = (x - mean) / (std + self.eps) 903 | 904 | if self.affine: 905 | shape = [1, -1] + [1] * (x.dim() - 2) 906 | x = x * self.gamma.view(*shape) + self.beta.view(*shape) 907 | return x 908 | 909 | 910 | class ResnetGenerator(nn.Module): 911 | """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations. 912 | 913 | We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) 914 | """ 915 | 916 | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, no_antialias_up=False, opt=None): 917 | """Construct a Resnet-based generator 918 | 919 | Parameters: 920 | input_nc (int) -- the number of channels in input images 921 | output_nc (int) -- the number of channels in output images 922 | ngf (int) -- the number of filters in the last conv layer 923 | norm_layer -- normalization layer 924 | use_dropout (bool) -- if use dropout layers 925 | n_blocks (int) -- the number of ResNet blocks 926 | padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero 927 | """ 928 | assert(n_blocks >= 0) 929 | super(ResnetGenerator, self).__init__() 930 | self.opt = opt 931 | if type(norm_layer) == functools.partial: 932 | use_bias = norm_layer.func == nn.InstanceNorm2d 933 | else: 934 | use_bias = norm_layer == nn.InstanceNorm2d 935 | 936 | model = [nn.ReflectionPad2d(3), 937 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), 938 | norm_layer(ngf), 939 | nn.ReLU(True)] 940 | 941 | n_downsampling = 2 942 | for i in range(n_downsampling): # add downsampling layers 943 | mult = 2 ** i 944 | if(no_antialias): 945 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), 946 | norm_layer(ngf * mult * 2), 947 | nn.ReLU(True)] 948 | else: 949 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=1, padding=1, bias=use_bias), 950 | norm_layer(ngf * mult * 2), 951 | nn.ReLU(True), 952 | Downsample(ngf * mult * 2)] 953 | 954 | mult = 2 ** n_downsampling 955 | for i in range(n_blocks): # add ResNet blocks 956 | 957 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] 958 | 959 | for i in range(n_downsampling): # add upsampling layers 960 | mult = 2 ** (n_downsampling - i) 961 | if no_antialias_up: 962 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), 963 | kernel_size=3, stride=2, 964 | padding=1, output_padding=1, 965 | bias=use_bias), 966 | norm_layer(int(ngf * mult / 2)), 967 | nn.ReLU(True)] 968 | else: 969 | model += [Upsample(ngf * mult), 970 | nn.Conv2d(ngf * mult, int(ngf * mult / 2), 971 | kernel_size=3, stride=1, 972 | padding=1, # output_padding=1, 973 | bias=use_bias), 974 | norm_layer(int(ngf * mult / 2)), 975 | nn.ReLU(True)] 976 | model += [nn.ReflectionPad2d(3)] 977 | model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] 978 | model += [nn.Sigmoid()] 979 | 980 | self.model = nn.Sequential(*model) 981 | 982 | def forward(self, input, layers=[], encode_only=False): 983 | if -1 in layers: 984 | layers.append(len(self.model)) 985 | if len(layers) > 0: 986 | feat = input 987 | feats = [] 988 | for layer_id, layer in enumerate(self.model): 989 | # print(layer_id, layer) 990 | feat = layer(feat) 991 | if layer_id in layers: 992 | # print("%d: adding the output of %s %d" % (layer_id, layer.__class__.__name__, feat.size(1))) 993 | feats.append(feat) 994 | else: 995 | # print("%d: skipping %s %d" % (layer_id, layer.__class__.__name__, feat.size(1))) 996 | pass 997 | if layer_id == layers[-1] and encode_only: 998 | # print('encoder only return features') 999 | return feats # return intermediate features alone; stop in the last layers 1000 | 1001 | return feat, feats # return both output and intermediate features 1002 | else: 1003 | """Standard forward""" 1004 | fake = self.model(input) 1005 | return fake 1006 | 1007 | 1008 | class ResnetDecoder(nn.Module): 1009 | """Resnet-based decoder that consists of a few Resnet blocks + a few upsampling operations. 1010 | """ 1011 | 1012 | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False): 1013 | """Construct a Resnet-based decoder 1014 | 1015 | Parameters: 1016 | input_nc (int) -- the number of channels in input images 1017 | output_nc (int) -- the number of channels in output images 1018 | ngf (int) -- the number of filters in the last conv layer 1019 | norm_layer -- normalization layer 1020 | use_dropout (bool) -- if use dropout layers 1021 | n_blocks (int) -- the number of ResNet blocks 1022 | padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero 1023 | """ 1024 | assert(n_blocks >= 0) 1025 | super(ResnetDecoder, self).__init__() 1026 | if type(norm_layer) == functools.partial: 1027 | use_bias = norm_layer.func == nn.InstanceNorm2d 1028 | else: 1029 | use_bias = norm_layer == nn.InstanceNorm2d 1030 | model = [] 1031 | n_downsampling = 2 1032 | mult = 2 ** n_downsampling 1033 | for i in range(n_blocks): # add ResNet blocks 1034 | 1035 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] 1036 | 1037 | for i in range(n_downsampling): # add upsampling layers 1038 | mult = 2 ** (n_downsampling - i) 1039 | if(no_antialias): 1040 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), 1041 | kernel_size=3, stride=2, 1042 | padding=1, output_padding=1, 1043 | bias=use_bias), 1044 | norm_layer(int(ngf * mult / 2)), 1045 | nn.ReLU(True)] 1046 | else: 1047 | model += [Upsample(ngf * mult), 1048 | nn.Conv2d(ngf * mult, int(ngf * mult / 2), 1049 | kernel_size=3, stride=1, 1050 | padding=1, 1051 | bias=use_bias), 1052 | norm_layer(int(ngf * mult / 2)), 1053 | nn.ReLU(True)] 1054 | model += [nn.ReflectionPad2d(3)] 1055 | model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] 1056 | model += [nn.Tanh()] 1057 | 1058 | self.model = nn.Sequential(*model) 1059 | 1060 | def forward(self, input): 1061 | """Standard forward""" 1062 | return self.model(input) 1063 | 1064 | 1065 | class ResnetEncoder(nn.Module): 1066 | """Resnet-based encoder that consists of a few downsampling + several Resnet blocks 1067 | """ 1068 | 1069 | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False): 1070 | """Construct a Resnet-based encoder 1071 | 1072 | Parameters: 1073 | input_nc (int) -- the number of channels in input images 1074 | output_nc (int) -- the number of channels in output images 1075 | ngf (int) -- the number of filters in the last conv layer 1076 | norm_layer -- normalization layer 1077 | use_dropout (bool) -- if use dropout layers 1078 | n_blocks (int) -- the number of ResNet blocks 1079 | padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero 1080 | """ 1081 | assert(n_blocks >= 0) 1082 | super(ResnetEncoder, self).__init__() 1083 | if type(norm_layer) == functools.partial: 1084 | use_bias = norm_layer.func == nn.InstanceNorm2d 1085 | else: 1086 | use_bias = norm_layer == nn.InstanceNorm2d 1087 | 1088 | model = [nn.ReflectionPad2d(3), 1089 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), 1090 | norm_layer(ngf), 1091 | nn.ReLU(True)] 1092 | 1093 | n_downsampling = 2 1094 | for i in range(n_downsampling): # add downsampling layers 1095 | mult = 2 ** i 1096 | if(no_antialias): 1097 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), 1098 | norm_layer(ngf * mult * 2), 1099 | nn.ReLU(True)] 1100 | else: 1101 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=1, padding=1, bias=use_bias), 1102 | norm_layer(ngf * mult * 2), 1103 | nn.ReLU(True), 1104 | Downsample(ngf * mult * 2)] 1105 | 1106 | mult = 2 ** n_downsampling 1107 | for i in range(n_blocks): # add ResNet blocks 1108 | 1109 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] 1110 | 1111 | self.model = nn.Sequential(*model) 1112 | 1113 | def forward(self, input): 1114 | """Standard forward""" 1115 | return self.model(input) 1116 | 1117 | 1118 | class ResnetBlock(nn.Module): 1119 | """Define a Resnet block""" 1120 | 1121 | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): 1122 | """Initialize the Resnet block 1123 | 1124 | A resnet block is a conv block with skip connections 1125 | We construct a conv block with build_conv_block function, 1126 | and implement skip connections in function. 1127 | Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf 1128 | """ 1129 | super(ResnetBlock, self).__init__() 1130 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) 1131 | 1132 | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): 1133 | """Construct a convolutional block. 1134 | 1135 | Parameters: 1136 | dim (int) -- the number of channels in the conv layer. 1137 | padding_type (str) -- the name of padding layer: reflect | replicate | zero 1138 | norm_layer -- normalization layer 1139 | use_dropout (bool) -- if use dropout layers. 1140 | use_bias (bool) -- if the conv layer uses bias or not 1141 | 1142 | Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU)) 1143 | """ 1144 | conv_block = [] 1145 | p = 0 1146 | if padding_type == 'reflect': 1147 | conv_block += [nn.ReflectionPad2d(1)] 1148 | elif padding_type == 'replicate': 1149 | conv_block += [nn.ReplicationPad2d(1)] 1150 | elif padding_type == 'zero': 1151 | p = 1 1152 | else: 1153 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 1154 | 1155 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)] 1156 | if use_dropout: 1157 | conv_block += [nn.Dropout(0.5)] 1158 | 1159 | p = 0 1160 | if padding_type == 'reflect': 1161 | conv_block += [nn.ReflectionPad2d(1)] 1162 | elif padding_type == 'replicate': 1163 | conv_block += [nn.ReplicationPad2d(1)] 1164 | elif padding_type == 'zero': 1165 | p = 1 1166 | else: 1167 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 1168 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] 1169 | 1170 | return nn.Sequential(*conv_block) 1171 | 1172 | def forward(self, x): 1173 | """Forward function (with skip connections)""" 1174 | out = x + self.conv_block(x) # add skip connections 1175 | return out 1176 | 1177 | 1178 | class UnetGenerator(nn.Module): 1179 | """Create a Unet-based generator""" 1180 | 1181 | def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False): 1182 | """Construct a Unet generator 1183 | Parameters: 1184 | input_nc (int) -- the number of channels in input images 1185 | output_nc (int) -- the number of channels in output images 1186 | num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, 1187 | image of size 128x128 will become of size 1x1 # at the bottleneck 1188 | ngf (int) -- the number of filters in the last conv layer 1189 | norm_layer -- normalization layer 1190 | 1191 | We construct the U-Net from the innermost layer to the outermost layer. 1192 | It is a recursive process. 1193 | """ 1194 | super(UnetGenerator, self).__init__() 1195 | # construct unet structure 1196 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer 1197 | for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters 1198 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) 1199 | # gradually reduce the number of filters from ngf * 8 to ngf 1200 | unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 1201 | unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 1202 | unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 1203 | self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer 1204 | 1205 | def forward(self, input): 1206 | """Standard forward""" 1207 | return self.model(input) 1208 | 1209 | 1210 | class UnetSkipConnectionBlock(nn.Module): 1211 | """Defines the Unet submodule with skip connection. 1212 | X -------------------identity---------------------- 1213 | |-- downsampling -- |submodule| -- upsampling --| 1214 | """ 1215 | 1216 | def __init__(self, outer_nc, inner_nc, input_nc=None, 1217 | submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): 1218 | """Construct a Unet submodule with skip connections. 1219 | 1220 | Parameters: 1221 | outer_nc (int) -- the number of filters in the outer conv layer 1222 | inner_nc (int) -- the number of filters in the inner conv layer 1223 | input_nc (int) -- the number of channels in input images/features 1224 | submodule (UnetSkipConnectionBlock) -- previously defined submodules 1225 | outermost (bool) -- if this module is the outermost module 1226 | innermost (bool) -- if this module is the innermost module 1227 | norm_layer -- normalization layer 1228 | use_dropout (bool) -- if use dropout layers. 1229 | """ 1230 | super(UnetSkipConnectionBlock, self).__init__() 1231 | self.outermost = outermost 1232 | if type(norm_layer) == functools.partial: 1233 | use_bias = norm_layer.func == nn.InstanceNorm2d 1234 | else: 1235 | use_bias = norm_layer == nn.InstanceNorm2d 1236 | if input_nc is None: 1237 | input_nc = outer_nc 1238 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, 1239 | stride=2, padding=1, bias=use_bias) 1240 | downrelu = nn.LeakyReLU(0.2, True) 1241 | downnorm = norm_layer(inner_nc) 1242 | uprelu = nn.ReLU(True) 1243 | upnorm = norm_layer(outer_nc) 1244 | 1245 | if outermost: 1246 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 1247 | kernel_size=4, stride=2, 1248 | padding=1) 1249 | down = [downconv] 1250 | up = [uprelu, upconv, nn.Tanh()] 1251 | model = down + [submodule] + up 1252 | elif innermost: 1253 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc, 1254 | kernel_size=4, stride=2, 1255 | padding=1, bias=use_bias) 1256 | down = [downrelu, downconv] 1257 | up = [uprelu, upconv, upnorm] 1258 | model = down + up 1259 | else: 1260 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 1261 | kernel_size=4, stride=2, 1262 | padding=1, bias=use_bias) 1263 | down = [downrelu, downconv, downnorm] 1264 | up = [uprelu, upconv, upnorm] 1265 | 1266 | if use_dropout: 1267 | model = down + [submodule] + up + [nn.Dropout(0.5)] 1268 | else: 1269 | model = down + [submodule] + up 1270 | 1271 | self.model = nn.Sequential(*model) 1272 | 1273 | def forward(self, x): 1274 | if self.outermost: 1275 | return self.model(x) 1276 | else: # add skip connections 1277 | return torch.cat([x, self.model(x)], 1) 1278 | 1279 | 1280 | class NLayerDiscriminator(nn.Module): 1281 | """Defines a PatchGAN discriminator""" 1282 | 1283 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, no_antialias=False): 1284 | """Construct a PatchGAN discriminator 1285 | 1286 | Parameters: 1287 | input_nc (int) -- the number of channels in input images 1288 | ndf (int) -- the number of filters in the last conv layer 1289 | n_layers (int) -- the number of conv layers in the discriminator 1290 | norm_layer -- normalization layer 1291 | """ 1292 | super(NLayerDiscriminator, self).__init__() 1293 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 1294 | use_bias = norm_layer.func == nn.InstanceNorm2d 1295 | else: 1296 | use_bias = norm_layer == nn.InstanceNorm2d 1297 | 1298 | kw = 4 1299 | padw = 1 1300 | if(no_antialias): 1301 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 1302 | else: 1303 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=1, padding=padw), nn.LeakyReLU(0.2, True), Downsample(ndf)] 1304 | nf_mult = 1 1305 | nf_mult_prev = 1 1306 | for n in range(1, n_layers): # gradually increase the number of filters 1307 | nf_mult_prev = nf_mult 1308 | nf_mult = min(2 ** n, 8) 1309 | if(no_antialias): 1310 | sequence += [ 1311 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 1312 | norm_layer(ndf * nf_mult), 1313 | nn.LeakyReLU(0.2, True) 1314 | ] 1315 | else: 1316 | sequence += [ 1317 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 1318 | norm_layer(ndf * nf_mult), 1319 | nn.LeakyReLU(0.2, True), 1320 | Downsample(ndf * nf_mult)] 1321 | 1322 | nf_mult_prev = nf_mult 1323 | nf_mult = min(2 ** n_layers, 8) 1324 | sequence += [ 1325 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 1326 | norm_layer(ndf * nf_mult), 1327 | nn.LeakyReLU(0.2, True) 1328 | ] 1329 | 1330 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 1331 | self.model = nn.Sequential(*sequence) 1332 | 1333 | def forward(self, input, layers: list = []): 1334 | """Standard forward.""" 1335 | if len(layers) > 0: 1336 | feat = input 1337 | feats = [] 1338 | for layer_id, layer in enumerate(self.model): 1339 | # print(layer_id, layer) 1340 | feat = layer(feat) 1341 | if layer_id in layers: 1342 | # print("%d: adding the output of %s %d" % (layer_id, layer.__class__.__name__, feat.size(1))) 1343 | feats.append(feat) 1344 | else: 1345 | # print("%d: skipping %s %d" % (layer_id, layer.__class__.__name__, feat.size(1))) 1346 | pass 1347 | if layer_id == layers[-1]: 1348 | # print('encoder only return features') 1349 | return feats # return intermediate features alone; stop in the last layers 1350 | 1351 | return feat, feats # return both output and intermediate features 1352 | else: 1353 | return self.model(input) 1354 | 1355 | 1356 | class PixelDiscriminator(nn.Module): 1357 | """Defines a 1x1 PatchGAN discriminator (pixelGAN)""" 1358 | 1359 | def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d): 1360 | """Construct a 1x1 PatchGAN discriminator 1361 | 1362 | Parameters: 1363 | input_nc (int) -- the number of channels in input images 1364 | ndf (int) -- the number of filters in the last conv layer 1365 | norm_layer -- normalization layer 1366 | """ 1367 | super(PixelDiscriminator, self).__init__() 1368 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 1369 | use_bias = norm_layer.func == nn.InstanceNorm2d 1370 | else: 1371 | use_bias = norm_layer == nn.InstanceNorm2d 1372 | 1373 | self.net = [ 1374 | nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), 1375 | nn.LeakyReLU(0.2, True), 1376 | nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), 1377 | norm_layer(ndf * 2), 1378 | nn.LeakyReLU(0.2, True), 1379 | nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)] 1380 | 1381 | self.net = nn.Sequential(*self.net) 1382 | 1383 | def forward(self, input): 1384 | """Standard forward.""" 1385 | return self.net(input) 1386 | 1387 | 1388 | class PatchDiscriminator(NLayerDiscriminator): 1389 | """Defines a PatchGAN discriminator""" 1390 | 1391 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, no_antialias=False): 1392 | super().__init__(input_nc, ndf, 2, norm_layer, no_antialias) 1393 | 1394 | def forward(self, input): 1395 | B, C, H, W = input.size(0), input.size(1), input.size(2), input.size(3) 1396 | size = 16 1397 | Y = H // size 1398 | X = W // size 1399 | input = input.view(B, C, Y, size, X, size) 1400 | input = input.permute(0, 2, 4, 1, 3, 5).contiguous().view(B * Y * X, C, size, size) 1401 | return super().forward(input) 1402 | 1403 | 1404 | class GroupedChannelNorm(nn.Module): 1405 | def __init__(self, num_groups): 1406 | super().__init__() 1407 | self.num_groups = num_groups 1408 | 1409 | def forward(self, x): 1410 | shape = list(x.shape) 1411 | new_shape = [shape[0], self.num_groups, shape[1] // self.num_groups] + shape[2:] 1412 | x = x.view(*new_shape) 1413 | mean = x.mean(dim=2, keepdim=True) 1414 | std = x.std(dim=2, keepdim=True) 1415 | x_norm = (x - mean) / (std + 1e-7) 1416 | return x_norm.view(*shape) 1417 | 1418 | class MsDiscriminator(nn.Module): 1419 | def __init__(self, n_scales=3, input_nc=1): 1420 | super().__init__() 1421 | self.disc = nn.ModuleList() 1422 | 1423 | for _ in range(n_scales): 1424 | self.disc.append(define_D(input_nc=input_nc, ndf=64, netD='basic', 1425 | n_layers_D=3, norm="instance", 1426 | init_type="xavier", init_gain=0.02, 1427 | no_antialias=False)) 1428 | 1429 | def forward(self, x): 1430 | output = [] 1431 | for i, d in enumerate(self.disc): 1432 | factor = 1 / (2 ** i) 1433 | input = torch.nn.functional.interpolate(x, scale_factor=factor, align_corners=False, mode='bilinear') 1434 | output.append(d(input)) 1435 | 1436 | return output -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.8.1 2 | pytorch-lightning==1.4.4 3 | numpy==1.19.5 4 | matplotlib==3.3.4 5 | nibabel==3.2.1 6 | SimpleITK==2.0.2 7 | torchvision==0.9.1 8 | h5py==3.1.0 9 | sklearn --------------------------------------------------------------------------------