├── requirements.txt ├── .gitignore ├── Dockerfile ├── optimizer.py ├── discriminator_config.py ├── download_dataset.py ├── config.py ├── train.py ├── register.py ├── dataset.py ├── critertion.py ├── train_without_trainer.py ├── LICENSE ├── model.py └── measurements.py /requirements.txt: -------------------------------------------------------------------------------- 1 | autoencoder 2 | h5py 3 | imageio 4 | torch 5 | gdown 6 | googledrivedownloader -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | *.tar 3 | .ipynb_checkpoints 4 | /_python_build 5 | *.pyc 6 | __pycache__ 7 | *.swp 8 | 9 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:19.10-py3 2 | 3 | ADD . /opt/project 4 | WORKDIR /opt/project 5 | 6 | 7 | RUN pip install --upgrade pip 8 | RUN pip install --disable-pip-version-check -r requirements.txt 9 | 10 | 11 | -------------------------------------------------------------------------------- /optimizer.py: -------------------------------------------------------------------------------- 1 | def backward(loss, **kwargs): 2 | loss.backward(**kwargs) 3 | 4 | 5 | class OptimizerCollection(object): 6 | 7 | def __init__(self, *optimizers): 8 | 9 | for optimizer in optimizers: 10 | if not hasattr(optimizer, 'backward'): 11 | setattr(optimizer, 'backward', backward) 12 | self.optimizers = optimizers 13 | 14 | def backward(self, losses): 15 | for loss, optimizer in zip(losses, self.optimizers): 16 | optimizer.zero_grad() 17 | optimizer.backward(loss) 18 | optimizer.step() 19 | -------------------------------------------------------------------------------- /discriminator_config.py: -------------------------------------------------------------------------------- 1 | """module containing configurations for the model and training routine""" 2 | 3 | from torch.nn.functional import relu, selu, leaky_relu, tanh 4 | from torch.nn import MSELoss, L1Loss, SmoothL1Loss 5 | from torch.optim import Adam 6 | 7 | discriminator = { 8 | 'n_layers': 7, 9 | 'kernel_size': (3, 3), 10 | 'activation': selu, 11 | 'channel_factor': 8, 12 | 'max_channels': 1024, 13 | 'input_channels': 1, 14 | 'n_residual': (0, 0), 15 | 'affine': True 16 | } 17 | 18 | dataset = { 19 | 'vmin': 'mean-0.5', 20 | 'vmax': 1.0, 21 | 'whitening': False 22 | } 23 | 24 | dataloader = { 25 | 'batch_size': 32, 26 | 'shuffle': True, 27 | 'num_workers': 2, 28 | } 29 | 30 | optimizer = { 31 | 'optimizer_type': Adam, 32 | 'learning_rate': 0.005, 33 | 'apex': False, 34 | 'lr_decay': (0.99, 1e5) 35 | } 36 | 37 | training = { 38 | 'epochs': 25, 39 | 'name': 'discriminator-test', 40 | } 41 | 42 | 43 | def get_complete_config(): 44 | import config 45 | complete_config = {key: value for key, value in config.__dict__.items() 46 | if isinstance(value, dict) and key != '__builtins__'} 47 | return complete_config 48 | 49 | 50 | -------------------------------------------------------------------------------- /download_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import zipfile 4 | 5 | from google_drive_downloader import GoogleDriveDownloader as gdd 6 | 7 | PARSER = argparse.ArgumentParser() 8 | 9 | PARSER.add_argument('--data_dir', 10 | type=str, 11 | default='./data', 12 | help="""Directory where to download the dataset""") 13 | 14 | 15 | def main(): 16 | FLAGS = PARSER.parse_args() 17 | 18 | if not os.path.exists(FLAGS.data_dir): 19 | os.makedirs(FLAGS.data_dir) 20 | 21 | filename = 'oct_quality.zip' 22 | 23 | gdd.download_file_from_google_drive(file_id='19p1KDG2j93mBJp9O_yenwMHQMaqgJiWG', 24 | dest_path=os.path.join(FLAGS.data_dir, filename), 25 | unzip=False) 26 | 27 | print('Unpacking...') 28 | 29 | with zipfile.ZipFile(os.path.join(FLAGS.data_dir, filename), 'r') as zip_ref: 30 | zip_ref.extractall(FLAGS.data_dir) 31 | 32 | print('Cleaning up...') 33 | 34 | os.remove(os.path.join(FLAGS.data_dir, filename)) 35 | 36 | print("Finished downloading files for HDCycleGan to {}".format(FLAGS.data_dir)) 37 | 38 | 39 | if __name__ == '__main__': 40 | main() 41 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | """module containing configurations for the model and training routine""" 2 | 3 | from torch.nn.functional import relu 4 | from torch.nn import L1Loss, CrossEntropyLoss, BCEWithLogitsLoss 5 | from torch.optim import Adam 6 | from torch import float16 7 | 8 | from model import Discriminator, HDCycleGAN, Generator, CycleGAN 9 | from dataset import OCTQualityDataset 10 | from critertion import ClassificationLoss, ClassificationLossHD 11 | 12 | 13 | dtype = float16 14 | cuda = True 15 | seed = 0 16 | 17 | MODEL = HDCycleGAN 18 | LOSS = ClassificationLossHD 19 | DATASET = OCTQualityDataset 20 | OPTIMIZER = Adam 21 | LOGDIR = f'./trained/{seed}' 22 | 23 | try: 24 | from apex.fp16_utils import FP16_Optimizer 25 | APEX = FP16_Optimizer 26 | apex = { 27 | 'dynamic_loss_scale': True, 28 | 'dynamic_loss_args': {'init_scale': 2 ** 10}, 29 | 'verbose': False 30 | } 31 | except ImportError: 32 | pass 33 | 34 | generator = { 35 | 'scale_factor': 3, 36 | 'channel_factor': 16, 37 | 'activation': relu, 38 | 'kernel_size': (3, 3), 39 | 'n_residual': (6, 3), 40 | 'input_channels': 1, 41 | 'skip_conn': 'concat' 42 | } 43 | 44 | discriminator = { 45 | 'n_layers': 7, 46 | 'kernel_size': (3, 3), 47 | 'activation': relu, 48 | 'channel_factor': 16, 49 | 'max_channels': 1024, 50 | 'input_channels': 1, 51 | 'n_residual': (1, 2), 52 | 'affine': False 53 | } 54 | 55 | model = { 56 | 'discriminator': (Discriminator, discriminator), 57 | 'generator': (Generator, generator), 58 | 'input_size': (1, 512, 512), 59 | 'pool_size': 32, 60 | 'pool_write_probability': 1 61 | } 62 | 63 | dataset = { 64 | 'parent_folder': '/home/ilja/Datasets/oct_quality', # replace with the path to the parent folder on your system 65 | 'fraction': 0.1 66 | } 67 | 68 | dataloader = { 69 | 'batch_size': 4, 70 | 'shuffle': True, 71 | 'num_workers': 0, 72 | } 73 | 74 | loss = { 75 | 'cycle_loss': L1Loss, 76 | 'discriminator_loss': CrossEntropyLoss, 77 | 'cycle_factor': 10 78 | } 79 | 80 | optimizer = { 81 | 'lr': 0.0005 82 | } 83 | 84 | trainer = { 85 | 'loss_decay': 0.8, 86 | 'split_sample': lambda x: (x[0], x) 87 | } 88 | 89 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from trainer import Trainer, events, Config 2 | from trainer.handlers import EventSave 3 | from optimizer import OptimizerCollection 4 | import torch as pt 5 | from functools import partial 6 | from itertools import chain 7 | 8 | pt.cuda.set_device(0) 9 | 10 | print('initializing trainer...', end='') 11 | trainer = Trainer.from_config_file('config.py', False) 12 | 13 | config = Config.from_file('config.py') 14 | 15 | gen_optimizer = config.OPTIMIZER(chain(trainer.model.generator['hn'].parameters(), 16 | trainer.model.generator['ln'].parameters()) 17 | , **config.optimizer) 18 | disc_optimizer = config.OPTIMIZER(trainer.model.discriminator.parameters(), **config.optimizer) 19 | optimizers = [] 20 | for optimizer in (disc_optimizer, gen_optimizer): 21 | if hasattr(config, 'APEX'): 22 | optimizer = config.APEX(optimizer, **config.apex) 23 | optimizers.append(optimizer) 24 | trainer.optimizer = OptimizerCollection(*optimizers) 25 | trainer.backward_pass = trainer.optimizer.backward 26 | 27 | sample = next(iter(trainer.dataloader))[:4] 28 | 29 | 30 | def sample_inference(trainer, part, sample, ind, *args, **kwargs): 31 | return part(trainer._transform(sample)[0][ind]) 32 | 33 | 34 | trainer.register_event_handler(events.EACH_STEP, sample_inference, name='gen_hn', interval=100, sample=sample, 35 | part=trainer.model.generator['hn'], ind=0) 36 | trainer.register_event_handler(events.EACH_STEP, sample_inference, name='gen_ln', interval=100, sample=sample, 37 | part=trainer.model.generator['ln'], ind=1) 38 | #trainer.register_event_handler(events.EACH_STEP, trainer, name='sample', interval=250, sample=sample) 39 | trainer.register_event_handler(events.EACH_EPOCH, EventSave(), monitor=False) 40 | #trainer.monitor(name='criterion.ln_discriminator_loss') 41 | #trainer.monitor(name='criterion.hn_discriminator_loss') 42 | trainer.monitor(name='criterion.discriminator_loss') 43 | trainer.monitor(name='criterion.ln_generator_loss') 44 | trainer.monitor(name='criterion.hn_generator_loss') 45 | trainer.monitor(name='criterion.cycle_loss') 46 | print('done!') 47 | 48 | print('\ncommencing training!') 49 | trainer.train(n_epochs=100, resume=True) 50 | 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /register.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | 4 | from utils.registrator import Registrator 5 | 6 | 7 | def progress(i, total): 8 | 9 | fraction = i / total 10 | bar_length = 50 11 | bar_fraction = int(fraction*bar_length) 12 | bar = '#'*bar_fraction + '='*(bar_length-bar_fraction) 13 | print(f'\r|{bar}| {round(fraction*100, 2)}%\t', end='') 14 | 15 | 16 | def preprocess(sample, factor=0.5): 17 | sample = sample/sample.max() 18 | mean = sample[sample>0].mean() 19 | std = sample[sample>0].std() 20 | level = mean - factor*std 21 | sample = sample.clip(level, 1.0) - level 22 | sample = sample/sample.max() 23 | sample[0:8] = 0 24 | return sample 25 | 26 | 27 | storage_dir = '/media/network/DL_PC/Datasets/oct_quality_validation/' 28 | lq = h5py.File(storage_dir+'low.hdf5', 'r') 29 | hq = h5py.File(storage_dir+'high.hdf5', 'r') 30 | output = h5py.File(storage_dir+'registered.hdf5', 'a') 31 | 32 | registrator = Registrator(order=3, offset=0) 33 | 34 | lq_keys, hq_keys = list(lq.keys()), list(hq.keys()) 35 | n_samples = len(lq_keys) 36 | 37 | try: 38 | for i, (lq_key, hq_key) in enumerate(zip(lq_keys, hq_keys)): 39 | 40 | subject, template = preprocess(lq[lq_key].value), preprocess(hq[hq_key].value) 41 | try: 42 | registered, transformation, difference = registrator.register_single(template, subject) 43 | transformation.pop('timg') 44 | except (IndexError, ValueError): 45 | print(f'\nsomething went wrong in the registration of {lq_key}') 46 | registered = np.zeros_like(subject) 47 | difference = 1 48 | transformation = {} 49 | output[lq_key] = registered 50 | for key, value in transformation.items(): 51 | output[lq_key].attrs[key] = value 52 | output[lq_key].attrs['difference'] = difference 53 | # output.flush() 54 | progress(i, n_samples) 55 | finally: 56 | output.close() 57 | lq.close() 58 | hq.close() 59 | 60 | """ 61 | something went wrong in the registration of 17182-624913-00 62 | something went wrong in the registration of 17182-624913-01 63 | something went wrong in the registration of 17182-624913-02 64 | something went wrong in the registration of 17802-626829-14 65 | something went wrong in the registration of 17802-626829-15 66 | something went wrong in the registration of 7681-624748-00 67 | something went wrong in the registration of 7681-624748-01 68 | something went wrong in the registration of 7681-624748-10 69 | something went wrong in the registration of 7681-624748-11 70 | something went wrong in the registration of 7681-624748-13 71 | """ -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019 Ilja Manakov 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 5 | documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 6 | rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit 7 | persons to whom the Software is furnished to do so, subject to the following conditions: 8 | 9 | The above copyright notice and this permission notice shall be included in all copies or 10 | substantial portions of the Software. 11 | 12 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE 13 | WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 14 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 15 | OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 16 | """ 17 | 18 | import os 19 | from os.path import join 20 | 21 | import h5py 22 | import torch as pt 23 | from torch.utils.data import Dataset 24 | from imageio import imread 25 | 26 | 27 | class OCTQualityDatasetHDF5(Dataset): 28 | 29 | def __init__(self, hn, ln, storage_dir, vmin=0.0, vmax=1.0, whitening=False, preprocess=True, transforms=None): 30 | 31 | self._preprocess = preprocess 32 | self.storage_dir = storage_dir 33 | self.hn = hn 34 | self.ln = ln 35 | hn = h5py.File(os.path.join(self.storage_dir, hn), 'r') 36 | ln = h5py.File(os.path.join(self.storage_dir, ln), 'r') 37 | self.keys = {'hn': list(hn.keys()), 'ln': list(ln.keys())} 38 | hn.close() 39 | ln.close() # hdf5 files are closed and reopened later, otherwise there are problems with dataloader workers 40 | self.vmin = vmin 41 | self.vmax = vmax 42 | self.whitening = whitening 43 | self.transforms = transforms 44 | 45 | def init(self): 46 | """ 47 | delayed init due to problems with hdf5 and multiprocessing in dataloader that arise otherwise 48 | :return: None 49 | """ 50 | 51 | self.hn = h5py.File(os.path.join(self.storage_dir, self.hn), 'r') 52 | self.ln = h5py.File(os.path.join(self.storage_dir, self.ln), 'r') 53 | 54 | def __len__(self): 55 | 56 | return min(len(self.keys['ln']), len(self.keys['hn'])) 57 | 58 | def __getitem__(self, item): 59 | 60 | # late init because of how hdf5 and dataloader work together 61 | if isinstance(self.hn, str) or isinstance(self.ln, str): 62 | self.init() 63 | 64 | hn_key = self.keys['hn'][item] 65 | hn = [self.preprocess(self.hn[hn_key][()]), 66 | self.hn[hn_key].attrs['frames'] * pt.ones(1)] 67 | ln_key = self.keys['ln'][item] 68 | ln = [self.preprocess(self.ln[ln_key][()]), 69 | self.ln[ln_key].attrs['frames'] * pt.ones(1)] 70 | 71 | # convert number of frames to classes for classification 72 | ln[1] = ln[1] / 12 73 | ln[1][ln[1] > 1] = 2 74 | 75 | hn[1] = hn[1] / 12 76 | hn[1][hn[1] > 1] = 2 77 | 78 | sample = ((hn[0], ln[0]), (hn[1].long(), ln[1].long())) 79 | return sample 80 | 81 | def preprocess(self, image): 82 | if self._preprocess: 83 | image = image / image.max() 84 | vmin = self.parse_threshold(self.vmin, image) 85 | vmax = self.parse_threshold(self.vmax, image) 86 | image = image.clip(vmin, vmax) 87 | image -= vmin 88 | image = image / image.max() 89 | 90 | if self.transforms: 91 | image = self.transforms(image) 92 | 93 | if self.whitening: 94 | image = image - image.mean() 95 | image = image / image.std() 96 | 97 | return pt.from_numpy(image[None, ...]) 98 | 99 | def close(self): 100 | 101 | if self.hn is not None: self.hn.close() 102 | if self.ln is not None: self.ln.close() 103 | 104 | @staticmethod 105 | def parse_threshold(threshold, image): 106 | 107 | if isinstance(threshold, float): 108 | return threshold 109 | elif threshold == 'mean': 110 | return image[image > 0].mean() 111 | elif isinstance(threshold, str) and 'mean' in threshold: 112 | factor = float(threshold[4:]) 113 | std = image[image > 0].std() 114 | mean = image[image > 0].mean() 115 | return mean + factor * std 116 | 117 | 118 | class OCTQualityDataset(Dataset): 119 | 120 | def __init__(self, parent_folder, fraction=0.8, transformation=lambda x: x): 121 | 122 | self.transformation = transformation 123 | 124 | # get lists of filenames 125 | self.hn_files = self.gather_filenames(join(parent_folder, 'high-noise')) 126 | self.ln_files = self.gather_filenames(join(parent_folder, 'low-noise')) 127 | 128 | # keep fraction 129 | last_index = int(len(self.hn_files) * abs(fraction)) 130 | self.hn_files = self.hn_files[:last_index] if fraction > 0 else self.hn_files[-last_index:] 131 | last_index = int(len(self.ln_files) * abs(fraction)) 132 | self.ln_files = self.ln_files[:last_index] if fraction > 0 else self.ln_files[-last_index:] 133 | 134 | @staticmethod 135 | def gather_filenames(folder): 136 | 137 | # walk through directories and collect filenames with path (in numerical order to preserve pairing) 138 | filenames = [] 139 | for root, dirs, files in os.walk(folder): 140 | dirs.sort(key=int) 141 | if not files: continue 142 | files.sort(key=lambda x: int(x.split('.')[0])) 143 | filenames += [f'{root}/{f}' for f in files] 144 | return filenames 145 | 146 | def prepare_image(self, filename): 147 | 148 | # load and convert to normalized float32 tensor 149 | image = imread(filename) 150 | image = self.transformation(image) 151 | image = pt.from_numpy(image)[None, ...].float() 152 | image = image / image.max() 153 | return image 154 | 155 | def __len__(self): 156 | return min(len(self.hn_files), len(self.ln_files)) 157 | 158 | def __getitem__(self, item): 159 | 160 | hn = self.prepare_image(self.hn_files[item]) 161 | ln = self.prepare_image(self.ln_files[item]) 162 | images = (hn, ln) 163 | labels = (1, 2) 164 | 165 | return images, labels 166 | -------------------------------------------------------------------------------- /critertion.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019 Ilja Manakov 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 5 | documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 6 | rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit 7 | persons to whom the Software is furnished to do so, subject to the following conditions: 8 | 9 | The above copyright notice and this permission notice shall be included in all copies or 10 | substantial portions of the Software. 11 | 12 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE 13 | WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 14 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 15 | OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 16 | """ 17 | 18 | 19 | from collections import namedtuple 20 | 21 | import torch as pt 22 | from torch.nn import L1Loss, CrossEntropyLoss 23 | 24 | 25 | class ClassificationLoss(object): 26 | 27 | def __init__(self, cycle_loss=L1Loss, discriminator_loss=CrossEntropyLoss, cycle_factor=10): 28 | 29 | self.cycle_factor = cycle_factor 30 | self.cyc_loss = cycle_loss() 31 | self.disc_loss = discriminator_loss() 32 | 33 | def __call__(self, prediction, sample): 34 | 35 | # convert components to namedtuples for easier handling 36 | images = namedtuple('images', ('hn', 'ln'))(*sample[0]) 37 | cycled = namedtuple('cycled', ('hn', 'ln'))(*prediction.cycled) 38 | 39 | # reals are flipped due to architecture 40 | scores = namedtuple('scores', ('real', 'fake', 'pool_fake')) 41 | hn_scores = scores(prediction.ln_scores.real.float(), 42 | prediction.hn_scores.fake.float(), 43 | prediction.hn_scores.pool_fake.float()) 44 | ln_scores = scores(prediction.hn_scores.real.float(), 45 | prediction.ln_scores.fake.float(), 46 | prediction.ln_scores.pool_fake.float()) 47 | 48 | disc_hn_loss = self.hn_discriminator_loss(hn_scores) 49 | disc_ln_loss = self.ln_discriminator_loss(ln_scores) 50 | gen_loss = self.ln_generator_loss(ln_scores) + self.hn_generator_loss(hn_scores) + self.cycle_loss(images, cycled) 51 | 52 | return disc_hn_loss, disc_ln_loss, gen_loss 53 | 54 | def hn_discriminator_loss(self, hn_scores): 55 | return self.disc_loss(hn_scores.real, self.create_target(hn_scores.real, 1)) +\ 56 | self.disc_loss(hn_scores.pool_fake, self.create_target(hn_scores.pool_fake, 0)) 57 | 58 | def ln_discriminator_loss(self, ln_scores): 59 | return self.disc_loss(ln_scores.real, self.create_target(ln_scores.real, 1)) +\ 60 | self.disc_loss(ln_scores.pool_fake, self.create_target(ln_scores.pool_fake, 0)) 61 | 62 | def ln_generator_loss(self, ln_scores): 63 | return self.disc_loss(ln_scores.fake, self.create_target(ln_scores.fake, 1)) 64 | 65 | def hn_generator_loss(self, hn_scores): 66 | return self.disc_loss(hn_scores.fake, self.create_target(hn_scores.fake, 1)) 67 | 68 | def cycle_loss(self, images, cycled): 69 | return self.cycle_factor * (self.cyc_loss(cycled.hn.float(), images.hn) + 70 | self.cyc_loss(cycled.ln.float(), images.ln)) 71 | 72 | def create_target(self, tensor, value): 73 | if value: 74 | return value*pt.ones_like(tensor) 75 | else: 76 | return pt.zeros_like(tensor) 77 | 78 | 79 | class ClassificationLossHD(object): 80 | 81 | def __init__(self, cycle_loss=L1Loss, discriminator_loss=CrossEntropyLoss, cycle_factor=10): 82 | 83 | self.cycle_factor = cycle_factor 84 | self.cyc_loss = cycle_loss() 85 | self.disc_loss = discriminator_loss() 86 | 87 | def __call__(self, prediction, sample): 88 | 89 | # convert components to namedtuples for easier handling 90 | images = namedtuple('images', ('hn', 'ln'))(*sample[0]) 91 | targets = namedtuple('targets', ('hn', 'ln'))(*sample[1]) 92 | cycled = namedtuple('cycled', ('hn', 'ln'))(*prediction.cycled) 93 | 94 | # reals are flipped due to architecture 95 | scores = namedtuple('scores', ('real', 'fake', 'pool_fake')) 96 | hn_scores = scores(prediction.ln_scores.real, prediction.hn_scores.fake, prediction.hn_scores.pool_fake) 97 | ln_scores = scores(prediction.hn_scores.real, prediction.ln_scores.fake, prediction.ln_scores.pool_fake) 98 | 99 | # f_disc_hn, f_disc_ln, f_gen_hn, f_gen_ln = self.generate_fake_targets(hn_scores, ln_scores) 100 | fake_t, hn_t, ln_t = self.generate_fake_targets(hn_scores) 101 | disc_loss = self.discriminator_loss(hn_scores, ln_scores, fake_t, hn_t, ln_t) 102 | gen_loss = self.ln_generator_loss(ln_scores, ln_t) + self.hn_generator_loss(hn_scores, hn_t) 103 | cyc_loss = self.cycle_loss(images, cycled) 104 | 105 | return 0.5*disc_loss, gen_loss + cyc_loss 106 | 107 | def generate_fake_targets(self, hn_scores): 108 | 109 | template = hn_scores.pool_fake.float() 110 | batch_size = template.shape[0] 111 | # disc_hn = pt.zeros(batch_size).to(hn_scores.pool_fake).long() 112 | # disc_ln = 2*pt.zeros(batch_size).to(ln_scores.pool_fake).long() 113 | # gen_hn = pt.ones(batch_size).to(hn_scores.fake).long() 114 | # gen_ln = 2*pt.ones(batch_size).to(ln_scores.fake).long() 115 | fake_t = pt.tensor([[1, 0, 0]]*batch_size).to(template) 116 | hn_t = pt.tensor([[0, 1, 0]]*batch_size).to(template) 117 | ln_t = pt.tensor([[0, 0, 1]]*batch_size).to(template) 118 | 119 | return fake_t, hn_t, ln_t 120 | 121 | def discriminator_loss(self, hn_scores, ln_scores, fake_t, hn_t, ln_t): 122 | return self.disc_loss(hn_scores.real.float(), hn_t) +\ 123 | self.disc_loss(ln_scores.real.float(), ln_t) + \ 124 | self.disc_loss(hn_scores.pool_fake.float(), fake_t) +\ 125 | self.disc_loss(ln_scores.pool_fake.float(), fake_t) 126 | 127 | def ln_generator_loss(self, ln_scores, f_gen_ln): 128 | return self.disc_loss(ln_scores.fake.float(), f_gen_ln) 129 | 130 | def hn_generator_loss(self, hn_scores, f_gen_hn): 131 | return self.disc_loss(hn_scores.fake.float(), f_gen_hn) 132 | 133 | def cycle_loss(self, images, cycled): 134 | return self.cycle_factor * (self.cyc_loss(cycled.hn.float(), images.hn.float()) + 135 | self.cyc_loss(cycled.ln.float(), images.ln.float())) 136 | 137 | -------------------------------------------------------------------------------- /train_without_trainer.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import argparse 3 | import os 4 | from os.path import join, isdir 5 | from itertools import chain 6 | 7 | import torch as pt 8 | import torch.nn as nn 9 | from torch.nn import init 10 | from torch.utils.data import DataLoader 11 | import numpy as np 12 | 13 | from optimizer import OptimizerCollection 14 | 15 | 16 | def load_config(file): 17 | """ 18 | initialize module from .py config file 19 | :param file: filename of the config .py file 20 | :return: config as module 21 | """ 22 | 23 | spec = importlib.util.spec_from_file_location("config", file) 24 | module = importlib.util.module_from_spec(spec) 25 | spec.loader.exec_module(module) 26 | 27 | return module 28 | 29 | 30 | def set_seed(seed): 31 | pt.backends.cudnn.deterministic = True 32 | pt.backends.cudnn.benchmark = False 33 | np.random.seed(seed) 34 | pt.manual_seed(seed) 35 | pt.cuda.manual_seed_all(seed) 36 | 37 | 38 | def weight_init(m): 39 | """ 40 | Usage: 41 | model = Model() 42 | model.apply(weight_init) 43 | """ 44 | 45 | if isinstance(m, nn.Conv1d): 46 | init.normal_(m.weight.data) 47 | if m.bias is not None: 48 | init.normal_(m.bias.data) 49 | elif isinstance(m, nn.Conv2d): 50 | init.xavier_normal_(m.weight.data) 51 | if m.bias is not None: 52 | init.normal_(m.bias.data) 53 | elif isinstance(m, nn.Conv3d): 54 | init.xavier_normal_(m.weight.data) 55 | if m.bias is not None: 56 | init.normal_(m.bias.data) 57 | elif isinstance(m, nn.ConvTranspose1d): 58 | init.normal_(m.weight.data) 59 | if m.bias is not None: 60 | init.normal_(m.bias.data) 61 | elif isinstance(m, nn.ConvTranspose2d): 62 | init.xavier_normal_(m.weight.data) 63 | if m.bias is not None: 64 | init.normal_(m.bias.data) 65 | elif isinstance(m, nn.ConvTranspose3d): 66 | init.xavier_normal_(m.weight.data) 67 | if m.bias is not None: 68 | init.normal_(m.bias.data) 69 | elif isinstance(m, nn.BatchNorm1d): 70 | init.normal_(m.weight.data, mean=1, std=0.02) 71 | init.constant_(m.bias.data, 0) 72 | elif isinstance(m, nn.BatchNorm2d): 73 | init.normal_(m.weight.data, mean=1, std=0.02) 74 | init.constant_(m.bias.data, 0) 75 | elif isinstance(m, nn.BatchNorm3d): 76 | init.normal_(m.weight.data, mean=1, std=0.02) 77 | init.constant_(m.bias.data, 0) 78 | elif isinstance(m, nn.Linear): 79 | init.xavier_normal_(m.weight.data) 80 | init.normal_(m.bias.data) 81 | elif isinstance(m, nn.LSTM): 82 | for param in m.parameters(): 83 | if len(param.shape) >= 2: 84 | init.orthogonal_(param.data) 85 | else: 86 | init.normal_(param.data) 87 | elif isinstance(m, nn.LSTMCell): 88 | for param in m.parameters(): 89 | if len(param.shape) >= 2: 90 | init.orthogonal_(param.data) 91 | else: 92 | init.normal_(param.data) 93 | elif isinstance(m, nn.GRU): 94 | for param in m.parameters(): 95 | if len(param.shape) >= 2: 96 | init.orthogonal_(param.data) 97 | else: 98 | init.normal_(param.data) 99 | elif isinstance(m, nn.GRUCell): 100 | for param in m.parameters(): 101 | if len(param.shape) >= 2: 102 | init.orthogonal_(param.data) 103 | else: 104 | init.normal_(param.data) 105 | 106 | 107 | def backward(loss): 108 | loss.backward() 109 | 110 | 111 | # cmd arguments 112 | parser = argparse.ArgumentParser() 113 | parser.add_argument('--config', '-c', required=True, dest='config') 114 | parser.add_argument('--n_epochs', '-n', default=50, type=int, dest='n_epochs') 115 | parser.add_argument('--logdir', '-l', default='./', dest='logdir') 116 | parser.add_argument('--seed', '-s', default=0, type=int, dest='seed') 117 | args = parser.parse_args() 118 | 119 | # fix seed and load config 120 | set_seed(args.seed) 121 | config = load_config(args.config) 122 | 123 | # use fp16 if possible 124 | dtype = pt.float16 if hasattr(config, 'APEX') else pt.float32 125 | 126 | # initialize components 127 | model = config.MODEL(**config.model).cuda().to(dtype) 128 | model.apply(weight_init) 129 | dataset = config.DATASET(**config.dataset) 130 | criterion = config.LOSS(**config.loss) 131 | gen_optimizer = config.OPTIMIZER(chain(model.generator['hn'].parameters(), 132 | model.generator['ln'].parameters()) , **config.optimizer) 133 | disc_optimizer = config.OPTIMIZER(model.discriminator.parameters(), **config.optimizer) 134 | optimizers = [] 135 | for optimizer in (disc_optimizer, gen_optimizer): 136 | if hasattr(config, 'APEX'): 137 | optimizer = config.APEX(optimizer, **config.apex) 138 | optimizers.append(optimizer) 139 | optimizer = OptimizerCollection(*optimizers) 140 | dataloader = DataLoader(dataset, **config.dataloader) 141 | 142 | # convert optimizer and unify backward call 143 | if dtype == pt.float16: 144 | optimizer = config.APEX(optimizer, **config.apex) 145 | if not hasattr(optimizer, 'backward'): 146 | setattr(optimizer, 'backward', backward) 147 | 148 | # init log directory 149 | if not isdir(args.logdir): 150 | os.makedirs(args.logdir) 151 | 152 | # training loop 153 | steps_in_epoch = len(dataloader) 154 | try: 155 | for epoch in range(args.n_epochs): 156 | for step, (images, labels) in enumerate(dataloader): 157 | 158 | # convert data to match network 159 | images = [i.cuda().to(dtype) for i in images] 160 | labels = [l.cuda().long() for l in labels] 161 | 162 | # forward pass 163 | result = model(images) 164 | loss = criterion(result, ([i.float() for i in images], labels)) 165 | 166 | # backward pass 167 | optimizer.zero_grad() 168 | optimizer.backward(loss) 169 | optimizer.step() 170 | 171 | # reporting 172 | progress = (epoch*steps_in_epoch+step) / (args.n_epochs*steps_in_epoch) 173 | progress = round(100*progress, 2) 174 | print(f'\repoch: {epoch} of {args.n_epochs}, step: {step}, progress: {progress}%,' 175 | f' loss: {round(loss.item(), 4)}', end='') 176 | 177 | # write loss 178 | with open(join(args.logdir, 'losses.csv'), 'a') as file: 179 | print(epoch, step, loss, sep=',', file=file) 180 | 181 | # always save model and optimizer before exiting 182 | finally: 183 | pt.save(model.state_dict(), join(args.logdir, f'{model.__class__.__name__}_epoch-{epoch}_step-{step}.pt')) 184 | for name, opt in zip(['gen', 'disc'], optimizer.optimizers): 185 | pt.save(opt.state_dict(), join(args.logdir, f'{optimizer.__class__.__name__}_{name}_epoch-{epoch}_step-{step}.pt')) 186 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | 3 | Apache License 4 | 5 | Version 2.0, January 2004 6 | 7 | http://www.apache.org/licenses/ 8 | 9 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 10 | 11 | 1. Definitions. 12 | 13 | "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. 14 | 15 | "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. 18 | 19 | "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. 20 | 21 | "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. 22 | 23 | "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. 24 | 25 | "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). 26 | 27 | "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. 28 | 29 | "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." 30 | 31 | "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 32 | 33 | 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 34 | 35 | 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 36 | 37 | 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: 38 | 39 | You must give any other recipients of the Work or Derivative Works a copy of this License; and 40 | You must cause any modified files to carry prominent notices stating that You changed the files; and 41 | You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and 42 | If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. 43 | 44 | You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 45 | 46 | 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 47 | 48 | 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 49 | 50 | 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 51 | 52 | 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 53 | 54 | 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. 55 | 56 | END OF TERMS AND CONDITIONS 57 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019 Ilja Manakov 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 5 | documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 6 | rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit 7 | persons to whom the Software is furnished to do so, subject to the following conditions: 8 | 9 | The above copyright notice and this permission notice shall be included in all copies or 10 | substantial portions of the Software. 11 | 12 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE 13 | WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 14 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 15 | OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 16 | """ 17 | 18 | 19 | import torch as pt 20 | from torch.nn.functional import selu, relu 21 | from torch.nn.parameter import Parameter 22 | from autoencoder import model_parts as parts 23 | from collections import namedtuple 24 | 25 | 26 | class Generator(pt.nn.Module): 27 | 28 | def __init__(self, scale_factor, n_residual=(6, 3), activation=relu, input_channels=1, 29 | channel_factor=8, kernel_size=(3, 3), skip_conn='concat', 30 | norm_func=pt.nn.InstanceNorm2d, up_conv=parts.ConvResize2d, pad=pt.nn.ZeroPad2d): 31 | super().__init__() 32 | 33 | self.skip_conn = skip_conn 34 | self.downsampling = [] 35 | self.residuals = [] 36 | self.upsampling = [] 37 | 38 | # initial 7x7 convolution 39 | self.initial_conv = parts.GeneralConvolution(input_channels, channel_factor, (7, 7), (1, 1), pt.nn.ZeroPad2d, 40 | norm_func, activation, pt.nn.Conv2d, True) 41 | 42 | # downsampling with strided convolutions, channels are double after each convolution 43 | for i in range(scale_factor): 44 | channel_factor *= 2 45 | self.downsampling.append( 46 | parts.GeneralConvolution(channel_factor // 2, channel_factor, kernel_size, (2, 2), pt.nn.ZeroPad2d, 47 | norm_func, activation, pt.nn.Conv2d, True) 48 | ) 49 | self.add_module(f'down_conv_{i}', self.downsampling[-1]) 50 | 51 | # add residual blocks 52 | for i in range(n_residual[0]): 53 | self.residuals.append( 54 | parts.ResBlock2d(channel_factor, n_residual[1], kernel_size, pt.nn.ZeroPad2d, norm_func) 55 | ) 56 | self.add_module(f'res_block_{i}', self.residuals[-1]) 57 | 58 | # upsampling 59 | for i in range(scale_factor): 60 | in_channels = channel_factor*2 if skip_conn == 'concat' else channel_factor 61 | channel_factor = channel_factor // 2 62 | self.upsampling.append( 63 | up_conv(in_channels, channel_factor, kernel_size, (1, 1), 64 | norm=norm_func, activation=activation, affine=True, padding=pad) 65 | ) 66 | self.add_module(f'up_conv_{i}', self.upsampling[-1]) 67 | 68 | # final convolution 69 | in_channels = channel_factor * 2 if skip_conn == 'concat' else channel_factor 70 | self.final_conv = parts.GeneralConvolution(in_channels, input_channels, kernel_size, (1, 1), pt.nn.ZeroPad2d, 71 | None, activation, pt.nn.Conv2d, True) 72 | 73 | def forward(self, x): 74 | 75 | out = x 76 | 77 | out = self.initial_conv(out) 78 | 79 | skips = [] 80 | for down_conv in self.downsampling: 81 | skips.append(out) 82 | out = down_conv(out) 83 | skips.append(out) 84 | 85 | for residual in self.residuals: 86 | out = residual(out) 87 | 88 | for up_conv in self.upsampling: 89 | if self.skip_conn == 'concat': 90 | out = pt.cat([out, skips.pop()], dim=1) 91 | elif self.skip_conn == 'add': 92 | out = out + skips.pop() 93 | out = up_conv(out) 94 | 95 | if self.skip_conn == 'concat': 96 | out = pt.cat([out, skips.pop()], dim=1) 97 | elif self.skip_conn == 'add': 98 | out = out + skips.pop() 99 | out = self.final_conv(out) 100 | 101 | return out 102 | 103 | 104 | class Discriminator(pt.nn.Module): 105 | 106 | def __init__(self, n_out, channel_factor=2, n_layers=7, activation=relu, kernel_size=(4, 4), 107 | n_residual=(0, 0), max_channels=1024, input_channels=1, affine=False, **kwargs): 108 | 109 | super(Discriminator, self).__init__() 110 | self.layers = [] 111 | current_channels = input_channels 112 | for depth_index in range(n_layers): 113 | out_channels = channel_factor * 2 ** depth_index 114 | if out_channels > max_channels: out_channels = max_channels 115 | for res_index in range(n_residual[0]): 116 | self.layers.append( 117 | parts.ResBlock2d(current_channels, n_residual[1], kernel_size, activation=activation, affine=affine, **kwargs)) 118 | self.add_module('r-block{}-{}'.format(depth_index + 1, res_index + 1), self.layers[-1]) 119 | 120 | self.layers.append( 121 | parts.GeneralConvolution(current_channels, out_channels, kernel_size, (2, 2), activation=activation, 122 | padding=pt.nn.ReflectionPad2d, norm=pt.nn.InstanceNorm2d, 123 | convolution=pt.nn.Conv2d, affine=affine, **kwargs)) 124 | self.add_module('conv{}'.format(depth_index + 1), self.layers[-1]) 125 | current_channels = out_channels 126 | 127 | self.layers.append(parts.GlobalAveragePooling2d()) 128 | self.add_module('average-pooling', self.layers[-1]) 129 | 130 | self.layers.append(parts.Flatten()) 131 | self.add_module('flatten', self.layers[-1]) 132 | 133 | self.layers.append(pt.nn.Linear(current_channels, n_out)) 134 | self.add_module('linear', self.layers[-1]) 135 | 136 | def forward(self, x): 137 | 138 | out = x 139 | for layer in self.layers: 140 | out = layer(out) 141 | 142 | out = pt.sigmoid(out) if out.shape[-1] == 1 else pt.softmax(out, -1) 143 | 144 | return out 145 | 146 | 147 | class ImagePool(pt.nn.Module): 148 | 149 | def __init__(self, size, shape, write_probability=1): 150 | 151 | super(ImagePool, self).__init__() 152 | self.pool = Parameter(pt.rand(size, *shape), False) 153 | self.write_probability = write_probability 154 | 155 | def write(self, item): 156 | 157 | if pt.rand(1) <= self.write_probability: 158 | 159 | if item.shape[0] > 1: 160 | item = item[self.random_index(len(item)), ...] 161 | 162 | self.pool[self.random_index()] = item.detach() 163 | 164 | def sample(self, batch_size): 165 | samples = [] 166 | for _ in range(batch_size): 167 | samples.append(self.pool[self.random_index()]) 168 | samples = pt.cat(samples) 169 | return samples 170 | 171 | def random_index(self, size=None): 172 | 173 | size = size if size is not None else len(self.pool) 174 | return (pt.rand(1)*size).long() 175 | 176 | 177 | class CycleGAN(pt.nn.Module): 178 | 179 | def __init__(self, generator, discriminator, input_size, pool_size=64, pool_write_probability=1): 180 | 181 | super(CycleGAN, self).__init__() 182 | 183 | self.discriminator = {'hn': discriminator[0](n_out=1, **discriminator[1]), 184 | 'ln': discriminator[0](n_out=1, **discriminator[1])} 185 | self.generator = {'hn': generator[0](**generator[1]), 186 | 'ln': generator[0](**generator[1])} 187 | 188 | self.add_module('discriminator_ln', self.discriminator['ln']) 189 | self.add_module('discriminator_hn', self.discriminator['hn']) 190 | self.add_module('generator_ln', self.generator['ln']) 191 | self.add_module('generator_hn', self.generator['hn']) 192 | 193 | self.pool_size = pool_size 194 | self.pool_write_probability = pool_write_probability 195 | self.pool = {'ln': ImagePool(self.pool_size, input_size, write_probability=self.pool_write_probability), 196 | 'hn': ImagePool(self.pool_size, input_size, write_probability=self.pool_write_probability)} 197 | self.add_module('pool_ln', self.pool['ln']) 198 | self.add_module('pool_hn', self.pool['hn']) 199 | 200 | def generate(self, x, quality): 201 | return self.generator[quality](x) 202 | 203 | def discriminate(self, x, quality): 204 | return self.discriminator[quality](x) 205 | 206 | def cycle(self, x, start_quality): 207 | other = 'hn' if start_quality == 'ln' else 'ln' 208 | return self.generate(self.generate(x, other), start_quality) 209 | 210 | def discriminate_from_pool(self, quality, batch_size): 211 | return self.discriminate(self.pool[quality].sample(batch_size), quality) 212 | 213 | def _forward(self, x, target_quality): 214 | 215 | start_quality = 'hn' if target_quality == 'ln' else 'ln' 216 | generated = self.generate(x, target_quality) 217 | self.pool[target_quality].write(generated) 218 | 219 | real = self.discriminate(x, start_quality) 220 | fake = self.discriminate(generated, target_quality) 221 | pool_fake = self.discriminate_from_pool(target_quality, len(generated)) 222 | 223 | # CAUTION: real score is for a sample from the other domain 224 | scores = namedtuple('scores', ('real', 'fake', 'pool_fake')) 225 | prediction = namedtuple(target_quality, ('generated', 'scores')) 226 | 227 | return prediction(generated, scores(real, fake, pool_fake)) 228 | 229 | def forward(self, x): 230 | 231 | hn, ln = x 232 | 233 | generated_ln, prediction_ln = self._forward(hn, 'ln') 234 | generated_hn, prediction_hn = self._forward(ln, 'hn') 235 | cycled = self.generate(generated_ln, 'hn'), self.generate(generated_hn, 'ln') 236 | 237 | result = namedtuple('Result', ('cycled', 'hn_scores', 'ln_scores')) 238 | 239 | return result(cycled, prediction_hn, prediction_ln) 240 | 241 | 242 | class HDCycleGAN(pt.nn.Module): 243 | 244 | def __init__(self, generator, discriminator, input_size, pool_size=64, pool_write_probability=1): 245 | 246 | super(HDCycleGAN, self).__init__() 247 | 248 | self.discriminator = discriminator[0](n_out=3, **discriminator[1]) 249 | self.generator = {'hn': generator[0](**generator[1]), 250 | 'ln': generator[0](**generator[1])} 251 | 252 | self.add_module('discriminator', self.discriminator) 253 | self.add_module('generator_ln', self.generator['ln']) 254 | self.add_module('generator_hn', self.generator['hn']) 255 | 256 | self.pool_size = pool_size 257 | self.pool_write_probability = pool_write_probability 258 | self.pool = {'ln': ImagePool(self.pool_size, input_size, write_probability=self.pool_write_probability), 259 | 'hn': ImagePool(self.pool_size, input_size, write_probability=self.pool_write_probability)} 260 | self.add_module('pool_ln', self.pool['ln']) 261 | self.add_module('pool_hn', self.pool['hn']) 262 | 263 | def generate(self, x, quality): 264 | return self.generator[quality](x) 265 | 266 | def discriminate(self, x): 267 | return self.discriminator(x) 268 | 269 | def cycle(self, x, start_quality): 270 | other = 'hn' if start_quality == 'ln' else 'ln' 271 | return self.generate(self.generate(x, other), start_quality) 272 | 273 | def discriminate_from_pool(self, quality, batch_size): 274 | return self.discriminate(self.pool[quality].sample(batch_size)) 275 | 276 | def _forward(self, x, target_quality): 277 | 278 | generated = self.generate(x, target_quality) 279 | self.pool[target_quality].write(generated) 280 | 281 | real = self.discriminate(x) 282 | fake = self.discriminate(generated) 283 | pool_fake = self.discriminate_from_pool(target_quality, len(generated)) 284 | 285 | scores = namedtuple('scores', ('real', 'fake', 'pool_fake')) # caution: real score is for a sample from the other domain 286 | prediction = namedtuple(target_quality, ('generated', 'scores')) 287 | 288 | return prediction(generated, scores(real, fake, pool_fake)) 289 | 290 | def forward(self, x): 291 | 292 | hn, ln = x 293 | 294 | generated_ln, prediction_ln = self._forward(hn, 'ln') 295 | generated_hn, prediction_hn = self._forward(ln, 'hn') 296 | cycled = self.generate(generated_ln, 'hn'), self.generate(generated_hn, 'ln') 297 | 298 | result = namedtuple('Result', ('cycled', 'hn_scores', 'ln_scores')) 299 | 300 | return result(cycled, prediction_hn, prediction_ln) 301 | -------------------------------------------------------------------------------- /measurements.py: -------------------------------------------------------------------------------- 1 | #import pybm3d 2 | from skimage import filters, measure, restoration, morphology 3 | import numpy as np 4 | import torch as pt 5 | import os 6 | import sys 7 | sys.path.extend(['/media/data/Documents/Promotion/Project_Helpers/']) 8 | from functools import partial 9 | import pandas as pd 10 | import matplotlib.pyplot as plt 11 | from utils.registrator import Registrator 12 | from pt_models import Generator 13 | import importlib 14 | import importlib.util 15 | import h5py 16 | import imreg_dft as ird 17 | 18 | import multiprocessing as mp 19 | import warnings 20 | warnings.filterwarnings('ignore') 21 | from time import time 22 | 23 | 24 | def show(image, **kwargs): 25 | fig = plt.figure(figsize=kwargs.get('figsize', (8, 8))) 26 | plt.imshow(image, **kwargs) 27 | plt.axis('off') 28 | plt.show() 29 | 30 | 31 | def hist(image): 32 | fig = plt.figure(figsize=(8, 8)) 33 | plt.hist(image.ravel(), 100, range=(0.1, 1)) 34 | plt.show() 35 | 36 | 37 | def bilateral(image): 38 | sigma = restoration.estimate_sigma(image)*3 39 | denoised = restoration.denoise_bilateral(image.astype(np.float), sigma_color=sigma, multichannel=False) 40 | return denoised 41 | 42 | 43 | def wavelet(image): 44 | sigma = restoration.estimate_sigma(image)*1.5 45 | denoised = restoration.denoise_wavelet(image.astype(np.float), multichannel=False, sigma=sigma) 46 | return denoised 47 | 48 | 49 | def nl_means(image): 50 | sigma = restoration.estimate_sigma(image)*1.5 51 | denoised = restoration.denoise_nl_means(image.astype(np.float), h=sigma, multichannel=False) 52 | return denoised 53 | 54 | 55 | def median(image, filter_size=1): 56 | filter = morphology.disk(filter_size) 57 | image = filters.median(image.astype(np.float), selem=filter) 58 | image = image / image.max() 59 | 60 | return image 61 | 62 | 63 | def bm3d(image): 64 | sigma = restoration.estimate_sigma(image)*2 65 | denoised = pybm3d.bm3d.bm3d(image, sigma) 66 | denoised[np.isinf(denoised)] = 0 67 | denoised[np.isnan(denoised)] = 0 68 | denoised[denoised < 0] = 0 69 | return denoised 70 | 71 | 72 | def measurement_preparation(datapoint): 73 | image = datapoint.image if datapoint.mask is None else datapoint.image * datapoint.mask 74 | reference = datapoint.reference if datapoint.mask is None else datapoint.reference * datapoint.mask 75 | if datapoint.transformation is not None: 76 | image = ird.transform_img_dict(image, datapoint.transformation, bgval=0.0, order=3) 77 | 78 | return image, reference 79 | 80 | 81 | def psnr(datapoint): 82 | image, reference = measurement_preparation(datapoint) 83 | return measure.compare_psnr(reference, image) 84 | 85 | 86 | def ssim(datapoint): 87 | image, reference = measurement_preparation(datapoint) 88 | return measure.compare_ssim(image, reference) 89 | 90 | 91 | def cnr(datapoint): 92 | 93 | rois, background = [datapoint.image[roi > 0] for roi in datapoint.rois], datapoint.image[datapoint.background > 0] 94 | background_mean = background.mean() 95 | background_std = background.std() 96 | cnrs = [] 97 | for roi in rois: 98 | cnrs.append(np.abs(roi.mean() - background_mean) / np.sqrt(0.5*(roi.std()**2 + background_std)**2)) 99 | cnrs = np.array(cnrs) 100 | 101 | return cnrs.mean() 102 | 103 | 104 | def msr(datapoint): 105 | 106 | rois = [datapoint.image[roi > 0] for roi in datapoint.rois] 107 | msrs = [] 108 | for roi in rois: 109 | mean = roi.mean() 110 | std = roi.std() 111 | msrs.append(mean/std) 112 | msrs = np.array(msrs) 113 | 114 | return msrs.mean() 115 | 116 | 117 | class CycGAN(object): 118 | 119 | def __init__(self, checkpoint, config): 120 | dirname = os.path.dirname(config) 121 | sys.path.extend([dirname]) 122 | config = os.path.basename(config).split('.')[0] 123 | spec = importlib.util.spec_from_file_location(config, os.path.join(dirname, config)+'.py') 124 | config = importlib.util.module_from_spec(spec) 125 | spec.loader.exec_module(config) 126 | self.denoiser = Generator(**config.generator) 127 | self.checkpoints = [os.path.join(dirname, ckpt) for ckpt in os.listdir(dirname) if '.pt' in ckpt] 128 | checkpoint = pt.load(self.checkpoints[checkpoint])['model'] 129 | checkpoint = {'.'.join(key.split('.')[1:]): value for key, value in checkpoint.items() 130 | if 'generator_hq' in key} 131 | self.denoiser.load_state_dict(checkpoint) 132 | 133 | def __call__(self, image): 134 | 135 | denoiser = self.denoiser 136 | image = pt.from_numpy(image.copy())[None, None, ...] 137 | with pt.no_grad(): 138 | output = denoiser(image)[0, 0].numpy() 139 | return output 140 | 141 | 142 | class Datapoint(object): 143 | 144 | def __init__(self, key=None, image=None, reference=None, method=None, background=None, rois=None, mask=None, 145 | transformation=None): 146 | self.data = {} 147 | self.key = key 148 | self.method = method 149 | self.image = image.copy() 150 | self.mask = mask 151 | self.reference = reference 152 | self.rois = rois 153 | self.background = background 154 | self.contains_measurement = False 155 | self.transformation = transformation 156 | self.compute_time = 0 157 | 158 | @property 159 | def key(self): 160 | return self.data['key'] 161 | 162 | @key.setter 163 | def key(self, key): 164 | self.data['key'] = key 165 | 166 | @property 167 | def method(self): 168 | return self.data['method'] 169 | 170 | @method.setter 171 | def method(self, method): 172 | self.data['method'] = method 173 | 174 | @property 175 | def compute_time(self): 176 | return self.data['compute_time'] 177 | 178 | @compute_time.setter 179 | def compute_time(self, duration): 180 | self.data['compute_time'] = duration 181 | 182 | def extract_information(self): 183 | 184 | # test if some data is missing 185 | assert self.contains_measurement, f'no measurement was performed on this datapoint!' 186 | 187 | return self.data 188 | 189 | def add_measurement(self, name, value): 190 | self.data[name] = value 191 | self.contains_measurement = True 192 | 193 | def copy(self): 194 | 195 | new_datapoint = Datapoint(image=self.image, 196 | reference=self.reference, 197 | rois=self.rois, 198 | background=self.background, 199 | key=self.key, 200 | method=self.method, 201 | mask=self.mask, 202 | transformation=self.transformation) 203 | new_datapoint.data = self.data.copy() 204 | 205 | return new_datapoint 206 | 207 | 208 | class Analysis(object): 209 | 210 | def __init__(self, lq, hq, methods, measurements, output_path, n_processes=1, 211 | preprocess=lambda x: x, export_denoised=None): 212 | 213 | self.lq = self.open_storage(lq) 214 | self.hq = self.open_storage(hq) 215 | self.n_processes = n_processes 216 | self.export_denoised = export_denoised 217 | self.output_path = output_path 218 | self.output_dir = os.path.dirname(output_path) 219 | if not os.path.isdir(self.output_dir): 220 | os.makedirs(self.output_dir) 221 | 222 | self.measurements = dict(measurements) 223 | self.methods = dict(methods) 224 | self.preprocess = preprocess 225 | self.registrator = Registrator() 226 | self.lq_keys = list(self.lq.keys()) 227 | self.hq_keys = list(self.hq.keys()) 228 | 229 | # late binding because of mp 230 | self.hq.close() 231 | self.lq.close() 232 | self.lq = lq 233 | self.hq = hq 234 | 235 | @staticmethod 236 | def open_storage(filename): 237 | return h5py.File(filename, 'r', swmr=True, libver='latest') 238 | 239 | @staticmethod 240 | def get_rois(image, registrator): 241 | 242 | masked = registrator.segment(image, offset=1) 243 | masked = np.pad(masked, ((6, 6),), 'constant', constant_values=0) 244 | contours = registrator.get_contours(masked, offset=0.4, min_length=0.1) 245 | masks = [registrator.get_mask(contour, masked.shape)[6:-6, 6:-6] for contour in contours] 246 | 247 | return masks 248 | 249 | @staticmethod 250 | def get_background(image, registrator): 251 | 252 | inverted_mask = np.ones_like(image) 253 | inverted_mask[registrator.segment(image, offset=0) > 0] = 0 254 | 255 | return inverted_mask 256 | 257 | def apply_measurements(self, datapoint): 258 | 259 | for measurement_name, measurement in self.measurements.items(): 260 | result = measurement(datapoint) 261 | datapoint.add_measurement(measurement_name, result) 262 | 263 | result = datapoint.extract_information() 264 | return result 265 | 266 | def apply_denoising(self, method_name, datapoint): 267 | 268 | method = self.methods[method_name] 269 | datapoint = datapoint.copy() 270 | start = time() 271 | image = method(datapoint.image) 272 | duration = time() - start 273 | datapoint.image = image 274 | datapoint.compute_time = duration 275 | datapoint.method = method_name 276 | 277 | return datapoint 278 | 279 | def export_images(self, datapoints, index): 280 | 281 | indices = list(range(len(datapoints[1:]))) 282 | index = str(index) 283 | exported = {'index': index, 284 | 'key': datapoints[0].key} 285 | filedir = os.path.join(self.export_denoised, index) 286 | if not os.path.isdir(filedir): 287 | os.makedirs(filedir) 288 | 289 | # save original and reference 290 | datapoint = datapoints[0] 291 | image = datapoint.image.copy() 292 | plt.figure(figsize=(5.12, 5.12), frameon=False) 293 | plt.imshow(image, 'gray', interpolation='none') 294 | plt.axis('off') 295 | plt.tight_layout() 296 | plt.savefig(os.path.join(filedir, 'original.png'), dpi=100) 297 | plt.close() 298 | plt.clf() 299 | plt.cla() 300 | image = datapoint.reference.copy() 301 | plt.figure(figsize=(5.12, 5.12), frameon=False) 302 | plt.imshow(image, 'gray', interpolation='none') 303 | plt.axis('off') 304 | plt.tight_layout() 305 | plt.savefig(os.path.join(filedir, 'reference.png'), dpi=100) 306 | plt.close() 307 | plt.clf() 308 | plt.cla() 309 | 310 | for datapoint in datapoints[1:]: 311 | inner_index = indices.pop(np.random.randint(len(indices))) 312 | filename = os.path.join(filedir, f'{inner_index}.png') 313 | 314 | image = datapoint.image.copy() 315 | plt.figure(figsize=(5.12, 5.12), frameon=False) 316 | plt.imshow(image, 'gray', interpolation='none') 317 | plt.axis('off') 318 | plt.tight_layout() 319 | plt.savefig(filename, dpi=100) 320 | exported[datapoint.method] = f'{inner_index}.png' 321 | plt.close() 322 | plt.clf() 323 | plt.cla() 324 | 325 | return exported 326 | 327 | def __call__(self, acceptance=None): 328 | 329 | lq = self.open_storage(self.lq) 330 | hq = self.open_storage(self.hq) 331 | 332 | print('beginning analysis...') 333 | 334 | total_slices = len(self.lq_keys) 335 | 336 | print(f'\tnumber of datapoins:\t{total_slices}') 337 | 338 | # initialize result list for accumulating measurements 339 | results = [] 340 | exported = [] 341 | indices = list(range(len(self.lq_keys))) 342 | 343 | # cycle over samples 344 | for i, key in enumerate(self.lq_keys[150:]): 345 | 346 | i+= 150 347 | # skip failed registrations 348 | entry = lq[key] 349 | if acceptance is not None: 350 | if entry.attrs['difference'] > acceptance or np.isnan(entry.attrs['difference']): 351 | continue 352 | 353 | transform = dict(entry.attrs) 354 | transform.pop('difference') 355 | transform.pop('frames') 356 | image = self.preprocess(entry.value) 357 | reference = self.preprocess(hq[self.hq_keys[i]].value) 358 | 359 | rois = self.get_rois(reference, self.registrator) 360 | background = self.get_background(reference, self.registrator) 361 | mask = self.registrator.segment(reference, offset=-1) 362 | 363 | # initialize original image as datapoint 364 | datapoints = [Datapoint(key=key, reference=reference, 365 | rois=rois, background=background, 366 | method='original', image=image, transformation=transform, mask=mask)] 367 | 368 | # generate all denoised images 369 | pool = mp.Pool(self.n_processes) 370 | denoising = partial(self.apply_denoising, datapoint=datapoints[0]) 371 | datapoints += pool.map(denoising, self.methods.keys()) 372 | 373 | # perform measurements on all denoised images 374 | pool = mp.Pool(self.n_processes) 375 | results += pool.map(self.apply_measurements, datapoints) 376 | 377 | if self.export_denoised is not None: 378 | index = indices.pop(np.random.randint(len(indices))) 379 | exported_row = self.export_images(datapoints, index) 380 | exported.append(exported_row) 381 | 382 | print('\r\tprogress: \t{}%'.format(round(100 * i / total_slices, 2)), end='') 383 | 384 | # convert results to a dataframe and save 385 | results_frame = pd.DataFrame(data=results) 386 | results_frame.to_csv(self.output_path) 387 | if self.export_denoised is not None: 388 | exported_frame = pd.DataFrame(data=exported) 389 | exported_frame.to_csv(os.path.join(self.export_denoised, 'exports.csv')) 390 | 391 | hq.close() 392 | lq.close() 393 | 394 | return results_frame 395 | 396 | 397 | if __name__ == '__main__': 398 | 399 | checkpoint = 132 400 | config = '/media/network/DL_PC/Results/ilja/pt-cycoct/run_024/config.py' 401 | 402 | methods = {'median': median, 403 | 'ours': CycGAN(checkpoint, config), 404 | 'wavelet': wavelet, 405 | 'bilateral': bilateral, 406 | 'nl_means': nl_means, 407 | 'bm3d': bm3d} 408 | measurements = {'PSNR': psnr, 409 | 'CNR': cnr, 410 | 'MSR': msr, 411 | 'SSIM': ssim} 412 | savefile = './measurements.csv' 413 | lq = '/media/network/DL_PC/Datasets/oct_quality_validation/low.hdf5' 414 | hq = '/media/network/DL_PC/Datasets/oct_quality_validation/high.hdf5' 415 | 416 | 417 | def preprocess(image): 418 | image = image / image.max() 419 | mean = image[image>0].mean() 420 | std = image[image>0].std() 421 | level = mean - 0.5*std 422 | image = np.clip(image, level, 1.0) - level 423 | image = image / image.max() 424 | return image 425 | 426 | 427 | analysis = Analysis(lq, hq, methods, measurements, savefile, 4, preprocess=preprocess, export_denoised='./exports/') 428 | results = analysis() 429 | 430 | --------------------------------------------------------------------------------