├── README.md ├── datasets ├── dfc.py ├── nrw.py └── transforms.py ├── fid_comp.py ├── loss.py ├── models ├── arch.py ├── common.py ├── discriminator.py ├── generator.py └── unet.py ├── options ├── common.py ├── gan.py └── segment.py ├── pix_acc_iou_comp.py ├── synth_img.py ├── test.py ├── test_unet.py ├── train.py ├── train_unet.py ├── trainer.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Building a Parallel Universe - Image Synthesis from Land Cover Maps and Auxiliary Raster Data 2 | 3 | ![examples](../assets/nrw_sar_rgb_comp.jpg?raw=true) 4 | 5 | This repository contains the code for our paper [Building a Parallel Universe - Image Synthesis from Land Cover Maps and Auxiliary Raster Data](https://arxiv.org/abs/2011.11314) 6 | 7 | ## Installation 8 | 9 | Just clone this repository and also make sure you have a recent version of [PIL](https://github.com/python-pillow/Pillow) with JPEG2000 support. 10 | 11 | ```bash 12 | git clone https://github.com/gbaier/rs_img_synth.git 13 | ``` 14 | 15 | ## Datasets 16 | 17 | We employ two datasets of high and medium resolution. 18 | Both are freely available at the IEEE DataPort. 19 | 20 | ### GeoNRW 21 | 22 | The [GeoNRW](https://ieee-dataport.org/open-access/geonrw) dataset consists of 1m aerial photographs, digital elevation models and land cover maps. 23 | It can additionally be augmented with TerraSAR-X spotlight acquisitions, provided you have the corresponding data. 24 | In that case, please visit https://github.com/gbaier/geonrw to check how to process that data. 25 | 26 | ### DFC2020 27 | 28 | This dataset was used for the IEEE data fusion contest in 2020 and consists of Sentinel-1 and Sentinel-2 patches of 256x256pixels, together with semantic maps. 29 | All of roughly 10m resolution. 30 | You can get the dataset from here https://ieee-dataport.org/competitions/2020-ieee-grss-data-fusion-contest, where you need to download **DFC_Public_Dataset.zip**. 31 | 32 | ## Training new models 33 | 34 | 1. Training the models as in the paper requires a multiple GPUs due to memory constraints. 35 | Alternatively, the batch size or model capacity can be adjusted using command line parameters. 36 | 37 | 2. We advise to use multiple workers when training the NRW dataset. 38 | Reading the JPEG2000 files of the GeoNRW's aerial photographs seems to be CPU intensive, and only a single worker bottlenecks the GPU. 39 | 40 | 3. Both datasets have different types of input and output that the generator can consume or produce. 41 | These can be set using the corresponding command line parameters. 42 | The following table lists all of them. 43 | 44 | Dataset | Input | Output 45 | ------- | -------- | -------- 46 | NRW | dem, seg | rgb, sar 47 | DFC2020 | seg | rgb, sar 48 | 49 | 4. The *crop* and *resize* parameters are ignored for the DFC2020 dataset. 50 | 51 | 5. The following two examples show how to use the GeoNRW dataset to generate RGB images from digital elevation models and land cover maps 52 | ```bash 53 | python -m torch.distributed.launch --nproc_per_node=8 train.py \ 54 | --crop 256 \ 55 | --resize 256 \ 56 | --epochs 200 \ 57 | --batch_size 32 \ 58 | --model_cap 64 \ 59 | --lbda 5.0 \ 60 | --num_workers 4 \ 61 | --dataset 'nrw' \ 62 | --dataroot './data/geonrw' \ 63 | --input 'dem' 'seg' \ 64 | --output 'rgb' 65 | ``` 66 | and dual-pol SAR images from land cover maps alone using the dataset of the 2020 IEEE GRSS data fusion contest 67 | ```bash 68 | python -m torch.distributed.launch --nproc_per_node=8 train.py \ 69 | --epochs 200 \ 70 | --batch_size 32 \ 71 | --model_cap 64 \ 72 | --lbda 5.0 \ 73 | --num_workers 0 \ 74 | --dataset 'dfc' \ 75 | --dataroot './data/DFC_Public_Dataset' \ 76 | --input 'seg' \ 77 | --output 'sar 78 | ``` 79 | The training script creates a directory named something like **nrw_seg_dem2rgb_bs32_ep200_cap64_2020_09_05_07_58**, where the training configuration, logs and the generator and discriminator models will be stored. 80 | 81 | ## Testing 82 | 83 | Run the testing script with a generator model as a command line argument. 84 | ```bash 85 | python test.py results/nrw_seg_dem2rgb_bs32_ep200_cap64_2020_09_05_07_58/model_gnet.pt 86 | ``` 87 | This goes through the testing set and plots the results in the corresponding directory. 88 | 89 | ## Computing FID scores 90 | 91 | 1. Train a U-Net segmentation network 92 | ```bash 93 | python train_unet.py \ 94 | --crop 256 \ 95 | --resize 256 \ 96 | --epochs 100 \ 97 | --batch_size 32 \ 98 | --num_workers 4 \ 99 | --dataset 'nrw' \ 100 | --dataroot './data/geonrw' \ 101 | --input 'rgb' 102 | ``` 103 | 1. Compute FID scores by passing the generator to be tested and the just trained U-Net as arguments 104 | ```bash 105 | python fid_comp.py \ 106 | path_to_generator/model_gnet.pt \ 107 | path_to_unet/nrw_unet.pt 108 | ``` 109 | 1. Compute intersection-over-union and pixel accuracy 110 | ```bash 111 | python pix_acc_iou_comp.py \ 112 | path_to_generator/model_gnet.pt \ 113 | path_to_unet/nrw_unet.pt 114 | output_dir 115 | ``` 116 | 1. You can *optionally* also compute the segmentation results 117 | ```bash 118 | python test_unet.py \ 119 | path_to_unet/nrw_unet.pt 120 | ``` 121 | which stores the segmentation results in the model's directory. 122 | 123 | 124 | ## Codes structure 125 | 126 | ### GAN training and image synthesis 127 | * `train.py` and `test.py` for training and testing image synthesis. 128 | * `options/` defines command line arguments for training the GANs and U-Net. 129 | * `datasets/` contains data loaders and transforms. 130 | * `models/` contains the various network architectures. 131 | * `loss.py` defines the GAN loss functions. 132 | * `trainer.py` our general GAN trainer. 133 | * `utils.py` utility functions. 134 | 135 | ### Numerical analysis 136 | * `train_unet.py` and `test_unet.py` for training and testing U-Net.. 137 | * `fid_comp.py` calculates Fréchet inception distances using a pretrained U-Net 138 | * `pix_acc_iou_comp.py` calculates pixel accuarcy and IoU using a pretrained U-Net 139 | 140 | 141 | ## Citation 142 | 143 | In case you use this code in your research please consider citing 144 | ``` 145 | @misc{baier2020building, 146 | title={Building a Parallel Universe - Image Synthesis from Land Cover Maps and Auxiliary Raster Data}, 147 | author={Gerald Baier and Antonin Deschemps and Michael Schmitt and Naoto Yokoya}, 148 | year={2020}, 149 | eprint={2011.11314}, 150 | archivePrefix={arXiv}, 151 | primaryClass={cs.CV} 152 | } 153 | ``` 154 | -------------------------------------------------------------------------------- /datasets/dfc.py: -------------------------------------------------------------------------------- 1 | """ The dataset of the IEEE GRSS data fusion contest """ 2 | 3 | import pathlib 4 | import logging 5 | 6 | import rasterio 7 | import numpy as np 8 | import matplotlib 9 | from torchvision.datasets.vision import VisionDataset 10 | from torchvision.datasets.utils import verify_str_arg 11 | 12 | logging.getLogger("rasterio").setLevel(logging.WARNING) 13 | 14 | classes = [ 15 | "Forest", 16 | "Shrubland", 17 | "Savanna", 18 | "Grassland", 19 | "Wetlands", 20 | "Croplands", 21 | "Urban/Built-up", 22 | "Snow/Ice", 23 | "Barren", 24 | "Water", 25 | ] 26 | 27 | 28 | # check http://www.grss-ieee.org/community/technical-committees/data-fusion/2020-ieee-grss-data-fusion-contest/ 29 | lcov_cmap = matplotlib.colors.ListedColormap( 30 | [ 31 | "#009900", # Forest 32 | "#c6b044", # Shrubland 33 | "#fbff13", # Savanna 34 | "#b6ff05", # Grassland 35 | "#27ff87", # Wetlands 36 | "#c24f44", # Croplands 37 | "#a5a5a5", # Urban/Built-up 38 | "#69fff8", # Snow/Ice 39 | "#f9ffa4", # Barren 40 | "#1c0dff", # Water 41 | ] 42 | ) 43 | lcov_norm = matplotlib.colors.Normalize(vmin=1, vmax=10) 44 | 45 | N_LABELS = 10 + 1 # +1 due to 0 having no label 46 | N_CHANNELS = {"rgb": 3, "sar": 2, "seg": N_LABELS} 47 | 48 | 49 | class DFC2020(VisionDataset): 50 | """ IEEE GRSS data fusion contest dataset 51 | 52 | http://www.grss-ieee.org/community/technical-committees/data-fusion/2020-ieee-grss-data-fusion-contest/ 53 | 54 | Parameters 55 | ---------- 56 | root : string 57 | Root directory of dataset 58 | split : string, optional 59 | Image split to use, ``train`` or ``test`` 60 | transforms : callable, optional 61 | A function/transform that takes input sample and returns a transformed version. 62 | """ 63 | 64 | splits = ["train", "test"] 65 | datatypes = ["s1", "s2", "dfc"] 66 | 67 | def __init__(self, root, split="train", transforms=None): 68 | super().__init__(pathlib.Path(root), transforms=transforms) 69 | verify_str_arg(split, "split", self.splits) 70 | self.split = split 71 | self.tif_paths = {dt: self._get_tif_paths(dt) for dt in self.datatypes} 72 | 73 | def _get_tif_paths(self, datatype): 74 | if self.split == "test": 75 | pat = "ROIs0000*/{}_*/*9.tif".format(datatype) 76 | else: 77 | pat = "ROIs0000*/{}_*/*[0-8].tif".format(datatype) 78 | return list(sorted(self.root.glob(pat))) 79 | 80 | def __len__(self): 81 | return len(self.tif_paths["dfc"]) 82 | 83 | def __getitem__(self, index): 84 | def read_tif_as_np_array(path): 85 | with rasterio.open(path) as src: 86 | return src.read() 87 | 88 | sample = { 89 | dt: read_tif_as_np_array(self.tif_paths[dt][index]) for dt in self.datatypes 90 | } 91 | 92 | # Rename keys and exctract rgb bands from Sentinel-2. 93 | # Also move channels to last dimension, which is expected by pytorch's to_tensor 94 | sample["sar"] = ( 95 | self.sar_norm(sample.pop("s1")).transpose((1, 2, 0)).astype(np.float32) 96 | ) 97 | sample["rgb"] = self.extract_rgb_from_s2(sample.pop("s2")) 98 | sample["seg"] = sample.pop("dfc").transpose((1, 2, 0)) 99 | 100 | if self.transforms: 101 | sample = self.transforms(sample) 102 | return sample 103 | 104 | @staticmethod 105 | def sar_norm(arr): 106 | """ normalizes SAR to the interval [0, 1] """ 107 | arr = np.clip(arr, -20, 5) 108 | arr = (arr + 20.0) / 25.0 109 | return arr 110 | 111 | @staticmethod 112 | def extract_rgb_from_s2(s2_bands): 113 | """ extracts RGB bands from Sentinel-2 """ 114 | rgb = np.empty((*s2_bands.shape[1:], 3), dtype=s2_bands.dtype) 115 | rgb[:, :, 0] = s2_bands[3] 116 | rgb[:, :, 1] = s2_bands[2] 117 | rgb[:, :, 2] = s2_bands[1] 118 | rgb = np.clip(rgb, 0, 3500) 119 | rgb = rgb / 3500 120 | return rgb.astype(np.float32) 121 | 122 | @staticmethod 123 | def sar2rgb(arr): 124 | """ converts SAR to a plotable RGB image """ 125 | co_pol = arr[:, :, 0] 126 | cx_pol = arr[:, :, 1] 127 | 128 | rgb = np.empty((*arr.shape[:2], 3), dtype=arr.dtype) 129 | rgb[:, :, 0] = cx_pol + 0.25 130 | rgb[:, :, 1] = co_pol 131 | rgb[:, :, 2] = cx_pol + 0.25 132 | rgb = np.clip(rgb, 0.0, 1.0) 133 | return rgb.astype(np.float32) 134 | 135 | @staticmethod 136 | def seg2rgb(arr): 137 | """ converts segmentation map to a plotable RGB image """ 138 | return lcov_cmap(lcov_norm(np.squeeze(arr)))[:, :, :3] 139 | -------------------------------------------------------------------------------- /datasets/nrw.py: -------------------------------------------------------------------------------- 1 | """ The GeoNRW dataset """ 2 | 3 | import pathlib 4 | import itertools 5 | from PIL import Image 6 | 7 | import matplotlib 8 | import matplotlib.cm 9 | import numpy as np 10 | 11 | from torchvision.datasets.utils import verify_str_arg 12 | from torchvision.datasets.vision import VisionDataset 13 | 14 | classes = [ 15 | "Forest", 16 | "Water", 17 | "Agricultural", 18 | "Urban", 19 | "Grassland", 20 | "Railway", 21 | "Highway", 22 | "Airport, shipyard", 23 | "Roads", 24 | "Buildings", 25 | ] 26 | 27 | lcov_cmap = matplotlib.colors.ListedColormap( 28 | [ 29 | "#2ca02c", # matplotlib green for forest 30 | "#1f77b4", # matplotlib blue for water 31 | "#8c564b", # matplotlib brown for agricultural 32 | "#7f7f7f", # matplotlib gray residential_commercial_industrial 33 | "#bcbd22", # matplotlib olive for grassland_swamp_shrubbery 34 | "#ff7f0e", # matplotlib orange for railway_trainstation 35 | "#9467bd", # matplotlib purple for highway_squares 36 | "#17becf", # matplotlib cyan for airport_shipyard 37 | "#d62728", # matplotlib red for roads 38 | "#e377c2", # matplotlib pink for buildings 39 | ] 40 | ) 41 | lcov_norm = matplotlib.colors.Normalize(vmin=1, vmax=10) 42 | 43 | # number of classes + invalid 44 | N_LABELS = 11 45 | 46 | N_CHANNELS = {"rgb": 3, "sar": 1, "dem": 1, "seg": N_LABELS} 47 | 48 | 49 | class NRW(VisionDataset): 50 | """ Optical, SAR, LiDAR and landcover data from North Rhine-Westphalia. 51 | 52 | There are fewer SAR images then for the other types of data. 53 | If you don't need SAR, set include_sar to ``False`` for a bigger dataset. 54 | 55 | Parameters 56 | ---------- 57 | root : string 58 | Root directory of dataset 59 | split : string, optional 60 | Image split to use, ``train`` or ``test`` 61 | include_sar : boolean, optional 62 | Include SAR imagery when returning samples 63 | transforms : callable, optional 64 | A function/transform that takes input sample and returns a transformed version. 65 | 66 | """ 67 | 68 | splits = ["train", "test"] 69 | 70 | train_list = [ 71 | "aachen", 72 | "bergisch", 73 | "bielefeld", 74 | "bochum", 75 | "bonn", 76 | "borken", 77 | "bottrop", 78 | "coesfeld", 79 | "dortmund", 80 | "dueren", 81 | "duisburg", 82 | "ennepetal", 83 | "erftstadt", 84 | "essen", 85 | "euskirchen", 86 | "gelsenkirchen", 87 | "guetersloh", 88 | "hagen", 89 | "hamm", 90 | "heinsberg", 91 | "herford", 92 | "hoexter", 93 | "kleve", 94 | "koeln", 95 | "krefeld", 96 | "leverkusen", 97 | "lippetal", 98 | "lippstadt", 99 | "lotte", 100 | "moenchengladbach", 101 | "moers", 102 | "muelheim", 103 | "muenster", 104 | "oberhausen", 105 | "paderborn", 106 | "recklinghausen", 107 | "remscheid", 108 | "siegen", 109 | "solingen", 110 | "wuppertal", 111 | ] 112 | 113 | test_list = ["duesseldorf", "herne", "neuss"] 114 | 115 | # Convert segmentation map to different PIL mode. 116 | # Otherwise PyTorch later normalizes 117 | readers = { 118 | "sar": lambda path: Image.open(path).copy(), 119 | "rgb": lambda path: Image.open(path).convert("RGB"), 120 | "dem": lambda path: Image.open(path).copy(), 121 | "seg": lambda path: Image.open(path).convert("I;16"), 122 | } 123 | 124 | filenames = { 125 | "sar": lambda utm_coords: "{}_{}_sar.tif".format(*utm_coords), 126 | "rgb": lambda utm_coords: "{}_{}_rgb.jp2".format(*utm_coords), 127 | "dem": lambda utm_coords: "{}_{}_dem.tif".format(*utm_coords), 128 | "seg": lambda utm_coords: "{}_{}_seg.tif".format(*utm_coords), 129 | } 130 | 131 | def __init__(self, root, split="train", include_sar=False, transforms=None): 132 | super().__init__(pathlib.Path(root), transforms=transforms) 133 | verify_str_arg(split, "split", self.splits) 134 | if split == "test": 135 | self.city_names = self.test_list 136 | elif split == "train": 137 | self.city_names = self.train_list 138 | self.datatypes = ["rgb", "dem", "seg"] 139 | if include_sar: 140 | self.file_list = self._get_file_list("*sar.tif") 141 | self.datatypes.append("sar") 142 | else: 143 | self.file_list = self._get_file_list("*rgb.jp2") 144 | 145 | def _get_file_list(self, pattern): 146 | # iterate over citynames 147 | return list( 148 | sorted( 149 | itertools.chain.from_iterable( 150 | (self.root / cn).glob(pattern) for cn in self.city_names 151 | ) 152 | ) 153 | ) 154 | 155 | def __len__(self): 156 | return len(self.file_list) 157 | 158 | def __getitem__(self, index): 159 | path = self.file_list[index] 160 | utm_coords = path.stem.split("_")[:2] 161 | 162 | sample = {} 163 | for datatype in self.datatypes: 164 | path = path.parents[0] / self.filenames[datatype](utm_coords) 165 | sample[datatype] = self.readers[datatype](path) 166 | 167 | try: 168 | sample["sar"] = Image.fromarray(self.sar_norm(sample["sar"])) 169 | except KeyError: 170 | pass 171 | 172 | if self.transforms: 173 | sample = self.transforms(sample) 174 | 175 | return sample 176 | 177 | @staticmethod 178 | def sar_norm(arr): 179 | """ normalizes SAR to the interval [0, 1] """ 180 | arr = 20.0 * np.log10(arr) 181 | return np.clip(arr / 100.0, 0, 1) 182 | 183 | @staticmethod 184 | def seg2rgb(segm): 185 | """ converts segmentation map to a plotable RGB image """ 186 | return lcov_cmap(lcov_norm(segm))[:, :, :3] 187 | 188 | @staticmethod 189 | def depth2rgb(depth): 190 | """ converts DEM to a plotable RGB image """ 191 | depth -= depth.min() 192 | depth /= depth.max() 193 | return matplotlib.cm.viridis(depth)[:, :, :3] 194 | 195 | @staticmethod 196 | def sar2rgb(sar): 197 | """ converts SAR to a plotable RGB image """ 198 | sar = np.squeeze(np.clip(255 * sar, 0, 255).astype(np.uint8)) 199 | 200 | return matplotlib.cm.gray(sar)[:, :, :3] 201 | -------------------------------------------------------------------------------- /datasets/transforms.py: -------------------------------------------------------------------------------- 1 | """ Transforms for GeoNRW and DFC2020 samples. 2 | Samples are dictionaries, with the keys as the datatype (seg, rgb, dem or sar) 3 | """ 4 | 5 | import random 6 | import numbers 7 | 8 | import PIL.Image 9 | import torchvision.transforms.functional as TF 10 | 11 | # Segmentation maps require nearest neighbour resampling to preserve discrete classes. 12 | # Available resample methods 13 | # https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-filters. 14 | RESAMPLE_METHODS = { 15 | "seg": PIL.Image.NEAREST, 16 | "rgb": PIL.Image.BILINEAR, 17 | "dem": PIL.Image.BILINEAR, 18 | "sar": PIL.Image.BILINEAR, 19 | } 20 | 21 | 22 | class ToTensor: 23 | def __call__(self, sample): 24 | return {k: TF.to_tensor(v) for k, v in sample.items()} 25 | 26 | 27 | class TensorApply: 28 | """ Applies functions to some datatype of a sample """ 29 | 30 | valid_kwargs = set(RESAMPLE_METHODS.keys()) 31 | 32 | def __init__(self, **kwargs): 33 | """ Pass functions as keyword arguments. 34 | The key defines the datatype the function operatoes on 35 | 36 | """ 37 | if not self.valid_kwargs.issuperset(set(kwargs.keys())): 38 | raise ValueError( 39 | "Keywords must be chosen from {}".format(self.valid_kwargs) 40 | ) 41 | self.funcs = kwargs 42 | 43 | def __call__(self, sample): 44 | for key, func in self.funcs.items(): 45 | sample[key] = func(sample[key]) 46 | 47 | return sample 48 | 49 | 50 | class RandomCrop: 51 | def __init__(self, size): 52 | if isinstance(size, numbers.Number): 53 | self.size = (int(size), int(size)) 54 | else: 55 | self.size = size 56 | 57 | def __call__(self, sample): 58 | w, h = next(iter(sample.values())).size 59 | try: 60 | i = random.randrange(0, h - self.size[1]) 61 | except ValueError: # empty range because image is too small for crop 62 | i = 0 63 | try: 64 | j = random.randrange(0, w - self.size[0]) 65 | except ValueError: # empty range 66 | j = 0 67 | 68 | return {k: TF.crop(v, i, j, *self.size) for k, v in sample.items()} 69 | 70 | 71 | class RandomHorizontalFlip: 72 | def __init__(self, p=0.5): 73 | self.p = p 74 | 75 | def __call__(self, sample): 76 | if random.random() < self.p: 77 | return {k: TF.hflip(v) for k, v in sample.items()} 78 | return sample 79 | 80 | 81 | class CenterCrop: 82 | def __init__(self, size): 83 | if isinstance(size, numbers.Number): 84 | self.size = (int(size), int(size)) 85 | else: 86 | self.size = size 87 | 88 | def __call__(self, sample): 89 | return {k: TF.center_crop(v, self.size) for k, v in sample.items()} 90 | 91 | 92 | class Resize: 93 | def __init__(self, size): 94 | if isinstance(size, numbers.Number): 95 | self.size = (int(size), int(size)) 96 | else: 97 | self.size = size 98 | 99 | def __call__(self, sample): 100 | # resize crop to desired dimensions 101 | return { 102 | k: TF.resize(v, self.size, RESAMPLE_METHODS[k]) for k, v in sample.items() 103 | } 104 | -------------------------------------------------------------------------------- /fid_comp.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pathlib 3 | 4 | import yaml 5 | import torch 6 | import torch.nn.functional as TF 7 | import tqdm 8 | 9 | import numpy as np 10 | from scipy import linalg 11 | 12 | import options.gan 13 | import datasets.nrw 14 | import datasets.dfc 15 | import models.unet 16 | import utils 17 | 18 | 19 | ########################## 20 | # # 21 | # Comannd line arguments # 22 | # # 23 | ########################## 24 | 25 | parser = argparse.ArgumentParser( 26 | description="compare real and fake images using a segmentation network" 27 | ) 28 | parser.add_argument("generator", help="pt file of the generator model") 29 | parser.add_argument("segmentor", help="pt file of the segmentation model") 30 | args = parser.parse_args() 31 | 32 | if torch.cuda.is_available(): 33 | device = torch.device("cuda") 34 | else: 35 | raise RuntimeError("This scripts expects CUDA to be available") 36 | 37 | # Get the directory of the generator and load its configuration. 38 | # From the configuration we also get the dataset, etc. 39 | GEN_DIR = pathlib.Path(args.generator).absolute().parents[0] 40 | with open(GEN_DIR / "config.yml", "r") as stream: 41 | CONFIG = yaml.load(stream) 42 | 43 | 44 | ######### 45 | # # 46 | # U-Net # 47 | # # 48 | ######### 49 | 50 | # Create a U-Net model and load weights 51 | dset_class = getattr(datasets, CONFIG["dataset"]["name"]) 52 | n_labels = dset_class.N_LABELS 53 | # Number of channels from the generator's output defines 54 | # The U-Net's number of input channels 55 | input_nc = dset_class.N_CHANNELS[CONFIG["dataset"]["output"]] 56 | 57 | seg_model = models.unet.UNet(input_nc, n_labels).to(device) 58 | seg_model.load_state_dict(torch.load(args.segmentor).state_dict()) 59 | seg_model.to(device).eval() 60 | # return intermediate features to compute FID 61 | seg_model.return_intermed = True 62 | 63 | ############# 64 | # # 65 | # Generator # 66 | # # 67 | ############# 68 | 69 | gen_model = options.gan.get_generator(CONFIG) 70 | # remove distributed wrapping, i.e. module. from keynames 71 | state_dict = utils.unwrap_state_dict(torch.load(args.generator)) 72 | gen_model.load_state_dict(state_dict) 73 | gen_model = gen_model.to(device).eval() 74 | 75 | ################ 76 | # # 77 | # Dataset prep # 78 | # # 79 | ################ 80 | 81 | _, test_transforms = options.common.get_transforms(CONFIG) 82 | dataset_test = options.common.get_dataset( 83 | CONFIG, split="test", transforms=test_transforms 84 | ) 85 | # dataset_test = torch.utils.data.Subset(dataset_test, list(range(16))) 86 | 87 | BATCH_SIZE = 8 88 | 89 | test_dataloader = torch.utils.data.DataLoader( 90 | dataset_test, batch_size=BATCH_SIZE, num_workers=4 91 | ) 92 | 93 | 94 | def process_intermed_features(intermed_features): 95 | """ processes intermediate features before computing FID 96 | 97 | Applies global average pooling to the features of all layers and 98 | concatenates the resulting feature maps to a single vector. 99 | 100 | Parameters 101 | ---------- 102 | 103 | intermed_features: list of pytorch tensors 104 | each element of the list contains a batch of intermediate features from a different layer 105 | 106 | """ 107 | 108 | # global average pooling of spatial dimensions 109 | pooled = [TF.avg_pool2d(t, t.shape[-2:]).squeeze() for t in intermed_features] 110 | 111 | # concatenate features of different layers 112 | concat = torch.cat(pooled, dim=1) 113 | 114 | return concat.tolist() 115 | 116 | 117 | # Taken from https://github.com/mseitzer/pytorch-fid/blob/master/pytorch_fid/fid_score.py 118 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 119 | """Numpy implementation of the Frechet Distance. 120 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 121 | and X_2 ~ N(mu_2, C_2) is 122 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 123 | Stable version by Dougal J. Sutherland. 124 | Params: 125 | -- mu1 : Numpy array containing the activations of a layer of the 126 | inception net (like returned by the function 'get_predictions') 127 | for generated samples. 128 | -- mu2 : The sample mean over activations, precalculated on an 129 | representative data set. 130 | -- sigma1: The covariance matrix over activations for generated samples. 131 | -- sigma2: The covariance matrix over activations, precalculated on an 132 | representative data set. 133 | Returns: 134 | -- : The Frechet Distance. 135 | """ 136 | 137 | mu1 = np.atleast_1d(mu1) 138 | mu2 = np.atleast_1d(mu2) 139 | 140 | sigma1 = np.atleast_2d(sigma1) 141 | sigma2 = np.atleast_2d(sigma2) 142 | 143 | assert ( 144 | mu1.shape == mu2.shape 145 | ), "Training and test mean vectors have different lengths" 146 | assert ( 147 | sigma1.shape == sigma2.shape 148 | ), "Training and test covariances have different dimensions" 149 | 150 | diff = mu1 - mu2 151 | 152 | # Product might be almost singular 153 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 154 | if not np.isfinite(covmean).all(): 155 | msg = ( 156 | "fid calculation produces singular product; " 157 | "adding %s to diagonal of cov estimates" 158 | ) % eps 159 | print(msg) 160 | offset = np.eye(sigma1.shape[0]) * eps 161 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 162 | 163 | # Numerical error might give slight imaginary component 164 | if np.iscomplexobj(covmean): 165 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 166 | m = np.max(np.abs(covmean.imag)) 167 | raise ValueError("Imaginary component {}".format(m)) 168 | covmean = covmean.real 169 | 170 | tr_covmean = np.trace(covmean) 171 | 172 | return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean 173 | 174 | 175 | # Taken from https://github.com/mseitzer/pytorch-fid/blob/master/pytorch_fid/fid_score.py 176 | def calculate_activation_statistics(features): 177 | """Calculation of the statistics used by the FID. 178 | Returns: 179 | -- mu : The mean over samples of the activations of the pool_3 layer of 180 | the inception model. 181 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 182 | the inception model. 183 | """ 184 | features = np.array(features) 185 | mu = np.mean(features, axis=0) 186 | sigma = np.cov(features, rowvar=False) 187 | return mu, sigma 188 | 189 | 190 | ########### 191 | # # 192 | # Testing # 193 | # # 194 | ########### 195 | 196 | with torch.no_grad(): 197 | real = [] 198 | fake = [] 199 | for idx, sample in tqdm.tqdm( 200 | enumerate(test_dataloader), total=len(test_dataloader) 201 | ): 202 | sample = {k: v.to(device) for k, v in sample.items()} 203 | 204 | # generator fake images 205 | gen_input = {dt: sample[dt] for dt in CONFIG["dataset"]["input"]} 206 | gen_output = gen_model(gen_input) 207 | 208 | # get features from real and fake images 209 | seg_real, features_real = seg_model(sample[CONFIG["dataset"]["output"]]) 210 | seg_fake, features_fake = seg_model(gen_output) 211 | 212 | real += process_intermed_features(features_real) 213 | fake += process_intermed_features(features_fake) 214 | 215 | mu_real, sigma_real = calculate_activation_statistics(real) 216 | mu_fake, sigma_fake = calculate_activation_statistics(fake) 217 | 218 | fid_value = calculate_frechet_distance(mu_real, sigma_real, mu_fake, sigma_fake) 219 | print("FID: {:5.4f}".format(fid_value)) 220 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class HingeDiscriminator(nn.Module): 6 | """ Hinge loss for discriminator 7 | 8 | [1] Jae Hyun Lim, Jong Chul Ye, "Geometric GAN", 2017 9 | [2] Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida, 10 | "Spectral normalization for generative adversarial networks", 2018 11 | 12 | """ 13 | 14 | def __init__(self): 15 | super().__init__() 16 | 17 | def forward(self, disc_real_output, disc_fake_output): 18 | """ 19 | 20 | Args: 21 | disc_real_output: the discriminators output for a real sample 22 | disc_fake_output: the discriminators output for a fake sample 23 | 24 | """ 25 | 26 | loss = -torch.mean( 27 | torch.min(disc_real_output - 1, torch.zeros_like(disc_real_output)) 28 | ) 29 | loss -= torch.mean( 30 | torch.min(-disc_fake_output - 1, torch.zeros_like(disc_fake_output)) 31 | ) 32 | 33 | return loss 34 | 35 | 36 | class HingeGenerator(nn.Module): 37 | """ Hinge loss for discriminator 38 | 39 | [1] Jae Hyun Lim, Jong Chul Ye, "Geometric GAN", 2017 40 | [2] Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida, 41 | "Spectral normalization for generative adversarial networks", 2018 42 | 43 | """ 44 | 45 | def __init__(self): 46 | super().__init__() 47 | 48 | def forward(self, disc_fake_output): 49 | return -torch.mean(disc_fake_output) 50 | 51 | 52 | def iou(pr, gt, eps=1e-7, axis=(0, 2, 3)): 53 | """ 54 | intersection over union loss 55 | 56 | Source: 57 | https://github.com/catalyst-team/catalyst/ 58 | https://github.com/qubvel/segmentation_models.pytorch/blob/master/segmentation_models_pytorch/utils/functions.py 59 | https://discuss.pytorch.org/t/iou-score-for-muilticlass-segmentation/89350 60 | 61 | Args: 62 | pr (torch.Tensor): A list of predicted elements as softmax 63 | gt (torch.Tensor): A list of elements that are to be predicted as one hot encoded 64 | eps (float): epsilon to avoid zero division 65 | Returns: 66 | float: IoU (Jaccard) score 67 | """ 68 | 69 | intersection = torch.sum(gt.float() * pr.float(), axis) 70 | union = torch.sum(gt, axis).float() + torch.sum(pr, axis).float() - intersection + eps 71 | 72 | asdf = intersection / union 73 | 74 | return torch.mean(asdf) 75 | 76 | def pixel_acc(pred, target): 77 | """ pixel accuracy 78 | 79 | Args: 80 | pred (torch.Tensor): predicted classes. Not one-hot encoded 81 | targer (torch.Tensor):correct classes. Not one-hot encoded 82 | 83 | """ 84 | 85 | corr_pix = torch.sum(pred == target).float() 86 | 87 | return corr_pix / pred.nelement() 88 | -------------------------------------------------------------------------------- /models/arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.utils.spectral_norm as spectral_norm 4 | import torch.nn.functional as F 5 | 6 | 7 | ######################### 8 | # # 9 | # Basic building blocks # 10 | # # 11 | ######################### 12 | 13 | 14 | class ResnetBasicBlock(nn.Module): 15 | """ ResNet block """ 16 | 17 | def __init__(self, inplanes, planes, kernel_size=3, norm_layer=nn.BatchNorm2d): 18 | super().__init__() 19 | self.kernel_size = kernel_size 20 | 21 | if inplanes != planes: 22 | self.shortcut = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 23 | else: 24 | self.shortcut = None 25 | 26 | self.pad = nn.ReflectionPad2d(self.kernel_size // 2) 27 | self.relu = nn.ReLU(inplace=False) 28 | 29 | self.conv1 = spectral_norm(nn.Conv2d(inplanes, planes, self.kernel_size)) 30 | self.conv2 = spectral_norm(nn.Conv2d(planes, planes, self.kernel_size)) 31 | 32 | if norm_layer is not None: 33 | self.norm1 = norm_layer(inplanes) 34 | self.norm2 = norm_layer(planes) 35 | else: 36 | self.norm1 = self.norm2 = lambda x: x 37 | 38 | def forward(self, x): 39 | """ Ordering of operations differs in ResNet blocks depending on the publication. 40 | 41 | [1] : conv -> norm -> relu 42 | [2] and [3] : norm -> relu -> conv 43 | 44 | We follow [2] and [3] as they specifically target image synthesis. 45 | 46 | [1] He et. al. "Deep residual learning for image recognition, CVPR 2016 47 | [2] Brock et. al. "Large scale GAN training for high fidelity natural 48 | image synthesis", ICLR 2019 49 | [3] Park et. al. "Semantic Image Synthesis with Spatially-Adaptive 50 | Normalization.", CVPR 2019 51 | 52 | """ 53 | 54 | identity = x 55 | 56 | out = self.pad(x) 57 | out = self.norm1(out) 58 | out = self.relu(out) 59 | out = self.conv1(out) 60 | 61 | out = self.pad(out) 62 | out = self.norm2(out) 63 | out = self.relu(out) 64 | out = self.conv2(out) 65 | 66 | if self.shortcut: 67 | identity = self.shortcut(identity) 68 | out += identity 69 | 70 | return out 71 | 72 | 73 | class Downsampler(nn.Module): 74 | """ Typical downsampler by strided convolution, norm layer and ReLU """ 75 | 76 | def __init__(self, input_nc, norm_layer=nn.BatchNorm2d): 77 | super().__init__() 78 | self.conv = spectral_norm( 79 | nn.Conv2d(input_nc, 2 * input_nc, kernel_size=3, stride=2, padding=1) 80 | ) 81 | self.norm = norm_layer(2 * input_nc) 82 | self.relu = nn.ReLU() 83 | 84 | def forward(self, x): 85 | x = self.conv(x) 86 | x = self.norm(x) 87 | x = self.relu(x) 88 | 89 | return x 90 | 91 | 92 | ########################## 93 | # # 94 | # SPADE buildings blocks # 95 | # # 96 | ########################## 97 | 98 | 99 | class SPADE(nn.Module): 100 | """ SPADE normalization layer 101 | 102 | Code taken and modified from 103 | https://github.com/NVlabs/SPADE/blob/master/models/networks/normalization.py 104 | 105 | SPADE consists of two steps. First, it normalizes the activations using 106 | your favorite normalization method, such as Batch Norm or Instance Norm. 107 | Second, it applies scale and bias to the normalized output, conditioned on 108 | the segmentation map. 109 | 110 | Parameters 111 | ---------- 112 | num_features : int 113 | The number of channels of the normalized activations, 114 | i.e., SPADE's output dimension 115 | label_nc: int 116 | The number of channels of the input semantic map, 117 | i.e., SPADE's input dimension 118 | norm_layer: torch.nn.BatchNorm2d or InstanceNorm2d. 119 | Which normalization method to use together with SPADE. 120 | Generators often use batch or instance normalization. 121 | 122 | """ 123 | 124 | def __init__(self, num_features, label_nc, norm_layer=nn.BatchNorm2d): 125 | super().__init__() 126 | 127 | self.norm = norm_layer(num_features) 128 | 129 | # The dimension of the intermediate embedding space. Yes, hardcoded. 130 | nhidden = 128 131 | 132 | conv_opts = { 133 | "kernel_size": 3, 134 | "padding": 1, 135 | } 136 | 137 | self.mlp_shared = nn.Sequential( 138 | nn.Conv2d(label_nc, nhidden, **conv_opts), nn.ReLU() 139 | ) 140 | self.mlp_gamma = nn.Conv2d(nhidden, num_features, **conv_opts) 141 | self.mlp_beta = nn.Conv2d(nhidden, num_features, **conv_opts) 142 | 143 | def forward(self, x, segmap): 144 | 145 | # Part 1. generate parameter-free normalized activations 146 | normalized = self.norm(x) 147 | 148 | # Part 2. produce scaling and bias conditioned on semantic map 149 | segmap = F.interpolate(segmap, size=x.size()[2:], mode="nearest") 150 | actv = self.mlp_shared(segmap) 151 | gamma = self.mlp_gamma(actv) 152 | beta = self.mlp_beta(actv) 153 | 154 | # apply scale and bias 155 | out = normalized * (1 + gamma) + beta 156 | 157 | return out 158 | 159 | 160 | class DataSegmapTuple: 161 | 162 | data: torch.tensor 163 | segmap: torch.tensor 164 | 165 | def __init__(self, data, segmap): 166 | self.data = data 167 | self.segmap = segmap 168 | 169 | def __add__(self, other): 170 | return DataSegmapTuple(self.data + other.data, self.segmap) 171 | 172 | 173 | def pass_segmap(method): 174 | """ SPADE normalization requires the segmentation map as additional input 175 | 176 | By wrapping all forward methods with this wrapper SPADE blocks can be used 177 | just like regular nn.Modules, i.e. 178 | 179 | x = block_1(x) 180 | x = block_2(x) 181 | 182 | return x 183 | 184 | where the forward methods of block_1 and block_2 was wrapped using this function. 185 | 186 | ToDo: 187 | checkout register_forward_hook and register_forward_pre_hook 188 | 189 | """ 190 | 191 | def wrapper(dst: DataSegmapTuple): 192 | # treat SPADE modules differently 193 | if isinstance(method.__self__, SPADE): 194 | segmap_ds = F.interpolate( 195 | dst.segmap, size=dst.data.size()[2:], mode="nearest" 196 | ) 197 | x = method(dst.data, segmap_ds) 198 | else: 199 | x = method(dst.data) 200 | return DataSegmapTuple( 201 | x, dst.segmap 202 | ) # return new features and segmap of original size 203 | 204 | wrapper.__name__ = method.__name__ 205 | wrapper.__doc__ = method.__doc__ 206 | return wrapper 207 | 208 | 209 | class SPADEResnetBlock(ResnetBasicBlock): 210 | def __init__(self, inplanes, planes, label_nc, kernel_size=3): 211 | super().__init__( 212 | inplanes, 213 | planes, 214 | kernel_size, 215 | lambda num_features: SPADE(num_features, label_nc), 216 | ) 217 | 218 | for name, child in self.named_children(): 219 | # wraps all contained modules to pass the segmentation map 220 | # together with the feature tensor 221 | child.forward = pass_segmap(child.forward) 222 | setattr(self, name, child) 223 | 224 | def forward(self, x, segmap): 225 | # Call the base class's forward method but pass in a 226 | # DataSegmapTuple instead of just a tensor. Since all 227 | # child modules were modified with pass_segmap they can 228 | # process the DataSegmapTuple 229 | dst = super().forward(DataSegmapTuple(x, segmap)) 230 | 231 | return dst.data 232 | 233 | 234 | class SpadeDownsampler(Downsampler): 235 | def __init__(self, input_nc, label_nc): 236 | super().__init__(input_nc, lambda num_features: SPADE(num_features, label_nc)) 237 | 238 | for name, child in self.named_children(): 239 | # wraps all contained modules to pass the segmentation map 240 | # together with the feature tensor 241 | child.forward = pass_segmap(child.forward) 242 | setattr(self, name, child) 243 | 244 | def forward(self, x, segmap): 245 | # Call the base class's forward method but pass in a 246 | # DataSegmapTuple instead of just a tensor. Since all 247 | # child modules were modified with pass_segmap they can 248 | # process the DataSegmapTuple 249 | dst = super().forward(DataSegmapTuple(x, segmap)) 250 | 251 | return dst.data 252 | -------------------------------------------------------------------------------- /models/common.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | # Named tuple for a neural networks output with the 4 | # final output and intermediate features. 5 | # Intermediate features are useful for computing FID 6 | # or the discriminator feature loss of SPADE. 7 | NNOutput = namedtuple("NNOutput", "final features") 8 | -------------------------------------------------------------------------------- /models/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.utils.spectral_norm as spectral_norm 4 | 5 | from . import common 6 | 7 | 8 | class PatchGAN(nn.Module): 9 | """ PatchGAN discriminator used by Pix2Pix HD and SPADE 10 | 11 | Original code taken from [1] but with lots of modifications. 12 | 13 | [1] https://github.com/NVlabs/SPADE/blob/master/models/networks/discriminator.py. 14 | 15 | """ 16 | 17 | def __init__(self, input_nc, n_layers=4, init_nc=64): 18 | super().__init__() 19 | 20 | self.n_layers = n_layers 21 | 22 | self.kernel_size = 4 23 | self.init_nc = init_nc 24 | self._return_intermed = False 25 | 26 | self.init_conv = nn.Sequential( 27 | nn.Conv2d( 28 | input_nc, self.init_nc, self.kernel_size, stride=2, padding=self.padding 29 | ), 30 | nn.LeakyReLU(0.2), 31 | ) 32 | 33 | # number of channels 34 | ncs = [self.init_nc * 2 ** n_layer for n_layer in range(n_layers - 1)] 35 | # every layer downsamples except for last 36 | strides = (len(ncs) - 1) * [2] + [1] 37 | 38 | for idx, (nc, stride) in enumerate(zip(ncs, strides)): 39 | self.add_module(self.layername(idx), self.layer(nc, stride)) 40 | 41 | self.final_conv = nn.Conv2d( 42 | 2 * ncs[-1], 1, kernel_size=self.kernel_size, stride=1, padding=self.padding 43 | ) 44 | 45 | @property 46 | def return_intermed(self): 47 | return self._return_intermed 48 | 49 | @return_intermed.setter 50 | def return_intermed(self, value): 51 | self._return_intermed = value 52 | 53 | @property 54 | def padding(self): 55 | return self.kernel_size // 2 56 | 57 | def layer(self, input_nc, stride): 58 | return nn.Sequential( 59 | spectral_norm( 60 | nn.Conv2d( 61 | input_nc, 62 | 2 * input_nc, 63 | self.kernel_size, 64 | stride=stride, 65 | padding=self.padding, 66 | ) 67 | ), 68 | nn.InstanceNorm2d(2 * input_nc), 69 | nn.LeakyReLU(0.2), 70 | ) 71 | 72 | @staticmethod 73 | def layername(idx): 74 | return "conv_{}".format(idx) 75 | 76 | @property 77 | def intermed_layers(self): 78 | # The last layer is not considered a feature layer for computing feature loss, 79 | # since it already contributed to the generator loss. 80 | return [self.init_conv] + [ 81 | getattr(self, self.layername(idx)) for idx in range(self.n_layers - 1) 82 | ] 83 | 84 | def forward(self, gen_input, real_or_fake): 85 | """ a conditioned discriminator's forward method 86 | 87 | Parameters 88 | ---------- 89 | 90 | gen_input : dict of torch.Tensor 91 | input the generator received, i.e., the condition variable 92 | real_or_fake: torch.Tensor 93 | either the real sample or the fake one created by the generator 94 | 95 | """ 96 | 97 | x = torch.cat([gi for gi in gen_input.values()] + [real_or_fake], dim=1) 98 | 99 | xs = [] 100 | 101 | for layer in self.intermed_layers: 102 | x = layer(x) 103 | if self.return_intermed: 104 | xs.append(x) 105 | 106 | if self.final_conv is not None: 107 | x = self.final_conv(x) 108 | x = [x] # make consistent with multiscale discriminator 109 | 110 | return common.NNOutput(final=x, features=xs) 111 | 112 | 113 | class Multiscale(nn.Module): 114 | """ Multiscale discriminator 115 | 116 | Parameters 117 | ---------- 118 | discriminators : list[nn.Module] 119 | list of discriminators, each will operate at a different scale 120 | 121 | """ 122 | 123 | def __init__(self, discriminators): 124 | super().__init__() 125 | 126 | self.n_scales = len(discriminators) 127 | self._return_intermed = False 128 | 129 | for idx, disc in enumerate(discriminators): 130 | self.add_module(self.disc_name(idx), disc) 131 | 132 | @property 133 | def return_intermed(self): 134 | return self._return_intermed 135 | 136 | @return_intermed.setter 137 | def return_intermed(self, value): 138 | self._return_intermed = value 139 | for idx in range(self.n_scales): 140 | disc_name = self.disc_name(idx) 141 | getattr(self, disc_name).return_intermed = value 142 | 143 | @staticmethod 144 | def disc_name(idx): 145 | return "disc_{}".format(idx) 146 | 147 | def forward(self, gen_input, real_or_fake): 148 | """ a conditioned discriminator's forward method 149 | 150 | Parameters 151 | ---------- 152 | 153 | gen_input : dict of torch.Tensor 154 | input the generator received, i.e., the condition variable 155 | real_or_fake: torch.Tensor 156 | either the real sample or the fake one created by the generator 157 | 158 | """ 159 | 160 | xs = {} 161 | 162 | # cycle through all scales 163 | for idx in range(self.n_scales): 164 | disc_name = self.disc_name(idx) 165 | disc = getattr(self, disc_name) 166 | xs[disc_name] = disc(gen_input, real_or_fake) 167 | 168 | # downsample input 169 | gen_input = { 170 | k: torch.nn.functional.avg_pool2d(v, 2) for k, v in gen_input.items() 171 | } 172 | real_or_fake = torch.nn.functional.avg_pool2d(real_or_fake, 2) 173 | 174 | # concatenate output of discriminators 175 | final = [final for nno in xs.values() for final in nno.final] 176 | # concatenate and flatten features of different discriminators 177 | features = [feat for nno in xs.values() for feat in nno.features] 178 | 179 | return common.NNOutput(final, features) 180 | -------------------------------------------------------------------------------- /models/generator.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional 3 | import torch.nn.utils.spectral_norm as spectral_norm 4 | 5 | from . import arch 6 | 7 | 8 | ############################## 9 | # # 10 | # Generator buildings blocks # 11 | # # 12 | ############################## 13 | 14 | 15 | class Encoder(nn.Module): 16 | def __init__( 17 | self, 18 | input_nc, 19 | init_nc=64, 20 | init_kernel_size=7, 21 | n_downsample=4, 22 | downsampler=arch.Downsampler, 23 | ): 24 | super().__init__() 25 | 26 | self.input_nc = input_nc 27 | 28 | self.init_conv = nn.Sequential( 29 | nn.ReflectionPad2d(init_kernel_size // 2), 30 | spectral_norm(nn.Conv2d(input_nc, init_nc, init_kernel_size, padding=0)), 31 | nn.ReLU(), 32 | ) 33 | 34 | self.n_downsample = n_downsample 35 | 36 | # nn.sequential does not support multiple inputs 37 | for i in range(n_downsample): 38 | self.add_module("down{}".format(i), downsampler(init_nc * 2 ** i)) 39 | 40 | def forward(self, x, segmap=None): 41 | # no normalization layer in first convolution 42 | x = self.init_conv(x) 43 | 44 | for i in range(self.n_downsample): 45 | down = getattr(self, "down{}".format(i)) 46 | if segmap is not None: 47 | x = down(x, segmap) 48 | else: 49 | x = down(x) 50 | 51 | return x 52 | 53 | 54 | class Body(nn.Module): 55 | def __init__(self, input_nc, n_stages, res_block=arch.ResnetBasicBlock): 56 | 57 | super().__init__() 58 | 59 | self.n_stages = n_stages 60 | 61 | for i in range(n_stages): 62 | self.add_module("body{}".format(i), res_block(input_nc, input_nc)) 63 | 64 | def forward(self, x, segmap=None): 65 | """ expects output either from Encoder or an inital convolution layer""" 66 | 67 | for i in range(self.n_stages): 68 | rnb = getattr(self, "body{}".format(i)) 69 | if segmap is not None: 70 | x = rnb(x, segmap) 71 | else: 72 | x = rnb(x) 73 | 74 | return x 75 | 76 | 77 | class Decoder(nn.Module): 78 | """ Decoder part for an image synthesis generator. 79 | 80 | This decoder follow the philosophy of [1] and [2] by pairing ResNet-blocks with simple upsampling. 81 | Upsampling is done using nearest neighbour to prevent checkerboard artifacts. 82 | Decoder must be used together with Body, since both [1] and [2] have some layers without upsampling. 83 | 84 | [1] Brock et. al. "Large scale GAN training for high fidelity natural image synthesis", ICLR 2019 85 | [2] Park et. al. "Semantic Image Synthesis with Spatially-Adaptive Normalization.", CVPR 2019 86 | [3] Odena, et al., "Deconvolution and Checkerboard Artifacts", Distill, 2016. http://doi.org/10.23915/distill.00003 87 | 88 | """ 89 | 90 | def __init__( 91 | self, 92 | input_nc, 93 | output_nc=3, 94 | n_up_stages=4, 95 | res_block=arch.ResnetBasicBlock, 96 | final_kernel_size=3, 97 | return_intermed=False, 98 | ): 99 | 100 | super().__init__() 101 | 102 | self.input_nc = input_nc 103 | self.output_nc = output_nc 104 | self.n_up_stages = n_up_stages 105 | self._return_intermed = return_intermed 106 | 107 | for i in range(n_up_stages): 108 | in_channels = self.input_nc // (2 ** i) 109 | out_channels = in_channels // 2 110 | self.add_module("up{}".format(i), res_block(in_channels, out_channels)) 111 | 112 | self.final_conv = nn.Sequential( 113 | nn.ReflectionPad2d(final_kernel_size // 2), 114 | spectral_norm( 115 | nn.Conv2d(out_channels, output_nc, final_kernel_size, padding=0) 116 | ), 117 | nn.Tanh(), 118 | ) 119 | 120 | def input_shape(self, output_shape): 121 | """ given the output spatial dimension compute the dimension the input has to have """ 122 | return tuple((x // 2 ** self.n_up_stages) for x in output_shape) 123 | 124 | @property 125 | def return_intermed(self): 126 | return self._return_intermed 127 | 128 | @return_intermed.setter 129 | def return_intermed(self, value): 130 | self._return_intermed = value 131 | 132 | def forward(self, x, segmap=None): 133 | """ expects output from Body """ 134 | 135 | xs = [] 136 | 137 | for i in range(self.n_up_stages): 138 | # simple nearest neighbour upsampling 139 | x = torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest") 140 | rnb = getattr(self, "up{}".format(i)) 141 | if segmap is not None: 142 | x = rnb(x, segmap) 143 | else: 144 | x = rnb(x) 145 | 146 | if self.return_intermed: 147 | xs.append(x) 148 | if self.return_intermed: 149 | return xs 150 | 151 | return self.final_conv(x) 152 | 153 | 154 | ############################# 155 | # # 156 | # Generator implementations # 157 | # # 158 | ############################# 159 | 160 | 161 | class ResnetEncoderDecoder(nn.Module): 162 | """ Encoder-Decoder architecture heavily inspired by [1], [2] and [3] for style transfer also used in Pix2Pix and Pix2PixHD 163 | 164 | The default parameterization corresponds to the global generator of the Pix2Pix HD [2]. 165 | 166 | There are slight differences between [1] and [2]. [2] uses reflection padding for all Resnet blocks, [1] notes that zero-padding let to 167 | artifacts and only pads in at the initial convolution layer, i.e. each Resnet block reduces the spatial dimensions. 168 | 169 | [1] Johnson, Justin, Alexandre Alahi, and Li Fei-Fei. “Perceptual Losses for Real-Time Style Transfer and Super-Resolution.”, ECCV 2016 170 | https://arxiv.org/abs/1603.08155 and supplementary material https://cs.stanford.edu/people/jcjohns/papers/fast-style/fast-style-supp.pdf 171 | [2] Ting-Chun Wang, Ming-Yu Liu, Jun-Yan Zhu, Andrew Tao, Jan Kautz, and Bryan Catanzaro. 172 | "High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs", CVPR, 2018 https://arxiv.org/abs/1711.11585 173 | 174 | [3] Park et. al. "Semantic Image Synthesis with Spatially-Adaptive Normalization.", CVPR 2019, https://nvlabs.github.io/SPADE/ 175 | 176 | 177 | """ 178 | 179 | def __init__( 180 | self, 181 | input_nc, 182 | init_nc=64, 183 | output_nc=3, 184 | init_kernel_size=7, 185 | n_downsample=4, 186 | n_resnet_blocks=9, 187 | ): 188 | super().__init__() 189 | 190 | self.encoder = Encoder(input_nc, init_nc, init_kernel_size, n_downsample) 191 | 192 | body_nc = init_nc * 2 ** n_downsample 193 | 194 | self.body = Body(body_nc, n_resnet_blocks) 195 | 196 | self.decoder = Decoder(body_nc, output_nc, n_downsample) 197 | 198 | def forward(self, gen_input): 199 | x = torch.cat([gi for gi in gen_input.values()], dim=1) 200 | 201 | x = self.encoder(x) 202 | x = self.body(x) 203 | x = self.decoder(x) 204 | 205 | return x 206 | 207 | @property 208 | def input_nc(self): 209 | return self.encoder.input_nc 210 | 211 | 212 | class SPADEResnetEncoderDecoder(nn.Module): 213 | def __init__( 214 | self, 215 | input_nc, 216 | label_nc, 217 | init_nc=64, 218 | output_nc=3, 219 | init_kernel_size=7, 220 | n_downsample=4, 221 | n_resnet_blocks=9, 222 | ): 223 | super().__init__() 224 | 225 | get_spade_block = lambda inc, outc: arch.SPADEResnetBlock(inc, outc, label_nc) 226 | get_spade_downsampler = lambda inc: arch.SpadeDownsampler(inc, label_nc) 227 | 228 | self.encoder = Encoder( 229 | input_nc, init_nc, init_kernel_size, n_downsample, get_spade_downsampler 230 | ) 231 | 232 | body_nc = init_nc * 2 ** n_downsample 233 | 234 | self.body = Body(body_nc, n_resnet_blocks, get_spade_block) 235 | 236 | self.decoder = Decoder(body_nc, output_nc, n_downsample, get_spade_block) 237 | 238 | def forward(self, gen_input): 239 | # get segmentation map 240 | segmap = gen_input["seg"] 241 | 242 | # concatenate all other inputs 243 | x = torch.cat([v for k, v in gen_input.items() if k != "seg"], dim=1) 244 | 245 | x = self.encoder(x, segmap) 246 | x = self.body(x, segmap) 247 | x = self.decoder(x, segmap) 248 | 249 | return x 250 | 251 | 252 | class SPADEGenerator(nn.Module): 253 | def __init__( 254 | self, label_nc, init_nc=256, output_nc=3, n_mid_stages=2, n_up_stages=4 255 | ): 256 | super().__init__() 257 | 258 | self.label_nc = label_nc 259 | 260 | # iniital embedding of segmentation map 261 | init_kernel_size = 3 262 | self.init_conv = nn.Sequential( 263 | nn.ReflectionPad2d(init_kernel_size // 2), 264 | spectral_norm(nn.Conv2d(label_nc, init_nc, init_kernel_size, padding=0)), 265 | nn.ReLU(), 266 | ) 267 | 268 | get_spade_block = lambda inc, outc: arch.SPADEResnetBlock(inc, outc, label_nc) 269 | 270 | self.head = Body(init_nc, n_mid_stages, get_spade_block) 271 | 272 | self.decoder = Decoder(init_nc, output_nc, n_up_stages, get_spade_block) 273 | 274 | def forward(self, gen_input): 275 | segmap = gen_input["seg"] 276 | # Input dimension deterime final output dimension 277 | shape_ds = self.decoder.input_shape(segmap.shape[-2:]) 278 | segmap_ds = torch.nn.functional.interpolate(segmap, size=shape_ds) 279 | 280 | x = self.init_conv(segmap_ds) 281 | 282 | x = self.head(x, segmap) 283 | x = self.decoder(x, segmap) 284 | 285 | return x 286 | 287 | @property 288 | def input_nc(self): 289 | return self.label_nc 290 | -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as TF 4 | 5 | from . import common 6 | 7 | 8 | class UNet(nn.Module): 9 | """ U-Net implemenation 10 | 11 | U-Net: Convolutional Networks for Biomedical Image Segmentation 12 | (Ronneberger et al., 2015) https://arxiv.org/abs/1505.04597 13 | 14 | Parameters 15 | ---------- 16 | in_channels : int 17 | number of input channels 18 | n_classes : int 19 | number of output channels 20 | n_downsample : int 21 | depth of the network 22 | """ 23 | 24 | def __init__(self, in_channels, n_classes, n_downsample=4): 25 | super().__init__() 26 | 27 | # input and output channels for the downsampling path 28 | out_channels = [64 * (2 ** i) for i in range(n_downsample)] 29 | in_channels = [in_channels] + out_channels[:-1] 30 | 31 | self.down_path = nn.ModuleList( 32 | [ 33 | self._make_unet_conv_block(ich, och) 34 | for ich, och in zip(in_channels, out_channels) 35 | ] 36 | ) 37 | 38 | # input channels of the upsampling path 39 | in_channels = [64 * (2 ** i) for i in range(n_downsample, 0, -1)] 40 | self.body = self._make_unet_conv_block(out_channels[-1], in_channels[0]) 41 | 42 | self.upsamplers = nn.ModuleList( 43 | [self._make_upsampler(nch, nch // 2) for nch in in_channels] 44 | ) 45 | 46 | self.up_path = nn.ModuleList( 47 | [self._make_unet_conv_block(nch, nch // 2) for nch in in_channels] 48 | ) 49 | 50 | self.last = nn.Conv2d(64, n_classes, kernel_size=1) 51 | 52 | self._return_intermed = False 53 | 54 | @staticmethod 55 | def _make_unet_conv_block(in_channels, out_channels): 56 | return nn.Sequential( 57 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 58 | nn.ReLU(), 59 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 60 | nn.ReLU(), 61 | ) 62 | 63 | @staticmethod 64 | def _make_upsampler(in_channels, out_channels): 65 | return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2) 66 | 67 | @property 68 | def return_intermed(self): 69 | return self._return_intermed 70 | 71 | @return_intermed.setter 72 | def return_intermed(self, value): 73 | self._return_intermed = value 74 | 75 | def forward(self, x): 76 | blocks = [] 77 | for down in self.down_path: 78 | # UNet conv block increases the number of channels 79 | x = down(x) 80 | blocks.append(x) 81 | # Downsampling, by mass pooling 82 | x = TF.max_pool2d(x, 2) 83 | 84 | x = self.body(x) 85 | 86 | for upsampler, up, block in zip( 87 | self.upsamplers, self.up_path, reversed(blocks) 88 | ): 89 | # upsample and reduce number of channels 90 | x = upsampler(x) 91 | x = torch.cat([x, block], dim=1) 92 | # UNet conv block reduces the number of channels again 93 | x = up(x) 94 | 95 | x = self.last(x) 96 | 97 | if self.return_intermed: 98 | return common.NNOutput(x, blocks) 99 | 100 | return x 101 | -------------------------------------------------------------------------------- /options/common.py: -------------------------------------------------------------------------------- 1 | """ common command line arguments and helper functions for GANs and U-Net """ 2 | 3 | import argparse 4 | import pathlib 5 | 6 | import torch 7 | import torchvision 8 | 9 | import datasets.transforms 10 | 11 | 12 | def get_parser(): 13 | """ returns common ArgumentParser for GANs and U-Net """ 14 | 15 | parser = argparse.ArgumentParser( 16 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 17 | ) 18 | parser.add_argument( 19 | "--batch_size", type=int, default=32, help="Batch size for training" 20 | ) 21 | parser.add_argument( 22 | "--crop", 23 | type=int, 24 | default=(256,), 25 | nargs="+", 26 | help="Size of crop. Can be a tuple of height and width", 27 | ) 28 | parser.add_argument( 29 | "--resize", 30 | type=int, 31 | default=(256,), 32 | nargs="+", 33 | help="Resizing after cropping. Can be a tuple of height and width", 34 | ) 35 | parser.add_argument( 36 | "--epochs", type=int, default=200, help="number of epochs to train" 37 | ) 38 | parser.add_argument( 39 | "--seed", default=None, type=int, help="seed for initializing training. " 40 | ) 41 | parser.add_argument( 42 | "--suffix", help="suffix appended to the otuput directory", default=None 43 | ) 44 | 45 | parser.add_argument( 46 | "--dataset", 47 | type=str, 48 | default="nrw", 49 | choices=["nrw", "dfc"], 50 | help="Which dataset to use: GeoNRWa or Data Fusion Contest 2020?", 51 | ) 52 | parser.add_argument( 53 | "--dataroot", type=str, default="./data/nrw", help="Path to dataset" 54 | ) 55 | 56 | parser.add_argument( 57 | "--num_workers", default=2, type=int, help="Number of workers for data loader.", 58 | ) 59 | 60 | parser.add_argument( 61 | "--out_dir", 62 | type=pathlib.Path, 63 | default="./results", 64 | help="Where to store models, log, etc.", 65 | ) 66 | 67 | return parser 68 | 69 | 70 | def get_transforms(config): 71 | """ returns dataset transforms 72 | 73 | Parameters 74 | ---------- 75 | config : dict 76 | configuration returned by args2dict 77 | 78 | Returns 79 | ------- 80 | train and test transforms for the dataset 81 | 82 | """ 83 | 84 | n_labels = getattr(datasets, config["dataset"]["name"]).N_LABELS 85 | 86 | if config["dataset"]["name"] == "nrw": 87 | train_transforms = [ 88 | datasets.transforms.RandomCrop(config["training"]["crop"]), 89 | datasets.transforms.RandomHorizontalFlip(), 90 | datasets.transforms.Resize(config["training"]["resize"]), 91 | datasets.transforms.ToTensor(), 92 | datasets.transforms.TensorApply( 93 | seg=lambda x: torch.nn.functional.one_hot(x.long(), n_labels) 94 | .squeeze() 95 | .permute(2, 0, 1) 96 | .float() 97 | ), 98 | ] 99 | 100 | # remove resize layer if size of crop and resize are identical 101 | if config["training"]["crop"] == config["training"]["resize"]: 102 | train_transforms = [ 103 | tt 104 | for tt in train_transforms 105 | if not isinstance(tt, datasets.transforms.Resize) 106 | ] 107 | 108 | train_transforms = torchvision.transforms.Compose(train_transforms) 109 | 110 | # Get test transform from train trainsform. 111 | # Test transform should be deterministic. 112 | # 1) replace random crop with center crop 113 | # 2) remove horizontal flip 114 | test_transforms = torchvision.transforms.Compose( 115 | [ 116 | datasets.transforms.CenterCrop(config["training"]["crop"]), 117 | *train_transforms.transforms[2:], 118 | ] 119 | ) 120 | elif config["dataset"]["name"] == "dfc": 121 | train_transforms = torchvision.transforms.Compose( 122 | [ 123 | datasets.transforms.ToTensor(), 124 | datasets.transforms.TensorApply( 125 | seg=lambda x: torch.nn.functional.one_hot(x.long(), n_labels) 126 | .squeeze() 127 | .permute(2, 0, 1) 128 | .float() 129 | ), 130 | ] 131 | ) 132 | test_transforms = train_transforms 133 | else: 134 | raise RuntimeError("Invalid dataset. This should never happen") 135 | return train_transforms, test_transforms 136 | 137 | 138 | def get_dataset(config, split, transforms): 139 | """ returns dataset 140 | 141 | Parameters 142 | ---------- 143 | config : dict 144 | configuration returned by args2dict 145 | split : string 146 | use train or test split 147 | transforms 148 | train or test transforms returned by get_transforms 149 | 150 | Returns 151 | ------- 152 | dataset class 153 | 154 | """ 155 | 156 | name = config["dataset"]["name"] 157 | root = config["dataset"]["root"] 158 | 159 | if name == "dfc": 160 | return datasets.dfc.DFC2020(root, split, transforms) 161 | if name == "nrw": 162 | # extra check whether also to load SAR acquisitions 163 | try: 164 | include_sar = config["dataset"]["output"] == "sar" 165 | except KeyError: 166 | include_sar = False 167 | return datasets.nrw.NRW(root, split, include_sar, transforms) 168 | # raising should never happen 169 | raise ValueError("Dataset must be nrw or dfc, but is {}".format(name)) 170 | -------------------------------------------------------------------------------- /options/gan.py: -------------------------------------------------------------------------------- 1 | """ command line arguments and helper functions for GANs """ 2 | 3 | import argparse 4 | import datetime 5 | 6 | import models.generator 7 | import models.discriminator 8 | import datasets 9 | from . import common 10 | 11 | 12 | def get_parser(): 13 | """ returns ArgumentParser for GANs """ 14 | 15 | # Get common ArgumentParser. 16 | parser = common.get_parser() 17 | 18 | # Add GAN specific arguments to the parser. 19 | parser.add_argument( 20 | "--input", 21 | nargs="+", 22 | default="seg", 23 | help="Input of the generator. Depends on the dataset.", 24 | ) 25 | parser.add_argument( 26 | "--concat", 27 | action="store_true", 28 | help="Concatenate inputs before feeding the generator.", 29 | ) 30 | parser.add_argument( 31 | "--output", 32 | default="rgb", 33 | choices=["rgb", "sar"], 34 | help="Output of the generator.", 35 | ) 36 | parser.add_argument( 37 | "--local_rank", type=int, default=0, help="for distributed training" 38 | ) 39 | parser.add_argument( 40 | "--n_sampling", 41 | default=4, 42 | type=int, 43 | help="number of upsampling/downsampling in the generator", 44 | ) 45 | parser.add_argument( 46 | "--model_cap", 47 | default=64, 48 | type=int, 49 | choices=[16, 32, 48, 64], 50 | help="Model capacity, i.e. number of features.", 51 | ) 52 | parser.add_argument( 53 | "--n_scales", 54 | default=2, 55 | type=int, 56 | help="Number of scales for multiscale discriminator", 57 | ) 58 | parser.add_argument( 59 | "--lbda", default=None, type=float, help="weighting of feature loss" 60 | ) 61 | 62 | return parser 63 | 64 | 65 | def args2str(args): 66 | """ converts arguments to string 67 | 68 | Parameters 69 | ---------- 70 | args: arguments returned by parser 71 | 72 | Returns 73 | ------- 74 | string of arguments 75 | 76 | """ 77 | 78 | # translate what to what 79 | trans_str = "_".join(args.input) + "2{}".format(args.output) 80 | 81 | # training arguments 82 | train_str = "{args.dataset}_{w2w}_bs{args.batch_size}_ep{args.epochs}_cap{args.model_cap}".format( 83 | args=args, w2w=trans_str 84 | ) 85 | 86 | if args.seed: 87 | train_str += "_rs{args.seed}".format(args=args) 88 | if args.concat: 89 | train_str += "_concat" 90 | 91 | datestr = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M") 92 | 93 | idstr = "_".join([train_str, datestr]) 94 | if args.suffix: 95 | idstr = idstr + "_{}".format(args.suffix) 96 | return idstr 97 | 98 | 99 | def args2dict(args): 100 | """ converts arguments to dict 101 | 102 | Parameters 103 | ---------- 104 | args: arguments returned by parser 105 | 106 | Returns 107 | ------- 108 | dict of arguments 109 | 110 | """ 111 | 112 | # model_parameters 113 | model_args = ["model_cap", "n_sampling", "n_scales"] 114 | if args.concat: 115 | model_args.append("concat") 116 | train_args = ["crop", "resize", "batch_size", "epochs", "lbda"] 117 | if args.seed: 118 | train_args.append("seed") 119 | 120 | model = {param: getattr(args, param) for param in model_args} 121 | train = {param: getattr(args, param) for param in train_args} 122 | data = { 123 | "name": args.dataset, 124 | "root": args.dataroot, 125 | "input": args.input, 126 | "output": args.output, 127 | } 128 | 129 | return {"model": model, "training": train, "dataset": data} 130 | 131 | 132 | def get_generator(config): 133 | """ returns generator 134 | 135 | Parameters 136 | ---------- 137 | config : dict 138 | configuration returned by args2dict 139 | 140 | Returns 141 | ------- 142 | torch.nn.Model of generator 143 | 144 | """ 145 | 146 | dset_class = getattr(datasets, config["dataset"]["name"]) 147 | n_labels = dset_class.N_LABELS 148 | output_nc = dset_class.N_CHANNELS[config["dataset"]["output"]] 149 | 150 | # only segmentation map as input -> SPADE generator 151 | if config["dataset"]["input"] == ["seg"]: 152 | return models.generator.SPADEGenerator( 153 | n_labels, 154 | config["model"]["model_cap"] * 2 ** config["model"]["n_sampling"], 155 | output_nc, 156 | n_up_stages=config["model"]["n_sampling"], 157 | ) 158 | # no segmentation map as input -> Pix2Pix generator 159 | if "seg" not in config["dataset"]["input"]: 160 | input_nc = sum(dset_class.N_CHANNELS[it] for it in config["dataset"]["input"]) 161 | return models.generator.ResnetEncoderDecoder( 162 | input_nc, 163 | config["model"]["model_cap"], 164 | output_nc, 165 | n_downsample=config["model"]["n_sampling"], 166 | ) 167 | 168 | # Deal with generator architectures that deal with segmentation maps 169 | # and continous raster data as input 170 | # 1) concatenate and use regular Pix2Pix 171 | # 2) use proposed archicture from the paper 172 | 173 | # number of channels for all input types except the segmentation_map, 174 | input_nc = sum( 175 | dset_class.N_CHANNELS[it] for it in config["dataset"]["input"] if it != "seg" 176 | ) 177 | 178 | # Which generator architecture to use with multiple inputs. 179 | # Conventional generator (pix2pix) with concatenated input 180 | try: 181 | if config["model"]["concat"]: 182 | input_nc += n_labels 183 | return models.generator.ResnetEncoderDecoder( 184 | input_nc, 185 | config["model"]["model_cap"], 186 | output_nc, 187 | n_downsample=config["model"]["n_sampling"], 188 | ) 189 | except KeyError: 190 | # Proposed conventional generator with SPADE norm layers everywhere 191 | return models.generator.SPADEResnetEncoderDecoder( 192 | input_nc, 193 | n_labels, 194 | config["model"]["model_cap"], 195 | output_nc, 196 | n_downsample=config["model"]["n_sampling"], 197 | ) 198 | 199 | 200 | def get_discriminator(config): 201 | """ returns discriminator 202 | 203 | Parameters 204 | ---------- 205 | config : dict 206 | configuration returned by args2dict 207 | 208 | Returns 209 | ------- 210 | torch.nn.Model of discriminator 211 | """ 212 | dset_class = getattr(datasets, config["dataset"]["name"]) 213 | # generator conditioned on this input 214 | gen_input_nc = sum( 215 | dset_class.N_CHANNELS[it] for it in config["dataset"]["input"] if it != "seg" 216 | ) 217 | if "seg" in config["dataset"]["input"]: 218 | gen_input_nc += dset_class.N_LABELS 219 | 220 | disc_input_nc = gen_input_nc + dset_class.N_CHANNELS[config["dataset"]["output"]] 221 | 222 | # Downsampling is done in the multiscale discriminator, 223 | # i.e., all discriminators are identically configures 224 | d_nets = [ 225 | models.discriminator.PatchGAN(input_nc=disc_input_nc, init_nc=64) 226 | for _ in range(config["model"]["n_scales"]) 227 | ] 228 | 229 | return models.discriminator.Multiscale(d_nets) 230 | -------------------------------------------------------------------------------- /options/segment.py: -------------------------------------------------------------------------------- 1 | """ command line arguments and helper functions for segmentation network """ 2 | 3 | import argparse 4 | import datetime 5 | 6 | from . import common 7 | 8 | 9 | def get_parser(): 10 | """ returns ArgumentParser for segmentation networks """ 11 | 12 | parser = common.get_parser() 13 | 14 | parser.add_argument( 15 | "--input", 16 | default="rgb", 17 | choices=["rgb", "sar"], 18 | help="Input of the segmentation network.", 19 | ) 20 | parser.add_argument( 21 | "--learning_rate", type=float, default=0.0002, help="learning rate" 22 | ) 23 | parser.add_argument( 24 | "--weight_decay", type=float, default=0.0005, help="weight decay" 25 | ) 26 | return parser 27 | 28 | 29 | def args2str(args): 30 | """ converts arguments to string 31 | 32 | Parameters 33 | ---------- 34 | args: arguments returned by parser 35 | 36 | Returns 37 | ------- 38 | string of arguments 39 | 40 | """ 41 | 42 | # training arguments 43 | train_str = "{args.dataset}_unet_{args.input}_bs{args.batch_size}_ep{args.epochs}_lr{args.learning_rate}".format( 44 | args=args 45 | ) 46 | if args.seed: 47 | train_str += "_rs{args.seed}".format(args=args) 48 | 49 | datestr = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M") 50 | 51 | idstr = "_".join([train_str, datestr]) 52 | if args.suffix: 53 | idstr = idstr + "_{}".format(args.suffix) 54 | return idstr 55 | 56 | 57 | def args2dict(args): 58 | """ converts arguments to dict 59 | 60 | Parameters 61 | ---------- 62 | args: arguments returned by parser 63 | 64 | Returns 65 | ------- 66 | dict of arguments 67 | 68 | """ 69 | 70 | # model_parameters 71 | train_args = ["crop", "resize", "batch_size", "epochs", "learning_rate"] 72 | if args.seed: 73 | train_args.append("seed") 74 | 75 | train = {param: getattr(args, param) for param in train_args} 76 | data = { 77 | "name": args.dataset, 78 | "root": args.dataroot, 79 | "input": args.input, 80 | } 81 | 82 | return {"training": train, "dataset": data} 83 | -------------------------------------------------------------------------------- /pix_acc_iou_comp.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pathlib 3 | import logging 4 | 5 | import yaml 6 | import torch 7 | import tqdm 8 | import numpy as np 9 | import matplotlib 10 | # headless 11 | matplotlib.use('Agg') 12 | import matplotlib.pyplot as plt 13 | import ignite 14 | 15 | # shitty workaround to include main directory in the python path 16 | import sys 17 | sys.path.insert(0, str(pathlib.Path(__file__).absolute().parents[2])) 18 | 19 | import options.segment 20 | import options.gan 21 | import datasets.nrw 22 | import datasets.dfc 23 | import models.unet 24 | import utils 25 | import loss 26 | 27 | 28 | ########################## 29 | # # 30 | # Comannd line arguments # 31 | # # 32 | ########################## 33 | 34 | parser = argparse.ArgumentParser( 35 | description="compare real and fake images using a segmentation network" 36 | ) 37 | parser.add_argument("generator", help="pt file of the generator model") 38 | parser.add_argument("segmentor", help="pt file of the segmentation model") 39 | parser.add_argument("out_dir", type=pathlib.Path, help="output directory") 40 | args = parser.parse_args() 41 | 42 | if torch.cuda.is_available(): 43 | device = torch.device("cuda") 44 | else: 45 | raise RuntimeError("This scripts expects CUDA to be available") 46 | 47 | OUT_DIR = args.out_dir 48 | OUT_DIR.mkdir(exist_ok=True) 49 | logging.basicConfig( 50 | format="%(asctime)s [%(levelname)-8s] %(message)s", 51 | level=logging.INFO, 52 | filename=OUT_DIR / "log_segmentatation_analysis.txt", 53 | ) 54 | logger = logging.getLogger() 55 | logger.info("Saving logs, configs and models to %s", OUT_DIR) 56 | 57 | GEN_DIR = pathlib.Path(args.generator).absolute().parents[0] 58 | # loading config 59 | with open(GEN_DIR / "config.yml", "r") as stream: 60 | CONFIG = yaml.load(stream) 61 | logging.info("Generator config: {}".format(CONFIG)) 62 | 63 | ################ 64 | # # 65 | # Dataset prep # 66 | # # 67 | ################ 68 | 69 | _, test_transforms = options.common.get_transforms(CONFIG) 70 | dataset_test = options.common.get_dataset(CONFIG, split="test", transforms=test_transforms) 71 | #dataset_test = torch.utils.data.Subset(dataset_test, list(range(16))) 72 | 73 | BATCH_SIZE = 8 74 | 75 | test_dataloader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE) 76 | 77 | ############### 78 | # # 79 | # Model setup # 80 | # # 81 | ############### 82 | 83 | 84 | dset_class = getattr(datasets, CONFIG["dataset"]["name"]) 85 | n_labels = dset_class.N_LABELS 86 | # get channels from output from generator 87 | input_nc = dset_class.N_CHANNELS[CONFIG["dataset"]["output"]] 88 | seg_model = models.unet.UNet(input_nc, n_labels).to(device) 89 | 90 | seg_model.load_state_dict(torch.load(args.segmentor).state_dict()) 91 | seg_model.to(device).eval() 92 | 93 | gen_model = options.gan.get_generator(CONFIG) 94 | # remove distributed wrapping, i.e. module. from keynames 95 | state_dict = utils.unwrap_state_dict(torch.load(args.generator)) 96 | gen_model.load_state_dict(state_dict) 97 | gen_model = gen_model.to(device).eval() 98 | 99 | ########### 100 | # # 101 | # Testing # 102 | # # 103 | ########### 104 | 105 | def vis_sample(real_img, fake_img, true_seg, pred_seg_real, pred_seg_fake): 106 | 107 | fig = plt.figure(figsize=(6., 4.2)) 108 | fig.subplots_adjust(top=.95, bottom=0.0, left=0.0, right=1.0, hspace=0.01, wspace=0.01) 109 | 110 | ax_rgb = fig.add_subplot(2, 3, 2) 111 | ax_rgb.imshow(real_img) 112 | ax_rgb.axis('off') 113 | ax_rgb.set_title('real') 114 | 115 | ax_rgb = fig.add_subplot(2, 3, 3) 116 | ax_rgb.imshow(fake_img) 117 | ax_rgb.axis('off') 118 | ax_rgb.set_title('fake') 119 | 120 | ax_seg_true = fig.add_subplot(2, 3, 4) 121 | ax_seg_true.imshow(dataset_test.seg2rgb(true_seg)) 122 | ax_seg_true.axis('off') 123 | ax_seg_true.set_title('ground truth') 124 | 125 | ax_seg_pred = fig.add_subplot(2, 3, 5) 126 | ax_seg_pred.imshow(dataset_test.seg2rgb(pred_seg_real)) 127 | ax_seg_pred.axis('off') 128 | 129 | ax_seg_pred = fig.add_subplot(2, 3, 6) 130 | ax_seg_pred.imshow(dataset_test.seg2rgb(pred_seg_fake)) 131 | ax_seg_pred.axis('off') 132 | 133 | return fig 134 | 135 | from ignite.metrics import Accuracy, IoU, mIoU, ConfusionMatrix 136 | from ignite.engine import Engine, create_supervised_evaluator 137 | from ignite.metrics.metrics_lambda import MetricsLambda 138 | from typing import Optional 139 | 140 | def cmAccuracy(cm: ConfusionMatrix, ignore_index: Optional[int] = None) -> MetricsLambda: 141 | """Calculates accuracy using :class:`~ignite.metrics.ConfusionMatrix` metric. 142 | Args: 143 | cm (ConfusionMatrix): instance of confusion matrix metric 144 | Returns: 145 | MetricsLambda 146 | """ 147 | # Increase floating point precision and pass to CPU 148 | cm = cm.type(torch.DoubleTensor) 149 | 150 | correct_pixels = cm.diag() 151 | total_class_pixels = cm.sum(dim=1) 152 | 153 | pix_accs = correct_pixels / (total_class_pixels + 1e-15) 154 | 155 | if ignore_index is not None: 156 | 157 | def ignore_index_fn(pix_accs_vector): 158 | if ignore_index >= len(pix_accs_vector): 159 | raise ValueError( 160 | "ignore_index {} is larger than the length of pix_accs vector {}".format(ignore_index, len(pix_accs_vector)) 161 | ) 162 | indices = list(range(len(pix_accs_vector))) 163 | indices.remove(ignore_index) 164 | return pix_accs_vector[indices] 165 | 166 | return MetricsLambda(ignore_index_fn, pix_accs) 167 | else: 168 | return pix_accs 169 | 170 | 171 | def output_transform(y_pred_and_y): 172 | y_pred, y = y_pred_and_y 173 | # remove one one encoding 174 | y = torch.argmax(y, dim=1) 175 | return y_pred, y 176 | 177 | def make_engine(process_function): 178 | evaluator = Engine(process_function) 179 | 180 | cm = ConfusionMatrix(num_classes=getattr(datasets, CONFIG["dataset"]["name"]).N_LABELS, output_transform=output_transform) 181 | IoU(cm, ignore_index=0).attach(evaluator, 'IoU') 182 | mIoU(cm, ignore_index=0).attach(evaluator, 'mIoU') 183 | Accuracy(output_transform=output_transform).attach(evaluator, 'Accuracy') 184 | cmAccuracy(cm, ignore_index=0).attach(evaluator, 'ClasswiseAccuracy') 185 | 186 | return evaluator 187 | 188 | def log_metrics(metrics): 189 | logging.info("mIoU: {:0>6.4f}".format(metrics['mIoU'])) 190 | logging.info("class-wise IoU:") 191 | for ds_class, iou_val in zip(getattr(datasets, CONFIG["dataset"]["name"]).classes, metrics['IoU']): 192 | logging.info("{:>40s}: {:0>6.4f}".format(ds_class, iou_val)) 193 | logging.info("pixel accuracy: {:0>6.4f}".format(metrics['Accuracy'])) 194 | logging.info("class-wise pixel accuracy:") 195 | for ds_class, iou_val in zip(getattr(datasets, CONFIG["dataset"]["name"]).classes, metrics['ClasswiseAccuracy']): 196 | logging.info("{:>40s}: {:0>6.4f}".format(ds_class, iou_val)) 197 | 198 | 199 | ####################### 200 | # # 201 | # Validation original # 202 | # # 203 | ####################### 204 | 205 | def validation_step_original(engine, batch): 206 | seg_model.eval() 207 | with torch.no_grad(): 208 | batch = {k: v.to(device) for k, v in batch.items()} 209 | y_pred = seg_model(batch[CONFIG["dataset"]["output"]]) 210 | y = batch["seg"] 211 | return y_pred, y 212 | 213 | evaluator = make_engine(validation_step_original) 214 | state = evaluator.run(test_dataloader) 215 | logging.info("real + ground truth labels") 216 | logging.info("==========================") 217 | log_metrics(evaluator.state.metrics) 218 | 219 | 220 | ########################################### 221 | # # 222 | # Validation with respect to ground truth # 223 | # # 224 | ########################################### 225 | 226 | def validation_step_wrt_gt(engine, batch): 227 | seg_model.eval() 228 | with torch.no_grad(): 229 | gen_input = {dt: batch[dt].to(device) for dt in CONFIG["dataset"]["input"]} 230 | gen_output = gen_model(gen_input) 231 | y_pred = seg_model(gen_output) 232 | y = batch["seg"].to(device) 233 | return y_pred, y 234 | 235 | evaluator = make_engine(validation_step_wrt_gt) 236 | state = evaluator.run(test_dataloader) 237 | logging.info("fake + ground truth labels") 238 | logging.info("==========================") 239 | log_metrics(evaluator.state.metrics) 240 | 241 | 242 | ####################################### 243 | # # 244 | # Validation with respect to Original # 245 | # # 246 | ####################################### 247 | 248 | def validation_step_wrt_gt(engine, batch): 249 | seg_model.eval() 250 | with torch.no_grad(): 251 | gen_input = {dt: batch[dt].to(device) for dt in CONFIG["dataset"]["input"]} 252 | gen_output = gen_model(gen_input) 253 | y_pred = seg_model(gen_output) 254 | y = seg_model(batch[CONFIG["dataset"]["output"]].to(device)) 255 | return y_pred, y 256 | 257 | evaluator = make_engine(validation_step_wrt_gt) 258 | state = evaluator.run(test_dataloader) 259 | logging.info("fake + labels from real") 260 | logging.info("=======================") 261 | log_metrics(evaluator.state.metrics) 262 | 263 | ############ 264 | # # 265 | # Plotting # 266 | # # 267 | ############ 268 | 269 | 270 | with torch.no_grad(): 271 | for idx, sample in tqdm.tqdm( 272 | enumerate(test_dataloader), total=len(test_dataloader) 273 | ): 274 | sample = {k: v.to(device) for k,v in sample.items()} 275 | 276 | # generator fake images 277 | gen_input = {dt: sample[dt] for dt in CONFIG["dataset"]["input"]} 278 | gen_output = gen_model(gen_input) 279 | 280 | seg_real = seg_model(sample[CONFIG["dataset"]["output"]]) 281 | est_real = torch.argmax(seg_real, dim=1) 282 | est_real_one_hot = torch.nn.functional.one_hot(est_real, seg_real.shape[1]).permute(0, 3, 1, 2) 283 | 284 | seg_fake = seg_model(gen_output) 285 | est_fake = torch.argmax(seg_fake, dim=1) 286 | est_fake_one_hot = torch.nn.functional.one_hot(est_fake, seg_fake.shape[1]).permute(0, 3, 1, 2) 287 | 288 | seg_gt = sample["seg"] 289 | seg_gt_not_one_hot = torch.argmax(seg_gt, 1) 290 | 291 | for idy, (rgb_real, rgb_fake, true_seg, pred_seg_real, pred_seg_fake) in enumerate(zip(sample[CONFIG["dataset"]["output"]], gen_output, sample["seg"], est_real, est_fake)): 292 | rgb_real = np.moveaxis(rgb_real.cpu().numpy(), 0, 2) 293 | rgb_fake = np.clip(np.moveaxis(rgb_fake.cpu().numpy(), 0, 2), 0, 1) 294 | true_seg = torch.argmax(true_seg, 0).cpu().numpy() 295 | pred_seg_real = pred_seg_real.cpu().numpy() 296 | pred_seg_fake = pred_seg_fake.cpu().numpy() 297 | fig = vis_sample(rgb_real, rgb_fake, true_seg, pred_seg_real, pred_seg_fake) 298 | fig.savefig(OUT_DIR / "{:04d}_{:04d}.jpg".format(idx, idy), dpi=200) 299 | plt.close(fig) 300 | -------------------------------------------------------------------------------- /synth_img.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pathlib 3 | 4 | import yaml 5 | import torch 6 | import torchvision.transforms.functional as TF 7 | import matplotlib.colors 8 | from PIL import Image 9 | import numpy as np 10 | 11 | import options.gan 12 | import datasets.nrw 13 | from utils import unwrap_state_dict 14 | 15 | 16 | from IPython import embed 17 | 18 | def invert_colormap(img, cmap, norm): 19 | img_invert = np.zeros(img.shape[:2], dtype=np.int32) 20 | for color, idx in zip(cmap.colors, range(int(norm.vmin), int(norm.vmax)+1)): 21 | # conversion from hex to rgb and rescaling 22 | color_rgb = matplotlib.colors.to_rgb(color) 23 | red, green, blue = (255*x for x in color_rgb) 24 | red_mask = img[:, :, 0] == red 25 | green_mask = img[:, :, 1] == green 26 | blue_mask = img[:, :, 2] == blue 27 | 28 | mask = np.logical_and(red_mask, green_mask) 29 | mask = np.logical_and(mask, blue_mask) 30 | 31 | img_invert[mask] = idx 32 | return img_invert 33 | 34 | 35 | ################### 36 | # # 37 | # Parse arguments # 38 | # # 39 | ################### 40 | 41 | parser = argparse.ArgumentParser( 42 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 43 | ) 44 | parser.add_argument( 45 | "--seg", type=str, help="segmentation map" 46 | ) 47 | parser.add_argument( 48 | "--dem", type=str, help="digitial elevation model" 49 | ) 50 | parser.add_argument("model", type=str) 51 | parser.add_argument("output", type=str) 52 | args = parser.parse_args() 53 | 54 | ######################## 55 | # # 56 | # Get config and model # 57 | # # 58 | ######################## 59 | 60 | OUT_DIR = pathlib.Path(args.model).absolute().parents[0] 61 | 62 | # loading config 63 | with open(OUT_DIR / "config.yml", "r") as stream: 64 | CONFIG = yaml.load(stream) 65 | print("config: {}".format(CONFIG)) 66 | 67 | if torch.cuda.device_count() >= 1: 68 | device = torch.device("cuda") 69 | else: 70 | device = torch.device("cpu") 71 | 72 | print("loading model {}".format(args.model)) 73 | model = options.gan.get_generator(CONFIG) 74 | # remove distributed wrapping, i.e. module. from keynames 75 | state_dict = unwrap_state_dict(torch.load(args.model)) 76 | model.load_state_dict(state_dict) 77 | model.eval() 78 | model.to(device) 79 | 80 | ############## 81 | # # 82 | # Load image # 83 | # # 84 | ############## 85 | 86 | def seg2tensor(seg): 87 | seg = np.array(Image.open(seg)) 88 | seg_inv = invert_colormap(seg, datasets.nrw.lcov_cmap, datasets.nrw.lcov_norm) 89 | seg_inv_one_hot = torch.nn.functional.one_hot(TF.to_tensor(seg_inv).long(), 11).squeeze().permute(2, 0, 1).float() 90 | return seg_inv_one_hot.unsqueeze(0) 91 | 92 | def dem2tensor(dem): 93 | dem = np.array(Image.open(dem)) 94 | return TF.to_tensor(dem).unsqueeze(0) 95 | 96 | sample = {} 97 | if args.seg: 98 | sample["seg"] = seg2tensor(args.seg) 99 | if args.dem: 100 | sample["dem"] = dem2tensor(args.dem) 101 | 102 | with torch.no_grad(): 103 | fake_rgb = model({k: v.to(device) for k, v in sample.items()}) 104 | 105 | 106 | def sar2rgb(sar): 107 | return np.squeeze(np.clip(255*sar, 0, 255).astype(np.uint8)) 108 | 109 | # for SAR 110 | # fake_rgb = sar2rgb(fake_rgb.squeeze().cpu().numpy()) 111 | 112 | # for RGB 113 | fake_rgb = (fake_rgb.squeeze().cpu().numpy() * 255).astype(np.uint8) 114 | fake_rgb = np.moveaxis(fake_rgb, 0, 2) 115 | 116 | result = Image.fromarray(fake_rgb) 117 | result.save(args.output) 118 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pathlib 3 | 4 | import torch 5 | import torch.nn 6 | import torchvision 7 | import numpy as np 8 | import tqdm 9 | import yaml 10 | 11 | import datasets.nrw 12 | import datasets.dfc 13 | import options.gan as options 14 | from utils import unwrap_state_dict 15 | 16 | 17 | ############################################### 18 | # # 19 | # Parsing and checking command line arguments # 20 | # # 21 | ############################################### 22 | 23 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 24 | parser.add_argument("model", type=str) 25 | args = parser.parse_args() 26 | 27 | print("loading model {}".format(args.model)) 28 | 29 | # infer output directory from model path 30 | OUT_DIR = pathlib.Path(args.model).absolute().parents[0] 31 | 32 | # loading config 33 | with open(OUT_DIR / "config.yml", "r") as stream: 34 | CONFIG = yaml.load(stream) 35 | print("config: {}".format(CONFIG)) 36 | 37 | train_transforms, test_transforms = options.common.get_transforms(CONFIG) 38 | 39 | dataset = options.common.get_dataset(CONFIG, split='test', transforms=test_transforms) 40 | 41 | 42 | ########### 43 | # # 44 | # Testing # 45 | # # 46 | ########### 47 | 48 | if torch.cuda.device_count() >= 1: 49 | device = torch.device("cuda") 50 | else: 51 | device = torch.device("cpu") 52 | 53 | model = options.get_generator(CONFIG) 54 | # remove distributed wrapping, i.e. module. from keynames 55 | state_dict = unwrap_state_dict(torch.load(args.model)) 56 | model.load_state_dict(state_dict) 57 | model.eval() 58 | model.to(device) 59 | 60 | 61 | ############ 62 | # # 63 | # Plotting # 64 | # # 65 | ############ 66 | 67 | BATCH_SIZE = 8 68 | 69 | test_dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE) 70 | 71 | with torch.no_grad(): 72 | for idx, sample in tqdm.tqdm( 73 | enumerate(test_dataloader), total=len(test_dataloader) 74 | ): 75 | imgs = [] 76 | 77 | gen_input = {dt: sample[dt] for dt in CONFIG["dataset"]["input"]} 78 | 79 | fake = model({k: v.to(device) for k, v in gen_input.items()}).cpu() 80 | real = sample[CONFIG["dataset"]["output"]] 81 | 82 | if CONFIG["dataset"]["output"] == "sar": 83 | fake = [ 84 | np.moveaxis(dataset.sar2rgb(np.moveaxis(x.numpy(), 0, -1)), -1, 0) 85 | for x in fake.clone().detach() 86 | ] 87 | fake = torch.tensor(fake).float() 88 | 89 | real = [ 90 | np.moveaxis(dataset.sar2rgb(np.moveaxis(x.numpy(), 0, -1)), -1, 0) 91 | for x in real.clone().detach() 92 | ] 93 | real = torch.tensor(real).float() 94 | 95 | if "dem" in gen_input: 96 | depth_as_rgb = [ 97 | np.moveaxis(dataset.depth2rgb(x.squeeze().numpy()), -1, 0) 98 | for x in sample["dem"].clone().detach() 99 | ] 100 | depth_as_rgb = torch.tensor(depth_as_rgb).float() 101 | imgs.append(depth_as_rgb) 102 | 103 | if "seg" in gen_input: 104 | seg_no_one_hot = torch.argmax(sample["seg"], 1).unsqueeze(1) 105 | seg_as_rgb = [ 106 | np.moveaxis(dataset.seg2rgb(x.squeeze()), -1, 0) for x in seg_no_one_hot 107 | ] 108 | seg_as_rgb = torch.tensor(seg_as_rgb).float() 109 | imgs.append(seg_as_rgb) 110 | 111 | if "sar" in gen_input: 112 | sar_as_rgb = [ 113 | np.moveaxis(dataset.sar2rgb(np.moveaxis(x, 0, -1)), -1, 0) 114 | for x in sample["sar"].clone().detach().numpy() 115 | ] 116 | sar_as_rgb = torch.tensor(sar_as_rgb).float() 117 | imgs.append(sar_as_rgb) 118 | 119 | imgs.append(fake) 120 | imgs.append(real) 121 | 122 | torchvision.utils.save_image( 123 | torch.cat(imgs), OUT_DIR / "{:04}.jpg".format(idx), 124 | ) 125 | -------------------------------------------------------------------------------- /test_unet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pathlib 3 | 4 | import numpy as np 5 | import tqdm 6 | import yaml 7 | import torch 8 | import matplotlib 9 | 10 | # headless 11 | matplotlib.use("Agg") 12 | import matplotlib.pyplot as plt 13 | 14 | import options.segment 15 | import datasets.nrw 16 | import datasets.dfc 17 | import models.unet 18 | 19 | ########################## 20 | # # 21 | # Comannd line arguments # 22 | # # 23 | ########################## 24 | 25 | parser = argparse.ArgumentParser(description="apply a model to the test data set") 26 | parser.add_argument("model", help="pt file of the segmentation model") 27 | args = parser.parse_args() 28 | 29 | print("loading model {}".format(args.model)) 30 | 31 | OUT_DIR = pathlib.Path(args.model).absolute().parents[0] 32 | # loading config 33 | with open(OUT_DIR / "config.yml", "r") as stream: 34 | CONFIG = yaml.load(stream) 35 | print("config: {}".format(CONFIG)) 36 | 37 | 38 | train_transforms, test_transforms = options.common.get_transforms(CONFIG) 39 | 40 | dataset_test = options.common.get_dataset( 41 | CONFIG, split="test", transforms=test_transforms 42 | ) 43 | 44 | ########### 45 | # # 46 | # Testing # 47 | # # 48 | ########### 49 | 50 | if torch.cuda.is_available(): 51 | device = torch.device("cuda") 52 | else: 53 | raise RuntimeError("This scripts expects CUDA to be available") 54 | 55 | dset_class = getattr(datasets, CONFIG["dataset"]["name"]) 56 | n_labels = dset_class.N_LABELS 57 | input_nc = dset_class.N_CHANNELS[CONFIG["dataset"]["input"]] 58 | 59 | model = models.unet.UNet(input_nc, n_labels).to(device) 60 | model.load_state_dict(torch.load(args.model).state_dict()) 61 | model.to(device).eval() 62 | 63 | 64 | def vis_sample(input_img, true_seg, pred_seg): 65 | 66 | fig = plt.figure(figsize=(6.05, 2)) 67 | fig.subplots_adjust( 68 | top=1.0, bottom=0.0, left=0.0, right=1.0, hspace=0.01, wspace=0.01 69 | ) 70 | 71 | ax_rgb = fig.add_subplot(1, 3, 1) 72 | ax_rgb.imshow(input_img) 73 | ax_rgb.axis("off") 74 | 75 | ax_seg_pred = fig.add_subplot(1, 3, 2) 76 | ax_seg_pred.imshow(dataset_test.seg2rgb(pred_seg)) 77 | ax_seg_pred.axis("off") 78 | 79 | ax_seg_true = fig.add_subplot(1, 3, 3) 80 | ax_seg_true.imshow(dataset_test.seg2rgb(true_seg)) 81 | ax_seg_true.axis("off") 82 | 83 | return fig 84 | 85 | 86 | def vis_single(img): 87 | fig = plt.figure(figsize=(1, 1)) 88 | fig.subplots_adjust(top=1.0, bottom=0.0, left=0.0, right=1.0) 89 | 90 | ax = fig.add_subplot(1, 1, 1) 91 | ax.imshow(img) 92 | ax.axis("off") 93 | 94 | return fig 95 | 96 | 97 | test_dataloader = torch.utils.data.DataLoader(dataset_test, batch_size=1) 98 | 99 | RGB_DIR = OUT_DIR / "rgb" 100 | LABEL_DIR = OUT_DIR / "label" 101 | SEGM_DIR = OUT_DIR / "segm" 102 | 103 | for img_dir in [RGB_DIR, LABEL_DIR, SEGM_DIR]: 104 | img_dir.mkdir(exist_ok=True) 105 | 106 | for idx, sample in tqdm.tqdm(enumerate(dataset_test), total=len(dataset_test)): 107 | asdf = sample[CONFIG["dataset"]["input"]].to(device).unsqueeze(0) 108 | output = model(asdf) 109 | pred = torch.argmax(output, dim=1) 110 | 111 | input_img = np.moveaxis(sample[CONFIG["dataset"]["input"]].numpy(), 0, -1) 112 | true_seg = torch.argmax(sample["seg"], 0).numpy() 113 | pred_seg = pred.cpu().squeeze().numpy() 114 | 115 | fig = vis_sample(input_img, true_seg, pred_seg) 116 | fig.savefig(OUT_DIR / "{:04d}.jpg".format(idx), dpi=200) 117 | 118 | fig_rgb = vis_single(input_img) 119 | fig_rgb.savefig(RGB_DIR / "{:04d}.jpg".format(idx), dpi=200) 120 | 121 | fig_seg = vis_single(dataset_test.seg2rgb(pred_seg)) 122 | fig_seg.savefig(SEGM_DIR / "{:04d}.png".format(idx), dpi=200) 123 | 124 | fig_lab = vis_single(dataset_test.seg2rgb(true_seg)) 125 | fig_lab.savefig(LABEL_DIR / "{:04d}.png".format(idx), dpi=200) 126 | 127 | plt.close(fig) 128 | plt.close(fig_rgb) 129 | plt.close(fig_seg) 130 | plt.close(fig_lab) 131 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.nn 5 | import torch.distributed 6 | 7 | import numpy as np 8 | import yaml 9 | 10 | import datasets.nrw 11 | import datasets.dfc 12 | import options.common 13 | import options.gan 14 | from trainer import Trainer 15 | 16 | 17 | ################################## 18 | # # 19 | # Parsing command line arguments # 20 | # # 21 | ################################## 22 | 23 | parser = options.gan.get_parser() 24 | args = parser.parse_args() 25 | 26 | OUT_DIR = args.out_dir / options.gan.args2str(args) 27 | # All process make the directory. 28 | # This avoids errors when setting up logging later due to race conditions. 29 | OUT_DIR.mkdir(exist_ok=True) 30 | 31 | 32 | ########### 33 | # # 34 | # Logging # 35 | # # 36 | ########### 37 | 38 | logging.basicConfig( 39 | format="%(asctime)s [%(levelname)-8s] %(message)s", 40 | level=logging.INFO, 41 | filename=OUT_DIR / "log_training.txt", 42 | ) 43 | logger = logging.getLogger() 44 | if args.local_rank == 0: 45 | logger.info("Saving logs, configs and models to %s", OUT_DIR) 46 | 47 | 48 | ################################### 49 | # # 50 | # Checking command line arguments # 51 | # # 52 | ################################### 53 | 54 | # Reproducibilty config https://pytorch.org/docs/stable/notes/randomness.html 55 | if args.seed is not None: 56 | torch.manual_seed(args.seed) 57 | np.random.seed(args.seed) 58 | torch.backends.cudnn.deterministic = True 59 | torch.backends.cudnn.benchmark = False 60 | logger.warning( 61 | "You have chosen to seed training. " 62 | "This will turn on the CUDNN deterministic setting, " 63 | "which can slow down your training considerably! " 64 | "You may see unexpected behavior when restarting " 65 | "from checkpoints." 66 | ) 67 | 68 | if len(args.crop) == 1: 69 | args.crop = args.crop[0] 70 | 71 | if len(args.resize) == 1: 72 | args.resize = args.resize[0] 73 | 74 | CONFIG = options.gan.args2dict(args) 75 | 76 | with open(OUT_DIR / "config.yml", "w") as cfg_file: 77 | yaml.dump(CONFIG, cfg_file) 78 | 79 | 80 | if not torch.cuda.is_available(): 81 | raise RuntimeError("This scripts expects CUDA to be available") 82 | 83 | device = torch.device("cuda:{}".format(args.local_rank)) 84 | 85 | # set device of this process. Otherwise apex.amp throws errors. 86 | # see https://github.com/NVIDIA/apex/issues/319 87 | torch.cuda.set_device(device) 88 | torch.distributed.init_process_group( 89 | "nccl", 90 | init_method="env://", 91 | world_size=torch.cuda.device_count(), 92 | rank=args.local_rank, 93 | ) 94 | 95 | 96 | ######################### 97 | # # 98 | # Dataset configuration # 99 | # # 100 | ######################### 101 | 102 | train_transforms, test_transforms = options.common.get_transforms(CONFIG) 103 | 104 | dataset = options.common.get_dataset(CONFIG, split="train", transforms=train_transforms) 105 | 106 | if args.local_rank == 0: 107 | logger.info(dataset) 108 | 109 | 110 | ################################ 111 | # # 112 | # Neural network configuration # 113 | # # 114 | ################################ 115 | 116 | g_net = options.gan.get_generator(CONFIG).to(device) 117 | d_net = options.gan.get_discriminator(CONFIG).to(device) 118 | 119 | ##################### 120 | # # 121 | # Distributed setup # 122 | # # 123 | ##################### 124 | 125 | # separate processing groups for generator and discriminator 126 | # https://discuss.pytorch.org/t/calling-distributeddataparallel-on-multiple-modules/38055 127 | g_pg = torch.distributed.new_group(range(torch.distributed.get_world_size())) 128 | g_net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(g_net, process_group=g_pg) 129 | g_net = torch.nn.parallel.DistributedDataParallel( 130 | g_net.cuda(args.local_rank), 131 | device_ids=[args.local_rank], 132 | output_device=args.local_rank, 133 | process_group=g_pg, 134 | ) 135 | 136 | d_pg = torch.distributed.new_group(range(torch.distributed.get_world_size())) 137 | # no batch norms in discriminator that need to be synced 138 | d_net = torch.nn.parallel.DistributedDataParallel( 139 | d_net.cuda(args.local_rank), 140 | device_ids=[args.local_rank], 141 | output_device=args.local_rank, 142 | process_group=d_pg, 143 | ) 144 | 145 | ############ 146 | # # 147 | # Training # 148 | # # 149 | ############ 150 | 151 | trainer = Trainer( 152 | g_net, 153 | d_net, 154 | args.input, 155 | args.output, 156 | feat_loss=CONFIG["training"]["lbda"], 157 | ) 158 | 159 | train_sampler = torch.utils.data.distributed.DistributedSampler( 160 | dataset, shuffle=True, num_replicas=torch.cuda.device_count(), rank=args.local_rank, 161 | ) 162 | train_dataloader = torch.utils.data.DataLoader( 163 | dataset, 164 | batch_size=args.batch_size // torch.cuda.device_count(), 165 | sampler=train_sampler, 166 | num_workers=args.num_workers, 167 | ) 168 | 169 | trainer.train(train_dataloader, args.epochs) 170 | 171 | 172 | ########## 173 | # # 174 | # Saving # 175 | # # 176 | ########## 177 | 178 | if args.local_rank == 0: 179 | torch.save(trainer.g_net.state_dict(), OUT_DIR / "model_gnet.pt") 180 | torch.save(trainer.d_net.state_dict(), OUT_DIR / "model_dnet.pt") 181 | -------------------------------------------------------------------------------- /train_unet.py: -------------------------------------------------------------------------------- 1 | # Script for distributed training with synchronized batch norm 2 | # Guide for converting a regular training script to a distributed one can 3 | # be found at https://github.com/dougsouza/pytorch-sync-batchnorm-example 4 | 5 | import argparse 6 | import datetime 7 | import numbers 8 | import pathlib 9 | import warnings 10 | import logging 11 | 12 | import torch 13 | import tqdm 14 | import yaml 15 | import numpy as np 16 | 17 | import loss 18 | import options.segment 19 | import datasets.nrw 20 | import datasets.dfc 21 | import models.unet 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | def train_and_eval( 27 | model, optimizer, criterion, train_dataloader, test_dataloader, num_epochs, device 28 | ): 29 | 30 | pbar = tqdm.tqdm(total=num_epochs) 31 | for n_epoch in range(num_epochs): 32 | 33 | loss_train, iou_val_train = train(model, optimizer, criterion, train_dataloader, 'rgb', device) 34 | 35 | loss_test, iou_val_test = evaluate( 36 | model, criterion, test_dataloader, 'rgb', device 37 | ) 38 | 39 | info_str = "epoch {:3d} train: loss={:6.3f} iou={:6.3f}, test loss={:6.3f} iou={:6.3f}".format( 40 | n_epoch, loss_train, loss_test, iou_val_train, iou_val_test 41 | ) 42 | 43 | logger.info(info_str) 44 | 45 | pbar.update(1) 46 | pbar.set_description(info_str) 47 | pbar.write(info_str) 48 | return model 49 | 50 | 51 | def evaluate(model, criterion, dataloader, src="rgb", device='cuda'): 52 | """ evaluates a model on the given dataset 53 | 54 | Parameters 55 | ---------- 56 | 57 | model: (torch.nn.Module) 58 | the neural network 59 | optimizer: (torch.optim) 60 | optimizer for parameters of model 61 | criterion: unction 62 | takes batch_output and batch_labels and computes the loss for the batch 63 | dataloader: torch.utils.data.DataLoader 64 | fetches training data 65 | 66 | """ 67 | 68 | model.eval() 69 | 70 | running_loss = 0.0 71 | iou_val = 0.0 72 | for sample in dataloader: 73 | sample = {k: v.to(device) for k, v in sample.items()} 74 | 75 | with torch.no_grad(): 76 | output = model(sample[src]) 77 | 78 | pred = (output > 0).long() 79 | 80 | loss_t = criterion(output, sample["seg"]) 81 | running_loss += loss_t.item() 82 | iou_val += loss.iou(pred, sample["seg"]).item() 83 | 84 | running_loss /= len(dataloader) 85 | iou_val /= len(dataloader) 86 | 87 | return running_loss, iou_val 88 | 89 | 90 | def train(model, optimizer, criterion, dataloader, src="rgb", device='cuda'): 91 | """ trains a model for one epoch 92 | 93 | Parameters 94 | ---------- 95 | 96 | model: (torch.nn.Module) 97 | the neural network 98 | optimizer: (torch.optim) 99 | optimizer for parameters of model 100 | loss_fn: unction 101 | takes batch_output and batch_labels and computes the loss for the batch 102 | dataloader: torch.utils.data.DataLoader 103 | fetches training data 104 | 105 | """ 106 | 107 | model.train() 108 | 109 | running_loss = 0.0 110 | iou_val = 0.0 111 | for sample in dataloader: 112 | optimizer.zero_grad() 113 | 114 | sample = {k: v.to(device) for k, v in sample.items()} 115 | 116 | output = model(sample[src]) 117 | pred = (output > 0).long() 118 | 119 | loss_t = criterion(output, sample["seg"]) 120 | 121 | loss_t.backward() 122 | running_loss += loss_t.item() 123 | iou_val += loss.iou(pred, sample["seg"]).item() 124 | optimizer.step() 125 | 126 | running_loss /= len(dataloader) 127 | iou_val /= len(dataloader) 128 | 129 | return running_loss, iou_val 130 | 131 | 132 | if __name__ == "__main__": 133 | 134 | parser = options.segment.get_parser() 135 | args = parser.parse_args() 136 | 137 | # Reproducibilty config https://pytorch.org/docs/stable/notes/randomness.html 138 | if args.seed is not None: 139 | torch.manual_seed(args.seed) 140 | np.random.seed(args.seed) 141 | torch.backends.cudnn.deterministic = True 142 | torch.backends.cudnn.benchmark = False 143 | warnings.warn( 144 | "You have chosen to seed training. " 145 | "This will turn on the CUDNN deterministic setting, " 146 | "which can slow down your training considerably! " 147 | "You may see unexpected behavior when restarting " 148 | "from checkpoints." 149 | ) 150 | 151 | if len(args.crop) == 1: 152 | args.crop = args.crop[0] 153 | 154 | if len(args.resize) == 1: 155 | args.resize = args.resize[0] 156 | 157 | if torch.cuda.is_available(): 158 | device = torch.device("cuda") 159 | else: 160 | raise RuntimeError("This scripts expects CUDA to be available") 161 | OUT_DIR = args.out_dir / options.segment.args2str(args) 162 | # All process make the directory. 163 | # This avoids errors when setting up logging later due to race conditions. 164 | OUT_DIR.mkdir(exist_ok=True) 165 | 166 | ########### 167 | # # 168 | # Logging # 169 | # # 170 | ########### 171 | 172 | logging.basicConfig( 173 | format="%(asctime)s [%(levelname)-8s] %(message)s", 174 | level=logging.INFO, 175 | filename=OUT_DIR / "log_training.txt", 176 | ) 177 | logger = logging.getLogger() 178 | logger.info("Saving logs, configs and models to %s", OUT_DIR) 179 | 180 | CONFIG = options.segment.args2dict(args) 181 | with open(OUT_DIR / "config.yml", "w") as cfg_file: 182 | yaml.dump(CONFIG, cfg_file) 183 | 184 | ######################### 185 | # # 186 | # Dataset configuration # 187 | # # 188 | ######################### 189 | 190 | train_transforms, test_transforms = options.common.get_transforms(CONFIG) 191 | 192 | dataset_train = options.common.get_dataset(CONFIG, split='train', transforms=train_transforms) 193 | dataset_test = options.common.get_dataset(CONFIG, split='test', transforms=test_transforms) 194 | 195 | print("training dataset statistics") 196 | print(dataset_train) 197 | 198 | print("testing dataset statistics") 199 | print(dataset_test) 200 | 201 | dataloader_train = torch.utils.data.DataLoader( 202 | dataset_train, 203 | batch_size=args.batch_size, 204 | shuffle=True, 205 | num_workers=args.num_workers, 206 | ) 207 | 208 | dataloader_test = torch.utils.data.DataLoader( 209 | dataset_test, 210 | batch_size=args.batch_size, 211 | shuffle=False, 212 | num_workers=args.num_workers, 213 | ) 214 | 215 | ############### 216 | # # 217 | # Model setup # 218 | # # 219 | ############### 220 | 221 | criterion = torch.nn.BCEWithLogitsLoss() 222 | 223 | dset_class = getattr(datasets, CONFIG["dataset"]["name"]) 224 | n_labels = dset_class.N_LABELS 225 | input_nc = dset_class.N_CHANNELS[CONFIG["dataset"]["input"]] 226 | model = models.unet.UNet(input_nc, n_labels).to(device) 227 | 228 | ############ 229 | # # 230 | # TRAINING # 231 | # # 232 | ############ 233 | 234 | optimizer = torch.optim.Adam( 235 | model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay, 236 | ) 237 | 238 | train_and_eval( 239 | model, 240 | optimizer, 241 | criterion, 242 | dataloader_train, 243 | dataloader_test, 244 | num_epochs=args.epochs, 245 | device=device, 246 | ) 247 | 248 | ############## 249 | # # 250 | # save model # 251 | # # 252 | ############## 253 | 254 | torch.save(model, OUT_DIR / "{}.pt".format(options.segment.args2str(args))) 255 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import tqdm 5 | 6 | import loss 7 | 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | # ToDo 12 | # Implement and test different learning rates for generator and discriminator and 13 | # do multiple discriminator steps per generator step. See [1] and [2] 14 | # 15 | # [1] Heusel et. al., "GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium", 2018 16 | # [2] Zhang et. al., "Self-Attention Generative Adversarial Networks", 2019 17 | 18 | 19 | class Trainer: 20 | def __init__( 21 | self, g_net, d_net, src=["dem", "seg"], dest="rgb", feat_loss=None 22 | ): 23 | self.rank = torch.distributed.get_rank() 24 | 25 | self.src = src 26 | self.dest = dest 27 | 28 | self.d_net = d_net 29 | self.g_net = g_net 30 | 31 | # parameters taken from SPADE https://github.com/NVlabs/SPADE/issues/50#issuecomment-494217696 32 | self.g_optim = torch.optim.Adam( 33 | self.g_net.parameters(), lr=0.0001, betas=(0, 0.9) 34 | ) 35 | self.d_optim = torch.optim.Adam( 36 | self.d_net.parameters(), lr=0.0004, betas=(0, 0.9) 37 | ) 38 | 39 | self.g_loss = loss.HingeGenerator() 40 | self.g_feat_lambda = feat_loss 41 | if feat_loss is not None: 42 | self.g_feat_loss = torch.nn.functional.l1_loss 43 | self.d_net.module.return_intermed = True 44 | self.d_loss = loss.HingeDiscriminator() 45 | 46 | def sample2gen_input(self, sample): 47 | return {src: sample[src] for src in self.src} 48 | 49 | def g_one_step(self, sample): 50 | self.g_optim.zero_grad() 51 | 52 | g_input = self.sample2gen_input(sample) 53 | 54 | dest_fake = self.g_net(g_input) 55 | d_output_fake = self.d_net(g_input, dest_fake) 56 | 57 | loss_val = sum(self.g_loss(o) for o in d_output_fake.final) 58 | 59 | if self.g_feat_lambda is not None: 60 | dest_real = sample[self.dest] 61 | d_output_real = self.d_net(g_input, dest_real) 62 | if not d_output_real.features: 63 | logger.error("Trying to compute feature loss on empty list") 64 | raise RuntimeError("Trying to compute feature loss on empty list") 65 | 66 | feat_loss = sum( 67 | self.g_feat_loss(fake, real) 68 | for real, fake in zip( 69 | d_output_real.features, d_output_fake.features 70 | ) 71 | ) 72 | loss_val += self.g_feat_lambda * feat_loss 73 | 74 | loss_val.backward() 75 | self.g_optim.step() 76 | 77 | return loss_val 78 | 79 | def d_one_step(self, sample): 80 | self.d_optim.zero_grad() 81 | 82 | g_input = self.sample2gen_input(sample) 83 | # call detach to not compute gradients for generator 84 | dest_fake = self.g_net(g_input).detach() 85 | dest_real = sample[self.dest] 86 | 87 | disc_real = self.d_net(g_input, dest_real).final 88 | disc_fake = self.d_net(g_input, dest_fake).final 89 | 90 | loss_val = sum( 91 | self.d_loss(*disc_out) for disc_out in zip(disc_real, disc_fake) 92 | ) 93 | 94 | loss_val.backward() 95 | self.d_optim.step() 96 | 97 | return loss_val 98 | 99 | def train(self, dataloader, n_epochs): 100 | pbar = tqdm.tqdm(total=n_epochs) 101 | for n_epoch in range(1, n_epochs + 1): 102 | running_g_loss = torch.tensor(0.0, requires_grad=False) 103 | running_d_loss = torch.tensor(0.0, requires_grad=False) 104 | for idx, sample in enumerate(dataloader): 105 | g_loss = self.g_one_step(sample) 106 | torch.distributed.all_reduce(g_loss) 107 | running_g_loss += g_loss.item() 108 | 109 | d_loss = self.d_one_step(sample) 110 | torch.distributed.all_reduce(d_loss) 111 | running_d_loss += d_loss.item() 112 | 113 | if self.rank == 0: 114 | logger.debug( 115 | "batch idx {:3d}, g_loss:{:7.3f}, d_loss:{:7.3f}".format( 116 | idx, g_loss.item(), d_loss.item() 117 | ) 118 | ) 119 | 120 | running_g_loss /= len(dataloader) 121 | running_d_loss /= len(dataloader) 122 | 123 | info_str = "epoch {:3d}, g_loss:{:7.3f}, d_loss:{:7.3f}".format( 124 | n_epoch, running_g_loss, running_d_loss 125 | ) 126 | pbar.update(1) 127 | pbar.set_description(info_str) 128 | if self.rank == 0: 129 | pbar.write(info_str) 130 | logger.info(info_str) 131 | 132 | return None 133 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | def unwrap_state_dict(state_dict): 4 | """ cleans up the keys of a model state dictionary 5 | 6 | Methods and classes such as convert_sync_batchnorm or DataParallel wrap their 7 | respective module, which also alters the state dictionary's keys. 8 | This functions removes the leading module. string of the keys. 9 | 10 | """ 11 | 12 | unwrap = lambda x: ".".join(x.split(".")[1:]) 13 | 14 | # PyTorch state dicionaries are just regulard ordered dictionaries 15 | return OrderedDict((unwrap(k), v) for k, v in state_dict.items()) 16 | --------------------------------------------------------------------------------