├── .gitignore ├── Dockerfile ├── README.md ├── data ├── __init__.py ├── aligned_dataset.py ├── aligned_ir_dataset.py ├── base_dataset.py ├── dataLoader.py ├── image_folder.py ├── sen12mscrts_dataset.py ├── single_dataset.py ├── temporal_dataset.py ├── temporal_ir_dataset.py └── unaligned_dataset.py ├── models ├── __init__.py ├── base_model.py ├── cycle_gan_model.py ├── network_resnet_branched.py ├── networks.py ├── networks_branched.py ├── pix2pix_ir_model.py ├── pix2pix_model.py ├── template_model.py ├── temporal_branched_ir_model.py ├── temporal_branched_ir_modified_model.py ├── temporal_branched_model.py ├── temporal_ir_model.py ├── temporal_model.py └── test_model.py ├── options ├── __init__.py ├── base_options.py ├── test_options.py └── train_options.py ├── preview └── single_banner.png ├── standalone_dataloader.py ├── test.py ├── train.py └── util ├── __init__.py ├── detect_cloudshadow.py ├── dl_data.sh ├── get_data.py ├── hdf5converter ├── script_tif2hdf5.sh ├── sen12mscrts_to_hdf5.py └── tif2hdf5.py ├── html.py ├── image_pool.py ├── pytorch_ssim └── __init__.py ├── util.py └── visualizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | */*.pyc 2 | */**/*.pyc 3 | */**/**/*.pyc 4 | */**/**/**/*.pyc 5 | */**/**/**/**/*.pyc 6 | 7 | ./__pycache__ 8 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:latest 2 | 3 | # note: as of now, pytorch/pytorch:latest is not compiled for CUDA > 11.3 yet, 4 | # if you run CUDA > 11.3 please consider base image nvcr.io/nvidia/pytorch:latest 5 | # on NGS: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch 6 | 7 | # in case you run CUDA > 11.3 and prefer pytorch/pytorch:latest, then consider this conda-forge build: 8 | # RUN conda install pytorch torchvision torchaudio cudatoolkit=11.6 -c pytorch -c conda-forge 9 | 10 | # install dependencies 11 | RUN conda install -c conda-forge cupy 12 | RUN conda install -c conda-forge opencv 13 | RUN pip install scipy rasterio natsort matplotlib scikit-image tqdm natsort 14 | RUN pip install s2cloudless 15 | RUN conda install pillow=6.1 16 | RUN pip install dominate 17 | RUN pip install visdom 18 | 19 | # bake repository into dockerfile 20 | RUN mkdir -p ./data 21 | RUN mkdir -p ./models 22 | RUN mkdir -p ./options 23 | RUN mkdir -p ./util 24 | 25 | ADD data ./data 26 | ADD models ./models 27 | ADD options ./options 28 | ADD util ./util 29 | ADD . ./ 30 | 31 | WORKDIR /workspace 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SEN12MS-CR-TS Toolbox 2 | 3 | ![banner gif](preview/single_banner.png) 4 | > 5 | > _On average, the majority of all optical satellite data is affected by clouds. This observation shows a scene of agricultural land cover in Czechia from the SEN12MS-CR-TS data set for multi-modal multi-temporal cloud removal. SEN12MS-CR-TS contains whole-year time series of radar and optical satellite data distributed globally across our planet's surface._ 6 | ---- 7 | This repository contains code accompanying the paper 8 | > Ebel, P., Xu, Y., & Schmitt, M. , & Zhu, X. X. (2022). SEN12MS-CR-TS: A Remote Sensing Data Set for Multi-modal Multi-temporal Cloud Removal. IEEE Transactions on Geoscience and Remote Sensing, In Press. 9 | 10 | It serves as a quick start for working with the associated SEN12MS-CR-TS data set. For additional information: 11 | 12 | * The open-access publication is available at [the IEEE TGRS page](https://ieeexplore.ieee.org/document/9691348). 13 | * The open-access SEN12MS-CR data set is available at the MediaTUM page [here](https://mediatum.ub.tum.de/1639953) (train split) and [here](https://mediatum.ub.tum.de/1659251) (test split). 14 | * You can find additional information on this and related projects on the associated [cloud removal projects page](https://patrickTUM.github.io/cloud_removal/). 15 | * For any further questions, please reach out to me here or via the credentials on my [website](https://pwjebel.com). 16 | --- 17 | 18 | ## Installation 19 | ### Dataset 20 | You can download the SEN12MS-CR-TS data set (or parts of it) via the MediaTUM website [here](https://mediatum.ub.tum.de/1639953) (train split) and [here](https://mediatum.ub.tum.de/1659251) (test split) or in the terminal (passwd: *m1639953* or *m1659251*) using wget or rsync, for instance via 21 | 22 | ```bash 23 | wget "ftp://m1639953:m1639953@dataserv.ub.tum.de/s1_africa.tar.gz" 24 | rsync -chavzP --stats rsync://m1639953@dataserv.ub.tum.de/m1639953/ . 25 | rsync -chavzP --stats rsync://m1659251@dataserv.ub.tum.de/m1659251/ . 26 | ``` 27 | 28 | For the sake of convenient downloading and unzipping, the data set is sharded into separate archives per sensor modality and geographical region. You can, if needed only download and exclusively work on e.g. Sentinel-2 data for cloud removal in Africa. However, we recommend utilizing the global distribution of ROI and emphasize that this code base is written with the full data set in mind. After all archives are downloaded and their subdirectories extracted (e.g. via `find . -name '*.tar.gz' -exec tar -xzvf {} \;`), you can simply merge them via `rsync -a */* .` in the parent directory to obtain the required structure that the repository's code expects. Handle the test split likewise. 29 | 30 | **Update:** You can now easily download SEN12MS-CR-TS (and SEN12MS-CR) via the shell script provided [here](https://github.com/PatrickTUM/SEN12MS-CR-TS/blob/master/util/dl_data.sh). 31 | 32 | ### Code 33 | Clone this data set via `git clone https://github.com/PatrickTUM/SEN12MS-CR-TS.git`. 34 | 35 | The code is written in Python 3 and uses PyTorch > 1.4. It is strongly recommended to run the code with CUDA and GPU support. The code has been developed and deployed in Ubuntu 20 LTS and should be able to run in any comparable OS. 36 | 37 | --- 38 | 39 | ## Usage 40 | ### Dataset 41 | If you already have your own model in place or wish to build one on the SEN12MS-CR-TS data loader for training and testing, the data loader can be used as a stand-alone script as demonstrated in `./standalone_dataloader.py`. This only requires the files `./data/dataLoader.py` (the actual data loader) and `./util/detect_cloudshadow.py` (if this type of cloud detector is chosen). 42 | 43 | For using the dataset as a stand-alone with your own model, loading multi-temporal multi-modal data from SEN12MS-CR-TS is as simple as 44 | 45 | ``` python 46 | import torch 47 | from data.dataLoader import SEN12MSCRTS 48 | dir_SEN12MSCRTS = '/path/to/your/SEN12MSCRTS' 49 | sen12mscrts = SEN12MSCRTS(dir_SEN12MSCRTS, split='all', region='all', n_input_samples=3) 50 | dataloader = torch.utils.data.DataLoader(sen12mscrts) 51 | 52 | for pdx, samples in enumerate(dataloader): print(samples['input'].keys()) 53 | ``` 54 | 55 | and, likewise, if you wish to (pre-)train on the mono-temporal multi-modal SEN12MS-CR dataset: 56 | 57 | ``` python 58 | import torch 59 | from data.dataLoader import SEN12MSCR 60 | dir_SEN12MSCR = '/path/to/your/SEN12MSCR' 61 | sen12mscr = SEN12MSCR(dir_SEN12MSCR, split='all', region='all') 62 | dataloader = torch.utils.data.DataLoader(sen12mscr) 63 | 64 | for pdx, samples in enumerate(dataloader): print(samples['input'].keys()) 65 | ``` 66 | 67 | Depending on your choice of the split, ROI, the length of the input time series and the cloud detector algorithm, you may end up with different samples of input and output data. We encourage making use of as much of the data set as practicable. However, to ensure a well-defined and replicable test split of holdout data on which to benchmark, we provide separate files [here](https://u.pcloud.link/publink/show?code=kZXdbk0ZaAHNV2a5ofbB9UW4xCyCT0YFYAFk) that can be loaded with the `--import_data_path /path/to/files/file.npy` flag. Please use those if you which to report your performances on the test split. 68 | 69 | ### Basic Commands 70 | You can train a new model via 71 | ```bash 72 | python train.py --dataroot /path/to/sen12mscrts --dataset_mode sen12mscrts --name exemplary_training_run --sample_type cloudy_cloudfree --model temporal_branched --netG resnet3d_9blocks_withoutBottleneck --gpu_ids 0 --max_dataset_size 100000 --checkpoints_dir /path/to/results --input_type train --cloud_masks s2cloudless_mask --include_S1 --input_nc 15 --output_nc 13 --G_loss L1 --lambda_GAN 0.0 --display_freq 1000 --alter_initial_model --initial_model_path /path/to/models/baseline_resnet.pth --n_input_samples 3 --region all 73 | ``` 74 | and you can test a (pre-)trained model via 75 | ```bash 76 | python test.py --dataroot /path/to/sen12mscrts --dataset_mode sen12mscrts --results_dir /path/to/results --checkpoints_dir /path/to/results --name exemplary_training_run --model temporal_branched --netG resnet3d_9blocks_withoutBottleneck --include_S1 --input_nc 15 --output_nc 13 --sample_type cloudy_cloudfree --cloud_masks s2cloudless_mask --input_type test --max_dataset_size 100000 --num_test 100000 --n_input_samples 3 --epoch latest --eval --phase test --alter_initial_model --initial_model_path /path/to/models/baseline_resnet.pth --min_cov 0.0 --max_cov 1.0 --region all 77 | ``` 78 | 79 | For a list and description of all flags, please see the parser files in directory `./options`. 80 | 81 | --- 82 | 83 | 84 | ## References 85 | 86 | If you use this code, our models or data set for your research, please cite [this](https://ieeexplore.ieee.org/document/9691348) publication: 87 | ```bibtex 88 | @article{sen12mscrts, 89 | title = {{SEN12MS-CR-TS: A Remote Sensing Data Set for Multi-modal Multi-temporal Cloud Removal}}, 90 | author = {Ebel, Patrick and Xu, Yajin and Schmitt, Michael and Zhu, Xiao Xiang}, 91 | journal = {IEEE Transactions on Geoscience and Remote Sensing}, 92 | year = {2022} 93 | publisher = {IEEE} 94 | } 95 | ``` 96 | You may also be interested in our SEN12MS-CR data set for mono-temporal cloud removal (available [here](https://mediatum.ub.tum.de/1554803)) and the related publication (see [related paper](https://ieeexplore.ieee.org/document/9211498)). Also check out our recently released model for quantifying uncertainties in cloud removal, [UnCRtainTS](https://github.com/PatrickTUM/UnCRtainTS). You can find further information on these and related projects on the accompanying [cloud removal website](https://patrickTUM.github.io/cloud_removal/). 97 | 98 | 99 | 100 | ## Credits 101 | 102 | This code was originally based on the [STGAN repository](https://github.com/ermongroup/STGAN), which was originally based on the [pix2pix repository](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix). Our developed seq2point network was inspired by the original STGAN architecture (see [related paper](https://arxiv.org/abs/1912.06838)) as well as the ResNet for cloud removal in mono-temporal optical satellite data (see [related paper](https://www.sciencedirect.com/science/article/pii/S0924271620301398)). Thanks for making your code publicly available! 103 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes all the modules related to data loading and preprocessing 2 | 3 | To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. 4 | You need to implement four functions: 5 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 6 | -- <__len__>: return the size of dataset. 7 | -- <__getitem__>: get a data point from data loader. 8 | -- : (optionally) add dataset-specific options and set default options. 9 | 10 | Now you can use the dataset class by specifying flag '--dataset_mode dummy'. 11 | See our template dataset class 'template_dataset.py' for more details. 12 | """ 13 | import importlib 14 | import torch.utils.data 15 | from data.base_dataset import BaseDataset 16 | 17 | 18 | def find_dataset_using_name(dataset_name): 19 | """Import the module "data/[dataset_name]_dataset.py". 20 | 21 | In the file, the class called DatasetNameDataset() will 22 | be instantiated. It has to be a subclass of BaseDataset, 23 | and it is case-insensitive. 24 | """ 25 | dataset_filename = "data." + dataset_name + "_dataset" 26 | datasetlib = importlib.import_module(dataset_filename) 27 | 28 | dataset = None 29 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 30 | for name, cls in datasetlib.__dict__.items(): 31 | if name.lower() == target_dataset_name.lower() \ 32 | and issubclass(cls, BaseDataset): 33 | dataset = cls 34 | 35 | if dataset is None: 36 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) 37 | 38 | return dataset 39 | 40 | 41 | def get_option_setter(dataset_name): 42 | """Return the static method of the dataset class.""" 43 | dataset_class = find_dataset_using_name(dataset_name) 44 | return dataset_class.modify_commandline_options 45 | 46 | 47 | def create_dataset(opt): 48 | """Create a dataset given the option. 49 | 50 | This function wraps the class CustomDatasetDataLoader. 51 | This is the main interface between this package and 'train.py'/'test.py' 52 | 53 | Example: 54 | >>> from data import create_dataset 55 | >>> dataset = create_dataset(opt) 56 | """ 57 | data_loader = CustomDatasetDataLoader(opt) 58 | dataset = data_loader.load_data() 59 | return dataset 60 | 61 | 62 | class CustomDatasetDataLoader(): 63 | """Wrapper class of Dataset class that performs multi-threaded data loading""" 64 | 65 | def __init__(self, opt): 66 | """Initialize this class 67 | 68 | Step 1: create a dataset instance given the name [dataset_mode] 69 | Step 2: create a multi-threaded data loader. 70 | """ 71 | self.opt = opt 72 | dataset_class = find_dataset_using_name(opt.dataset_mode) 73 | self.dataset = dataset_class(opt) 74 | print(len(self.dataset)) 75 | self.dataloader = torch.utils.data.DataLoader( 76 | self.dataset, 77 | batch_size=opt.batch_size, 78 | shuffle=not opt.serial_batches, 79 | num_workers=int(opt.num_threads)) 80 | print("dataset [%s] was created" % type(self.dataset).__name__) 81 | 82 | def load_data(self): 83 | return self 84 | 85 | def __len__(self): 86 | """Return the number of data in the dataset""" 87 | return min(len(self.dataset), self.opt.max_dataset_size) 88 | 89 | def __iter__(self): 90 | """Return a batch of data""" 91 | for i, data in enumerate(self.dataloader): 92 | if i * self.opt.batch_size >= self.opt.max_dataset_size: 93 | break 94 | yield data 95 | -------------------------------------------------------------------------------- /data/aligned_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import random 3 | from data.base_dataset import BaseDataset, get_params, get_transform 4 | import torchvision.transforms as transforms 5 | from data.image_folder import make_dataset 6 | from PIL import Image 7 | 8 | 9 | class AlignedDataset(BaseDataset): 10 | """A dataset class for paired image dataset. 11 | 12 | It assumes that the directory '/path/to/data/train' contains image pairs in the form of {A,B}. 13 | During test time, you need to prepare a directory '/path/to/data/test'. 14 | """ 15 | 16 | def __init__(self, opt): 17 | """Initialize this dataset class. 18 | 19 | Parameters: 20 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 21 | """ 22 | BaseDataset.__init__(self, opt) 23 | self.dir_AB = os.path.join(opt.dataroot, opt.phase) # get the image directory 24 | self.AB_paths = sorted(make_dataset(self.dir_AB, opt.max_dataset_size)) # get image paths 25 | assert(self.opt.load_size >= self.opt.crop_size) # crop_size should be smaller than the size of loaded image 26 | self.input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc 27 | self.output_nc = self.opt.input_nc if self.opt.direction == 'BtoA' else self.opt.output_nc 28 | 29 | def __getitem__(self, index): 30 | """Return a data point and its metadata information. 31 | 32 | Parameters: 33 | index - - a random integer for data indexing 34 | 35 | Returns a dictionary that contains A, B, A_paths and B_paths 36 | A (tensor) - - an image in the input domain 37 | B (tensor) - - its corresponding image in the target domain 38 | A_paths (str) - - image paths 39 | B_paths (str) - - image paths (same as A_paths) 40 | """ 41 | # read a image given a random integer index 42 | AB_path = self.AB_paths[index] 43 | AB = Image.open(AB_path).convert('RGB') 44 | # split AB image into A and B 45 | w, h = AB.size 46 | w2 = int(w / 2) 47 | A = AB.crop((0, 0, w2, h)) 48 | B = AB.crop((w2, 0, w, h)) 49 | 50 | # apply the same transform to both A and B 51 | transform_params = get_params(self.opt, A.size) 52 | A_transform = get_transform(self.opt, transform_params, grayscale=(self.input_nc == 1)) 53 | B_transform = get_transform(self.opt, transform_params, grayscale=(self.output_nc == 1)) 54 | 55 | A = A_transform(A) 56 | B = B_transform(B) 57 | 58 | return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path} 59 | 60 | def __len__(self): 61 | """Return the total number of images in the dataset.""" 62 | return len(self.AB_paths) 63 | -------------------------------------------------------------------------------- /data/aligned_ir_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import random 3 | from data.base_dataset import BaseDataset, get_params, get_transform 4 | import torchvision.transforms as transforms 5 | from data.image_folder import make_dataset 6 | from PIL import Image 7 | import torch 8 | 9 | class AlignedIrDataset(BaseDataset): 10 | """A dataset class for paired image dataset. 11 | 12 | It assumes that the directory '/path/to/data/train' contains image pairs in the form of {A,B}. 13 | During test time, you need to prepare a directory '/path/to/data/test'. 14 | """ 15 | 16 | def __init__(self, opt): 17 | """Initialize this dataset class. 18 | 19 | Parameters: 20 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 21 | """ 22 | BaseDataset.__init__(self, opt) 23 | self.dir_AB = os.path.join(opt.dataroot, opt.phase) # get the image directory 24 | self.AB_paths = sorted(make_dataset(self.dir_AB, opt.max_dataset_size)) # get image paths 25 | assert(self.opt.load_size >= self.opt.crop_size) # crop_size should be smaller than the size of loaded image 26 | self.input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc 27 | self.output_nc = self.opt.input_nc if self.opt.direction == 'BtoA' else self.opt.output_nc 28 | 29 | def __getitem__(self, index): 30 | """Return a data point and its metadata information. 31 | 32 | Parameters: 33 | index - - a random integer for data indexing 34 | 35 | Returns a dictionary that contains A, B, A_paths and B_paths 36 | A (tensor) - - an image in the input domain 37 | B (tensor) - - its corresponding image in the target domain 38 | A_paths (str) - - image paths 39 | B_paths (str) - - image paths (same as A_paths) 40 | """ 41 | # read a image given a random integer index 42 | AB_path = self.AB_paths[index] 43 | AB = Image.open(AB_path).convert('RGB') 44 | # split AB image into A and B 45 | w, h = AB.size 46 | w2 = int(w / 3) 47 | A = AB.crop((0, 0, w2, h)) 48 | A_ir = AB.crop((w2, 0, 2*w2, h)) 49 | B = AB.crop((2*w2, 0, w, h)) 50 | 51 | # apply the same transform to both A and B 52 | transform_params = get_params(self.opt, A.size) 53 | A_transform = get_transform(self.opt, transform_params, grayscale=(self.input_nc == 1)) 54 | B_transform = get_transform(self.opt, transform_params, grayscale=(self.output_nc == 1)) 55 | 56 | A = A_transform(A) 57 | A_ir = A_transform(A_ir) 58 | B = B_transform(B) 59 | 60 | A_ir = A_ir[0,:,:].unsqueeze(0) 61 | A = torch.cat((A,A_ir), 0) # A_ir was stored with three identical channels, this makes A a 4-channel image 62 | return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path} 63 | 64 | def __len__(self): 65 | """Return the total number of images in the dataset.""" 66 | return len(self.AB_paths) 67 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets. 2 | 3 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. 4 | """ 5 | import random 6 | import numpy as np 7 | import torch.utils.data as data 8 | from PIL import Image 9 | import torchvision.transforms as transforms 10 | from abc import ABC, abstractmethod 11 | 12 | 13 | class BaseDataset(data.Dataset, ABC): 14 | """This class is an abstract base class (ABC) for datasets. 15 | 16 | To create a subclass, you need to implement the following four functions: 17 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 18 | -- <__len__>: return the size of dataset. 19 | -- <__getitem__>: get a data point. 20 | -- : (optionally) add dataset-specific options and set default options. 21 | """ 22 | 23 | def __init__(self, opt): 24 | """Initialize the class; save the options in the class 25 | 26 | Parameters: 27 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 28 | """ 29 | self.opt = opt 30 | self.root = opt.dataroot 31 | 32 | @staticmethod 33 | def modify_commandline_options(parser, is_train): 34 | """Add new dataset-specific options, and rewrite default values for existing options. 35 | 36 | Parameters: 37 | parser -- original option parser 38 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 39 | 40 | Returns: 41 | the modified parser. 42 | """ 43 | return parser 44 | 45 | @abstractmethod 46 | def __len__(self): 47 | """Return the total number of images in the dataset.""" 48 | return 0 49 | 50 | @abstractmethod 51 | def __getitem__(self, index): 52 | """Return a data point and its metadata information. 53 | 54 | Parameters: 55 | index - - a random integer for data indexing 56 | 57 | Returns: 58 | a dictionary of data with their names. It ususally contains the data itself and its metadata information. 59 | """ 60 | pass 61 | 62 | 63 | def get_params(opt, size): 64 | w, h = size 65 | new_h = h 66 | new_w = w 67 | if opt.preprocess == 'resize_and_crop': 68 | new_h = new_w = opt.load_size 69 | elif opt.preprocess == 'scale_width_and_crop': 70 | new_w = opt.load_size 71 | new_h = opt.load_size * h // w 72 | 73 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) 74 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) 75 | 76 | flip = random.random() > 0.5 77 | 78 | return {'crop_pos': (x, y), 'flip': flip} 79 | 80 | 81 | def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True): 82 | transform_list = [] 83 | if grayscale: 84 | transform_list.append(transforms.Grayscale(1)) 85 | if 'resize' in opt.preprocess: 86 | osize = [opt.load_size, opt.load_size] 87 | transform_list.append(transforms.Resize(osize, method)) 88 | elif 'scale_width' in opt.preprocess: 89 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method))) 90 | 91 | if 'crop' in opt.preprocess: 92 | if params is None: 93 | transform_list.append(transforms.RandomCrop(opt.crop_size)) 94 | else: 95 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) 96 | 97 | if opt.preprocess == 'none': 98 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method))) 99 | 100 | if not opt.no_flip: 101 | if params is None: 102 | transform_list.append(transforms.RandomHorizontalFlip()) 103 | elif params['flip']: 104 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 105 | 106 | if convert: 107 | transform_list += [transforms.ToTensor(), 108 | transforms.Normalize((0.5, 0.5, 0.5), 109 | (0.5, 0.5, 0.5))] 110 | return transforms.Compose(transform_list) 111 | 112 | 113 | def __make_power_2(img, base, method=Image.BICUBIC): 114 | ow, oh = img.size 115 | h = int(round(oh / base) * base) 116 | w = int(round(ow / base) * base) 117 | if (h == oh) and (w == ow): 118 | return img 119 | 120 | __print_size_warning(ow, oh, w, h) 121 | return img.resize((w, h), method) 122 | 123 | 124 | def __scale_width(img, target_width, method=Image.BICUBIC): 125 | ow, oh = img.size 126 | if (ow == target_width): 127 | return img 128 | w = target_width 129 | h = int(target_width * oh / ow) 130 | return img.resize((w, h), method) 131 | 132 | 133 | def __crop(img, pos, size): 134 | ow, oh = img.size 135 | x1, y1 = pos 136 | tw = th = size 137 | if (ow > tw or oh > th): 138 | return img.crop((x1, y1, x1 + tw, y1 + th)) 139 | return img 140 | 141 | 142 | def __flip(img, flip): 143 | if flip: 144 | return img.transpose(Image.FLIP_LEFT_RIGHT) 145 | return img 146 | 147 | 148 | def __print_size_warning(ow, oh, w, h): 149 | """Print warning information about image size(only print once)""" 150 | if not hasattr(__print_size_warning, 'has_printed'): 151 | print("The image size needs to be a multiple of 4. " 152 | "The loaded image size was (%d, %d), so it was adjusted to " 153 | "(%d, %d). This adjustment will be done to all images " 154 | "whose sizes are not multiples of 4" % (ow, oh, w, h)) 155 | __print_size_warning.has_printed = True 156 | -------------------------------------------------------------------------------- /data/image_folder.py: -------------------------------------------------------------------------------- 1 | """A modified image folder class 2 | 3 | We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) 4 | so that this class can load images from both current directory and its subdirectories. 5 | """ 6 | 7 | import torch.utils.data as data 8 | 9 | from PIL import Image 10 | import os 11 | import os.path 12 | 13 | IMG_EXTENSIONS = [ 14 | '.jpg', '.JPG', '.jpeg', '.JPEG', 15 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 16 | ] 17 | 18 | 19 | def is_image_file(filename): 20 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 21 | 22 | 23 | def make_dataset(dir, max_dataset_size=float("inf")): 24 | images = [] 25 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 26 | 27 | for root, _, fnames in sorted(os.walk(dir)): 28 | for fname in fnames: 29 | if is_image_file(fname): 30 | path = os.path.join(root, fname) 31 | images.append(path) 32 | return images[:min(max_dataset_size, len(images))] 33 | 34 | 35 | def default_loader(path): 36 | return Image.open(path).convert('RGB') 37 | 38 | 39 | class ImageFolder(data.Dataset): 40 | 41 | def __init__(self, root, transform=None, return_paths=False, 42 | loader=default_loader): 43 | imgs = make_dataset(root) 44 | if len(imgs) == 0: 45 | raise(RuntimeError("Found 0 images in: " + root + "\n" 46 | "Supported image extensions are: " + 47 | ",".join(IMG_EXTENSIONS))) 48 | 49 | self.root = root 50 | self.imgs = imgs 51 | self.transform = transform 52 | self.return_paths = return_paths 53 | self.loader = loader 54 | 55 | def __getitem__(self, index): 56 | path = self.imgs[index] 57 | img = self.loader(path) 58 | if self.transform is not None: 59 | img = self.transform(img) 60 | if self.return_paths: 61 | return img, path 62 | else: 63 | return img 64 | 65 | def __len__(self): 66 | return len(self.imgs) 67 | -------------------------------------------------------------------------------- /data/sen12mscrts_dataset.py: -------------------------------------------------------------------------------- 1 | """Dataset class SEN12MSCRTS 2 | 3 | This class wraps around the SEN12MSCRTS dataloader in ./dataLoader.py 4 | """ 5 | 6 | import numpy as np 7 | import random 8 | from data.base_dataset import BaseDataset 9 | import torchvision.transforms as transforms 10 | from data.image_folder import make_dataset 11 | from data.dataLoader import SEN12MSCRTS 12 | 13 | 14 | class Sen12mscrtsDataset(BaseDataset): 15 | @staticmethod 16 | def modify_commandline_options(parser, is_train): 17 | """Add new dataset-specific options, and rewrite default values for existing options. 18 | 19 | Parameters: 20 | parser -- original option parser 21 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 22 | 23 | Returns: 24 | the modified parser. 25 | """ 26 | parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option') 27 | parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) 28 | return parser 29 | 30 | def __init__(self, opt): 31 | """Initialize this dataset class. 32 | 33 | Parameters: 34 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 35 | 36 | A few things can be done here. 37 | - save the options (have been done in BaseDataset) 38 | - get image paths and meta information of the dataset. 39 | - define the image transformation. 40 | """ 41 | BaseDataset.__init__(self, opt) 42 | 43 | if opt.alter_initial_model or opt.benchmark_resnet_model: 44 | self.rescale_method = 'resnet' # rescale SAR to [0,2] and optical to [0,5] 45 | else: 46 | self.rescale_method = 'default' # rescale all to [-1,1] (gets rescaled to [0,1]) 47 | 48 | self.opt = opt 49 | self.data_loader = SEN12MSCRTS(opt.dataroot, split=opt.input_type, region=opt.region, cloud_masks=opt.cloud_masks, sample_type=opt.sample_type, n_input_samples=opt.n_input_samples, rescale_method=self.rescale_method, min_cov=opt.min_cov, max_cov=opt.max_cov, import_data_path=opt.import_data_path, export_data_path=opt.export_data_path) 50 | self.max_bands = 13 51 | 52 | def __getitem__(self, index): 53 | """Return a data point and its metadata information. 54 | 55 | Parameters: 56 | index -- a random integer for data indexing 57 | 58 | Returns: 59 | a dictionary of data with their names. It usually contains the data itself and its metadata information. 60 | 61 | Step 1: get a random image path: e.g., path = self.image_paths[index] 62 | Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB'). 63 | Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image) 64 | Step 4: return a data point as a dictionary. 65 | """ 66 | 67 | # call data loader to get item 68 | cloudy_cloudfree = self.data_loader.__getitem__(index) 69 | 70 | if self.opt.include_S1: 71 | input_channels = [i for i in range(self.max_bands)] 72 | 73 | # for each input sample, collect the SAR data 74 | A_S1 = [] 75 | for i in range(self.opt.n_input_samples): 76 | A_S1_01 = cloudy_cloudfree['input']['S1'][i] 77 | 78 | if self.rescale_method == 'default': 79 | A_S1.append((A_S1_01 * 2) - 1) # rescale from [0,1] to [-1,+1] 80 | elif self.rescale_method == 'resnet': 81 | A_S1.append(A_S1_01) # no need to rescale, keep at [0,2] 82 | 83 | # fetch the target S1 image (and optionally rescale) 84 | B_S1_01 = cloudy_cloudfree['target']['S1'][0] 85 | if self.rescale_method == 'default': 86 | B_S1 = (B_S1_01 * 2) - 1 # rescale from [0,1] to [-1,+1] 87 | elif self.rescale_method == 'resnet': 88 | B_S1 = B_S1_01 # no need to rescale, keep at [0,2] 89 | 90 | else: # not containing any S1 91 | assert self.opt.input_nc <= self.max_bands, Exception("MS input channel number larger than 13 (S1 not included)!") 92 | input_channels = [i for i in range(self.opt.input_nc)] 93 | 94 | # use only NIR+BGR channels when training STGAN 95 | if self.opt.model == "temporal_branched_ir_modified": input_channels = [7, 1, 2, 3] 96 | 97 | A_S2, A_S2_mask = [], [] 98 | 99 | if self.opt.in_only_S1: # using only S1 input 100 | input_channels = [i for i in range(self.max_bands)] 101 | for i in range(self.opt.n_input_samples): 102 | A_S2_01 = cloudy_cloudfree['input']['S1'][i] 103 | if self.rescale_method == 'default': 104 | A_S2.append((A_S2_01 * 2) - 1) # rescale from [0,1] to [-1,+1] 105 | elif self.rescale_method == 'resnet': 106 | A_S2.append(A_S2_01) # no need to rescale, keep at [0,5] 107 | A_S2_mask.append(cloudy_cloudfree['target']['masks'][0].reshape((1, 256, 256))) 108 | else: # this is the typical case 109 | for i in range(self.opt.n_input_samples): 110 | A_S2_01 = cloudy_cloudfree['input']['S2'][i][input_channels] 111 | if self.rescale_method == 'default': 112 | A_S2.append((A_S2_01 * 2) - 1) # rescale from 0,1 to -1,+1 113 | elif self.rescale_method == 'resnet': 114 | A_S2.append(A_S2_01) # no need to rescale, keep at [0,5] 115 | A_S2_mask.append(cloudy_cloudfree['input']['masks'][i].reshape((1, 256, 256))) 116 | 117 | # get the target cloud-free optical image 118 | B_01 = cloudy_cloudfree['target']['S2'][0] 119 | if self.opt.output_nc == 4: B_01 = B_01[input_channels] 120 | if self.rescale_method == 'default': 121 | B = (B_01 * 2) - 1 # rescale from [0,1] to [-1,+1] 122 | elif self.rescale_method == 'resnet': 123 | B = B_01 # no need to rescale, keep at [0,5] 124 | B_mask = cloudy_cloudfree['target']['masks'][0].reshape((1, 256, 256)) 125 | image_path = cloudy_cloudfree['target']['S2 path'] 126 | 127 | coverage_bin = True 128 | if "coverage bin" in cloudy_cloudfree: coverage_bin = cloudy_cloudfree["coverage bin"] 129 | 130 | if self.opt.include_S1: 131 | return {'A_S1': A_S1, 'A_S2': A_S2, 'A_mask': A_S2_mask, 'B': B, 'B_S1': B_S1, 'B_mask': B_mask, 'image_path': image_path, "coverage_bin": coverage_bin} 132 | else: 133 | return {'A_S2': A_S2, 'A_mask': A_S2_mask, 'B': B, 'B_mask': B_mask, 'image_path': image_path, "coverage_bin": coverage_bin} 134 | 135 | def __len__(self): 136 | """Return the total number of images.""" 137 | return len(self.data_loader) 138 | -------------------------------------------------------------------------------- /data/single_dataset.py: -------------------------------------------------------------------------------- 1 | from data.base_dataset import BaseDataset, get_transform 2 | from data.image_folder import make_dataset 3 | from PIL import Image 4 | 5 | 6 | class SingleDataset(BaseDataset): 7 | """This dataset class can load a set of images specified by the path --dataroot /path/to/data. 8 | 9 | It can be used for generating CycleGAN results only for one side with the model option '-model test'. 10 | """ 11 | 12 | def __init__(self, opt): 13 | """Initialize this dataset class. 14 | 15 | Parameters: 16 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 17 | """ 18 | BaseDataset.__init__(self, opt) 19 | self.A_paths = sorted(make_dataset(opt.dataroot, opt.max_dataset_size)) 20 | input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc 21 | self.transform = get_transform(opt, grayscale=(input_nc == 1)) 22 | 23 | def __getitem__(self, index): 24 | """Return a data point and its metadata information. 25 | 26 | Parameters: 27 | index - - a random integer for data indexing 28 | 29 | Returns a dictionary that contains A and A_paths 30 | A(tensor) - - an image in one domain 31 | A_paths(str) - - the path of the image 32 | """ 33 | A_path = self.A_paths[index] 34 | A_img = Image.open(A_path).convert('RGB') 35 | A = self.transform(A_img) 36 | return {'A': A, 'A_paths': A_path} 37 | 38 | def __len__(self): 39 | """Return the total number of images in the dataset.""" 40 | return len(self.A_paths) 41 | -------------------------------------------------------------------------------- /data/temporal_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import random 3 | from data.base_dataset import BaseDataset, get_params, get_transform 4 | import torchvision.transforms as transforms 5 | from data.image_folder import make_dataset 6 | from PIL import Image 7 | 8 | 9 | class TemporalDataset(BaseDataset): 10 | """A dataset class for temporal image dataset. 11 | 12 | It assumes that the directory '/path/to/data/train' contains image pairs in the form of {{A_0, A_1, A_2},B}. 13 | During test time, you need to prepare a directory '/path/to/data/test'. 14 | """ 15 | 16 | def __init__(self, opt): 17 | """Initialize this dataset class. 18 | 19 | Parameters: 20 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 21 | """ 22 | BaseDataset.__init__(self, opt) 23 | self.dir_AB = os.path.join(opt.dataroot, opt.phase) # get the image directory 24 | self.AB_paths = sorted(make_dataset(self.dir_AB, opt.max_dataset_size)) # get image paths 25 | assert(self.opt.load_size >= self.opt.crop_size) # crop_size should be smaller than the size of loaded image 26 | self.input_nc = self.opt.input_nc 27 | self.output_nc = self.opt.output_nc 28 | 29 | def __getitem__(self, index): 30 | """Return a data point and its metadata information. 31 | 32 | Parameters: 33 | index - - a random integer for data indexing 34 | 35 | Returns a dictionary that contains A, B, A_paths and B_paths 36 | A (tensor) - - an image in the input domain 37 | B (tensor) - - its corresponding image in the target domain 38 | A_paths (str) - - image paths 39 | B_paths (str) - - image paths (same as A_paths) 40 | """ 41 | # read a image given a random integer index 42 | AB_path = self.AB_paths[index] 43 | AB = Image.open(AB_path).convert('RGB') 44 | # split AB image into A and B 45 | 46 | w, h = AB.size 47 | w4 = int(w / 4) 48 | A_0 = AB.crop((0, 0, w4, h)) 49 | A_1 = AB.crop((w4, 0, 2*w4, h)) 50 | A_2 = AB.crop((2*w4, 0, 3*w4, h)) 51 | B = AB.crop((3*w4, 0, w, h)) 52 | 53 | # apply the same transform to both A and B 54 | transform_params = get_params(self.opt, A_0.size) 55 | A_transform = get_transform(self.opt, transform_params, grayscale=(self.input_nc == 1)) 56 | B_transform = get_transform(self.opt, transform_params, grayscale=(self.output_nc == 1)) 57 | 58 | A_0 = A_transform(A_0) 59 | A_1 = A_transform(A_1) 60 | A_2 = A_transform(A_2) 61 | B = B_transform(B) 62 | 63 | return {'A_0': A_0, 'A_1': A_1, 'A_2': A_2, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path} 64 | 65 | def __len__(self): 66 | """Return the total number of images in the dataset.""" 67 | return len(self.AB_paths) 68 | -------------------------------------------------------------------------------- /data/temporal_ir_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import random 3 | from data.base_dataset import BaseDataset, get_params, get_transform 4 | import torchvision.transforms as transforms 5 | from data.image_folder import make_dataset 6 | from PIL import Image 7 | 8 | 9 | class TemporalIrDataset(BaseDataset): 10 | """A dataset class for temporal image dataset. 11 | 12 | It assumes that the directory '/path/to/data/train' contains image pairs in the form of {{A_0, A_1, A_2, ir},B}. 13 | During test time, you need to prepare a directory '/path/to/data/test'. 14 | """ 15 | 16 | def __init__(self, opt): 17 | """Initialize this dataset class. 18 | 19 | Parameters: 20 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 21 | """ 22 | BaseDataset.__init__(self, opt) 23 | self.dir_AB = os.path.join(opt.dataroot, opt.phase) # get the image directory 24 | self.AB_paths = sorted(make_dataset(self.dir_AB, opt.max_dataset_size)) # get image paths 25 | assert(self.opt.load_size >= self.opt.crop_size) # crop_size should be smaller than the size of loaded image 26 | self.input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc 27 | self.output_nc = self.opt.input_nc if self.opt.direction == 'BtoA' else self.opt.output_nc 28 | 29 | def __getitem__(self, index): 30 | """Return a data point and its metadata information. 31 | 32 | Parameters: 33 | index - - a random integer for data indexing 34 | 35 | Returns a dictionary that contains A, B, A_paths and B_paths 36 | A (tensor) - - an image in the input domain 37 | B (tensor) - - its corresponding image in the target domain 38 | A_paths (str) - - image paths 39 | B_paths (str) - - image paths (same as A_paths) 40 | """ 41 | # read a image given a random integer index 42 | AB_path = self.AB_paths[index] 43 | AB = Image.open(AB_path).convert('RGB') 44 | # split AB image into A and B 45 | w, h = AB.size 46 | w5 = int(w / 5) 47 | A_0 = AB.crop((0, 0, w5, h)) 48 | A_1 = AB.crop((w5, 0, 2*w5, h)) 49 | A_2 = AB.crop((2*w5, 0, 3*w5, h)) 50 | ir = AB.crop((3*w5, 0, 4*w5, h)) 51 | B = AB.crop((4*w5, 0, w, h)) 52 | 53 | # apply the same transform to both A and B 54 | transform_params = get_params(self.opt, A_0.size) 55 | A_transform = get_transform(self.opt, transform_params, grayscale=(self.input_nc == 1)) 56 | B_transform = get_transform(self.opt, transform_params, grayscale=(self.output_nc == 1)) 57 | 58 | A_0 = A_transform(A_0) 59 | A_1 = A_transform(A_1) 60 | A_2 = A_transform(A_2) 61 | ir = A_transform(ir) 62 | B = B_transform(B) 63 | 64 | # Now split ir into constituent channels 65 | A_0_ir = ir[0,:,:].unsqueeze(0) 66 | A_1_ir = ir[1,:,:].unsqueeze(0) 67 | A_2_ir = ir[2,:,:].unsqueeze(0) 68 | 69 | return {'A_0': A_0, 'A_1': A_1, 'A_2': A_2,'A_0_ir': A_0_ir, 'A_1_ir': A_1_ir, 'A_2_ir': A_2_ir, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path} 70 | 71 | def __len__(self): 72 | """Return the total number of images in the dataset.""" 73 | return len(self.AB_paths) 74 | -------------------------------------------------------------------------------- /data/unaligned_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from data.base_dataset import BaseDataset, get_transform 3 | from data.image_folder import make_dataset 4 | from PIL import Image 5 | import random 6 | 7 | 8 | class UnalignedDataset(BaseDataset): 9 | """ 10 | This dataset class can load unaligned/unpaired datasets. 11 | 12 | It requires two directories to host training images from domain A '/path/to/data/trainA' 13 | and from domain B '/path/to/data/trainB' respectively. 14 | You can train the model with the dataset flag '--dataroot /path/to/data'. 15 | Similarly, you need to prepare two directories: 16 | '/path/to/data/testA' and '/path/to/data/testB' during test time. 17 | """ 18 | 19 | def __init__(self, opt): 20 | """Initialize this dataset class. 21 | 22 | Parameters: 23 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 24 | """ 25 | BaseDataset.__init__(self, opt) 26 | self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') # create a path '/path/to/data/trainA' 27 | self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') # create a path '/path/to/data/trainB' 28 | 29 | self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA' 30 | self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB' 31 | self.A_size = len(self.A_paths) # get the size of dataset A 32 | self.B_size = len(self.B_paths) # get the size of dataset B 33 | btoA = self.opt.direction == 'BtoA' 34 | input_nc = self.opt.output_nc if btoA else self.opt.input_nc # get the number of channels of input image 35 | output_nc = self.opt.input_nc if btoA else self.opt.output_nc # get the number of channels of output image 36 | self.transform_A = get_transform(self.opt, grayscale=(input_nc == 1)) 37 | self.transform_B = get_transform(self.opt, grayscale=(output_nc == 1)) 38 | 39 | def __getitem__(self, index): 40 | """Return a data point and its metadata information. 41 | 42 | Parameters: 43 | index (int) -- a random integer for data indexing 44 | 45 | Returns a dictionary that contains A, B, A_paths and B_paths 46 | A (tensor) -- an image in the input domain 47 | B (tensor) -- its corresponding image in the target domain 48 | A_paths (str) -- image paths 49 | B_paths (str) -- image paths 50 | """ 51 | A_path = self.A_paths[index % self.A_size] # make sure index is within then range 52 | if self.opt.serial_batches: # make sure index is within then range 53 | index_B = index % self.B_size 54 | else: # randomize the index for domain B to avoid fixed pairs. 55 | index_B = random.randint(0, self.B_size - 1) 56 | B_path = self.B_paths[index_B] 57 | A_img = Image.open(A_path).convert('RGB') 58 | B_img = Image.open(B_path).convert('RGB') 59 | # apply image transformation 60 | A = self.transform_A(A_img) 61 | B = self.transform_B(B_img) 62 | 63 | return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path} 64 | 65 | def __len__(self): 66 | """Return the total number of images in the dataset. 67 | 68 | As we have two datasets with potentially different number of images, 69 | we take a maximum of 70 | """ 71 | return max(self.A_size, self.B_size) 72 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | """This package contains modules related to objective functions, optimizations, and network architectures. 2 | 3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. 4 | You need to implement the following five functions: 5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 6 | -- : unpack data from dataset and apply preprocessing. 7 | -- : produce intermediate results. 8 | -- : calculate loss, gradients, and update network weights. 9 | -- : (optionally) add model-specific options and set default options. 10 | 11 | In the function <__init__>, you need to define four lists: 12 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 13 | -- self.model_names (str list): define networks used in our training. 14 | -- self.visual_names (str list): specify the images that you want to display and save. 15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. 16 | 17 | Now you can use the model class by specifying flag '--model dummy'. 18 | See our template model class 'template_model.py' for more details. 19 | """ 20 | 21 | import importlib 22 | from models.base_model import BaseModel 23 | 24 | 25 | def find_model_using_name(model_name): 26 | """Import the module "models/[model_name]_model.py". 27 | 28 | In the file, the class called DatasetNameModel() will 29 | be instantiated. It has to be a subclass of BaseModel, 30 | and it is case-insensitive. 31 | """ 32 | model_filename = "models." + model_name + "_model" 33 | modellib = importlib.import_module(model_filename) 34 | model = None 35 | target_model_name = model_name.replace('_', '') + 'model' 36 | for name, cls in modellib.__dict__.items(): 37 | if name.lower() == target_model_name.lower() \ 38 | and issubclass(cls, BaseModel): 39 | model = cls 40 | 41 | if model is None: 42 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 43 | exit(0) 44 | 45 | return model 46 | 47 | 48 | def get_option_setter(model_name): 49 | """Return the static method of the model class.""" 50 | model_class = find_model_using_name(model_name) 51 | return model_class.modify_commandline_options 52 | 53 | 54 | def create_model(opt): 55 | """Create a model given the option. 56 | 57 | This function warps the class CustomDatasetDataLoader. 58 | This is the main interface between this package and 'train.py'/'test.py' 59 | 60 | Example: 61 | >>> from models import create_model 62 | >>> model = create_model(opt) 63 | """ 64 | model = find_model_using_name(opt.model) 65 | instance = model(opt) 66 | print("model [%s] was created" % type(instance).__name__) 67 | return instance 68 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from abc import ABC, abstractmethod 5 | from . import networks 6 | 7 | 8 | class BaseModel(ABC): 9 | """This class is an abstract base class (ABC) for models. 10 | To create a subclass, you need to implement the following five functions: 11 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 12 | -- : unpack data from dataset and apply preprocessing. 13 | -- : produce intermediate results. 14 | -- : calculate losses, gradients, and update network weights. 15 | -- : (optionally) add model-specific options and set default options. 16 | """ 17 | 18 | def __init__(self, opt): 19 | """Initialize the BaseModel class. 20 | 21 | Parameters: 22 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 23 | 24 | When creating your custom class, you need to implement your own initialization. 25 | In this fucntion, you should first call 26 | Then, you need to define four lists: 27 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 28 | -- self.model_names (str list): specify the images that you want to display and save. 29 | -- self.visual_names (str list): define networks used in our training. 30 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. 31 | """ 32 | self.opt = opt 33 | self.gpu_ids = opt.gpu_ids 34 | self.isTrain = opt.isTrain 35 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU 36 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir 37 | if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark. 38 | torch.backends.cudnn.benchmark = True 39 | self.loss_names = [] 40 | self.model_names = [] 41 | self.visual_names = [] 42 | self.optimizers = [] 43 | self.image_paths = [] 44 | self.metric = 0 # used for learning rate policy 'plateau' 45 | 46 | @staticmethod 47 | def modify_commandline_options(parser, is_train): 48 | """Add new model-specific options, and rewrite default values for existing options. 49 | 50 | Parameters: 51 | parser -- original option parser 52 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 53 | 54 | Returns: 55 | the modified parser. 56 | """ 57 | return parser 58 | 59 | @abstractmethod 60 | def set_input(self, input): 61 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 62 | 63 | Parameters: 64 | input (dict): includes the data itself and its metadata information. 65 | """ 66 | pass 67 | 68 | @abstractmethod 69 | def forward(self): 70 | """Run forward pass; called by both functions and .""" 71 | pass 72 | 73 | @abstractmethod 74 | def optimize_parameters(self): 75 | """Calculate losses, gradients, and update network weights; called in every training iteration""" 76 | pass 77 | 78 | def setup(self, opt): 79 | """Load and print networks; create schedulers 80 | 81 | Parameters: 82 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 83 | """ 84 | if self.isTrain: 85 | self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] 86 | if not self.isTrain or opt.continue_train: 87 | load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch 88 | self.load_networks(load_suffix) 89 | self.print_networks(opt.verbose) 90 | 91 | def eval(self): 92 | """Make models eval mode during test time""" 93 | for name in self.model_names: 94 | if isinstance(name, str): 95 | net = getattr(self, 'net' + name) 96 | net.eval() 97 | 98 | def test(self): 99 | """Forward function used in test time. 100 | 101 | This function wraps function in no_grad() so we don't save intermediate steps for backprop 102 | It also calls to produce additional visualization results 103 | """ 104 | with torch.no_grad(): 105 | self.forward() 106 | self.compute_visuals() 107 | 108 | def compute_visuals(self): 109 | """Calculate additional output images for visdom and HTML visualization""" 110 | pass 111 | 112 | def get_image_paths(self): 113 | """ Return image paths that are used to load current data""" 114 | return self.image_paths 115 | 116 | def update_learning_rate(self): 117 | """Update learning rates for all the networks; called at the end of every epoch""" 118 | for scheduler in self.schedulers: 119 | scheduler.step(self.metric) 120 | lr = self.optimizers[0].param_groups[0]['lr'] 121 | print('learning rate = %.7f' % lr) 122 | 123 | def get_current_visuals(self): 124 | """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" 125 | visual_ret = OrderedDict() 126 | for name in self.visual_names: 127 | if isinstance(name, str): 128 | visual_ret[name] = getattr(self, name) 129 | return visual_ret 130 | 131 | def get_current_losses(self): 132 | """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" 133 | errors_ret = OrderedDict() 134 | for name in self.loss_names: 135 | if isinstance(name, str): 136 | errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number 137 | return errors_ret 138 | 139 | def save_networks(self, epoch): 140 | """Save all the networks to the disk. 141 | 142 | Parameters: 143 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 144 | """ 145 | for name in self.model_names: 146 | if isinstance(name, str): 147 | save_filename = '%s_net_%s.pth' % (epoch, name) 148 | save_path = os.path.join(self.save_dir, save_filename) 149 | net = getattr(self, 'net' + name) 150 | 151 | if len(self.gpu_ids) > 0 and torch.cuda.is_available(): 152 | torch.save(net.module.cpu().state_dict(), save_path) 153 | net.cuda(self.gpu_ids[0]) 154 | else: 155 | torch.save(net.cpu().state_dict(), save_path) 156 | 157 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): 158 | """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" 159 | key = keys[i] 160 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer 161 | if module.__class__.__name__.startswith('InstanceNorm') and \ 162 | (key == 'running_mean' or key == 'running_var'): 163 | if getattr(module, key) is None: 164 | state_dict.pop('.'.join(keys)) 165 | if module.__class__.__name__.startswith('InstanceNorm') and \ 166 | (key == 'num_batches_tracked'): 167 | state_dict.pop('.'.join(keys)) 168 | else: 169 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) 170 | 171 | def load_networks(self, epoch): 172 | """Load all the networks from the disk. 173 | 174 | Parameters: 175 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 176 | """ 177 | for name in self.model_names: 178 | if isinstance(name, str): 179 | load_filename = '%s_net_%s.pth' % (epoch, name) 180 | load_path = os.path.join(self.save_dir, load_filename) 181 | net = getattr(self, 'net' + name) 182 | if isinstance(net, torch.nn.DataParallel): 183 | net = net.module 184 | print('loading the model from %s' % load_path) 185 | # if you are using PyTorch newer than 0.4 (e.g., built from 186 | # GitHub source), you can remove str() on self.device 187 | state_dict = torch.load(load_path, map_location=str(self.device)) 188 | if hasattr(state_dict, '_metadata'): 189 | del state_dict._metadata 190 | 191 | # patch InstanceNorm checkpoints prior to 0.4 192 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop 193 | self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) 194 | net.load_state_dict(state_dict) 195 | 196 | def print_networks(self, verbose): 197 | """Print the total number of parameters in the network and (if verbose) network architecture 198 | 199 | Parameters: 200 | verbose (bool) -- if verbose: print the network architecture 201 | """ 202 | print('---------- Networks initialized -------------') 203 | for name in self.model_names: 204 | if isinstance(name, str): 205 | net = getattr(self, 'net' + name) 206 | num_params = 0 207 | for param in net.parameters(): 208 | num_params += param.numel() 209 | if verbose: 210 | print(net) 211 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) 212 | print('-----------------------------------------------') 213 | 214 | def set_requires_grad(self, nets, requires_grad=False): 215 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 216 | Parameters: 217 | nets (network list) -- a list of networks 218 | requires_grad (bool) -- whether the networks require gradients or not 219 | """ 220 | if not isinstance(nets, list): 221 | nets = [nets] 222 | for net in nets: 223 | if net is not None: 224 | for param in net.parameters(): 225 | param.requires_grad = requires_grad 226 | -------------------------------------------------------------------------------- /models/cycle_gan_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import itertools 3 | from util.image_pool import ImagePool 4 | from .base_model import BaseModel 5 | from . import networks 6 | 7 | 8 | class CycleGANModel(BaseModel): 9 | """ 10 | This class implements the CycleGAN model, for learning image-to-image translation without paired data. 11 | 12 | The model training requires '--dataset_mode unaligned' dataset. 13 | By default, it uses a '--netG resnet_9blocks' ResNet generator, 14 | a '--netD basic' discriminator (PatchGAN introduced by pix2pix), 15 | and a least-square GANs objective ('--gan_mode lsgan'). 16 | 17 | CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf 18 | """ 19 | @staticmethod 20 | def modify_commandline_options(parser, is_train=True): 21 | """Add new dataset-specific options, and rewrite default values for existing options. 22 | 23 | Parameters: 24 | parser -- original option parser 25 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 26 | 27 | Returns: 28 | the modified parser. 29 | 30 | For CycleGAN, in addition to GAN losses, we introduce lambda_A, lambda_B, and lambda_identity for the following losses. 31 | A (source domain), B (target domain). 32 | Generators: G_A: A -> B; G_B: B -> A. 33 | Discriminators: D_A: G_A(A) vs. B; D_B: G_B(B) vs. A. 34 | Forward cycle loss: lambda_A * ||G_B(G_A(A)) - A|| (Eqn. (2) in the paper) 35 | Backward cycle loss: lambda_B * ||G_A(G_B(B)) - B|| (Eqn. (2) in the paper) 36 | Identity loss (optional): lambda_identity * (||G_A(B) - B|| * lambda_B + ||G_B(A) - A|| * lambda_A) (Sec 5.2 "Photo generation from paintings" in the paper) 37 | Dropout is not used in the original CycleGAN paper. 38 | """ 39 | parser.set_defaults(no_dropout=True) # default CycleGAN did not use dropout 40 | if is_train: 41 | parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)') 42 | parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)') 43 | parser.add_argument('--lambda_identity', type=float, default=0.5, help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1') 44 | 45 | return parser 46 | 47 | def __init__(self, opt): 48 | """Initialize the CycleGAN class. 49 | 50 | Parameters: 51 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 52 | """ 53 | BaseModel.__init__(self, opt) 54 | # specify the training losses you want to print out. The training/test scripts will call 55 | self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B'] 56 | # specify the images you want to save/display. The training/test scripts will call 57 | visual_names_A = ['real_A', 'fake_B', 'rec_A'] 58 | visual_names_B = ['real_B', 'fake_A', 'rec_B'] 59 | if self.isTrain and self.opt.lambda_identity > 0.0: # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B) 60 | visual_names_A.append('idt_B') 61 | visual_names_B.append('idt_A') 62 | 63 | self.visual_names = visual_names_A + visual_names_B # combine visualizations for A and B 64 | # specify the models you want to save to the disk. The training/test scripts will call and . 65 | if self.isTrain: 66 | self.model_names = ['G_A', 'G_B', 'D_A', 'D_B'] 67 | else: # during test time, only load Gs 68 | self.model_names = ['G_A', 'G_B'] 69 | 70 | # define networks (both Generators and discriminators) 71 | # The naming is different from those used in the paper. 72 | # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) 73 | self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, 74 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt=opt) 75 | self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm, 76 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt=opt) 77 | 78 | if self.isTrain: # define discriminators 79 | self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD, 80 | opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) 81 | self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD, 82 | opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) 83 | 84 | if self.isTrain: 85 | if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels 86 | assert(opt.input_nc == opt.output_nc) 87 | self.fake_A_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images 88 | self.fake_B_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images 89 | # define loss functions 90 | self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) # define GAN loss. 91 | self.criterionCycle = torch.nn.L1Loss() 92 | self.criterionIdt = torch.nn.L1Loss() 93 | # initialize optimizers; schedulers will be automatically created by function . 94 | self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) 95 | self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) 96 | self.optimizers.append(self.optimizer_G) 97 | self.optimizers.append(self.optimizer_D) 98 | 99 | def set_input(self, input): 100 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 101 | 102 | Parameters: 103 | input (dict): include the data itself and its metadata information. 104 | 105 | The option 'direction' can be used to swap domain A and domain B. 106 | """ 107 | AtoB = self.opt.direction == 'AtoB' 108 | self.real_A = input['A' if AtoB else 'B'].to(self.device) 109 | self.real_B = input['B' if AtoB else 'A'].to(self.device) 110 | self.image_paths = input['A_paths' if AtoB else 'B_paths'] 111 | 112 | def forward(self): 113 | """Run forward pass; called by both functions and .""" 114 | self.fake_B = self.netG_A(self.real_A) # G_A(A) 115 | self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A)) 116 | self.fake_A = self.netG_B(self.real_B) # G_B(B) 117 | self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B)) 118 | 119 | def backward_D_basic(self, netD, real, fake): 120 | """Calculate GAN loss for the discriminator 121 | 122 | Parameters: 123 | netD (network) -- the discriminator D 124 | real (tensor array) -- real images 125 | fake (tensor array) -- images generated by a generator 126 | 127 | Return the discriminator loss. 128 | We also call loss_D.backward() to calculate the gradients. 129 | """ 130 | # Real 131 | pred_real = netD(real) 132 | loss_D_real = self.criterionGAN(pred_real, True) 133 | # Fake 134 | pred_fake = netD(fake.detach()) 135 | loss_D_fake = self.criterionGAN(pred_fake, False) 136 | # Combined loss and calculate gradients 137 | loss_D = (loss_D_real + loss_D_fake) * 0.5 138 | loss_D.backward() 139 | return loss_D 140 | 141 | def backward_D_A(self): 142 | """Calculate GAN loss for discriminator D_A""" 143 | fake_B = self.fake_B_pool.query(self.fake_B) 144 | self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) 145 | 146 | def backward_D_B(self): 147 | """Calculate GAN loss for discriminator D_B""" 148 | fake_A = self.fake_A_pool.query(self.fake_A) 149 | self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) 150 | 151 | def backward_G(self): 152 | """Calculate the loss for generators G_A and G_B""" 153 | lambda_idt = self.opt.lambda_identity 154 | lambda_A = self.opt.lambda_A 155 | lambda_B = self.opt.lambda_B 156 | # Identity loss 157 | if lambda_idt > 0: 158 | # G_A should be identity if real_B is fed: ||G_A(B) - B|| 159 | self.idt_A = self.netG_A(self.real_B) 160 | self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt 161 | # G_B should be identity if real_A is fed: ||G_B(A) - A|| 162 | self.idt_B = self.netG_B(self.real_A) 163 | self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt 164 | else: 165 | self.loss_idt_A = 0 166 | self.loss_idt_B = 0 167 | 168 | # GAN loss D_A(G_A(A)) 169 | self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True) 170 | # GAN loss D_B(G_B(B)) 171 | self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True) 172 | # Forward cycle loss || G_B(G_A(A)) - A|| 173 | self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A 174 | # Backward cycle loss || G_A(G_B(B)) - B|| 175 | self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B 176 | # combined loss and calculate gradients 177 | self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B 178 | self.loss_G.backward() 179 | 180 | def optimize_parameters(self): 181 | """Calculate losses, gradients, and update network weights; called in every training iteration""" 182 | # forward 183 | self.forward() # compute fake images and reconstruction images. 184 | # G_A and G_B 185 | self.set_requires_grad([self.netD_A, self.netD_B], False) # Ds require no gradients when optimizing Gs 186 | self.optimizer_G.zero_grad() # set G_A and G_B's gradients to zero 187 | self.backward_G() # calculate gradients for G_A and G_B 188 | self.optimizer_G.step() # update G_A and G_B's weights 189 | # D_A and D_B 190 | self.set_requires_grad([self.netD_A, self.netD_B], True) 191 | self.optimizer_D.zero_grad() # set D_A and D_B's gradients to zero 192 | self.backward_D_A() # calculate gradients for D_A 193 | self.backward_D_B() # calculate graidents for D_B 194 | self.optimizer_D.step() # update D_A and D_B's weights 195 | -------------------------------------------------------------------------------- /models/network_resnet_branched.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import init 6 | 7 | from .base_model import BaseModel 8 | 9 | class ResnetStackedArchitecture(nn.Module): 10 | 11 | def __init__(self, opt=None): 12 | super(ResnetStackedArchitecture, self).__init__() 13 | 14 | # architecture parameters 15 | self.F = 256 if not opt else opt.resnet_F 16 | self.B = 16 if not opt else opt.resnet_B 17 | self.kernel_size = 3 18 | self.padding_size= 1 19 | self.scale_res = 0.1 20 | self.dropout = False 21 | self.use_64C = True # rather removing these layers in networks_branched.py 22 | self.use_SAR = True if not opt else opt.include_S1 23 | self.use_long = False 24 | 25 | model = [nn.Conv2d(self.use_SAR*2+13, self.F, kernel_size=self.kernel_size, padding=self.padding_size, bias=True), nn.ReLU(True)] 26 | # generate a given number of blocks 27 | for i in range(self.B): 28 | model += [ResnetBlock(self.F, use_dropout=self.dropout, use_bias=True, 29 | res_scale=self.scale_res, padding_size=self.padding_size)] 30 | 31 | # adding in intermediate mapping layer from self.F to 64 channels for STGAN pre-training 32 | if self.use_64C: 33 | model += [nn.Conv2d(self.F, 64, kernel_size=self.kernel_size, padding=self.padding_size, bias=True)] 34 | model += [nn.ReLU(True)] 35 | if self.dropout: model += [nn.Dropout(0.2)] 36 | 37 | 38 | if self.use_64C: 39 | model += [nn.Conv2d(64, 13, kernel_size=self.kernel_size, padding=self.padding_size, bias=True)] 40 | else: 41 | model += [nn.Conv2d(self.F, 13, kernel_size=self.kernel_size, padding=self.padding_size, bias=True)] 42 | 43 | self.model = nn.Sequential(*model) 44 | 45 | def forward(self, input): 46 | # long-skip connection: add cloudy MS input (excluding the trailing two SAR channels) and model output 47 | return self.model(input) # + self.use_long*input[:, :(-2*self.use_SAR), ...] 48 | 49 | 50 | # Define a resnet block 51 | class ResnetBlock(nn.Module): 52 | def __init__(self, dim, use_dropout, use_bias, res_scale=0.1, padding_size=1): 53 | super(ResnetBlock, self).__init__() 54 | self.res_scale = res_scale 55 | self.padding_size = padding_size 56 | self.conv_block = self.build_conv_block(dim, use_dropout, use_bias) 57 | 58 | # conv_block: 59 | # CONV (pad, conv, norm), 60 | # RELU (relu, dropout), 61 | # CONV (pad, conv, norm) 62 | def build_conv_block(self, dim, use_dropout, use_bias): 63 | conv_block = [] 64 | 65 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=self.padding_size, bias=use_bias)] 66 | conv_block += [nn.ReLU(True)] 67 | 68 | if use_dropout: 69 | conv_block += [nn.Dropout(0.2)] 70 | 71 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=self.padding_size, bias=use_bias)] 72 | 73 | return nn.Sequential(*conv_block) 74 | 75 | def forward(self, x): 76 | # add residual mapping 77 | out = x + self.res_scale * self.conv_block(x) 78 | return out 79 | -------------------------------------------------------------------------------- /models/pix2pix_ir_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_model import BaseModel 3 | from . import networks 4 | 5 | 6 | class Pix2PixIrModel(BaseModel): 7 | """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data. 8 | 9 | The model training requires '--dataset_mode aligned' dataset. 10 | By default, it uses a '--netG unet256' U-Net generator, 11 | a '--netD basic' discriminator (PatchGAN), 12 | and a '--gan_mode' vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper). 13 | 14 | pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf 15 | """ 16 | @staticmethod 17 | def modify_commandline_options(parser, is_train=True): 18 | """Add new dataset-specific options, and rewrite default values for existing options. 19 | 20 | Parameters: 21 | parser -- original option parser 22 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 23 | 24 | Returns: 25 | the modified parser. 26 | 27 | For pix2pix, we do not use image buffer 28 | The training objective is: GAN Loss + lambda_L1 * ||G(A)-B||_1 29 | By default, we use vanilla GAN loss, UNet with batchnorm, and aligned datasets. 30 | """ 31 | # changing the default values to match the pix2pix paper (https://phillipi.github.io/pix2pix/) 32 | parser.set_defaults(norm='batch', netG='unet_256', dataset_mode='aligned') 33 | if is_train: 34 | parser.set_defaults(pool_size=0, gan_mode='vanilla') 35 | parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss') 36 | 37 | return parser 38 | 39 | def __init__(self, opt): 40 | """Initialize the pix2pix class. 41 | 42 | Parameters: 43 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 44 | """ 45 | BaseModel.__init__(self, opt) 46 | # specify the training losses you want to print out. The training/test scripts will call 47 | self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake'] 48 | # specify the images you want to save/display. The training/test scripts will call 49 | self.visual_names = ['fake_B', 'real_B'] 50 | # specify the models you want to save to the disk. The training/test scripts will call and 51 | if self.isTrain: 52 | self.model_names = ['G', 'D'] 53 | else: # during test time, only load G 54 | self.model_names = ['G'] 55 | # define networks (both generator and discriminator) 56 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, 57 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt=opt) 58 | 59 | if self.isTrain: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc 60 | self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, 61 | opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) 62 | 63 | if self.isTrain: 64 | # define loss functions 65 | self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) 66 | self.criterionL1 = torch.nn.L1Loss() 67 | # initialize optimizers; schedulers will be automatically created by function . 68 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 69 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 70 | self.optimizers.append(self.optimizer_G) 71 | self.optimizers.append(self.optimizer_D) 72 | 73 | def set_input(self, input): 74 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 75 | 76 | Parameters: 77 | input (dict): include the data itself and its metadata information. 78 | 79 | The option 'direction' can be used to swap images in domain A and domain B. 80 | """ 81 | AtoB = self.opt.direction == 'AtoB' 82 | self.real_A = input['A' if AtoB else 'B'].to(self.device) 83 | self.real_B = input['B' if AtoB else 'A'].to(self.device) 84 | self.image_paths = input['A_paths' if AtoB else 'B_paths'] 85 | 86 | def forward(self): 87 | """Run forward pass; called by both functions and .""" 88 | self.fake_B = self.netG(self.real_A) # G(A) 89 | 90 | def backward_D(self): 91 | """Calculate GAN loss for the discriminator""" 92 | # Fake; stop backprop to the generator by detaching fake_B 93 | fake_AB = torch.cat((self.real_A, self.fake_B), 1) # we use conditional GANs; we need to feed both input and output to the discriminator 94 | pred_fake = self.netD(fake_AB.detach()) 95 | self.loss_D_fake = self.criterionGAN(pred_fake, False) 96 | # Real 97 | real_AB = torch.cat((self.real_A, self.real_B), 1) 98 | pred_real = self.netD(real_AB) 99 | self.loss_D_real = self.criterionGAN(pred_real, True) 100 | # combine loss and calculate gradients 101 | self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 102 | self.loss_D.backward() 103 | 104 | def backward_G(self): 105 | """Calculate GAN and L1 loss for the generator""" 106 | # First, G(A) should fake the discriminator 107 | fake_AB = torch.cat((self.real_A, self.fake_B), 1) 108 | pred_fake = self.netD(fake_AB) 109 | self.loss_G_GAN = self.criterionGAN(pred_fake, True) 110 | # Second, G(A) = B 111 | self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1 112 | # combine loss and calculate gradients 113 | self.loss_G = self.loss_G_GAN + self.loss_G_L1 114 | self.loss_G.backward() 115 | 116 | def optimize_parameters(self): 117 | self.forward() # compute fake images: G(A) 118 | # update D 119 | self.set_requires_grad(self.netD, True) # enable backprop for D 120 | self.optimizer_D.zero_grad() # set D's gradients to zero 121 | self.backward_D() # calculate gradients for D 122 | self.optimizer_D.step() # update D's weights 123 | # update G 124 | self.set_requires_grad(self.netD, False) # D requires no gradients when optimizing G 125 | self.optimizer_G.zero_grad() # set G's gradients to zero 126 | self.backward_G() # calculate graidents for G 127 | self.optimizer_G.step() # udpate G's weights 128 | -------------------------------------------------------------------------------- /models/pix2pix_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_model import BaseModel 3 | from . import networks 4 | 5 | 6 | class Pix2PixModel(BaseModel): 7 | """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data. 8 | 9 | The model training requires '--dataset_mode aligned' dataset. 10 | By default, it uses a '--netG unet256' U-Net generator, 11 | a '--netD basic' discriminator (PatchGAN), 12 | and a '--gan_mode' vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper). 13 | 14 | pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf 15 | """ 16 | @staticmethod 17 | def modify_commandline_options(parser, is_train=True): 18 | """Add new dataset-specific options, and rewrite default values for existing options. 19 | 20 | Parameters: 21 | parser -- original option parser 22 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 23 | 24 | Returns: 25 | the modified parser. 26 | 27 | For pix2pix, we do not use image buffer 28 | The training objective is: GAN Loss + lambda_L1 * ||G(A)-B||_1 29 | By default, we use vanilla GAN loss, UNet with batchnorm, and aligned datasets. 30 | """ 31 | # changing the default values to match the pix2pix paper (https://phillipi.github.io/pix2pix/) 32 | parser.set_defaults(norm='batch', netG='unet_256', dataset_mode='aligned') 33 | if is_train: 34 | parser.set_defaults(pool_size=0, gan_mode='vanilla') 35 | parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss') 36 | 37 | return parser 38 | 39 | def __init__(self, opt): 40 | """Initialize the pix2pix class. 41 | 42 | Parameters: 43 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 44 | """ 45 | BaseModel.__init__(self, opt) 46 | # specify the training losses you want to print out. The training/test scripts will call 47 | self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake'] 48 | # specify the images you want to save/display. The training/test scripts will call 49 | self.visual_names = ['real_A', 'fake_B', 'real_B'] 50 | # specify the models you want to save to the disk. The training/test scripts will call and 51 | if self.isTrain: 52 | self.model_names = ['G', 'D'] 53 | else: # during test time, only load G 54 | self.model_names = ['G'] 55 | # define networks (both generator and discriminator) 56 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, 57 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt=opt) 58 | 59 | if self.isTrain: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc 60 | self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, 61 | opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) 62 | 63 | if self.isTrain: 64 | # define loss functions 65 | self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) 66 | self.criterionL1 = torch.nn.L1Loss() 67 | # initialize optimizers; schedulers will be automatically created by function . 68 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 69 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 70 | self.optimizers.append(self.optimizer_G) 71 | self.optimizers.append(self.optimizer_D) 72 | 73 | def set_input(self, input): 74 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 75 | 76 | Parameters: 77 | input (dict): include the data itself and its metadata information. 78 | 79 | The option 'direction' can be used to swap images in domain A and domain B. 80 | """ 81 | AtoB = self.opt.direction == 'AtoB' 82 | self.real_A = input['A' if AtoB else 'B'].to(self.device) 83 | self.real_B = input['B' if AtoB else 'A'].to(self.device) 84 | self.image_paths = input['A_paths' if AtoB else 'B_paths'] 85 | 86 | def forward(self): 87 | """Run forward pass; called by both functions and .""" 88 | self.fake_B = self.netG(self.real_A) # G(A) 89 | 90 | def backward_D(self): 91 | """Calculate GAN loss for the discriminator""" 92 | # Fake; stop backprop to the generator by detaching fake_B 93 | fake_AB = torch.cat((self.real_A, self.fake_B), 1) # we use conditional GANs; we need to feed both input and output to the discriminator 94 | pred_fake = self.netD(fake_AB.detach()) 95 | self.loss_D_fake = self.criterionGAN(pred_fake, False) 96 | # Real 97 | real_AB = torch.cat((self.real_A, self.real_B), 1) 98 | pred_real = self.netD(real_AB) 99 | self.loss_D_real = self.criterionGAN(pred_real, True) 100 | # combine loss and calculate gradients 101 | self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 102 | self.loss_D.backward() 103 | 104 | def backward_G(self): 105 | """Calculate GAN and L1 loss for the generator""" 106 | # First, G(A) should fake the discriminator 107 | fake_AB = torch.cat((self.real_A, self.fake_B), 1) 108 | pred_fake = self.netD(fake_AB) 109 | self.loss_G_GAN = self.criterionGAN(pred_fake, True) 110 | # Second, G(A) = B 111 | self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1 112 | # combine loss and calculate gradients 113 | self.loss_G = self.loss_G_GAN + self.loss_G_L1 114 | self.loss_G.backward() 115 | 116 | def optimize_parameters(self): 117 | self.forward() # compute fake images: G(A) 118 | # update D 119 | self.set_requires_grad(self.netD, True) # enable backprop for D 120 | self.optimizer_D.zero_grad() # set D's gradients to zero 121 | self.backward_D() # calculate gradients for D 122 | self.optimizer_D.step() # update D's weights 123 | # update G 124 | self.set_requires_grad(self.netD, False) # D requires no gradients when optimizing G 125 | self.optimizer_G.zero_grad() # set G's gradients to zero 126 | self.backward_G() # calculate graidents for G 127 | self.optimizer_G.step() # udpate G's weights 128 | -------------------------------------------------------------------------------- /models/template_model.py: -------------------------------------------------------------------------------- 1 | """Model class template 2 | 3 | This module provides a template for users to implement custom models. 4 | You can specify '--model template' to use this model. 5 | The class name should be consistent with both the filename and its model option. 6 | The filename should be _dataset.py 7 | The class name should be Dataset.py 8 | It implements a simple image-to-image translation baseline based on regression loss. 9 | Given input-output pairs (data_A, data_B), it learns a network netG that can minimize the following L1 loss: 10 | min_ ||netG(data_A) - data_B||_1 11 | You need to implement the following functions: 12 | : Add model-specific options and rewrite default values for existing options. 13 | <__init__>: Initialize this model class. 14 | : Unpack input data and perform data pre-processing. 15 | : Run forward pass. This will be called by both and . 16 | : Update network weights; it will be called in every training iteration. 17 | """ 18 | import torch 19 | from .base_model import BaseModel 20 | from . import networks 21 | 22 | 23 | class TemplateModel(BaseModel): 24 | @staticmethod 25 | def modify_commandline_options(parser, is_train=True): 26 | """Add new model-specific options and rewrite default values for existing options. 27 | 28 | Parameters: 29 | parser -- the option parser 30 | is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options. 31 | 32 | Returns: 33 | the modified parser. 34 | """ 35 | parser.set_defaults(dataset_mode='aligned') # You can rewrite default values for this model. For example, this model usually uses aligned dataset as its dataset. 36 | if is_train: 37 | parser.add_argument('--lambda_regression', type=float, default=1.0, help='weight for the regression loss') # You can define new arguments for this model. 38 | 39 | return parser 40 | 41 | def __init__(self, opt): 42 | """Initialize this model class. 43 | 44 | Parameters: 45 | opt -- training/test options 46 | 47 | A few things can be done here. 48 | - (required) call the initialization function of BaseModel 49 | - define loss function, visualization images, model names, and optimizers 50 | """ 51 | BaseModel.__init__(self, opt) # call the initialization method of BaseModel 52 | # specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk. 53 | self.loss_names = ['loss_G'] 54 | # specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images. 55 | self.visual_names = ['data_A', 'data_B', 'output'] 56 | # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks to save and load networks. 57 | # you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them. 58 | self.model_names = ['G'] 59 | # define networks; you can use opt.isTrain to specify different behaviors for training and test. 60 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids, opt=opt) 61 | if self.isTrain: # only defined during training time 62 | # define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss. 63 | # We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device) 64 | self.criterionLoss = torch.nn.L1Loss() 65 | # define and initialize optimizers. You can define one optimizer for each network. 66 | # If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. 67 | self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 68 | self.optimizers = [self.optimizer] 69 | 70 | # Our program will automatically call to define schedulers, load networks, and print networks 71 | 72 | def set_input(self, input): 73 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 74 | 75 | Parameters: 76 | input: a dictionary that contains the data itself and its metadata information. 77 | """ 78 | AtoB = self.opt.direction == 'AtoB' # use to swap data_A and data_B 79 | self.data_A = input['A' if AtoB else 'B'].to(self.device) # get image data A 80 | self.data_B = input['B' if AtoB else 'A'].to(self.device) # get image data B 81 | self.image_paths = input['A_paths' if AtoB else 'B_paths'] # get image paths 82 | 83 | def forward(self): 84 | """Run forward pass. This will be called by both functions and .""" 85 | self.output = self.netG(self.data_A) # generate output image given the input data_A 86 | 87 | def backward(self): 88 | """Calculate losses, gradients, and update network weights; called in every training iteration""" 89 | # caculate the intermediate results if necessary; here self.output has been computed during function 90 | # calculate loss given the input and intermediate results 91 | self.loss_G = self.criterionLoss(self.output, self.data_B) * self.opt.lambda_regression 92 | self.loss_G.backward() # calculate gradients of network G w.r.t. loss_G 93 | 94 | def optimize_parameters(self): 95 | """Update network weights; it will be called in every training iteration.""" 96 | self.forward() # first call forward to calculate intermediate results 97 | self.optimizer.zero_grad() # clear network G's existing gradients 98 | self.backward() # calculate gradients for network G 99 | self.optimizer.step() # update gradients for network G 100 | -------------------------------------------------------------------------------- /models/temporal_branched_ir_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_model import BaseModel 3 | from . import networks_branched as networks 4 | 5 | class TemporalBranchedIrModel(BaseModel): 6 | """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data. 7 | 8 | The model training requires '--dataset_mode aligned' dataset. 9 | By default, it uses a '--netG unet256' U-Net generator, 10 | a '--netD basic' discriminator (PatchGAN), 11 | and a '--gan_mode' vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper). 12 | 13 | pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf 14 | """ 15 | @staticmethod 16 | def modify_commandline_options(parser, is_train=True): 17 | """Add new dataset-specific options, and rewrite default values for existing options. 18 | 19 | Parameters: 20 | parser -- original option parser 21 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 22 | 23 | Returns: 24 | the modified parser. 25 | 26 | For pix2pix, we do not use image buffer 27 | The training objective is: GAN Loss + lambda_L1 * ||G(A)-B||_1 28 | By default, we use vanilla GAN loss, UNet with batchnorm, and aligned datasets. 29 | """ 30 | # changing the default values to match the pix2pix paper (https://phillipi.github.io/pix2pix/) 31 | parser.set_defaults(norm='batch', netG='unet_256', dataset_mode='aligned') 32 | parser.add_argument('--scramble', type=bool, default=False, help='scramble order of input images?') 33 | if is_train: 34 | parser.set_defaults(pool_size=0, gan_mode='vanilla') 35 | parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss') 36 | 37 | return parser 38 | 39 | def __init__(self, opt): 40 | """Initialize the pix2pix class. 41 | 42 | Parameters: 43 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 44 | """ 45 | BaseModel.__init__(self, opt) 46 | # specify the training losses you want to print out. The training/test scripts will call 47 | self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake'] 48 | # specify the images you want to save/display. The training/test scripts will call 49 | self.visual_names = ['real_A_0', 'real_A_1', 'real_A_2', 'fake_B', 'real_B'] 50 | # specify the models you want to save to the disk. The training/test scripts will call and 51 | if self.isTrain: 52 | self.model_names = ['G', 'D'] 53 | else: # during test time, only load G 54 | self.model_names = ['G'] 55 | # define networks (both generator and discriminator) 56 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, 57 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt=opt) 58 | 59 | if self.isTrain: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc 60 | self.netD = networks.define_D(3*opt.input_nc + opt.output_nc, opt.ndf, opt.netD, 61 | opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) 62 | 63 | if self.isTrain: 64 | # define loss functions 65 | self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) 66 | self.criterionL1 = torch.nn.L1Loss() 67 | # initialize optimizers; schedulers will be automatically created by function . 68 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 69 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 70 | self.optimizers.append(self.optimizer_G) 71 | self.optimizers.append(self.optimizer_D) 72 | 73 | def set_input(self, input): 74 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 75 | 76 | Parameters: 77 | input (dict): include the data itself and its metadata information. 78 | 79 | The option 'direction' can be used to swap images in domain A and domain B. 80 | """ 81 | self.real_A_0 = input['A_0'].to(self.device) 82 | self.real_A_1 = input['A_1'].to(self.device) 83 | self.real_A_2 = input['A_2'].to(self.device) 84 | self.real_A_0_ir = input['A_0_ir'].to(self.device) 85 | self.real_A_1_ir = input['A_1_ir'].to(self.device) 86 | self.real_A_2_ir = input['A_2_ir'].to(self.device) 87 | self.real_A_0_combined = torch.cat((self.real_A_0, self.real_A_0_ir), 1).to(self.device) 88 | self.real_A_1_combined = torch.cat((self.real_A_1, self.real_A_1_ir), 1).to(self.device) 89 | self.real_A_2_combined = torch.cat((self.real_A_2, self.real_A_2_ir), 1).to(self.device) 90 | self.real_A = torch.cat((self.real_A_0_combined, self.real_A_1_combined, self.real_A_2_combined), 1).to(self.device) 91 | self.real_A_input = [self.real_A_0_combined, self.real_A_1_combined, self.real_A_2_combined] 92 | self.real_B = input['B'].to(self.device) 93 | self.image_paths = input['A_paths'] 94 | 95 | def forward(self): 96 | """Run forward pass; called by both functions and .""" 97 | self.fake_B = self.netG(self.real_A_input) # G(A) 98 | # self.fake_B = self.netG(self.real_A_0) 99 | 100 | def backward_D(self): 101 | """Calculate GAN loss for the discriminator""" 102 | # Fake; stop backprop to the generator by detaching fake_B 103 | fake_AB = torch.cat((self.real_A, self.fake_B), 1) # we use conditional GANs; we need to feed both input and output to the discriminator 104 | pred_fake = self.netD(fake_AB.detach()) 105 | self.loss_D_fake = self.criterionGAN(pred_fake, False) 106 | # Real 107 | real_AB = torch.cat((self.real_A, self.real_B), 1) 108 | pred_real = self.netD(real_AB) 109 | self.loss_D_real = self.criterionGAN(pred_real, True) 110 | # combine loss and calculate gradients 111 | self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 112 | self.loss_D.backward() 113 | 114 | def backward_G(self): 115 | """Calculate GAN and L1 loss for the generator""" 116 | # First, G(A) should fake the discriminator 117 | fake_AB = torch.cat((self.real_A, self.fake_B), 1) 118 | pred_fake = self.netD(fake_AB) 119 | self.loss_G_GAN = self.criterionGAN(pred_fake, True) 120 | # Second, G(A) = B 121 | self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1 122 | # combine loss and calculate gradients 123 | self.loss_G = self.loss_G_GAN + self.loss_G_L1 124 | self.loss_G.backward() 125 | 126 | def optimize_parameters(self): 127 | self.forward() # compute fake images: G(A) 128 | # update D 129 | self.set_requires_grad(self.netD, True) # enable backprop for D 130 | self.optimizer_D.zero_grad() # set D's gradients to zero 131 | self.backward_D() # calculate gradients for D 132 | self.optimizer_D.step() # update D's weights 133 | # update G 134 | self.set_requires_grad(self.netD, False) # D requires no gradients when optimizing G 135 | self.optimizer_G.zero_grad() # set G's gradients to zero 136 | self.backward_G() # calculate graidents for G 137 | self.optimizer_G.step() # udpate G's weights 138 | -------------------------------------------------------------------------------- /models/temporal_branched_ir_modified_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_model import BaseModel 3 | from . import networks_branched as networks 4 | import numpy as np 5 | from util import util 6 | import warnings 7 | 8 | 9 | #class TemporalBranchedModel(BaseModel): 10 | class TemporalBranchedIRModifiedModel(BaseModel): 11 | """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data. 12 | 13 | The model training requires '--dataset_mode aligned' dataset. 14 | By default, it uses a '--netG unet256' U-Net generator, 15 | a '--netD basic' discriminator (PatchGAN), 16 | and a '--gan_mode' vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper). 17 | 18 | pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf 19 | """ 20 | 21 | @staticmethod 22 | def modify_commandline_options(parser, is_train=True): 23 | """Add new dataset-specific options, and rewrite default values for existing options. 24 | 25 | Parameters: 26 | parser -- original option parser 27 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 28 | 29 | Returns: 30 | the modified parser. 31 | 32 | For pix2pix, we do not use image buffer 33 | The training objective is: GAN Loss + lambda_L1 * ||G(A)-B||_1 34 | By default, we use vanilla GAN loss, UNet with batchnorm, and aligned datasets. 35 | """ 36 | # changing the default values to match the pix2pix paper (https://phillipi.github.io/pix2pix/) 37 | parser.set_defaults(norm='batch', netG='unet_256', dataset_mode='aligned') 38 | parser.add_argument('--scramble', type=bool, default=False, help='scramble order of input images?') 39 | if is_train: 40 | parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss') 41 | parser.add_argument('--lambda_GAN', type=float, default=1.0, 42 | help='weight of GAN loss for generator and discriminator') 43 | return parser 44 | 45 | def __init__(self, opt): 46 | """Initialize the pix2pix class. 47 | 48 | Parameters: 49 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 50 | """ 51 | BaseModel.__init__(self, opt) 52 | 53 | # specify the images you want to save/display. The training/test scripts will call 54 | self.visual_names = [] 55 | if opt.include_S1: 56 | for i in range(opt.n_input_samples): 57 | self.visual_names.append(f'real_A_{i}') 58 | self.visual_names.append(f'A_{i}_S1') 59 | self.visual_names.append(f'A_{i}_mask') 60 | self.visual_names += ['real_B_S1'] 61 | else: 62 | for i in range(opt.n_input_samples): 63 | self.visual_names.append(f'real_A_{i}') 64 | self.visual_names.append(f'A_{i}_mask') 65 | self.visual_names = self.visual_names + ['real_B', 'B_mask', 'fake_B'] 66 | 67 | # specify the models you want to save to the disk. The training/test scripts will call and 68 | if self.isTrain and opt.lambda_GAN != 0: 69 | self.model_names = ['G', 'D'] 70 | else: # during test time, only load G 71 | self.model_names = ['G'] 72 | # define networks (both generator and discriminator) 73 | # if opt.alter_initial_model: 74 | # assert opt.include_S1, Exception('Altering initial model must include S1 data') 75 | self.netG = networks.define_G(self.device, opt.alter_initial_model, opt.unfreeze_iter, opt.initial_model_path, 76 | opt.n_input_samples, opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, 77 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt=opt) 78 | 79 | if self.isTrain and opt.lambda_GAN != 0: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc 80 | # if not 0 -> include GAN loss and decriminator 81 | self.netD = networks.define_D(opt.n_input_samples * opt.input_nc + opt.output_nc, opt.ndf, opt.netD, 82 | opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) 83 | 84 | if self.isTrain: 85 | # define loss functions 86 | if opt.lambda_GAN != 0: # not 0 -> include GAN and D 87 | self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) 88 | # initialize optimizers; schedulers will be automatically created by function . 89 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 90 | self.optimizers.append(self.optimizer_D) 91 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 92 | self.optimizers.append(self.optimizer_G) 93 | 94 | if opt.use_perceptual_loss: 95 | # specify the training losses you want to print out. The training/test scripts will call 96 | if self.isTrain and opt.lambda_GAN != 0: 97 | self.loss_names = ['G_GAN', 'G_loss', 'D_real', 'D_fake', 'perceptual'] 98 | else: 99 | self.loss_names = ['G_loss', 'perceptual'] 100 | assert not opt.vgg16_path == 'none', 'Missing input of VGG16 path.' 101 | self.netL = util.LossNetwork(opt.vgg16_path, [8, 15, 22, 29], self.device) 102 | else: 103 | if self.isTrain and opt.lambda_GAN != 0: 104 | self.loss_names = ['G_GAN', 'G_loss', 'D_real', 'D_fake'] 105 | else: 106 | self.loss_names = ['G_loss'] 107 | self.total_iters = 0 108 | 109 | def set_input(self, input): 110 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 111 | 112 | Parameters: 113 | input (dict): include the data itself and its metadata information. 114 | 115 | The option 'direction' can be used to swap images in domain A and domain B. 116 | """ 117 | 118 | # dynamically process a variable number of input time points 119 | A_input, self.A_mask, self.input_is_SAR = [], [], [] 120 | for i in range(self.opt.n_input_samples): 121 | A_input.append(input['A_S2'][i].to(self.device)) 122 | setattr(self, f'real_A_{i}', A_input[i]) 123 | self.A_mask.append(input['A_mask'][i].to(self.device)) 124 | setattr(self, f'A_{i}_mask', self.A_mask[i]) 125 | 126 | if not self.opt.include_S1: 127 | self.real_A_input = A_input 128 | 129 | for i in range(self.opt.n_input_samples * A_input[0].shape[1]): 130 | self.input_is_SAR.append(False) 131 | else: 132 | self.real_A_input = [] 133 | for i in range(self.opt.n_input_samples): 134 | S1 = input['A_S1'][i].to(self.device) 135 | setattr(self, f'A_{i}_S1', S1) 136 | self.real_A_input.append(torch.cat((A_input[i], S1), 1).to(self.device)) 137 | 138 | for j in range(A_input[i].shape[1]): 139 | self.input_is_SAR.append(False) 140 | self.input_is_SAR.append([True, True]) 141 | 142 | # concatenate input patches across time (and across modalities, if including SAR) 143 | self.real_A = torch.cat(self.real_A_input, 1).to(self.device) 144 | # bookkeeping of target cloud-free patch 145 | self.real_B = input['B'].to(self.device) 146 | # bookkeeping of target mask 147 | self.B_mask = input['B_mask'].to(self.device) 148 | if self.opt.include_S1: 149 | self.real_B_S1 = input['B_S1'] 150 | 151 | self.S2_input = A_input 152 | 153 | self.image_paths = input['image_path'] 154 | 155 | def forward(self): 156 | """Run forward pass; called by both functions and .""" 157 | self.fake_B = self.netG(self.real_A_input) # G(A) 158 | # self.fake_B = self.netG(self.real_A_0) 159 | 160 | def backward_D(self): 161 | """Calculate GAN loss for the discriminator""" 162 | # Fake; stop backprop to the generator by detaching fake_B 163 | fake_AB = torch.cat((self.real_A, self.fake_B), 164 | 1) # we use conditional GANs; we need to feed both input and output to the discriminator 165 | pred_fake = self.netD(fake_AB.detach()) 166 | self.loss_D_fake = self.criterionGAN(pred_fake, False) 167 | # Real 168 | real_AB = torch.cat((self.real_A, self.real_B), 1) 169 | pred_real = self.netD(real_AB) 170 | self.loss_D_real = self.criterionGAN(pred_real, True) 171 | # combine loss and calculate gradients 172 | self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 * self.opt.lambda_GAN 173 | self.loss_D.backward() 174 | 175 | def criterionG(self, fake_B, real_B): 176 | if self.opt.G_loss == 'L1': 177 | loss = torch.nn.L1Loss() 178 | return loss(fake_B, real_B) * self.opt.lambda_L1 179 | else: 180 | raise Exception("Undefined G loss type.") 181 | 182 | def scale_to_01(self, im, method): 183 | if method == 'default': 184 | return (im + 1) / 2 185 | else: 186 | # dealing with only optical images, range 0,5 187 | return im / 5 188 | 189 | def get_perceptual_loss(self): 190 | loss = 0. 191 | if self.opt.alter_initial_model: 192 | method = 'resnet' 193 | else: 194 | method = 'default' 195 | fake = self.netL(self.scale_to_01(self.fake_B, method)) 196 | real = self.netL(self.scale_to_01(self.real_B, method)) 197 | mse = torch.nn.MSELoss() 198 | for i in range(len(fake)): 199 | loss += mse(fake[i], real[i]) 200 | return loss 201 | 202 | def backward_G(self): 203 | """Calculate GAN and L1 loss for the generator""" 204 | if self.opt.lambda_GAN != 0: 205 | # First, G(A) should fake the discriminator 206 | fake_AB = torch.cat((self.real_A, self.fake_B), 1) 207 | pred_fake = self.netD(fake_AB) 208 | self.loss_G_GAN = self.criterionGAN(pred_fake, True) 209 | # Second, G(A) = B 210 | self.loss_G_loss = self.criterionG(self.fake_B, self.real_B) 211 | # combine loss and calculate gradients 212 | if self.opt.use_perceptual_loss: 213 | self.loss_perceptual = self.get_perceptual_loss() * self.opt.lambda_percep 214 | if self.opt.lambda_GAN != 0: 215 | self.loss_G = self.opt.lambda_GAN * self.loss_G_GAN + self.loss_G_loss + self.loss_perceptual 216 | else: 217 | self.loss_G = self.loss_G_loss + self.loss_perceptual 218 | else: 219 | if self.opt.lambda_GAN != 0: 220 | self.loss_G = self.opt.lambda_GAN * self.loss_G_GAN + self.loss_G_loss 221 | else: 222 | self.loss_G = self.loss_G_loss 223 | self.loss_G.backward() 224 | 225 | def valid_grad(self, net): 226 | valid_gradients = True 227 | for name, param in net.named_parameters(): 228 | if param.grad is not None: 229 | valid_gradients = not (torch.isnan(param.grad).any() or torch.isinf(param.grad).any()) 230 | if not valid_gradients: break 231 | if not valid_gradients: 232 | warnings.warn(f'detected inf or nan values in gradients. not updating model parameters') 233 | return valid_gradients 234 | 235 | def optimize_parameters(self): 236 | self.forward() # compute fake images: G(A) 237 | if self.opt.lambda_GAN != 0: 238 | # update D 239 | self.set_requires_grad(self.netD, True) # enable backprop for D 240 | self.optimizer_D.zero_grad() # set D's gradients to zero 241 | self.backward_D() # calculate gradients for D 242 | if self.valid_grad(self.netD): 243 | self.optimizer_D.step() # update D's weights 244 | else: 245 | self.optimizer_D.zero_grad() # do not update D 246 | self.set_requires_grad(self.netD, False) # D requires no gradients when optimizing G 247 | # update G 248 | if self.opt.alter_initial_model: 249 | if self.opt.gpu_ids == []: 250 | if self.total_iters >= self.opt.unfreeze_iter and self.netG.initial_freezed: 251 | self.set_requires_grad(self.netG.model_initial, True) 252 | self.netG.model_initial.train(True) 253 | self.netG.initial_freezed = False 254 | else: 255 | if self.total_iters >= self.opt.unfreeze_iter and self.netG.module.initial_freezed: 256 | self.set_requires_grad(self.netG.module.model_initial, True) 257 | self.netG.module.model_initial.train(True) 258 | self.netG.module.initial_freezed = False 259 | self.optimizer_G.zero_grad() # set G's gradients to zero 260 | self.backward_G() # calculate graidents for G 261 | if self.valid_grad(self.netG): 262 | self.optimizer_G.step() # update G's weights 263 | else: 264 | self.optimizer_G.zero_grad() # do not update G 265 | -------------------------------------------------------------------------------- /models/temporal_ir_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_model import BaseModel 3 | from . import networks 4 | import random 5 | 6 | class TemporalModel(BaseModel): 7 | """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data. 8 | 9 | The model training requires '--dataset_mode aligned' dataset. 10 | By default, it uses a '--netG unet256' U-Net generator, 11 | a '--netD basic' discriminator (PatchGAN), 12 | and a '--gan_mode' vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper). 13 | 14 | pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf 15 | """ 16 | @staticmethod 17 | def modify_commandline_options(parser, is_train=True): 18 | """Add new dataset-specific options, and rewrite default values for existing options. 19 | 20 | Parameters: 21 | parser -- original option parser 22 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 23 | 24 | Returns: 25 | the modified parser. 26 | 27 | For pix2pix, we do not use image buffer 28 | The training objective is: GAN Loss + lambda_L1 * ||G(A)-B||_1 29 | By default, we use vanilla GAN loss, UNet with batchnorm, and aligned datasets. 30 | """ 31 | # changing the default values to match the pix2pix paper (https://phillipi.github.io/pix2pix/) 32 | parser.set_defaults(norm='batch', netG='unet_256', dataset_mode='aligned') 33 | parser.add_argument('--scramble', type=bool, default=False, help='scramble order of input images?') 34 | if is_train: 35 | parser.set_defaults(pool_size=0, gan_mode='vanilla') 36 | parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss') 37 | 38 | return parser 39 | 40 | def __init__(self, opt): 41 | """Initialize the pix2pix class. 42 | 43 | Parameters: 44 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 45 | """ 46 | BaseModel.__init__(self, opt) 47 | # specify the training losses you want to print out. The training/test scripts will call 48 | self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake'] 49 | # specify the images you want to save/display. The training/test scripts will call 50 | self.visual_names = ['real_A_0', 'real_A_1', 'real_A_2', 'fake_B', 'real_B'] 51 | # specify the models you want to save to the disk. The training/test scripts will call and 52 | if self.isTrain: 53 | self.model_names = ['G', 'D'] 54 | else: # during test time, only load G 55 | self.model_names = ['G'] 56 | # define networks (both generator and discriminator) 57 | self.netG = networks.define_G(3*opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, 58 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt=opt) 59 | 60 | if self.isTrain: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc 61 | self.netD = networks.define_D(3*opt.input_nc + opt.output_nc, opt.ndf, opt.netD, 62 | opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) 63 | 64 | if self.isTrain: 65 | # define loss functions 66 | self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) 67 | self.criterionL1 = torch.nn.L1Loss() 68 | # initialize optimizers; schedulers will be automatically created by function . 69 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 70 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 71 | self.optimizers.append(self.optimizer_G) 72 | self.optimizers.append(self.optimizer_D) 73 | 74 | def set_input(self, input): 75 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 76 | 77 | Parameters: 78 | input (dict): include the data itself and its metadata information. 79 | 80 | The option 'direction' can be used to swap images in domain A and domain B. 81 | """ 82 | self.real_A_0 = input['A_0'].to(self.device) 83 | self.real_A_1 = input['A_1'].to(self.device) 84 | self.real_A_2 = input['A_2'].to(self.device) 85 | options = [self.real_A_0.clone(), self.real_A_1.clone(), self.real_A_2.clone()] 86 | scramble = random.sample(range(0,3), 3) 87 | try: 88 | if self.opt.scramble == True: 89 | self.real_A_0 = options[scramble[0]] 90 | self.real_A_1 = options[scramble[1]] 91 | self.real_A_2 = options[scramble[2]] 92 | except: 93 | pass 94 | self.real_A = torch.cat((self.real_A_0, self.real_A_1, self.real_A_2), 1).to(self.device) 95 | self.real_B = input['B'].to(self.device) 96 | self.image_paths = input['A_paths'] 97 | 98 | def forward(self): 99 | """Run forward pass; called by both functions and .""" 100 | self.fake_B = self.netG(self.real_A) # G(A) 101 | 102 | def backward_D(self): 103 | """Calculate GAN loss for the discriminator""" 104 | # Fake; stop backprop to the generator by detaching fake_B 105 | fake_AB = torch.cat((self.real_A, self.fake_B), 1) # we use conditional GANs; we need to feed both input and output to the discriminator 106 | pred_fake = self.netD(fake_AB.detach()) 107 | self.loss_D_fake = self.criterionGAN(pred_fake, False) 108 | # Real 109 | real_AB = torch.cat((self.real_A, self.real_B), 1) 110 | pred_real = self.netD(real_AB) 111 | self.loss_D_real = self.criterionGAN(pred_real, True) 112 | # combine loss and calculate gradients 113 | self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 114 | self.loss_D.backward() 115 | 116 | def backward_G(self): 117 | """Calculate GAN and L1 loss for the generator""" 118 | # First, G(A) should fake the discriminator 119 | fake_AB = torch.cat((self.real_A, self.fake_B), 1) 120 | pred_fake = self.netD(fake_AB) 121 | self.loss_G_GAN = self.criterionGAN(pred_fake, True) 122 | # Second, G(A) = B 123 | self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1 124 | # combine loss and calculate gradients 125 | self.loss_G = self.loss_G_GAN + self.loss_G_L1 126 | self.loss_G.backward() 127 | 128 | def optimize_parameters(self): 129 | self.forward() # compute fake images: G(A) 130 | # update D 131 | self.set_requires_grad(self.netD, True) # enable backprop for D 132 | self.optimizer_D.zero_grad() # set D's gradients to zero 133 | self.backward_D() # calculate gradients for D 134 | self.optimizer_D.step() # update D's weights 135 | # update G 136 | self.set_requires_grad(self.netD, False) # D requires no gradients when optimizing G 137 | self.optimizer_G.zero_grad() # set G's gradients to zero 138 | self.backward_G() # calculate graidents for G 139 | self.optimizer_G.step() # udpate G's weights 140 | -------------------------------------------------------------------------------- /models/temporal_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_model import BaseModel 3 | from . import networks 4 | 5 | class TemporalModel(BaseModel): 6 | """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data. 7 | 8 | The model training requires '--dataset_mode aligned' dataset. 9 | By default, it uses a '--netG unet256' U-Net generator, 10 | a '--netD basic' discriminator (PatchGAN), 11 | and a '--gan_mode' vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper). 12 | 13 | pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf 14 | """ 15 | @staticmethod 16 | def modify_commandline_options(parser, is_train=True): 17 | """Add new dataset-specific options, and rewrite default values for existing options. 18 | 19 | Parameters: 20 | parser -- original option parser 21 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 22 | 23 | Returns: 24 | the modified parser. 25 | 26 | For pix2pix, we do not use image buffer 27 | The training objective is: GAN Loss + lambda_L1 * ||G(A)-B||_1 28 | By default, we use vanilla GAN loss, UNet with batchnorm, and aligned datasets. 29 | """ 30 | # changing the default values to match the pix2pix paper (https://phillipi.github.io/pix2pix/) 31 | parser.set_defaults(norm='batch', netG='unet_256', dataset_mode='aligned') 32 | parser.add_argument('--scramble', type=bool, default=False, help='scramble order of input images?') 33 | if is_train: 34 | parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss') 35 | return parser 36 | 37 | def __init__(self, opt): 38 | """Initialize the pix2pix class. 39 | 40 | Parameters: 41 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 42 | """ 43 | BaseModel.__init__(self, opt) 44 | # specify the training losses you want to print out. The training/test scripts will call 45 | self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake'] 46 | # specify the images you want to save/display. The training/test scripts will call 47 | if not opt.include_S1: 48 | self.visual_names = ['real_A_0', 'A_0_mask', 'real_A_1', 'A_1_mask', 'real_A_2', 'A_2_mask', 'fake_B', 'real_B', 'B_mask'] 49 | else: 50 | self.visual_names = ['real_A_0', 'A_0_S1', 'A_0_mask', 'real_A_1', 'A_1_S1', 'A_1_mask', 'real_A_2', 'A_2_S1', 'A_2_mask', 'fake_B', 'real_B', 'B_mask'] 51 | 52 | # specify the models you want to save to the disk. The training/test scripts will call and 53 | if self.isTrain: 54 | self.model_names = ['G', 'D'] 55 | else: # during test time, only load G 56 | self.model_names = ['G'] 57 | # define networks (both generator and discriminator) 58 | self.netG = networks.define_G(3*opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, 59 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt=opt) 60 | 61 | if self.isTrain: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc 62 | self.netD = networks.define_D(3*opt.input_nc + opt.output_nc, opt.ndf, opt.netD, 63 | opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) 64 | 65 | if self.isTrain: 66 | # define loss functions 67 | self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) 68 | self.criterionL1 = torch.nn.L1Loss() 69 | # initialize optimizers; schedulers will be automatically created by function . 70 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 71 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 72 | self.optimizers.append(self.optimizer_G) 73 | self.optimizers.append(self.optimizer_D) 74 | 75 | self.opt = opt 76 | 77 | def set_input(self, input): 78 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 79 | 80 | Parameters: 81 | input (dict): include the data itself and its metadata information. 82 | 83 | The option 'direction' can be used to swap images in domain A and domain B. 84 | """ 85 | self.real_A_0 = input['A_0'].to(self.device) 86 | self.A_0_mask = input['A_0_mask'].numpy() 87 | self.real_A_1 = input['A_1'].to(self.device) 88 | self.A_1_mask = input['A_1_mask'].numpy() 89 | self.real_A_2 = input['A_2'].to(self.device) 90 | self.A_2_mask = input['A_2_mask'].numpy() 91 | 92 | if not self.opt.include_S1: 93 | self.real_A = torch.cat((self.real_A_0, self.real_A_1, self.real_A_2), 1).to(self.device) 94 | else: 95 | self.A_0_S1 = input['A_0_S1'].to(self.device) 96 | self.A_1_S1 = input['A_1_S1'].to(self.device) 97 | self.A_2_S1 = input['A_2_S1'].to(self.device) 98 | self.A_0_combined = torch.cat((self.real_A_0, self.A_0_S1), 1).to(self.device) 99 | self.A_1_combined = torch.cat((self.real_A_1, self.A_1_S1), 1).to(self.device) 100 | self.A_2_combined = torch.cat((self.real_A_2, self.A_2_S1), 1).to(self.device) 101 | self.real_A = torch.cat((self.A_0_combined, self.A_1_combined, self.A_2_combined), 1).to(self.device) 102 | 103 | self.real_B = input['B'].to(self.device) 104 | self.B_mask = input['B_mask'].numpy() 105 | # self.image_paths = input['A_paths'] 106 | 107 | def forward(self): 108 | """Run forward pass; called by both functions and .""" 109 | self.fake_B = self.netG(self.real_A) 110 | 111 | def backward_D(self): 112 | """Calculate GAN loss for the discriminator""" 113 | # Fake; stop backprop to the generator by detaching fake_B 114 | fake_AB = torch.cat((self.real_A, self.fake_B), 1) # we use conditional GANs; we need to feed both input and output to the discriminator 115 | pred_fake = self.netD(fake_AB.detach()) 116 | self.loss_D_fake = self.criterionGAN(pred_fake, False) 117 | # Real 118 | real_AB = torch.cat((self.real_A, self.real_B), 1) 119 | pred_real = self.netD(real_AB) 120 | self.loss_D_real = self.criterionGAN(pred_real, True) 121 | # combine loss and calculate gradients 122 | self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 123 | self.loss_D.backward() 124 | 125 | def backward_G(self): 126 | """Calculate GAN and L1 loss for the generator""" 127 | # First, G(A) should fake the discriminator 128 | fake_AB = torch.cat((self.real_A, self.fake_B), 1) 129 | pred_fake = self.netD(fake_AB) 130 | self.loss_G_GAN = self.criterionGAN(pred_fake, True) 131 | # Second, G(A) = B 132 | self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1 133 | # combine loss and calculate gradients 134 | self.loss_G = self.loss_G_GAN + self.loss_G_L1 135 | self.loss_G.backward() 136 | 137 | def optimize_parameters(self): 138 | self.forward() # compute fake images: G(A) 139 | # update D 140 | self.set_requires_grad(self.netD, True) # enable backprop for D 141 | self.optimizer_D.zero_grad() # set D's gradients to zero 142 | self.backward_D() # calculate gradients for D 143 | self.optimizer_D.step() # update D's weights 144 | # update G 145 | self.set_requires_grad(self.netD, False) # D requires no gradients when optimizing G 146 | self.optimizer_G.zero_grad() # set G's gradients to zero 147 | self.backward_G() # calculate graidents for G 148 | self.optimizer_G.step() # udpate G's weights 149 | -------------------------------------------------------------------------------- /models/test_model.py: -------------------------------------------------------------------------------- 1 | from .base_model import BaseModel 2 | from . import networks 3 | 4 | 5 | class TestModel(BaseModel): 6 | """ This TesteModel can be used to generate CycleGAN results for only one direction. 7 | This model will automatically set '--dataset_mode single', which only loads the images from one collection. 8 | 9 | See the test instruction for more details. 10 | """ 11 | @staticmethod 12 | def modify_commandline_options(parser, is_train=True): 13 | """Add new dataset-specific options, and rewrite default values for existing options. 14 | 15 | Parameters: 16 | parser -- original option parser 17 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 18 | 19 | Returns: 20 | the modified parser. 21 | 22 | The model can only be used during test time. It requires '--dataset_mode single'. 23 | You need to specify the network using the option '--model_suffix'. 24 | """ 25 | assert not is_train, 'TestModel cannot be used during training time' 26 | parser.set_defaults(dataset_mode='single') 27 | parser.add_argument('--model_suffix', type=str, default='', help='In checkpoints_dir, [epoch]_net_G[model_suffix].pth will be loaded as the generator.') 28 | 29 | return parser 30 | 31 | def __init__(self, opt): 32 | """Initialize the pix2pix class. 33 | 34 | Parameters: 35 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 36 | """ 37 | assert(not opt.isTrain) 38 | BaseModel.__init__(self, opt) 39 | # specify the training losses you want to print out. The training/test scripts will call 40 | self.loss_names = [] 41 | # specify the images you want to save/display. The training/test scripts will call 42 | self.visual_names = ['real_A', 'fake_B'] 43 | # specify the models you want to save to the disk. The training/test scripts will call and 44 | self.model_names = ['G' + opt.model_suffix] # only generator is needed. 45 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, 46 | opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt=opt) 47 | 48 | # assigns the model to self.netG_[suffix] so that it can be loaded 49 | # please see 50 | setattr(self, 'netG' + opt.model_suffix, self.netG) # store netG in self. 51 | 52 | def set_input(self, input): 53 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 54 | 55 | Parameters: 56 | input: a dictionary that contains the data itself and its metadata information. 57 | 58 | We need to use 'single_dataset' dataset mode. It only load images from one domain. 59 | """ 60 | self.real_A = input['A'].to(self.device) 61 | self.image_paths = input['A_paths'] 62 | 63 | def forward(self): 64 | """Run forward pass.""" 65 | self.fake_B = self.netG(self.real_A) # G(A) 66 | 67 | def optimize_parameters(self): 68 | """No optimization for test model.""" 69 | pass 70 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- 1 | """This package options includes option modules: training options, test options, and basic options (used in both training and test).""" 2 | -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from util import util 4 | import torch 5 | import models 6 | import data 7 | 8 | 9 | class BaseOptions(): 10 | """This class defines options used during both training and test time. 11 | 12 | It also implements several helper functions such as parsing, printing, and saving the options. 13 | It also gathers additional options defined in functions in both dataset class and model class. 14 | """ 15 | 16 | def __init__(self): 17 | """Reset the class; indicates the class hasn't been initailized""" 18 | self.initialized = False 19 | 20 | def initialize(self, parser): 21 | """Define the common options that are used in both training and test.""" 22 | # basic parameters 23 | parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)') 24 | parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models') 25 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 26 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 27 | # model parameters 28 | parser.add_argument('--cloud_masks', type=str, default='s2cloudless_mask', help='chooses which cloud algorithm to use. [cloud_cloudshadow_mask | s2cloudless_map | s2cloudless_mask]') 29 | parser.add_argument('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]') 30 | parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale') 31 | parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale') 32 | parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer') 33 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer') 34 | parser.add_argument('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator') 35 | parser.add_argument('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]') 36 | parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers') 37 | parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]') 38 | parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]') 39 | parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') 40 | parser.add_argument('--include_S1', action='store_true', help='include SAR data') 41 | parser.add_argument('--in_only_S1', action='store_true', help='whether to train on S1 inputs only. This may be used for SAR->Optical domain translation tasks.') 42 | parser.add_argument('--use_perceptual_loss', action='store_true', help='use perceptual loss in training') 43 | parser.add_argument('--vgg16_path', type=str, default='none', help='the path of pretrained VGG16 network') 44 | parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator') 45 | # dataset parameters 46 | parser.add_argument('--sample_type', type=str, default='generic', help='choose the format of input data. [cloudy_cloudfree | generic]') 47 | parser.add_argument('--input_type', type=str, default='all', help='choose the type of input. [all | test | val | train]') 48 | parser.add_argument('--region', type=str, default='all', help='choose the region of data input. [all | africa | america | asiaEast | asiaWest | europa]') 49 | parser.add_argument('--dataset_mode', type=str, default='sen12mscrts', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization | sen12mscrts]') 50 | parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA') 51 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 52 | parser.add_argument('--num_threads', default=10, type=int, help='# threads for loading data') 53 | parser.add_argument('--batch_size', type=int, default=1, help='input batch size') 54 | parser.add_argument('--load_size', type=int, default=286, help='scale images to this size') 55 | parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size') 56 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 57 | parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]') 58 | parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') 59 | parser.add_argument('--lambda_percep', type=float, default=1., help='weight of the perceptual loss in training') 60 | parser.add_argument('--layers_percep', type=str, default='original', help='layers of VGG16 to use for the perceptual loss (choose: dip, video, original, experimental)') 61 | parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML') 62 | # additional parameters 63 | parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 64 | parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]') 65 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') 66 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') 67 | parser.add_argument('--n_input_samples', default=3, type=int, help='number of input samples') 68 | parser.add_argument('--unfreeze_iter', type=int, default=100000, help='number of iterations after which the initial ResNet model gets unfrozen, if -1 then initialize randomly') 69 | # new parameters pertaining to ResNet model 70 | parser.add_argument('--alter_initial_model', action='store_true', help='change the initial model with pre-trained network') 71 | parser.add_argument('--initial_model_path', type=str, default='none', help='path to the pre-trained initial model network') 72 | parser.add_argument('--resnet_F', type=int, default=256, help='If using ResNet, specify number of feature maps F') 73 | parser.add_argument('--resnet_B', type=int, default=16, help='If using ResNet, specify number of feature maps F') 74 | parser.add_argument('--no_64C', action='store_true', help='dont use the intermediate reduction to 64 channels') 75 | parser.add_argument('--benchmark_resnet_model', action='store_true', help='whether to use the single time point ResNet model during testing') 76 | # cloud coverage 77 | parser.add_argument('--min_cov', type=float, default=0.0, help='minimum acceptable cloud coverage') 78 | parser.add_argument('--max_cov', type=float, default=1.0, help='maximum acceptable cloud coverage') 79 | # data set handling 80 | parser.add_argument('--import_data_path', type=str, default='', help='Path to file containing split for data loader import') 81 | parser.add_argument('--export_data_path', type=str, default='', help='Path to file containing split for data loader export') 82 | 83 | self.initialized = True 84 | return parser 85 | 86 | def gather_options(self): 87 | """Initialize our parser with basic options(only once). 88 | Add additional model-specific and dataset-specific options. 89 | These options are defined in the function 90 | in model and dataset classes. 91 | """ 92 | if not self.initialized: # check if it has been initialized 93 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 94 | parser = self.initialize(parser) 95 | 96 | # get the basic options 97 | opt, _ = parser.parse_known_args() 98 | 99 | # modify model-related parser options 100 | model_name = opt.model 101 | model_option_setter = models.get_option_setter(model_name) 102 | parser = model_option_setter(parser, self.isTrain) 103 | opt, _ = parser.parse_known_args() # parse again with new defaults 104 | 105 | # modify dataset-related parser options 106 | dataset_name = opt.dataset_mode 107 | dataset_option_setter = data.get_option_setter(dataset_name) 108 | parser = dataset_option_setter(parser, self.isTrain) 109 | 110 | # save and return the parser 111 | self.parser = parser 112 | return parser.parse_args() 113 | 114 | def print_options(self, opt): 115 | """Print and save options 116 | 117 | It will print both current options and default values(if different). 118 | It will save options into a text file / [checkpoints_dir] / opt.txt 119 | """ 120 | message = '' 121 | message += '----------------- Options ---------------\n' 122 | for k, v in sorted(vars(opt).items()): 123 | comment = '' 124 | default = self.parser.get_default(k) 125 | if v != default: 126 | comment = '\t[default: %s]' % str(default) 127 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 128 | message += '----------------- End -------------------' 129 | print(message) 130 | 131 | # save to the disk 132 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 133 | util.mkdirs(expr_dir) 134 | file_name = os.path.join(expr_dir, 'opt.txt') 135 | with open(file_name, 'wt') as opt_file: 136 | opt_file.write(message) 137 | opt_file.write('\n') 138 | 139 | def parse(self): 140 | """Parse our options, create checkpoints directory suffix, and set up gpu device.""" 141 | opt = self.gather_options() 142 | opt.isTrain = self.isTrain # train or test 143 | 144 | # process opt.suffix 145 | if opt.suffix: 146 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' 147 | opt.name = opt.name + suffix 148 | 149 | self.print_options(opt) 150 | 151 | # set gpu ids 152 | str_ids = opt.gpu_ids.split(',') 153 | opt.gpu_ids = [] 154 | for str_id in str_ids: 155 | id = int(str_id) 156 | if id >= 0: 157 | opt.gpu_ids.append(id) 158 | if len(opt.gpu_ids) > 0: 159 | torch.cuda.set_device(opt.gpu_ids[0]) 160 | 161 | self.opt = opt 162 | return self.opt 163 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | """This class includes test options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) # define shared options 12 | parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') 13 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 14 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 15 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 16 | # Dropout and Batchnorm has different behavior during training and test. 17 | parser.add_argument('--eval', action='store_true', help='use eval mode during test time.') 18 | parser.add_argument('--num_test', type=int, default=50, help='how many test images to run') 19 | parser.add_argument('--include_simple_baselines', action='store_true', help='include simple baselines') 20 | # rewrite devalue values 21 | parser.set_defaults(model='test') 22 | # To avoid cropping, the load_size should be the same as crop_size 23 | parser.set_defaults(load_size=parser.get_default('crop_size')) 24 | self.isTrain = False 25 | return parser 26 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | """This class includes training options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) 12 | # visdom and HTML visualization parameters 13 | parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen') 14 | parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.') 15 | parser.add_argument('--display_id', type=int, default=-1, help='window id of the web display') 16 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') 17 | parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")') 18 | parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') 19 | parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html') 20 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 21 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 22 | # network saving and loading parameters 23 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') 24 | parser.add_argument('--save_epoch_freq', type=int, default=10, help='frequency of saving checkpoints at the end of epochs') 25 | parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') 26 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 27 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') 28 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 29 | # training parameters 30 | parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate') 31 | parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero') 32 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 33 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') 34 | parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') 35 | parser.add_argument('--G_loss', type=str, default='L1', help='generator loss, default is L1. [L1]') 36 | parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') 37 | parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]') 38 | parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') 39 | 40 | self.isTrain = True 41 | return parser 42 | -------------------------------------------------------------------------------- /preview/single_banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PatrickTUM/SEN12MS-CR-TS/2af933422f58fdfd6fc826ba3a827b288aad36fb/preview/single_banner.png -------------------------------------------------------------------------------- /standalone_dataloader.py: -------------------------------------------------------------------------------- 1 | # Python script to demonstrate utilizing the pyTorch data loader for SEN12MS-CR-TS 2 | 3 | import numpy as np 4 | import torch 5 | from data.dataLoader import SEN12MSCRTS 6 | 7 | if __name__ == '__main__': 8 | # main parameters for instantiating SEN12MS-CR-TS 9 | root = '' # path to your copy of SEN12MS-CR-TS 10 | split = 'all' # ROI to sample from, belonging to splits [all | train | val | test] 11 | input_t = 3 # number of input time points to sample 12 | region = 'all' # choose the region of data input. [all | africa | america | asiaEast | asiaWest | europa] 13 | import_path = None # path to importing the suppl. file specifying what time points to load for input and output 14 | sample_type = 'cloudy_cloudfree' # type of samples returned [cloudy_cloudfree | generic] 15 | sen12mscrts = SEN12MSCRTS(root, split=split, sample_type=sample_type, n_input_samples=input_t, region=region, import_data_path=import_path) 16 | dataloader = torch.utils.data.DataLoader(sen12mscrts, batch_size=1, shuffle=False, num_workers=10) 17 | 18 | # iterate over split and do some data accessing for demonstration 19 | for pdx, patch in enumerate(dataloader): 20 | print(f'Fetching {pdx}. batch of data.') 21 | 22 | input_s1 = patch['input']['S1'] 23 | input_s2 = patch['input']['S2'] 24 | input_c = np.sum(patch['input']['coverage'])/len(patch['input']['coverage']) 25 | output_s2 = patch['target']['S2'] 26 | dates_s1 = patch['input']['S1 TD'] 27 | dates_s2 = patch['input']['S2 TD'] 28 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """General-purpose training script for image-to-image translation. 2 | 3 | This script works for various models (with option '--model': e.g., pix2pix, cyclegan, colorization) and 4 | different datasets (with option '--dataset_mode': e.g., aligned, unaligned, single, colorization). 5 | You need to specify the dataset ('--dataroot'), experiment name ('--name'), and model ('--model'). 6 | 7 | It first creates model, dataset, and visualizer given the option. 8 | It then does standard network training. During the training, it also visualize/save the images, print/save the loss plot, and save models. 9 | The script supports continue/resume training. Use '--continue_train' to resume your previous training. 10 | 11 | Example: 12 | Train a CycleGAN model: 13 | python train.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan 14 | Train a pix2pix model: 15 | python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA 16 | 17 | See options/base_options.py and options/train_options.py for more training options. 18 | See training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md 19 | See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md 20 | """ 21 | 22 | import numpy as np 23 | import torch 24 | 25 | import time 26 | import warnings 27 | from options.train_options import TrainOptions 28 | from data import create_dataset 29 | from models import create_model 30 | from util.visualizer import Visualizer 31 | import faulthandler; faulthandler.enable() 32 | 33 | if __name__ == '__main__': 34 | opt = TrainOptions().parse() # get training options 35 | if opt.batch_size !=1: 36 | warnings.warn(f'Detected batch size {opt.batch_size}, but only supporting batch size 1! Defaulting to 1') 37 | opt.batch_size = 1 # test code only supports batch_size = 1 # TODO: change this in future versions 38 | dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options 39 | dataset_size = len(dataset) # get the number of images in the dataset. 40 | print('The number of training images = %d' % dataset_size) 41 | 42 | model = create_model(opt) # create a model given opt.model and other options 43 | model.setup(opt) # regular setup: load and print networks; create schedulers 44 | visualizer = Visualizer(opt) # create a visualizer that display/save images and plots 45 | total_iters = 0 # the total number of training iterations 46 | 47 | # adjust preprocessing of inputs and target patches 48 | if opt.alter_initial_model: method = 'resnet' 49 | else: method = 'default' 50 | 51 | for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): # outer loop for different epochs; we save the model by , + 52 | epoch_start_time = time.time() # timer for entire epoch 53 | iter_data_time = time.time() # timer for data loading per iteration 54 | epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch 55 | 56 | for i, data in enumerate(dataset): # inner loop within one epoch 57 | print(f"\nProcessing the {i+1}th observation") 58 | iter_start_time = time.time() # timer for computation per iteration 59 | if total_iters % opt.print_freq == 0: 60 | t_data = iter_start_time - iter_data_time 61 | visualizer.reset() 62 | total_iters += opt.batch_size 63 | epoch_iter += opt.batch_size 64 | model.total_iters = total_iters 65 | model.set_input(data) # unpack data from dataset and apply preprocessing 66 | model.optimize_parameters() # calculate loss functions, get gradients, update network weights 67 | 68 | 69 | if total_iters % opt.display_freq == 0: # display images on visdom and save images to a HTML file 70 | save_result = total_iters % opt.update_html_freq == 0 71 | model.compute_visuals() 72 | visualizer.display_current_results(model.get_current_visuals(), epoch, total_iters, save_result, method) 73 | 74 | if total_iters % opt.print_freq == 0: # print training losses and save logging information to the disk 75 | losses = model.get_current_losses() 76 | t_comp = (time.time() - iter_start_time) / opt.batch_size 77 | visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data) 78 | if opt.display_id > 0: 79 | visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses) 80 | 81 | if total_iters % opt.save_latest_freq == 0: # cache our latest model every iterations 82 | print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters)) 83 | save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest' 84 | model.save_networks(save_suffix) 85 | 86 | iter_data_time = time.time() 87 | if epoch % opt.save_epoch_freq == 0: # cache our model every epochs 88 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters)) 89 | model.save_networks('latest') 90 | model.save_networks(epoch) 91 | 92 | print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) 93 | model.update_learning_rate() # update learning rates at the end of every epoch. 94 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes a miscellaneous collection of useful helper functions.""" 2 | -------------------------------------------------------------------------------- /util/detect_cloudshadow.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | import scipy.signal as scisig 4 | 5 | 6 | def rescale(data, limits): 7 | return (data - limits[0]) / (limits[1] - limits[0]) 8 | 9 | 10 | def normalized_difference(channel1, channel2): 11 | subchan = channel1 - channel2 12 | sumchan = channel1 + channel2 13 | sumchan[sumchan == 0] = 0.001 # checking for 0 divisions 14 | return subchan / sumchan 15 | 16 | 17 | def get_shadow_mask(data_image): 18 | data_image = data_image / 10000. 19 | 20 | (ch, r, c) = data_image.shape 21 | shadowmask = np.zeros((r, c)).astype('float32') 22 | 23 | BB = data_image[1] 24 | BNIR = data_image[7] 25 | BSWIR1 = data_image[11] 26 | 27 | CSI = (BNIR + BSWIR1) / 2. 28 | 29 | t3 = 3/4 # cloud-score index threshold 30 | T3 = np.min(CSI) + t3 * (np.mean(CSI) - np.min(CSI)) 31 | 32 | t4 = 5 / 6 # water-body index threshold 33 | T4 = np.min(BB) + t4 * (np.mean(BB) - np.min(BB)) 34 | 35 | shadow_tf = np.logical_and(CSI < T3, BB < T4) 36 | 37 | shadowmask[shadow_tf] = -1 38 | shadowmask = scisig.medfilt2d(shadowmask, 5) 39 | 40 | return shadowmask 41 | 42 | 43 | def get_cloud_mask(data_image, cloud_threshold, binarize=False, use_moist_check=False): 44 | '''Adapted from https://github.com/samsammurphy/cloud-masking-sentinel2/blob/master/cloud-masking-sentinel2.ipynb''' 45 | 46 | data_image = data_image / 10000. 47 | (ch, r, c) = data_image.shape 48 | 49 | # Cloud until proven otherwise 50 | score = np.ones((r, c)).astype('float32') 51 | # Clouds are reasonably bright in the blue and aerosol/cirrus bands. 52 | score = np.minimum(score, rescale(data_image[1], [0.1, 0.5])) 53 | score = np.minimum(score, rescale(data_image[0], [0.1, 0.3])) 54 | score = np.minimum(score, rescale((data_image[0] + data_image[10]), [0.4, 0.9])) 55 | score = np.minimum(score, rescale((data_image[3] + data_image[2] + data_image[1]), [0.2, 0.8])) 56 | 57 | if use_moist_check: 58 | # Clouds are moist 59 | ndmi = normalized_difference(data_image[7], data_image[11]) 60 | score = np.minimum(score, rescale(ndmi, [-0.1, 0.1])) 61 | 62 | # However, clouds are not snow. 63 | ndsi = normalized_difference(data_image[2], data_image[11]) 64 | score = np.minimum(score, rescale(ndsi, [0.8, 0.6])) 65 | 66 | boxsize = 7 67 | box = np.ones((boxsize, boxsize)) / (boxsize ** 2) 68 | 69 | score = scipy.ndimage.morphology.grey_closing(score, size=(5, 5)) 70 | score = scisig.convolve2d(score, box, mode='same') 71 | 72 | score = np.clip(score, 0.00001, 1.0) 73 | 74 | if binarize: 75 | score[score >= cloud_threshold] = 1 76 | score[score < cloud_threshold] = 0 77 | 78 | return score 79 | 80 | # IN: [13 x H x W] S2 image (of arbitrary resolution H,W), scalar cloud detection threshold 81 | # OUT: cloud & shadow segmentation mask (of same resolution) 82 | # the multispectral S2 images are expected to have their default ranges and not be value-standardized yet 83 | # cloud_threshold: the higher the more conservative the masks (i.e. less pixels labeled clouds/shadows) 84 | def get_cloud_cloudshadow_mask(data_image, cloud_threshold): 85 | cloud_mask = get_cloud_mask(data_image, cloud_threshold, binarize=True) 86 | shadow_mask = get_shadow_mask(data_image) 87 | 88 | # encode clouds and shadows as segmentation masks 89 | cloud_cloudshadow_mask = np.zeros_like(cloud_mask) 90 | cloud_cloudshadow_mask[shadow_mask < 0] = -1 91 | cloud_cloudshadow_mask[cloud_mask > 0] = 1 92 | 93 | return cloud_cloudshadow_mask 94 | -------------------------------------------------------------------------------- /util/dl_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Script to download, extract and arrange SEN12MS-CR-TS and SEN12MS-CR. 4 | # Make this script executable (by running: chmod +x dl_data.sh), 5 | # then give it a run (by calling: ./dl_data.sh) and 6 | # follow the prompts in order to get the desired data. 7 | 8 | clear 9 | echo "This script is for downloading the SEN12MS-CR-TS data set for cloud removal in satellite data." 10 | echo See the associated paper: Ebel et al \(2022\) \'SEN12MS-CR-TS: A Remote Sensing Data Set for Multi-modal Multi-temporal Cloud Removal\' 11 | echo -e 'Click \e]8;;https://patricktum.github.io/cloud_removal/\ahere\e]8;;\a for more information' 12 | echo 13 | echo 14 | 15 | while true; do 16 | read -p "Do you wish to download the multitemporal SEN12MS-CR-TS data set? " yn 17 | case $yn in 18 | [Yy]* ) SEN12MSCRTS=true; break;; 19 | [Nn]* ) SEN12MSCRTS=false; break;; 20 | * ) echo "Please answer yes or no.";; 21 | esac 22 | done 23 | 24 | if [ "$SEN12MSCRTS" = "true" ]; then 25 | while true; do 26 | read -p "What regions would you like to download? [all|africa|america|asiaEast|asiaWest|europa] " region 27 | case $region in 28 | all|africa|america|asiaEast|asiaWest|europa ) reg=$region; break;; 29 | * ) echo "Please answer [all|africa|america|asiaEast|asiaWest|europa].";; 30 | esac 31 | done 32 | fi 33 | 34 | while true; do 35 | read -p "Do you wish to also download the monotemporal SEN12MS-CR data set (all regions)? " yn 36 | case $yn in 37 | [Yy]* ) SEN12MSCR=true; break;; 38 | [Nn]* ) SEN12MSCR=false; break;; 39 | * ) echo "Please answer yes or no.";; 40 | esac 41 | done 42 | 43 | while true; do 44 | read -p "Do you wish to also download the Sentinel-1 radar data associated with your previous choices? " yn 45 | case $yn in 46 | [Yy]* ) S1=true; break;; 47 | [Nn]* ) S1=false; break;; 48 | * ) echo "Please answer yes or no.";; 49 | esac 50 | done 51 | 52 | declare -A url_dict # holding links to data 53 | declare -A vol_dict # bookkeeping size of data 54 | 55 | echo "Please enter the path to download and extract the data to: " 56 | read dl_extract_to 57 | 58 | 59 | echo 60 | echo 61 | if [ "$SEN12MSCRTS" = "true" ]; then 62 | 63 | echo "Downloading SEN12MS-CR-TS data set." 64 | mkdir -p $dl_extract_to'/SEN12MSCRTS' 65 | 66 | # train split 67 | case $region in 68 | 'all') url_dict['multi_s2_africa']='https://dataserv.ub.tum.de/s/m1639953/download?path=/&files=s2_africa.tar.gz' 69 | vol_dict['multi_s2_africa']='98233900' 70 | 71 | url_dict['multi_s2_america']='https://dataserv.ub.tum.de/s/m1639953/download?path=/&files=s2_america.tar.gz' 72 | vol_dict['multi_s2_america']='110245004' 73 | 74 | url_dict['multi_s2_asiaEast']='https://dataserv.ub.tum.de/s/m1639953/download?path=/&files=s2_asiaEast.tar.gz' 75 | vol_dict['multi_s2_asiaEast']='113948560' 76 | 77 | url_dict['multi_s2_asiaWest']='https://dataserv.ub.tum.de/s/m1639953/download?path=/&files=s2_asiaWest.tar.gz' 78 | vol_dict['multi_s2_asiaWest']='96082796' 79 | 80 | url_dict['multi_s2_europa']='https://dataserv.ub.tum.de/s/m1639953/download?path=/&files=s2_europa.tar.gz' 81 | vol_dict['multi_s2_europa']='196669740' 82 | ;; 83 | 'africa') url_dict['multi_s2_africa']='https://dataserv.ub.tum.de/s/m1639953/download?path=/&files=s2_africa.tar.gz' 84 | vol_dict['multi_s2_africa']='98233900' 85 | ;; 86 | 'america') url_dict['multi_s2_america']='https://dataserv.ub.tum.de/s/m1639953/download?path=/&files=s2_america.tar.gz' 87 | vol_dict['multi_s2_america']='110245004' 88 | ;; 89 | 'asiaEast') url_dict['multi_s2_asiaEast']='https://dataserv.ub.tum.de/s/m1639953/download?path=/&files=s2_asiaEast.tar.gz' 90 | vol_dict['multi_s2_asiaEast']='113948560' 91 | ;; 92 | 'asiaWest') url_dict['multi_s2_asiaWest']='https://dataserv.ub.tum.de/s/m1639953/download?path=/&files=s2_asiaWest.tar.gz' 93 | vol_dict['multi_s2_asiaWest']='96082796' 94 | ;; 95 | 'europa') url_dict['multi_s2_europa']='https://dataserv.ub.tum.de/s/m1639953/download?path=/&files=s2_europa.tar.gz' 96 | vol_dict['multi_s2_europa']='196669740' 97 | ;; 98 | esac 99 | 100 | 101 | # test split 102 | case $region in 103 | 'all') url_dict['multi_s2_africa_test']='https://dataserv.ub.tum.de/s/m1659251/download?path=/&files=s2_africa_test.tar.gz' 104 | vol_dict['multi_s2_africa_test']='25421744' 105 | 106 | url_dict['multi_s2_america_test']='https://dataserv.ub.tum.de/s/m1659251/download?path=/&files=s2_america_test.tar.gz' 107 | vol_dict['multi_s2_america_test']='25421824' 108 | 109 | url_dict['multi_s2_asiaEast_test']='https://dataserv.ub.tum.de/s/m1659251/download?path=/&files=s2_asiaEast_test.tar.gz' 110 | vol_dict['multi_s2_asiaEast_test']='40534760' 111 | 112 | url_dict['multi_s2_asiaWest_test']='https://dataserv.ub.tum.de/s/m1659251/download?path=/&files=s2_asiaWest_test.tar.gz' 113 | vol_dict['multi_s2_asiaWest_test']='15012924' 114 | 115 | url_dict['multi_s2_europa_test']='https://dataserv.ub.tum.de/s/m1659251/download?path=/&files=s2_europa_test.tar.gz' 116 | vol_dict['multi_s2_europa_test']='79568460' 117 | ;; 118 | 'africa') url_dict['multi_s2_africa_test']='https://dataserv.ub.tum.de/s/m1659251/download?path=/&files=s2_africa_test.tar.gz' 119 | vol_dict['multi_s2_africa_test']='25421744' 120 | ;; 121 | 'america') url_dict['multi_s2_america_test']='https://dataserv.ub.tum.de/s/m1659251/download?path=/&files=s2_america_test.tar.gz' 122 | vol_dict['multi_s2_america_test']='25421824' 123 | ;; 124 | 'asiaEast') url_dict['multi_s2_asiaEast_test']='https://dataserv.ub.tum.de/s/m1659251/download?path=/&files=s2_asiaEast_test.tar.gz' 125 | vol_dict['multi_s2_asiaEast_test']='40534760' 126 | ;; 127 | 'asiaWest') url_dict['multi_s2_asiaWest_test']='https://dataserv.ub.tum.de/s/m1659251/download?path=/&files=s2_asiaWest_test.tar.gz' 128 | vol_dict['multi_s2_asiaWest_test']='15012924' 129 | ;; 130 | 'europa') url_dict['multi_s2_europa_test']='https://dataserv.ub.tum.de/s/m1659251/download?path=/&files=s2_europa_test.tar.gz' 131 | vol_dict['multi_s2_europa_test']='79568460' 132 | ;; 133 | esac 134 | 135 | 136 | if [ "$S1" = "true" ]; then 137 | # train split 138 | case $region in 139 | 'all') url_dict['multi_s1_africa']='https://dataserv.ub.tum.de/s/m1639953/download?path=/&files=s1_africa.tar.gz' 140 | vol_dict['multi_s1_africa']='60544524' 141 | 142 | url_dict['multi_s1_america']='https://dataserv.ub.tum.de/s/m1639953/download?path=/&files=s1_america.tar.gz' 143 | vol_dict['multi_s1_america']='67947416' 144 | 145 | url_dict['multi_s1_asiaEast']='https://dataserv.ub.tum.de/s/m1639953/download?path=/&files=s1_asiaEast.tar.gz' 146 | vol_dict['multi_s1_asiaEast']='70230104' 147 | 148 | url_dict['multi_s1_asiaWest']='https://dataserv.ub.tum.de/s/m1639953/download?path=/&files=s1_asiaWest.tar.gz' 149 | vol_dict['multi_s1_asiaWest']='59218848' 150 | 151 | url_dict['multi_s1_europa']='https://dataserv.ub.tum.de/s/m1639953/download?path=/&files=s1_europa.tar.gz' 152 | vol_dict['multi_s1_europa']='121213836' 153 | ;; 154 | 'africa') url_dict['multi_s1_africa']='https://dataserv.ub.tum.de/s/m1639953/download?path=/&files=s1_africa.tar.gz' 155 | vol_dict['multi_s1_africa']='60544524' 156 | ;; 157 | 'america') url_dict['multi_s1_america']='https://dataserv.ub.tum.de/s/m1639953/download?path=/&files=s1_america.tar.gz' 158 | vol_dict['multi_s1_america']='67947416' 159 | ;; 160 | 'asiaEast') url_dict['multi_s1_asiaEast']='https://dataserv.ub.tum.de/s/m1639953/download?path=/&files=s1_asiaEast.tar.gz' 161 | vol_dict['multi_s1_asiaEast']='70230104' 162 | ;; 163 | 'asiaWest') url_dict['multi_s1_asiaWest']='https://dataserv.ub.tum.de/s/m1639953/download?path=/&files=s1_asiaWest.tar.gz' 164 | vol_dict['multi_s1_asiaWest']='59218848' 165 | ;; 166 | 'europa') url_dict['multi_s1_europa']='https://dataserv.ub.tum.de/s/m1639953/download?path=/&files=s1_europa.tar.gz' 167 | vol_dict['multi_s1_europa']='121213836' 168 | ;; 169 | esac 170 | 171 | 172 | # test split 173 | case $region in 174 | 'all') url_dict['multi_s1_africa_test']='https://dataserv.ub.tum.de/s/m1659251/download?path=/&files=s1_africa_test.tar.gz' 175 | vol_dict['multi_s1_africa_test']='15668120' 176 | 177 | url_dict['multi_s1_america_test']='https://dataserv.ub.tum.de/s/m1659251/download?path=/&files=s1_america_test.tar.gz' 178 | vol_dict['multi_s1_america_test']='15668160' 179 | 180 | url_dict['multi_s1_asiaEast_test']='https://dataserv.ub.tum.de/s/m1659251/download?path=/&files=s1_asiaEast_test.tar.gz' 181 | vol_dict['multi_s1_asiaEast_test']='24982736' 182 | 183 | url_dict['multi_s1_asiaWest_test']='https://dataserv.ub.tum.de/s/m1659251/download?path=/&files=s1_asiaWest_test.tar.gz' 184 | vol_dict['multi_s1_asiaWest_test']='9252904' 185 | 186 | url_dict['multi_s1_europa_test']='https://dataserv.ub.tum.de/s/m1659251/download?path=/&files=s1_europa_test.tar.gz' 187 | vol_dict['multi_s1_europa_test']='49040432' 188 | ;; 189 | 'africa') url_dict['multi_s1_africa_test']='https://dataserv.ub.tum.de/s/m1659251/download?path=/&files=s1_africa_test.tar.gz' 190 | vol_dict['multi_s1_africa_test']='15668120' 191 | ;; 192 | 'america') url_dict['multi_s1_america_test']='https://dataserv.ub.tum.de/s/m1659251/download?path=/&files=s1_america_test.tar.gz' 193 | vol_dict['multi_s1_america_test']='15668160' 194 | ;; 195 | 'asiaEast') url_dict['multi_s1_asiaEast_test']='https://dataserv.ub.tum.de/s/m1659251/download?path=/&files=s1_asiaEast_test.tar.gz' 196 | vol_dict['multi_s1_asiaEast_test']='24982736' 197 | ;; 198 | 'asiaWest') url_dict['multi_s1_asiaWest_test']='https://dataserv.ub.tum.de/s/m1659251/download?path=/&files=s1_asiaWest_test.tar.gz' 199 | vol_dict['multi_s1_asiaWest_test']='9252904' 200 | ;; 201 | 'europa') url_dict['multi_s1_europa_test']='https://dataserv.ub.tum.de/s/m1659251/download?path=/&files=s1_europa_test.tar.gz' 202 | vol_dict['multi_s1_europa_test']='49040432' 203 | ;; 204 | esac 205 | fi 206 | fi 207 | 208 | 209 | # mono-temporal data (all regions) 210 | if [ "$SEN12MSCR" = "true" ]; then 211 | echo "Also downloading SEN12MS-CR data set." 212 | mkdir -p $dl_extract_to'/SEN12MSCR' 213 | url_dict['mono_s2_spring']='https://dataserv.ub.tum.de/s/m1554803/download?path=/&files=ROIs1158_spring_s2.tar.gz' 214 | vol_dict['mono_s2_spring']='48568904' 215 | 216 | url_dict['mono_s2_summer']='https://dataserv.ub.tum.de/s/m1554803/download?path=/&files=ROIs1868_summer_s2.tar.gz' 217 | vol_dict['mono_s2_summer']='56425520' 218 | 219 | url_dict['mono_s2_fall']='https://dataserv.ub.tum.de/s/m1554803/download?path=/&files=ROIs1970_fall_s2.tar.gz' 220 | vol_dict['mono_s2_fall']='68291864' 221 | 222 | url_dict['mono_s2_winter']='https://dataserv.ub.tum.de/s/m1554803/download?path=/&files=ROIs2017_winter_s2.tar.gz' 223 | vol_dict['mono_s2_winter']='30580552' 224 | 225 | url_dict['mono_s2_cloudy_spring']='https://dataserv.ub.tum.de/s/m1554803/download?path=/&files=ROIs1158_spring_s2_cloudy.tar.gz' 226 | vol_dict['mono_s2_cloudy_spring']='48569368' 227 | 228 | url_dict['mono_s2_cloudy_summer']='https://dataserv.ub.tum.de/s/m1554803/download?path=/&files=ROIs1868_summer_s2_cloudy.tar.gz' 229 | vol_dict['mono_s2_cloudy_summer']='56426004' 230 | 231 | url_dict['mono_s2_cloudy_fall']='https://dataserv.ub.tum.de/s/m1554803/download?path=/&files=ROIs1970_fall_s2_cloudy.tar.gz' 232 | vol_dict['mono_s2_cloudy_fall']='68292448' 233 | 234 | url_dict['mono_s2_cloudy_winter']='https://dataserv.ub.tum.de/s/m1554803/download?path=/&files=ROIs2017_winter_s2_cloudy.tar.gz' 235 | vol_dict['mono_s2_cloudy_winter']='30580812' 236 | 237 | # S1 data of SEN12MS-CR 238 | if [ "$S1" = "true" ]; then 239 | echo "Also downloading associated S1 data." 240 | url_dict['mono_s1_spring']='https://dataserv.ub.tum.de/s/m1554803/download?path=/&files=ROIs1158_spring_s1.tar.gz' 241 | vol_dict['mono_s1_spring']='15026120' 242 | 243 | url_dict['mono_s1_summer']='https://dataserv.ub.tum.de/s/m1554803/download?path=/&files=ROIs1868_summer_s1.tar.gz' 244 | vol_dict['mono_s1_summer']='17456784' 245 | 246 | url_dict['mono_s1_fall']='https://dataserv.ub.tum.de/s/m1554803/download?path=/&files=ROIs1970_fall_s1.tar.gz' 247 | vol_dict['mono_s1_fall']='21127832' 248 | 249 | url_dict['mono_s1_winter']='https://dataserv.ub.tum.de/s/m1554803/download?path=/&files=ROIs2017_winter_s1.tar.gz' 250 | vol_dict['mono_s1_winter']='9460956' 251 | fi 252 | fi 253 | 254 | req=0 255 | # integrate file size across archives 256 | for key in "${!vol_dict[@]}"; do 257 | # for each archive: sum up 258 | curr=${vol_dict[$key]} 259 | req=$((req+curr)) 260 | done 261 | 262 | echo 263 | echo 264 | # df -h $dl_extract_to 265 | avail=$(df $dl_extract_to | awk 'NR==2 { print $4 }') 266 | if (( avail < req )); then 267 | echo "Not enough space (512-byte disk sectors) on path "$dl_extract_to". Available "$avail". Required "$req #>&2 268 | exit 1 269 | else 270 | echo "Consuming "$req" of "$avail" (512-byte disk sectors) on path "$dl_extract_to 271 | fi 272 | echo 273 | echo 274 | 275 | # download each archive individually, then extract individually 276 | 277 | # fetch the actual data 278 | for key in "${!url_dict[@]}"; do 279 | url=${url_dict[$key]} 280 | filename=$(basename "$url") 281 | filename=${filename:7} 282 | # download 283 | wget --no-check-certificate -c -O $dl_extract_to'/'$filename ${url_dict[$key]} 284 | # unzip and delete archive 285 | tar --extract --file $dl_extract_to'/'$filename -C $dl_extract_to 286 | rm $dl_extract_to'/'$filename 287 | done 288 | 289 | # move the extracted data to its respective place (this may take a while, because we use rsync rather than mv) 290 | echo "Moving data in place, please don't stop this process." 291 | for key in "${!url_dict[@]}"; do 292 | url=${url_dict[$key]} 293 | filename=$(basename "$url") 294 | filename=${filename:7:-7} # remove base URL and trailing *.tar.gz 295 | if [[ ${url_dict[$key]} == *"m1554803"* ]]; then 296 | # move to SEN12MSCR directory 297 | mv $dl_extract_to'/'$filename $dl_extract_to'/SEN12MSCR' 298 | elif [[ ${url_dict[$key]} == *"m1639953"* ]]; then 299 | # move train ROI to SEN12MSCRTS directory 300 | no_prefix_filename=${filename:3} 301 | rsync -a -remove-source-files $dl_extract_to'/'$no_prefix_filename/* $dl_extract_to'/SEN12MSCRTS' 2>/dev/null 302 | rm -rf $dl_extract_to'/'$no_prefix_filename 303 | else 304 | # move test ROI to SEN12MSCRTS directory 305 | rsync -a -remove-source-files $dl_extract_to'/'$filename/* $dl_extract_to'/SEN12MSCRTS' 306 | rm -rf $dl_extract_to'/'$filename 307 | fi 308 | done 309 | 310 | echo 311 | echo "Completed downloading, extracting and moving data! Enjoy :)" 312 | -------------------------------------------------------------------------------- /util/get_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import tarfile 4 | import requests 5 | from warnings import warn 6 | from zipfile import ZipFile 7 | from bs4 import BeautifulSoup 8 | from os.path import abspath, isdir, join, basename 9 | 10 | 11 | class GetData(object): 12 | """A Python script for downloading CycleGAN or pix2pix datasets. 13 | 14 | Parameters: 15 | technique (str) -- One of: 'cyclegan' or 'pix2pix'. 16 | verbose (bool) -- If True, print additional information. 17 | 18 | Examples: 19 | >>> from util.get_data import GetData 20 | >>> gd = GetData(technique='cyclegan') 21 | >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed. 22 | 23 | Alternatively, You can use bash scripts: 'scripts/download_pix2pix_model.sh' 24 | and 'scripts/download_cyclegan_model.sh'. 25 | """ 26 | 27 | def __init__(self, technique='cyclegan', verbose=True): 28 | url_dict = { 29 | 'pix2pix': 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/', 30 | 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets' 31 | } 32 | self.url = url_dict.get(technique.lower()) 33 | self._verbose = verbose 34 | 35 | def _print(self, text): 36 | if self._verbose: 37 | print(text) 38 | 39 | @staticmethod 40 | def _get_options(r): 41 | soup = BeautifulSoup(r.text, 'lxml') 42 | options = [h.text for h in soup.find_all('a', href=True) 43 | if h.text.endswith(('.zip', 'tar.gz'))] 44 | return options 45 | 46 | def _present_options(self): 47 | r = requests.get(self.url) 48 | options = self._get_options(r) 49 | print('Options:\n') 50 | for i, o in enumerate(options): 51 | print("{0}: {1}".format(i, o)) 52 | choice = input("\nPlease enter the number of the " 53 | "dataset above you wish to download:") 54 | return options[int(choice)] 55 | 56 | def _download_data(self, dataset_url, save_path): 57 | if not isdir(save_path): 58 | os.makedirs(save_path) 59 | 60 | base = basename(dataset_url) 61 | temp_save_path = join(save_path, base) 62 | 63 | with open(temp_save_path, "wb") as f: 64 | r = requests.get(dataset_url) 65 | f.write(r.content) 66 | 67 | if base.endswith('.tar.gz'): 68 | obj = tarfile.open(temp_save_path) 69 | elif base.endswith('.zip'): 70 | obj = ZipFile(temp_save_path, 'r') 71 | else: 72 | raise ValueError("Unknown File Type: {0}.".format(base)) 73 | 74 | self._print("Unpacking Data...") 75 | obj.extractall(save_path) 76 | obj.close() 77 | os.remove(temp_save_path) 78 | 79 | def get(self, save_path, dataset=None): 80 | """ 81 | 82 | Download a dataset. 83 | 84 | Parameters: 85 | save_path (str) -- A directory to save the data to. 86 | dataset (str) -- (optional). A specific dataset to download. 87 | Note: this must include the file extension. 88 | If None, options will be presented for you 89 | to choose from. 90 | 91 | Returns: 92 | save_path_full (str) -- the absolute path to the downloaded data. 93 | 94 | """ 95 | if dataset is None: 96 | selected_dataset = self._present_options() 97 | else: 98 | selected_dataset = dataset 99 | 100 | save_path_full = join(save_path, selected_dataset.split('.')[0]) 101 | 102 | if isdir(save_path_full): 103 | warn("\n'{0}' already exists. Voiding Download.".format( 104 | save_path_full)) 105 | else: 106 | self._print('Downloading Data...') 107 | url = "{0}/{1}".format(self.url, selected_dataset) 108 | self._download_data(url, save_path=save_path) 109 | 110 | return abspath(save_path_full) 111 | -------------------------------------------------------------------------------- /util/hdf5converter/script_tif2hdf5.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # scripts kindly provided by Corinne Stucker 4 | # https://scholar.google.ch/citations?user=P-op4CgAAAAJ&hl=de 5 | # this code can be used to reconstruct the full-scene images in hdf5 format from the released individual patches in tif format 6 | 7 | #python tif2hdf5.py /scratch2/Data/SEN12MS-CR-TS/ val europa /scratch2/Data/SEN12MS-CR-TS_hdf5/all_ROIs/ 8 | #python tif2hdf5.py /scratch2/Data/SEN12MS-CR-TS/ val america /scratch2/Data/SEN12MS-CR-TS_hdf5/all_ROIs/ 9 | #python tif2hdf5.py /scratch2/Data/SEN12MS-CR-TS/ val africa /scratch2/Data/SEN12MS-CR-TS_hdf5/all_ROIs/ 10 | 11 | python tif2hdf5.py /scratch2/Data/SEN12MS-CR-TS/ val asiaWest /scratch2/Data/SEN12MS-CR-TS_hdf5/all_ROIs/ 12 | python tif2hdf5.py /scratch2/Data/SEN12MS-CR-TS/ val asiaEast /scratch2/Data/SEN12MS-CR-TS_hdf5/all_ROIs/ 13 | #python tif2hdf5.py /scratch2/Data/SEN12MS-CR-TS_testSplit/ test europa /scratch2/Data/SEN12MS-CR-TS_hdf5/all_ROIs/ 14 | #python tif2hdf5.py /scratch2/Data/SEN12MS-CR-TS_testSplit/ test america /scratch2/Data/SEN12MS-CR-TS_hdf5/all_ROIs/ 15 | #python tif2hdf5.py /scratch2/Data/SEN12MS-CR-TS_testSplit/ test africa /scratch2/Data/SEN12MS-CR-TS_hdf5/all_ROIs/ 16 | #python tif2hdf5.py /scratch2/Data/SEN12MS-CR-TS_testSplit/ test asiaWest /scratch2/Data/SEN12MS-CR-TS_hdf5/all_ROIs/ 17 | #python tif2hdf5.py /scratch2/Data/SEN12MS-CR-TS_testSplit/ test asiaEast /scratch2/Data/SEN12MS-CR-TS_hdf5/all_ROIs/ 18 | -------------------------------------------------------------------------------- /util/hdf5converter/sen12mscrts_to_hdf5.py: -------------------------------------------------------------------------------- 1 | # scripts kindly provided by Corinne Stucker 2 | # https://scholar.google.ch/citations?user=P-op4CgAAAAJ&hl=de 3 | # this code can be used to reconstruct the full-scene images in hdf5 format from the released individual patches in tif format 4 | 5 | from natsort import natsorted 6 | import numpy as np 7 | import os 8 | import rasterio 9 | from tqdm import tqdm 10 | from scipy.ndimage import gaussian_filter 11 | from s2cloudless import S2PixelCloudDetector 12 | 13 | from data.dataLoader import SEN12MSCRTS 14 | 15 | """ SEN12MSCRTS data loader class, used to load the data in the original format and prepare the data for hdf5 export 16 | 17 | IN: 18 | root: str, path to your copy of the SEN12MS-CR-TS data set 19 | split: str, in [all | train | val | test] 20 | region: str, [all | africa | america | asiaEast | asiaWest | europa] 21 | cloud_masks: str, type of cloud mask detector to run on optical data, in [None | cloud_cloudshadow_mask | s2cloudless_map | s2cloudless_mask] 22 | 23 | OUT: 24 | data_loader: SEN12MSCRTS instance, implements an iterator that can be traversed via __getitem__(pdx), 25 | which returns the pdx-th dictionary of patch-samples (whose structure depends on sample_type) 26 | """ 27 | 28 | 29 | class SEN12MSCRTS_to_hdf5(SEN12MSCRTS): 30 | def __init__(self, root, split="all", region='all', cloud_masks='s2cloudless_mask', modalities=["S1", "S2"]): 31 | 32 | self.root_dir = root # set root directory which contains all ROI 33 | self.region = region # region according to which the ROI are selected 34 | self.ROI = {'ROIs1158': ['106'], 35 | 'ROIs1868': ['17', '36', '56', '73', '85', '100', '114', '119', '121', '126', '127', '139', '142', 36 | '143'], 37 | 'ROIs1970': ['20', '21', '35', '40', '57', '65', '71', '82', '83', '91', '112', '116', '119', '128', 38 | '132', '133', '135', '139', '142', '144', '149'], 39 | 'ROIs2017': ['8', '22', '25', '32', '49', '61', '63', '69', '75', '103', '108', '115', '116', '117', 40 | '130', '140', '146']} 41 | 42 | # define splits conform with SEN12MS-CR 43 | self.splits = {} 44 | if self.region == 'all': 45 | all_ROI = [os.path.join(key, val) for key, vals in self.ROI.items() for val in vals] 46 | self.splits['test'] = [os.path.join('ROIs1868', '119'), os.path.join('ROIs1970', '139'), 47 | os.path.join('ROIs2017', '108'), os.path.join('ROIs2017', '63'), 48 | os.path.join('ROIs1158', '106'), os.path.join('ROIs1868', '73'), 49 | os.path.join('ROIs2017', '32'), 50 | os.path.join('ROIs1868', '100'), os.path.join('ROIs1970', '132'), 51 | os.path.join('ROIs2017', '103'), os.path.join('ROIs1868', '142'), 52 | os.path.join('ROIs1970', '20'), 53 | os.path.join('ROIs2017', '140')] # official test split, across continents 54 | self.splits['val'] = [os.path.join('ROIs2017', '22'), os.path.join('ROIs1970', '65'), 55 | os.path.join('ROIs2017', '117'), os.path.join('ROIs1868', '127'), 56 | os.path.join('ROIs1868', '17')] # insert your favorite validation split here 57 | self.splits['train'] = [roi for roi in all_ROI if roi not in self.splits['val'] and roi not in self.splits[ 58 | 'test']] # all remaining ROI are used for training 59 | elif self.region == 'africa': 60 | self.splits['test'] = [os.path.join('ROIs2017', '32'), os.path.join('ROIs2017', '140')] 61 | self.splits['val'] = [os.path.join('ROIs2017', '22')] 62 | self.splits['train'] = [os.path.join('ROIs1970', '21'), os.path.join('ROIs1970', '35'), 63 | os.path.join('ROIs1970', '40'), 64 | os.path.join('ROIs2017', '8'), os.path.join('ROIs2017', '61'), 65 | os.path.join('ROIs2017', '75')] 66 | elif self.region == 'america': 67 | self.splits['test'] = [os.path.join('ROIs1158', '106'), os.path.join('ROIs1970', '132')] 68 | self.splits['val'] = [os.path.join('ROIs1970', '65')] 69 | self.splits['train'] = [os.path.join('ROIs1868', '36'), os.path.join('ROIs1868', '85'), 70 | os.path.join('ROIs1970', '82'), os.path.join('ROIs1970', '142'), 71 | os.path.join('ROIs2017', '49'), os.path.join('ROIs2017', '116')] 72 | elif self.region == 'asiaEast': 73 | self.splits['test'] = [os.path.join('ROIs1868', '73'), os.path.join('ROIs1868', '119'), 74 | os.path.join('ROIs1970', '139')] 75 | self.splits['val'] = [os.path.join('ROIs2017', '117')] 76 | self.splits['train'] = [os.path.join('ROIs1868', '114'), os.path.join('ROIs1868', '126'), 77 | os.path.join('ROIs1868', '143'), 78 | os.path.join('ROIs1970', '116'), os.path.join('ROIs1970', '135'), 79 | os.path.join('ROIs2017', '25')] 80 | elif self.region == 'asiaWest': 81 | self.splits['test'] = [os.path.join('ROIs1868', '100')] 82 | self.splits['val'] = [os.path.join('ROIs1868', '127')] 83 | self.splits['train'] = [os.path.join('ROIs1970', '57'), os.path.join('ROIs1970', '83'), 84 | os.path.join('ROIs1970', '112'), 85 | os.path.join('ROIs2017', '69'), os.path.join('ROIs1970', '115'), 86 | os.path.join('ROIs1970', '130')] 87 | elif self.region == 'europa': 88 | self.splits['test'] = [os.path.join('ROIs2017', '63'), os.path.join('ROIs2017', '103'), 89 | os.path.join('ROIs2017', '108'), 90 | os.path.join('ROIs1868', '142'), os.path.join('ROIs1970', '20')] 91 | self.splits['val'] = [os.path.join('ROIs1868', '17')] 92 | self.splits['train'] = [os.path.join('ROIs1868', '56'), os.path.join('ROIs1868', '121'), 93 | os.path.join('ROIs1868', '139'), 94 | os.path.join('ROIs1970', '71'), os.path.join('ROIs1970', '91'), 95 | os.path.join('ROIs1970', '119'), 96 | os.path.join('ROIs1970', '128'), os.path.join('ROIs1970', '133'), 97 | os.path.join('ROIs1970', '144'), 98 | os.path.join('ROIs1970', '149'), 99 | os.path.join('ROIs2017', '146')] 100 | else: 101 | raise NotImplementedError 102 | 103 | self.splits["all"] = self.splits["train"] + self.splits["test"] + self.splits["val"] 104 | self.split = split 105 | 106 | assert split in ['all', 'train', 'val', 107 | 'test'], "Input dataset must be either assigned as all, train, test, or val!" 108 | assert cloud_masks in [None, 'cloud_cloudshadow_mask', 's2cloudless_map', 109 | 's2cloudless_mask'], "Unknown cloud mask type!" 110 | 111 | self.modalities = modalities 112 | self.time_points = range(30) 113 | self.cloud_masks = cloud_masks # e.g. 'cloud_cloudshadow_mask', 's2cloudless_map', 's2cloudless_mask' 114 | 115 | if self.cloud_masks in ['s2cloudless_map', 's2cloudless_mask']: 116 | self.cloud_detector = S2PixelCloudDetector(threshold=0.4, all_bands=True, average_over=4, dilation_size=2) 117 | 118 | self.paths = self.get_paths() 119 | self.n_samples = len(self.paths) 120 | 121 | # raise a warning that no data has been found 122 | if not self.n_samples: self.throw_warn() 123 | 124 | def get_paths(self): # assuming for the same ROI+num, the patch numbers are the same 125 | print(f'\nProcessing paths for {self.split} split of region {self.region}') 126 | 127 | paths = [] 128 | for roi_dir, rois in self.ROI.items(): 129 | for roi in tqdm(rois): 130 | roi_path = os.path.join(self.root_dir, roi_dir, roi) 131 | # skip non-existent ROI or ROI not part of the current data split 132 | if not os.path.isdir(roi_path) or os.path.join(roi_dir, roi) not in self.splits[self.split]: continue 133 | path_s1_t, path_s2_t = [], [] 134 | for tdx in self.time_points: 135 | if 'S1' in self.modalities: 136 | path_s1_complete = os.path.join(roi_path, 'S1', str(tdx)) 137 | path_s1 = os.path.join(roi_dir, roi, 'S1', str(tdx)) 138 | s1_t = natsorted([os.path.join(path_s1, f) for f in os.listdir(path_s1_complete) if 139 | (os.path.isfile(os.path.join(path_s1_complete, f)) and ".tif" in f)]) 140 | if 'S2' in self.modalities: 141 | path_s2_complete = os.path.join(roi_path, 'S2', str(tdx)) 142 | path_s2 = os.path.join(roi_dir, roi, 'S2', str(tdx)) 143 | s2_t = natsorted([os.path.join(path_s2, f) for f in os.listdir(path_s2_complete) if 144 | (os.path.isfile(os.path.join(path_s2_complete, f)) and ".tif" in f)]) 145 | 146 | if 'S1' in self.modalities and 'S2' in self.modalities: 147 | # same number of patches 148 | assert len(s1_t) == len(s2_t) 149 | 150 | # sort via file names according to patch number and store 151 | if 'S1' in self.modalities: 152 | path_s1_t.append(s1_t) 153 | if 'S2' in self.modalities: 154 | path_s2_t.append(s2_t) 155 | 156 | # for each patch of the ROI, collect its time points and make this one sample 157 | for pdx in range(len(path_s1_t[0])): 158 | sample = dict() 159 | if 'S1' in self.modalities: 160 | sample['S1'] = [path_s1_t[tdx][pdx] for tdx in self.time_points] 161 | if 'S2' in self.modalities: 162 | sample['S2'] = [path_s2_t[tdx][pdx] for tdx in self.time_points] 163 | 164 | paths.append(sample) 165 | 166 | return paths 167 | 168 | def get_cloud_mask(self, img, mask_type): 169 | if mask_type == 'cloud_cloudshadow_mask': 170 | threshold = 0.2 # set to e.g. 0.2 or 0.4 171 | mask = self.get_cloud_cloudshadow_mask(np.clip(img, 0, 10000), threshold) 172 | elif mask_type == 's2cloudless_map': 173 | threshold = 0.5 174 | mask = self.cloud_detector.get_cloud_probability_maps(np.moveaxis(np.clip(img, 0, 10000)/10000, 0, -1)[None, ...])[0, ...] 175 | mask[mask < threshold] = 0 176 | mask = gaussian_filter(mask, sigma=2).astype(np.float32) 177 | elif mask_type == 's2cloudless_mask': 178 | mask = self.cloud_detector.get_cloud_masks(np.moveaxis(np.clip(img, 0, 10000)/10000, 0, -1)[None, ...])[0, ...] 179 | elif mask_type == 's2cloud_prob': 180 | mask = self.cloud_detector.get_cloud_probability_maps(np.moveaxis(np.clip(img, 0, 10000) / 10000, 0, -1)[None, ...])[0, ...] 181 | 182 | return mask 183 | 184 | def __getitem__(self, pdx): # get the time series of one patch 185 | 186 | sample = dict() 187 | 188 | if 'S1' in self.modalities: 189 | s1 = [self.read_img(os.path.join(self.root_dir, img)) for img in self.paths[pdx]['S1']] 190 | s1_dates = [img.split('/')[-1].split('_')[5] for img in self.paths[pdx]['S1']] 191 | sample['S1'] = s1 192 | sample['S1_dates'] = s1_dates 193 | sample['S1_paths'] = self.paths[pdx]['S1'] 194 | 195 | if 'S2' in self.modalities: 196 | s2 = [self.read_img(os.path.join(self.root_dir, img)) for img in self.paths[pdx]['S2']] 197 | s2_dates = [img.split('/')[-1].split('_')[5] for img in self.paths[pdx]['S2']] 198 | 199 | cloud_prob = [self.get_cloud_mask(img, 's2cloud_prob') for img in s2] 200 | cloud_mask = [self.get_cloud_mask(img, 's2cloudless_mask') for img in s2] 201 | 202 | sample['S2'] = s2 203 | sample['S2_dates'] = s2_dates 204 | sample['S2_paths'] = self.paths[pdx]['S2'] 205 | sample['cloud_prob'] = cloud_prob 206 | 207 | sample['cloud_mask'] = cloud_mask 208 | 209 | return sample 210 | 211 | def __len__(self): 212 | # length of generated list 213 | return self.n_samples 214 | -------------------------------------------------------------------------------- /util/hdf5converter/tif2hdf5.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import h5py 3 | import numpy as np 4 | import torch 5 | from tqdm import tqdm 6 | import os, sys 7 | import re 8 | 9 | from sen12mscrts_to_hdf5 import SEN12MSCRTS_to_hdf5 10 | 11 | 12 | def extract_ROI_tile_patch_index(filename): 13 | ROI, tile, patch = map(re.split('_|.tif', filename.split('/')[-1]).__getitem__, [1, 2, -2]) 14 | return ROI, tile, patch 15 | 16 | def create_hdf5_group(hdf5_file, group): 17 | with h5py.File(hdf5_file, 'a', libver='latest') as f: 18 | if not f.__contains__(group): 19 | f.create_group(group) 20 | 21 | def process_sample_to_hdf5(hdf5_file, batch, verbose=0): 22 | if 'S2' in batch: 23 | ROI, tile, patch = extract_ROI_tile_patch_index(batch['S2_paths'][0][0]) 24 | 25 | # Create a hdf5 group: ROI -> tile -> patch -> S2 26 | create_hdf5_group(hdf5_file, os.path.join(ROI, tile, f'{ROI}_{tile}_patch_{patch}', 'S2')) 27 | 28 | # Populate the group with hdf5 datasets 29 | with h5py.File(hdf5_file, 'a', libver='latest') as f: 30 | 31 | # S2 image time series, T x C x H x W 32 | group = os.path.join(ROI, tile, f'{ROI}_{tile}_patch_{patch}', 'S2') 33 | s2 = torch.cat(batch['S2'], dim=0) 34 | dset = f[group].create_dataset('S2', data=s2.numpy().astype(np.uint16), compression='gzip', compression_opts=9) 35 | 36 | # Cloud probability mask, T x 1 x H x W 37 | cloud_prob = torch.cat(batch['cloud_prob'], dim=0).unsqueeze(1) 38 | dset = f[group].create_dataset('cloud_prob', data=cloud_prob.float(), compression='gzip', compression_opts=9) 39 | 40 | # Cloud mask, T x 1 x H x W 41 | cloud_mask = torch.cat(batch['cloud_mask'], dim=0).unsqueeze(1) 42 | dset = f[group].create_dataset('cloud_mask', data=cloud_mask, compression='gzip', compression_opts=9) 43 | 44 | # Date per observation 45 | dset = f[group].create_dataset('S2_dates', data=[date[0] for date in batch['S2_dates']], compression='gzip', compression_opts=9) 46 | 47 | if 'S1' in batch: 48 | ROI, tile, patch = extract_ROI_tile_patch_index(batch['S1_paths'][0][0]) 49 | 50 | # Create a hdf5 group: ROI -> tile -> patch -> S1 51 | create_hdf5_group(hdf5_file, os.path.join(ROI, tile, f'{ROI}_{tile}_patch_{patch}', 'S1')) 52 | 53 | # Populate the group with hdf5 datasets 54 | with h5py.File(hdf5_file, 'a', libver='latest') as f: 55 | 56 | # S1 image time series, T x C x H x W 57 | group = os.path.join(ROI, tile, f'{ROI}_{tile}_patch_{patch}', 'S1') 58 | s1 = torch.cat(batch['S1'], dim=0) 59 | dset = f[group].create_dataset('S1', data=s1, compression='gzip', compression_opts=9) 60 | 61 | # Date per observation 62 | dset = f[group].create_dataset('S1_dates', data=[date[0] for date in batch['S1_dates']], compression='gzip', compression_opts=9) 63 | 64 | if verbose == 1: 65 | print(f'Sample {ROI}_{tile}_patch_{patch} processed.') 66 | 67 | 68 | parser = ArgumentParser() 69 | parser.add_argument('root_source', type=str) 70 | parser.add_argument('split', type=str) 71 | parser.add_argument('region', type=str) 72 | parser.add_argument('root_dest', type=str) 73 | 74 | 75 | def main(args): 76 | 77 | dataset = SEN12MSCRTS_to_hdf5(args.root_source, split=args.split, region=args.region) 78 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4) 79 | 80 | hdf5_file = os.path.join(args.root_dest, args.split + '.hdf5') 81 | 82 | # Create a hdf5 file 83 | f = h5py.File(hdf5_file, 'a', libver='latest') 84 | 85 | # Iterate over all data samples in the given data split: tiff to hdf5 conversion 86 | for i, batch in enumerate(tqdm(dataloader)): 87 | process_sample_to_hdf5(hdf5_file, batch, verbose=0) 88 | 89 | f.close() 90 | print('Done') 91 | 92 | 93 | if __name__ == '__main__': 94 | 95 | if len(sys.argv) < 2: 96 | parser.print_help() 97 | sys.exit(1) 98 | else: 99 | main(parser.parse_args()) 100 | -------------------------------------------------------------------------------- /util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br 3 | import os 4 | 5 | 6 | class HTML: 7 | """This HTML class allows us to save images and write texts into a single HTML file. 8 | 9 | It consists of functions such as (add a text header to the HTML file), 10 | (add a row of images to the HTML file), and (save the HTML to the disk). 11 | It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. 12 | """ 13 | 14 | def __init__(self, web_dir, title, refresh=0): 15 | """Initialize the HTML classes 16 | 17 | Parameters: 18 | web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: 32 | with self.doc.head: 33 | meta(http_equiv="refresh", content=str(refresh)) 34 | 35 | def get_image_dir(self): 36 | """Return the directory that stores images""" 37 | return self.img_dir 38 | 39 | def add_header(self, text): 40 | """Insert a header to the HTML file 41 | 42 | Parameters: 43 | text (str) -- the header text 44 | """ 45 | with self.doc: 46 | h3(text) 47 | 48 | def add_images(self, ims, txts, links, width=400): 49 | """add images to the HTML file 50 | 51 | Parameters: 52 | ims (str list) -- a list of image paths 53 | txts (str list) -- a list of image names shown on the website 54 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page 55 | """ 56 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table 57 | self.doc.add(self.t) 58 | with self.t: 59 | with tr(): 60 | for im, txt, link in zip(ims, txts, links): 61 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 62 | with p(): 63 | with a(href=os.path.join('images', link)): 64 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 65 | br() 66 | p(txt) 67 | 68 | def save(self): 69 | """save the current content to the HMTL file""" 70 | html_file = '%s/index.html' % self.web_dir 71 | f = open(html_file, 'wt') 72 | f.write(self.doc.render()) 73 | f.close() 74 | 75 | 76 | if __name__ == '__main__': # we show an example usage here. 77 | html = HTML('web/', 'test_html') 78 | html.add_header('hello world') 79 | 80 | ims, txts, links = [], [], [] 81 | for n in range(4): 82 | ims.append('image_%d.png' % n) 83 | txts.append('text_%d' % n) 84 | links.append('image_%d.png' % n) 85 | html.add_images(ims, txts, links) 86 | html.save() 87 | -------------------------------------------------------------------------------- /util/image_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | 5 | class ImagePool(): 6 | """This class implements an image buffer that stores previously generated images. 7 | 8 | This buffer enables us to update discriminators using a history of generated images 9 | rather than the ones produced by the latest generators. 10 | """ 11 | 12 | def __init__(self, pool_size): 13 | """Initialize the ImagePool class 14 | 15 | Parameters: 16 | pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created 17 | """ 18 | self.pool_size = pool_size 19 | if self.pool_size > 0: # create an empty pool 20 | self.num_imgs = 0 21 | self.images = [] 22 | 23 | def query(self, images): 24 | """Return an image from the pool. 25 | 26 | Parameters: 27 | images: the latest generated images from the generator 28 | 29 | Returns images from the buffer. 30 | 31 | By 50/100, the buffer will return input images. 32 | By 50/100, the buffer will return images previously stored in the buffer, 33 | and insert the current images to the buffer. 34 | """ 35 | if self.pool_size == 0: # if the buffer size is 0, do nothing 36 | return images 37 | return_images = [] 38 | for image in images: 39 | image = torch.unsqueeze(image.data, 0) 40 | if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer 41 | self.num_imgs = self.num_imgs + 1 42 | self.images.append(image) 43 | return_images.append(image) 44 | else: 45 | p = random.uniform(0, 1) 46 | if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer 47 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive 48 | tmp = self.images[random_id].clone() 49 | self.images[random_id] = image 50 | return_images.append(tmp) 51 | else: # by another 50% chance, the buffer will return the current image 52 | return_images.append(image) 53 | return_images = torch.cat(return_images, 0) # collect all the images and return 54 | return return_images 55 | -------------------------------------------------------------------------------- /util/pytorch_ssim/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | def create_window(window_size, channel): 12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 15 | return window 16 | 17 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 20 | 21 | mu1_sq = mu1.pow(2) 22 | mu2_sq = mu2.pow(2) 23 | mu1_mu2 = mu1*mu2 24 | 25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 28 | 29 | C1 = 0.01**2 30 | C2 = 0.03**2 31 | 32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 33 | 34 | if size_average: 35 | return ssim_map.mean() 36 | else: 37 | return ssim_map.mean(1).mean(1).mean(1) 38 | 39 | class SSIM(torch.nn.Module): 40 | def __init__(self, window_size = 11, size_average = True): 41 | super(SSIM, self).__init__() 42 | self.window_size = window_size 43 | self.size_average = size_average 44 | self.channel = 1 45 | self.window = create_window(window_size, self.channel) 46 | 47 | def forward(self, img1, img2): 48 | (_, channel, _, _) = img1.size() 49 | 50 | if channel == self.channel and self.window.data.type() == img1.data.type(): 51 | window = self.window 52 | else: 53 | window = create_window(self.window_size, channel) 54 | 55 | if img1.is_cuda: 56 | window = window.cuda(img1.get_device()) 57 | window = window.type_as(img1) 58 | 59 | self.window = window 60 | self.channel = channel 61 | 62 | 63 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 64 | 65 | def ssim(img1, img2, window_size = 11, size_average = True): 66 | (_, channel, _, _) = img1.size() 67 | window = create_window(window_size, channel) 68 | 69 | if img1.is_cuda: 70 | window = window.cuda(img1.get_device()) 71 | window = window.type_as(img1) 72 | 73 | return _ssim(img1, img2, window, window_size, channel, size_average) 74 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | """This module contains simple helper functions """ 2 | from __future__ import print_function 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | import os 7 | import torch.nn as nn 8 | import torchvision.models as models 9 | import torch.nn.init as init 10 | 11 | 12 | def tensor2im(input_image, method, imtype=np.uint8): 13 | """"Converts a Tensor array into a numpy image array. 14 | 15 | Parameters: 16 | input_image (tensor) -- the input image tensor array 17 | imtype (type) -- the desired type of the converted numpy array 18 | """ 19 | if not isinstance(input_image, np.ndarray): 20 | if isinstance(input_image, torch.Tensor): # get the data from a variable 21 | image_tensor = input_image.data 22 | else: 23 | return input_image 24 | image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array 25 | # no need to do anything if image_numpy is 3-dimensiona already but for the other dimensions ... 26 | 27 | if image_numpy.shape[0] == 1: # grayscale to RGB 28 | image_numpy = np.tile(image_numpy, (3, 1, 1)) # triple channel 29 | image_numpy = (np.transpose(image_numpy, (1, 2, 0))) * 255.0 30 | 31 | if image_numpy.shape[0] == 13 or image_numpy.shape[0] == 4: # 13 bands multispectral (or 4 bands NIR) to RGB 32 | # RGB bands are [3, 2, 1] 33 | image_numpy = image_numpy[[3, 2, 1], ...] 34 | 35 | # method is either 'resnet' (if opt.alter_initial_mode) or 'default' 36 | if method == 'default': # re-normalize from [-1,+1] to [0,+1] 37 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 38 | elif method == 'resnet': # re-normalize from [0, 5] to [0,+1] 39 | image_numpy = (np.transpose(image_numpy, (1, 2, 0))) / 5.0 * 255.0 40 | 41 | if image_numpy.shape[0] == 2: # (VV,VH) SAR to RGB (just taking VV band) 42 | image_numpy = np.tile(image_numpy[[0]], (3, 1, 1)) 43 | if method == 'default': # re-normalize from [-1,+1] to [0,+1] 44 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 45 | elif method == 'resnet': # re-normalize from [0, 2] to [0,+1] 46 | image_numpy = (np.transpose(image_numpy, (1, 2, 0))) / 2.0 * 255.0 47 | # post-processing: tranpose and scaling 48 | else: # if it is a numpy array, do nothing 49 | image_numpy = input_image 50 | return image_numpy.astype(imtype) 51 | 52 | 53 | def diagnose_network(net, name='network'): 54 | """Calculate and print the mean of average absolute(gradients) 55 | 56 | Parameters: 57 | net (torch network) -- Torch network 58 | name (str) -- the name of the network 59 | """ 60 | mean = 0.0 61 | count = 0 62 | for param in net.parameters(): 63 | if param.grad is not None: 64 | mean += torch.mean(torch.abs(param.grad.data)) 65 | count += 1 66 | if count > 0: 67 | mean = mean / count 68 | print(name) 69 | print(mean) 70 | 71 | 72 | def save_image(image_numpy, image_path): 73 | """Save a numpy image to the disk 74 | 75 | Parameters: 76 | image_numpy (numpy array) -- input numpy array 77 | image_path (str) -- the path of the image 78 | """ 79 | image_pil = Image.fromarray(image_numpy) 80 | image_pil.save(image_path) 81 | 82 | 83 | def print_numpy(x, val=True, shp=False): 84 | """Print the mean, min, max, median, std, and size of a numpy array 85 | 86 | Parameters: 87 | val (bool) -- if print the values of the numpy array 88 | shp (bool) -- if print the shape of the numpy array 89 | """ 90 | x = x.astype(np.float64) 91 | if shp: 92 | print('shape,', x.shape) 93 | if val: 94 | x = x.flatten() 95 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 96 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 97 | 98 | 99 | def mkdirs(paths): 100 | """create empty directories if they don't exist 101 | 102 | Parameters: 103 | paths (str list) -- a list of directory paths 104 | """ 105 | if isinstance(paths, list) and not isinstance(paths, str): 106 | for path in paths: 107 | mkdir(path) 108 | else: 109 | mkdir(paths) 110 | 111 | 112 | def mkdir(path): 113 | """create a single empty directory if it didn't exist 114 | 115 | Parameters: 116 | path (str) -- a single directory path 117 | """ 118 | if not os.path.exists(path): 119 | os.makedirs(path) 120 | 121 | 122 | def weights_init_kaiming(m): # initialize the weights (kaiming method) 123 | classname = m.__class__.__name__ 124 | if classname.find('Conv2d') != -1: 125 | init.kaiming_normal_(m.weight.data) 126 | 127 | 128 | def fc_init_weights(m): 129 | if type(m) == nn.Linear: 130 | init.kaiming_normal_(m.weight.data) 131 | 132 | 133 | class VGG16(nn.Module): 134 | def __init__(self, n_inputs=12, numCls=17): # num. of classes 135 | super().__init__() 136 | 137 | vgg = models.vgg16(pretrained=False) 138 | 139 | self.encoder = nn.Sequential( 140 | nn.Conv2d(n_inputs, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), # 12 bands as input (s1 + s2) 141 | *vgg.features[1:] 142 | ) 143 | self.classifier = nn.Sequential( 144 | nn.Linear(8 * 8 * 512, 4096, bias=True), 145 | # 8*8*512: output size from encoder (origin img pixel 256*256-> 5 pooling = 8) 146 | nn.ReLU(inplace=True), 147 | nn.Dropout(), 148 | nn.Linear(4096, 4096, bias=True), 149 | nn.ReLU(inplace=True), 150 | nn.Dropout(), 151 | nn.Linear(4096, numCls, bias=True) 152 | ) 153 | 154 | self.apply(weights_init_kaiming) 155 | self.apply(fc_init_weights) 156 | 157 | self.names = ['conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 158 | 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 159 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 160 | 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 161 | 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'pool5', 162 | 'linear6', 'relu6', 'drop6', 'linear7', 'relu7', 'drop7', 'linear8'] 163 | 164 | def forward(self, x): 165 | x = self.encoder(x) 166 | x = x.view(x.size(0), -1) 167 | logits = self.classifier(x) 168 | 169 | return logits 170 | 171 | 172 | def load_vgg16(f, device): 173 | 174 | input_channels = 13 # number of bands the VGG16 was trained on 175 | classes = 10 # output units (set to number of classes the VGG16 was trained on) 176 | net = VGG16(input_channels, classes) 177 | 178 | ''' 179 | if torch.cuda.is_available(): 180 | state_dict = torch.load(f, map_location=device)['model_state_dict'] 181 | else: 182 | state_dict = torch.load(f, map_location=torch.device('cpu'))['model_state_dict'] 183 | ''' 184 | state_dict = torch.load(f, map_location=device)['model_state_dict'] 185 | 186 | net.load_state_dict(state_dict) 187 | net.to(device) 188 | net.eval() 189 | #net.requires_grad = False 190 | 191 | return net 192 | 193 | 194 | class LossNetwork(nn.Module): 195 | """ 196 | Extract certain feature maps from pretrained VGG model, used for computing perceptual loss 197 | """ 198 | 199 | def __init__(self, f, output_layer, device): 200 | super(LossNetwork, self).__init__() 201 | 202 | self.net = load_vgg16(f, device) 203 | self.output_layer = output_layer 204 | 205 | def forward(self, x): 206 | feature_list = [] 207 | for i, (n, module) in enumerate(self.net._modules.items()): 208 | if n == 'encoder': 209 | for idx in range(len(module)): 210 | x = module[idx](x) 211 | if idx in self.output_layer: 212 | feature_list.append(x) 213 | else: 214 | return feature_list 215 | -------------------------------------------------------------------------------- /util/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sys 4 | import ntpath 5 | import time 6 | from . import util, html 7 | from subprocess import Popen, PIPE 8 | 9 | import tifffile as tif 10 | import matplotlib.pyplot as plt 11 | 12 | if sys.version_info[0] == 2: 13 | VisdomExceptionBase = Exception 14 | else: 15 | VisdomExceptionBase = ConnectionError 16 | 17 | 18 | def save_images(method, webpage, visuals, image_path, aspect_ratio=1.0, width=256, saveTiff=True, savePng=False, image_dir=None): 19 | """Save images to the disk. 20 | 21 | Parameters: 22 | webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details) 23 | visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs 24 | image_path (str) -- the string is used to create image paths 25 | aspect_ratio (float) -- the aspect ratio of saved images 26 | width (int) -- the images will be resized to width x width 27 | 28 | This function will save images stored in 'visuals' to the HTML file specified by 'webpage'. 29 | """ 30 | image_dir = webpage.get_image_dir() if image_dir==None else os.path.join(webpage.get_image_dir(), image_dir) 31 | if not os.path.exists(image_dir): os.makedirs(image_dir) 32 | short_path = ntpath.basename(image_path[0]) 33 | name = os.path.splitext(short_path)[0] 34 | 35 | webpage.add_header(name) 36 | ims, txts, links = [], [], [] 37 | 38 | for label, im_data in visuals.items(): 39 | #if label.split(sep='_')[-1] == 'mask' or label == 'fake_B': 40 | if 'real_A' in label or label == 'real_B' or label == 'fake_B': 41 | image_numpy = util.tensor2im(im_data, method) 42 | image_name = '%s_%s' % (name, label) 43 | img_path = os.path.join(image_dir, image_name) 44 | 45 | if saveTiff: tif.imsave(img_path+'.tiff', image_numpy) 46 | if savePng: plt.imsave(img_path+'.png', image_numpy) 47 | 48 | ims.append(image_name+'.tiff') 49 | txts.append(label) 50 | links.append(image_name+'.tiff') 51 | webpage.add_images(ims, txts, links, width=width) 52 | return name 53 | 54 | 55 | class Visualizer(): 56 | """This class includes several functions that can display/save images and print/save logging information. 57 | 58 | It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images. 59 | """ 60 | 61 | def __init__(self, opt): 62 | """Initialize the Visualizer class 63 | 64 | Parameters: 65 | opt -- stores all the experiment flags; needs to be a subclass of BaseOptions 66 | Step 1: Cache the training/test options 67 | Step 2: connect to a visdom server 68 | Step 3: create an HTML object for saveing HTML filters 69 | Step 4: create a logging file to store training losses 70 | """ 71 | self.opt = opt # cache the option 72 | self.display_id = opt.display_id 73 | self.use_html = opt.isTrain and not opt.no_html 74 | self.win_size = opt.display_winsize 75 | self.name = opt.name 76 | self.port = opt.display_port 77 | self.saved = False 78 | if self.display_id > 0: # connect to a visdom server given and 79 | import visdom 80 | self.ncols = opt.display_ncols 81 | self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env) 82 | if not self.vis.check_connection(): 83 | self.create_visdom_connections() 84 | 85 | if self.use_html: # create an HTML object at /web/; images will be saved under /web/images/ 86 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 87 | self.img_dir = os.path.join(self.web_dir, 'images') 88 | print('create web directory %s...' % self.web_dir) 89 | util.mkdirs([self.web_dir, self.img_dir]) 90 | # create a logging file to store training losses 91 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 92 | with open(self.log_name, "a") as log_file: 93 | now = time.strftime("%c") 94 | log_file.write('================ Training Loss (%s) ================\n' % now) 95 | 96 | def reset(self): 97 | """Reset the self.saved status""" 98 | self.saved = False 99 | 100 | def create_visdom_connections(self): 101 | """If the program could not connect to Visdom server, this function will start a new server at port < self.port > """ 102 | cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port 103 | print('\n\nCould not connect to Visdom server. \n Trying to start a server....') 104 | print('Command: %s' % cmd) 105 | Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE) 106 | 107 | # # method is either 'resnet' (if opt.alter_initial_mode) or 'default' 108 | def display_current_results(self, visuals, epoch, total_iter, save_result, method): 109 | """Display current results on visdom; save current results to an HTML file. 110 | 111 | Parameters: 112 | visuals (OrderedDict) - - dictionary of images to display or save 113 | epoch (int) - - the current epoch 114 | save_result (bool) - - if save the current results to an HTML file 115 | """ 116 | if self.display_id > 0: # show images in the browser using visdom 117 | ncols = self.ncols 118 | if ncols > 0: # show all the images in one visdom panel 119 | ncols = min(ncols, len(visuals)) 120 | h, w = next(iter(visuals.values())).shape[:2] 121 | table_css = """""" % (w, h) # create a table css 125 | # create a table of images. 126 | title = self.name 127 | label_html = '' 128 | label_html_row = '' 129 | images = [] 130 | idx = 0 131 | for label, image in visuals.items(): 132 | image_numpy = util.tensor2im(image, method) 133 | label_html_row += '%s' % label 134 | images.append(image_numpy.transpose([2, 0, 1])) 135 | idx += 1 136 | if idx % ncols == 0: 137 | label_html += '%s' % label_html_row 138 | label_html_row = '' 139 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255 140 | while idx % ncols != 0: 141 | images.append(white_image) 142 | label_html_row += '' 143 | idx += 1 144 | if label_html_row != '': 145 | label_html += '%s' % label_html_row 146 | try: 147 | self.vis.images(images, nrow=ncols, win=self.display_id + 1, 148 | padding=2, opts=dict(title=title + ' images')) 149 | label_html = '%s
' % label_html 150 | self.vis.text(table_css + label_html, win=self.display_id + 2, 151 | opts=dict(title=title + ' labels')) 152 | except VisdomExceptionBase: 153 | self.create_visdom_connections() 154 | 155 | else: # show each image in a separate visdom panel; 156 | idx = 1 157 | try: 158 | for label, image in visuals.items(): 159 | image_numpy = util.tensor2im(image, method) 160 | self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), 161 | win=self.display_id + idx) 162 | idx += 1 163 | except VisdomExceptionBase: 164 | self.create_visdom_connections() 165 | 166 | if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved. 167 | self.saved = True 168 | # save images to the disk 169 | for label, image in visuals.items(): 170 | image_numpy = util.tensor2im(image, method) 171 | img_path = os.path.join(self.img_dir, 'epoch%.3d_it%.3d_%s.png' % (epoch, total_iter, label)) 172 | util.save_image(image_numpy, img_path) 173 | 174 | ''' 175 | # update website 176 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1) 177 | for n in range(epoch, 0, -1): 178 | webpage.add_header('epoch [%d]' % n) 179 | ims, txts, links = [], [], [] 180 | 181 | for label, image_numpy in visuals.items(): 182 | image_numpy = util.tensor2im(image) 183 | img_path = 'epoch%.3d_%s.png' % (n, label) 184 | ims.append(img_path) 185 | txts.append(label) 186 | links.append(img_path) 187 | webpage.add_images(ims, txts, links, width=self.win_size) 188 | webpage.save() 189 | ''' 190 | 191 | def plot_current_losses(self, epoch, counter_ratio, losses): 192 | """display the current losses on visdom display: dictionary of error labels and values 193 | 194 | Parameters: 195 | epoch (int) -- current epoch 196 | counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1 197 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs 198 | """ 199 | if not hasattr(self, 'plot_data'): 200 | self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())} 201 | self.plot_data['X'].append(epoch + counter_ratio) 202 | self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']]) 203 | try: 204 | self.vis.line( 205 | X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), 206 | Y=np.array(self.plot_data['Y']), 207 | opts={ 208 | 'title': self.name + ' loss over time', 209 | 'legend': self.plot_data['legend'], 210 | 'xlabel': 'epoch', 211 | 'ylabel': 'loss'}, 212 | win=self.display_id) 213 | except VisdomExceptionBase: 214 | self.create_visdom_connections() 215 | 216 | # losses: same format as |losses| of plot_current_losses 217 | def print_current_losses(self, epoch, iters, losses, t_comp, t_data): 218 | """print current losses on console; also save the losses to the disk 219 | 220 | Parameters: 221 | epoch (int) -- current epoch 222 | iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) 223 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs 224 | t_comp (float) -- computational time per data point (normalized by batch_size) 225 | t_data (float) -- data loading time per data point (normalized by batch_size) 226 | """ 227 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data) 228 | for k, v in losses.items(): 229 | message += '%s: %.3f ' % (k, v) 230 | 231 | print(message) # print the message 232 | with open(self.log_name, "a") as log_file: 233 | log_file.write('%s\n' % message) # save the message 234 | --------------------------------------------------------------------------------