├── .gitignore ├── Colorization ├── __init__.py ├── config │ └── configuration.json ├── dataset │ ├── README.md │ ├── __init__.py │ ├── big_earth_utils.py │ ├── dataset_big_earth.py │ ├── dataset_big_earth_torch.py │ ├── patches_with_cloud_and_shadow.csv │ ├── patches_with_seasonal_snow.csv │ └── quantiles_3000.json ├── job_config.py ├── losses │ ├── __init__.py │ └── loss.py ├── main.py ├── models │ ├── Decoder.py │ ├── Decoder_utils.py │ ├── Resnet18.py │ ├── Resnet50.py │ └── __init__.py ├── test.py ├── train.py └── utils.py ├── LICENSE ├── Multi_label_classification ├── __init__.py ├── config │ ├── __init__.py │ └── configuration.json ├── dataset │ ├── README.md │ ├── __init__.py │ ├── change_labels.py │ ├── dataset_big_earth_mlc.py │ └── dataset_big_earth_torch_mlc.py ├── job_config.py ├── main.py ├── main_ensemble.py ├── metrics │ ├── __init__.py │ └── metric.py ├── models │ ├── Ensemble.py │ ├── ResnetMLC.py │ └── __init__.py ├── test.py └── train.py ├── README.md └── colorization_framework-1.png /.gitignore: -------------------------------------------------------------------------------- 1 | Colorization/dataset/BigEarthNet_all_refactored_no_clouds_and_snow_v2_new_path.csv 2 | Colorization/dataset/BigEarthNet_all_refactored.csv 3 | Multi_label_classification/dataset/BigEarthNet_all_refactored_no_clouds_and_snow_v2_new_labels.csv 4 | Multi_label_classification/dataset/BigEarthNet_all_refactored_no_clouds_and_snow_v2_new_path_new_labels.csv 5 | -------------------------------------------------------------------------------- /Colorization/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevinc/TheColorOutOfSpace/f8451a6b4d27ac98937eb21da6188e41f6cd254a/Colorization/__init__.py -------------------------------------------------------------------------------- /Colorization/config/configuration.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size": 32, 3 | "img_size": 128, 4 | "dataset_nsamples": 150000, 5 | "epochs": 50, 6 | "seed": 42, 7 | "test_split": 0.2, 8 | "val_split": 0.4, 9 | 10 | 11 | "path_nas": "/nas/softechict-nas-2/svincenzi/colorization_resnet/experiments/", 12 | "dataset": "dataset/BigEarthNet_all_refactored_no_clouds_and_snow_v2_new_path.csv", 13 | "log_interval": 100, 14 | "log_dir": "Debug", 15 | "save_checkpoint": 1, 16 | "load_checkpoint": 0, 17 | "path_model_dict": "/nas/softechict-nas-2/svincenzi/colorization_resnet/experiments/", 18 | 19 | "num_workers": 4, 20 | "lr": 0.01, 21 | "momentum": 0.9, 22 | "weight_decay": 1e-5, 23 | "val_step": 10, 24 | "input_channels": 9, 25 | "out_channels": 2, 26 | 27 | "optim": "SGD", 28 | "scheduler": 1, 29 | "sched_step": 40, 30 | "sched_type": "step", 31 | "sched_milestones": [10, 60, 80], 32 | 33 | 34 | "loss": "L1", 35 | "grad_loss": 1, 36 | "weight_grad_loss": 0.01, 37 | "weight_rec_loss": 10, 38 | 39 | "backbone": 18, 40 | "decoder_version": 18, 41 | "pretrained": 0, 42 | "augmentation": 1, 43 | "dropout": 0.3 44 | } -------------------------------------------------------------------------------- /Colorization/dataset/README.md: -------------------------------------------------------------------------------- 1 | # BigEarthNet Dataset 2 | BigEarthNet is a new large-scale Sentinel-2 benchmark archive, consisting of 590,236 Sentinel-2 image patches. Can be downloaded [here](http://bigearth.net/), together with the files 3 | ``Image patches with seasonal snow`` and ``Image patches with cloud & shadow``. 4 | ## Create csv file of the dataset 5 | Once downloaded the files, the first step consist in the creation of a csv file containing the paths for all the images. To do that, run the following code: 6 | ``` 7 | python big_earth_utils.py --big_e_path [path to BigEarthNet dataset] --num_samples [-1 for all the images] --csv_filename [name of the file] --mode [csv_creation] 8 | ``` 9 | Since this operation may take a while, as alternative I loaded on drive [here](https://drive.google.com/drive/folders/19MsGGVveafgS5IG1A61brAoxsjCCBg3k?usp=sharing) my csv file, 10 | to change the paths run the code below. 11 | ``` 12 | python big_earth_utils.py --big_e_path [path to BigEarthNet dataset] --csv_filename [name of the file] --mode [replace_path_csv] --new_path_csv [your path] 13 | ``` 14 | 15 | The BigEarthNet dataset contains also images covered by snow or cloud, to remove this latters run: 16 | ``` 17 | python big_earth_utils.py --csv_filename [name of the file] --mode [delete_patches_v2] 18 | ``` 19 | ## Quantiles 20 | To conclude the pre-processing stage, you need to calculate the min and max quantile for each different band, to put a threshold on eventual too high or too low pixel values. 21 | I already loaded the file ``quantiles.json``, to change the number of samples used or recalculate the values run the ``big_earth_utils.py`` file with ``--mode quantiles``. 22 | 23 | ## BigEarth dataset vs BigEarth dataset torch version 24 | The .csv file and the quantiles.json created above are exploited in the two files: ``dataset_big_earth.py`` and ``dataset_big_earth_torch.py``. The torch version was created to speed up 25 | the training process, as it allows you to load only one tensor at a time instead of 12 .tif bands. 26 | The creation of the tensor takes long time (so eventually run it in a tmux session) and can be done with the following command: 27 | ``` 28 | python dataset_big_earth.py --csv_filename [name of the file] --n_samples [Number of samples to use] --create_torch_dataset [1] 29 | ``` 30 | 31 | 32 | 33 | ## Credits 34 | *G. Sumbul, M. Charfuelan, B. Demir, V. Markl, "BigEarthNet: A Large-Scale Benchmark Archive for Remote Sensing Image Understanding", IEEE International Geoscience and Remote Sensing Symposium, pp. 5901-5904, Yokohama, Japan, 2019.* 35 | -------------------------------------------------------------------------------- /Colorization/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevinc/TheColorOutOfSpace/f8451a6b4d27ac98937eb21da6188e41f6cd254a/Colorization/dataset/__init__.py -------------------------------------------------------------------------------- /Colorization/dataset/big_earth_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import glob 4 | import json 5 | import time 6 | from pathlib import Path 7 | 8 | import cv2 9 | import numpy as np 10 | import pandas as pd 11 | 12 | Land_cover_Classes = { 13 | 'Mixed forest': 0, 14 | 'Coniferous forest': 1, 15 | 'Non-irrigated arable land': 2, 16 | 'Transitional woodland/shrub': 3, 17 | 'Broad-leaved forest': 4, 18 | 'Land principally occupied by agriculture, with significant areas of natural vegetation': 5, 19 | 'Complex cultivation patterns': 6, 20 | 'Pastures': 7, 21 | 'Water bodies': 8, 22 | 'Sea and ocean': 9, 23 | 'Discontinuous urban fabric':10, 24 | 'Agro-forestry areas': 11, 25 | 'Peatbogs': 12, 26 | 'Permanently irrigated land': 13, 27 | 'Industrial or commercial units': 14, 28 | 'Natural grassland': 15, 29 | 'Olive groves': 16, 30 | 'Sclerophyllous vegetation': 17, 31 | 'Continuous urban fabric': 18, 32 | 'Water courses': 19, 33 | 'Vineyards': 20, 34 | 'Annual crops associated with permanent crops': 21, 35 | 'Inland marshes': 22, 36 | 'Moors and heathland': 23, 37 | 'Sport and leisure facilities': 24, 38 | 'Fruit trees and berry plantations': 25, 39 | 'Mineral extraction sites': 26, 40 | 'Rice fields': 27, 41 | 'Road and rail networks and associated land': 28, 42 | 'Bare rock': 29, 43 | 'Green urban areas': 30, 44 | 'Beaches, dunes, sands': 31, 45 | 'Sparsely vegetated areas': 32, 46 | 'Salt marshes': 33, 47 | 'Coastal lagoons': 34, 48 | 'Construction sites': 35, 49 | 'Estuaries': 36, 50 | 'Intertidal flats': 37, 51 | 'Airports': 38, 52 | 'Dump sites': 39, 53 | 'Port areas': 40, 54 | 'Salines': 41, 55 | 'Burnt areas': 42 56 | } 57 | 58 | 59 | class BigEarthUtils: 60 | def __init__(self): 61 | pass 62 | 63 | @staticmethod 64 | def big_earth_to_csv(big_e_path: str, num_samples: int, csv_filename: str) -> True: 65 | """ 66 | Function which generate the csv file of all or a portion of the BigEarth dataset 67 | :param big_e_path: path to BigEarth dataset 68 | :param num_samples: number of samples to consider in the creation of the csv file (-1 to select all dataset) 69 | :param csv_filename: name of the created file 70 | :return: True 71 | """ 72 | path = Path(big_e_path) 73 | print("collecting dirs...") 74 | start_time = time.time() 75 | labels_names = [] 76 | labels_values = [] 77 | if num_samples == -1: 78 | dirs = [str(e) for e in path.iterdir() if e.is_dir()] 79 | else: 80 | # zip and range() to choose only a specific number of example 81 | dirs = [str(e) for _, e in zip(range(num_samples), path.iterdir()) if e.is_dir()] 82 | for idx, d in enumerate(dirs): 83 | for e in glob.glob(d + "/*.json"): 84 | with open(e) as f: 85 | j_file = json.load(f) 86 | labels_names.append(j_file['labels']) 87 | labels_values.append([Land_cover_Classes[label] for label in j_file['labels']]) 88 | # write the dirs on a csv file 89 | print("writing on csv...") 90 | things_to_write = zip(dirs, labels_names, labels_values) 91 | with open(csv_filename, "w") as f: 92 | writer = csv.writer(f) 93 | writer.writerows(things_to_write) 94 | print(f"finishing in : {time.time() - start_time}") 95 | return True 96 | 97 | @staticmethod 98 | def min_max_quantile(csv_filename: str, n_samples: int) -> dict: 99 | """ 100 | Function that compute the 101 | :param csv_filename: path of the csv_filename of the BigEarth dataset 102 | :param n_samples: number of samples to use for calculate the min and max quantile 103 | :return: a dict containing min and max quantiles for every sentinel-2 bands 104 | """ 105 | bands = ["B01", "B02", "B03", "B04", "B05", "B06", "B07", "B08", "B8A", "B09", "B11", "B12"] 106 | data = pd.read_csv(csv_filename, header=None) 107 | paths = data.iloc[:, 0].tolist() 108 | quantiles = {} 109 | for b in bands: 110 | imgs = [] 111 | for i in range(n_samples): 112 | path = paths[i] # i choose the i-th path of the list 113 | for filename in glob.iglob(path + "/*" + b + ".tif"): 114 | img = cv2.imread(filename, cv2.IMREAD_UNCHANGED) 115 | imgs.append(img) 116 | imgs = np.stack(imgs, axis=0).reshape(-1) 117 | quantiles[b] = { 118 | 'min_q': np.quantile(imgs, 0.02), 119 | 'max_q': np.quantile(imgs, 0.98) 120 | } 121 | print(b, quantiles[b]) 122 | return quantiles 123 | 124 | @staticmethod 125 | def save_dict_to_json(d: dict, json_path: str) -> None: 126 | with open(json_path, 'w') as f: 127 | json.dump(d, f, indent=4) 128 | 129 | @staticmethod 130 | def delete_patches(csv_filename: str) -> True: 131 | """ 132 | Function which delete images covered by cloud or snow 133 | :param csv_filename: Dataset file created above 134 | :return: True 135 | """ 136 | csv_snow_patches = 'patches_with_seasonal_snow.csv' 137 | csv_clouds_patches = 'patches_with_cloud_and_shadow.csv' 138 | data = pd.read_csv(csv_filename, header=None) 139 | snow_patches = pd.read_csv(Path.cwd() / csv_snow_patches, header=None) 140 | clouds_patches = pd.read_csv(Path.cwd() / csv_clouds_patches, header=None) 141 | patches = snow_patches.iloc[:, 0].tolist() + clouds_patches.iloc[:, 0].tolist() 142 | df = data[~data.iloc[:, 0].str.contains('|'.join(patches))] 143 | df.to_csv(csv_filename[:-4] + '_no_clouds_and_snow_server' + csv_filename[-4:], header=None, index=False) 144 | return True 145 | 146 | @staticmethod 147 | def delete_patches_v2(csv_filename: str) -> True: 148 | """ 149 | Function which delete images covered by cloud or snow 150 | :param csv_filename: Dataset file created above 151 | :return: True 152 | """ 153 | data = pd.read_csv(csv_filename, header=None) 154 | data_copy = data.copy() 155 | data_copy = data_copy.replace({"/nas/softechict-nas-2/svincenzi/BigEarthNet-v1.0/": ""}, regex=True) 156 | csv_snow_patches = 'patches_with_seasonal_snow.csv' 157 | csv_clouds_patches = 'patches_with_cloud_and_shadow.csv' 158 | snow_patches = pd.read_csv(Path.cwd() / csv_snow_patches, header=None) 159 | clouds_patches = pd.read_csv(Path.cwd() / csv_clouds_patches, header=None) 160 | patches = snow_patches.iloc[:, 0].tolist() + clouds_patches.iloc[:, 0].tolist() 161 | data = data[~data_copy.iloc[:, 0].isin(patches)] 162 | data.to_csv(csv_filename[:-4] + '_no_clouds_and_snow_v2' + csv_filename[-4:], header=None, index=False) 163 | return True 164 | 165 | @staticmethod 166 | def replace_path_csv(csv_filename: str, new_path: str) -> True: 167 | """ 168 | function that change the path in the csv file 169 | :param csv_filename: Dataset file 170 | :param new_path: new path to set in the csv file 171 | :return: True 172 | """ 173 | data = pd.read_csv(csv_filename, header=None) 174 | data = data.replace({"/nas/softechict-nas-2/svincenzi/BigEarthNet-v1.0/": new_path}, regex=True) 175 | data.to_csv(csv_filename[:-4] + '_new_path' + csv_filename[-4:], header=None, index=False) 176 | return True 177 | 178 | 179 | if __name__ == '__main__': 180 | argparser = argparse.ArgumentParser(description='BigEarthNet utils') 181 | argparser.add_argument('--big_e_path', type=str, default=None, required=True, help='path to the BigEarth dataset') 182 | argparser.add_argument('--num_samples', type=int, default=-1, help='Number of samples to create the csv file') 183 | argparser.add_argument('--csv_filename', type=str, default='BigEarth.csv', help='Name of the csv dataset file') 184 | argparser.add_argument('--n_samples', type=int, default=3000, help='Number of samples to calculate the min-max quantile') 185 | argparser.add_argument('--mode', default='csv_creation', choices=['csv_creation', 'delete_patches', 'delete_patches_v2', 186 | 'quantiles', 'replace_path_csv'], 187 | type=str, help='select the action to perform: csv_creation, delete_patches, ' 188 | 'delete_patches_v2, quantiles or replace_path_csv') 189 | argparser.add_argument('--new_path_csv', type=str, default=None, help='indicate the new path to change the csv') 190 | 191 | args = argparser.parse_args() 192 | # csv creation 193 | if args.mode == 'csv_creation': 194 | BigEarthUtils.big_earth_to_csv(args.big_e_path, args.num_samples, args.csv_filename) 195 | # delete patches 196 | elif args.mode == 'delete_patches': 197 | BigEarthUtils.delete_patches(args.csv_filename) 198 | # delete patches_v2 199 | elif args.mode == 'delete_patches_v2': 200 | BigEarthUtils.delete_patches_v2(args.csv_filename) 201 | # min-max quantiles 202 | elif args.mode == 'quantiles': 203 | quantiles = BigEarthUtils.min_max_quantile(args.csv_filename, args.n_samples) 204 | # save the quantiles on a json file 205 | BigEarthUtils.save_dict_to_json(quantiles, f"quantiles_{args.n_samples}.json") 206 | # replace csv path 207 | elif args.mode == 'replace_path_csv': 208 | BigEarthUtils.replace_path_csv(args.csv_filename, args.new_path_csv) 209 | 210 | 211 | -------------------------------------------------------------------------------- /Colorization/dataset/dataset_big_earth.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import json 4 | import time 5 | from pathlib import Path 6 | from typing import Tuple, Union 7 | 8 | import cv2 9 | import numpy as np 10 | import pandas as pd 11 | import torch 12 | from torch.utils.data import SubsetRandomSampler 13 | from torch.utils.data.dataset import Dataset 14 | from torchvision import transforms 15 | 16 | 17 | # "bands": ["B01", "B02", "B03", "B04", "B05", "B06", "B07", "B08", "B8A", "B09", "B11", "B12"] 18 | def torch2numpy(tensor: torch.Tensor) -> np.ndarray: 19 | """ 20 | function to translate torch tensor in np array, remove batch size 1 21 | :param tensor: input torch tensor 22 | :return: ndarray 23 | """ 24 | tensor = torch.squeeze(tensor, dim=0) 25 | return np.transpose(tensor.numpy(), (1, 2, 0)) 26 | 27 | 28 | def load_dict_from_json(json_path: str) -> dict: 29 | """ 30 | function to load json 31 | :param json_path: path to the file 32 | :return: dict 33 | """ 34 | with open(json_path) as f: 35 | params = json.load(f) 36 | return params 37 | 38 | 39 | class Color: 40 | def __init__(self): 41 | pass 42 | 43 | @staticmethod 44 | def rgb2gray(rgb: np.ndarray) -> np.ndarray: 45 | """ 46 | function which convert an RGB image to grayscale 47 | :param rgb: input image 48 | :return: grayscale image 49 | """ 50 | return np.reshape(np.dot(rgb[..., :3], [0.2989, 0.5870, 0.1140]), 51 | (rgb.shape[0], rgb.shape[1], 1)).astype(np.float32) 52 | 53 | @staticmethod 54 | def rgb2lab(rgb: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 55 | """ 56 | function which convert an rgb image to lab, normalized between [0-1] 57 | :param rgb: input image 58 | :return: converted image 59 | """ 60 | lab_img = cv2.cvtColor(rgb, cv2.COLOR_RGB2LAB) 61 | lab_img[:, :, 0] *= 255 / 100 62 | lab_img[:, :, 1] += 128 63 | lab_img[:, :, 2] += 128 64 | lab_img /= 255 65 | return lab_img[:, :, 0], lab_img[:, :, 1:] 66 | 67 | @staticmethod 68 | def lab2rgb(L: np.ndarray, ab: np.ndarray) -> np.ndarray: 69 | """ 70 | function which convert a lab image to rgb, normalized between [0-1] 71 | :param L: input channel 72 | :param ab: input channels 73 | :return: converted image 74 | """ 75 | L = L.cpu().numpy() 76 | ab = ab.cpu().detach().numpy() 77 | Lab = np.concatenate((L, ab), axis=1) 78 | Lab = np.transpose(Lab, (0, 2, 3, 1)) 79 | B, W, H, C = Lab.shape[0], Lab.shape[1], Lab.shape[2], Lab.shape[3] 80 | # reshape to convert all the images in the batch without iteration 81 | Lab = np.reshape(Lab, (B * W, H, C)) 82 | Lab *= 255 83 | Lab[:, :, 0] *= 100 / 255 84 | Lab[:, :, 1] -= 128 85 | Lab[:, :, 2] -= 128 86 | rgb = cv2.cvtColor(Lab, cv2.COLOR_LAB2RGB) 87 | rgb = np.reshape(rgb, (B, W, H, C)) 88 | rgb = np.transpose(rgb, (0, 3, 1, 2)) 89 | rgb = torch.from_numpy(rgb) 90 | return rgb 91 | 92 | 93 | class BigEarthDataset(Dataset): 94 | def __init__(self, csv_path: str, quantiles: str, random_seed: int, bands: list, 95 | create_torch_dataset=0, n_samples=100000): 96 | """ 97 | Args: 98 | csv_path: path to csv file containing paths to images 99 | quantiles: path to json file containing quantiles of each bands 100 | random_seed: seed value 101 | bands: list of the bands to consider for the training 102 | n_samples: number of samples to exploit for the training 103 | """ 104 | # Transforms 105 | self.to_tensor = transforms.ToTensor() 106 | # Load dataset 107 | self.folder_path, self.labels_name, self.labels_class = self.load_dataset(csv_path, random_seed, n_samples) 108 | # Calculate len 109 | self.data_len = len(self.folder_path) 110 | print("Dataset len: ", self.data_len) 111 | # load quantiles json file 112 | self.quantiles = load_dict_from_json(quantiles) 113 | # bands 114 | self.bands = bands 115 | # flag for create dataset torch version 116 | self.create_torch_dataset = create_torch_dataset 117 | 118 | @staticmethod 119 | def load_dataset(csv_path: str, random_seed: int, n_samples: int) -> Tuple[list, list, list]: 120 | """ 121 | function to load the dataset from the csv path 122 | :param csv_path: path to the csv file 123 | :param random_seed: seed 124 | :param n_samples: n_samples to considerer for the training 125 | :return: list of paths to the images 126 | """ 127 | # Read the csv file 128 | data_info = pd.read_csv(csv_path, header=None) 129 | # First column contains the folder paths 130 | folder_path = data_info.iloc[:n_samples, 0].tolist() 131 | # Second column contains the text labels 132 | labels_name = data_info.iloc[:n_samples, 1].tolist() 133 | # Third column contains the number labels 134 | labels_class = data_info.iloc[:n_samples, 2].tolist() 135 | # shuffle the entries, specify the seed 136 | tmp_shuffle = list(zip(folder_path, labels_name, labels_class)) 137 | np.random.seed(random_seed) 138 | np.random.shuffle(tmp_shuffle) 139 | folder_path, labels_name, labels_class = zip(*tmp_shuffle) 140 | # for the colorization version return only the image paths 141 | return folder_path, labels_name, labels_class 142 | 143 | def split_dataset(self, thresh_test: float, thresh_val: float) -> Union[Tuple[list, list, list], Tuple[list, list]]: 144 | """ 145 | :param thresh_test: threshold for splitting the dataset in training and test set 146 | :param thresh_val: threshold for splitting the dataset in training, test and val set 147 | :return: the two split (or three, if I add the validation set) 148 | """ 149 | indices = list(range(self.data_len)) 150 | split_test = int(np.floor(thresh_test * self.data_len)) 151 | if thresh_val is not None: 152 | split_val = int(np.floor(thresh_val * self.data_len)) 153 | return indices[split_val:], indices[split_test:split_val], indices[:split_test] 154 | else: 155 | return indices[split_test:], indices[:split_test] 156 | 157 | def quantiles_std(self, img: np.ndarray, band: list, quantiles: dict) -> np.ndarray: 158 | """ 159 | function that normalize the input bands to [0-1] 160 | :param img: input image 161 | :param band: list of bands 162 | :param quantiles: dict containing the min-max quantiles for each band 163 | :return: normalized image 164 | """ 165 | min_q = quantiles[band]['min_q'] 166 | max_q = quantiles[band]['max_q'] 167 | img[img < min_q] = min_q 168 | img[img > max_q] = max_q 169 | img_dest = np.zeros_like(img) 170 | img_dest = cv2.normalize(img, img_dest, 0, 255, cv2.NORM_MINMAX) 171 | img_dest = img_dest.astype(np.float32) / 255. 172 | return img_dest 173 | 174 | def custom_loader(self, path: str, band: list, quantiles: dict) -> np.ndarray: 175 | """ 176 | function to open the image 177 | :param path: path to the image to load 178 | :param band: list of bands 179 | :param quantiles: dict containing the min-max quantiles for each band 180 | :return: ndarray image resized 181 | """ 182 | # read the band as it is, with IMREAD_UNCHANGED 183 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED) 184 | img = cv2.resize(img, dsize=(128, 128), interpolation=cv2.INTER_CUBIC) 185 | img = self.quantiles_std(img, band, quantiles) 186 | w, h = img.shape 187 | return img.reshape(w, h, 1) 188 | 189 | # split the bands between rgb and all the others 190 | def split_bands(self, spectral_img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 191 | """ 192 | function to split the bands between rgb and other bands 193 | :param spectral_img: complete image 194 | :return: splitted image 195 | """ 196 | indices = [0, 4, 5, 6, 7, 8, 9, 10, 11] 197 | indices_rgb = [3, 2, 1] 198 | spectral_bands = np.take(spectral_img, indices=indices, axis=2) 199 | rgb = np.take(spectral_img, indices=indices_rgb, axis=2) 200 | return spectral_bands, rgb 201 | 202 | def save_torch_dataset(self, imgs_file: str, spectral_img: np.ndarray) -> True: 203 | """ 204 | function to save a torch version of the dataset 205 | :param imgs_file: path to the current image 206 | :param spectral_img: current image 207 | :return: true 208 | """ 209 | parts = list(Path(imgs_file).parts) 210 | parts[4] = 'BigEarthNet_torch_version_v2' 211 | new_path = Path(*parts) 212 | new_path.mkdir(parents=True, exist_ok=True) 213 | torch.save(self.to_tensor(spectral_img), new_path / 'all_bands.pt') 214 | return True 215 | 216 | def __getitem__(self, index: int) -> Tuple[torch.tensor, torch.tensor, torch.tensor]: 217 | # obtain the right folder 218 | imgs_file = self.folder_path[index] # [2:-2] # to remove [\ \] 219 | imgs_bands = [] 220 | for b in self.bands: 221 | for filename in glob.iglob(imgs_file+"/*" + b + ".tif"): 222 | band = self.custom_loader(filename, b, self.quantiles) 223 | imgs_bands.append(band) 224 | spectral_img = np.concatenate(imgs_bands, axis=2) 225 | if self.create_torch_dataset: 226 | self.save_torch_dataset(imgs_file, spectral_img) 227 | spectral_bands, rgb = self.split_bands(spectral_img) 228 | L, ab = Color.rgb2lab(rgb) 229 | return self.to_tensor(spectral_bands), self.to_tensor(L), self.to_tensor(ab) 230 | 231 | def __len__(self) -> int: 232 | return self.data_len 233 | 234 | 235 | if __name__ == "__main__": 236 | argparser = argparse.ArgumentParser(description='BigEarthNet dataset tiff version') 237 | argparser.add_argument('--csv_filename', type=str, default='BigEarthNet_all_refactored_no_clouds_and_snow_server.csv', 238 | required=True, help='csv containing dataset paths') 239 | argparser.add_argument('--n_samples', type=int, default=3000, help='Number of samples to create the csv file') 240 | argparser.add_argument('--create_torch_dataset', type=int, default=0, choices=[0, 1], help='set 1 to create torch' 241 | 'version and 0 to use the ' 242 | 'training dataset') 243 | args = argparser.parse_args() 244 | # Dataset definition 245 | big_earth = BigEarthDataset(csv_path=args.csv_filename, quantiles='quantiles_3000.json', 246 | random_seed=19, bands=["B01", "B02", "B03", "B04", "B05", "B06", "B07", "B08", "B8A", "B09", "B11", "B12"], 247 | create_torch_dataset=args.create_torch_dataset, n_samples=args.n_samples) 248 | # dataset split 249 | train_idx, val_idx, test_idx = big_earth.split_dataset(0.2, 0.4) 250 | # dataset sampler 251 | train_sampler = SubsetRandomSampler(train_idx) 252 | val_sampler = SubsetRandomSampler(val_idx) 253 | test_sampler = SubsetRandomSampler(test_idx) 254 | # dataset loader 255 | train_loader = torch.utils.data.DataLoader(big_earth, batch_size=16, 256 | sampler=train_sampler, num_workers=4) 257 | test_loader = torch.utils.data.DataLoader(big_earth, batch_size=16, 258 | sampler=test_sampler, num_workers=4) 259 | val_loader = torch.utils.data.DataLoader(big_earth, batch_size=16, 260 | sampler=val_sampler, num_workers=4) 261 | start_time = time.time() 262 | 263 | for idx, (spectral_img, L, ab) in enumerate(train_loader): 264 | print(idx) 265 | 266 | for idx, (spectral_img, L, ab) in enumerate(test_loader): 267 | print(idx) 268 | 269 | for idx, (spectral_img, L, ab) in enumerate(val_loader): 270 | print(idx) 271 | 272 | print("time: ", time.time() - start_time) -------------------------------------------------------------------------------- /Colorization/dataset/dataset_big_earth_torch.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | from typing import Tuple, Union 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | import torchvision.transforms.functional as TF 9 | from torch import nn 10 | from torch.utils.data import SubsetRandomSampler, SequentialSampler 11 | from torch.utils.data.dataset import Dataset 12 | from torchvision import transforms 13 | 14 | from Colorization.dataset.dataset_big_earth import Color 15 | import albumentations as A 16 | from matplotlib import pyplot as plt 17 | 18 | 19 | class BigEarthDatasetTorch(Dataset): 20 | def __init__(self, csv_path: str, random_seed: int, bands_indices: list, img_size: int, 21 | augmentation: int, n_samples=100000): 22 | """ 23 | Args: 24 | csv_path: path to csv file containing paths to images 25 | quantiles: path to json file containing quantiles of each bands 26 | random_seed: seed value 27 | bands: list of the bands to consider for the training 28 | n_samples: number of samples to exploit for the training 29 | """ 30 | # Transforms 31 | self.to_tensor = transforms.ToTensor() 32 | # Load dataset 33 | self.folder_path, self.labels_name, self.labels_class = self.load_dataset(csv_path, random_seed, n_samples) 34 | # Calculate len 35 | self.data_len = len(self.folder_path) 36 | print("Dataset len: ", self.data_len) 37 | # image size 38 | self.img_size = img_size 39 | # bands 40 | self.bands_indices = torch.BoolTensor(bands_indices) 41 | # augmentation 42 | self.augmentation = augmentation 43 | 44 | @staticmethod 45 | def load_dataset(csv_path: str, random_seed: int, n_samples: int) -> Tuple[list, list, list]: 46 | """ 47 | function to load the dataset from the csv path 48 | :param csv_path: path to the csv file 49 | :param random_seed: seed 50 | :param n_samples: n_samples to considerer for the training 51 | :return: list of paths to the images 52 | """ 53 | # Read the csv file 54 | data_info = pd.read_csv(csv_path, header=None) 55 | # First column contains the folder paths 56 | folder_path = data_info.iloc[:n_samples, 0].tolist() 57 | # Second column contains the text labels 58 | labels_name = data_info.iloc[:n_samples, 1].tolist() 59 | # Third column contains the number labels 60 | labels_class = data_info.iloc[:n_samples, 2].tolist() 61 | # shuffle the entries, specify the seed 62 | tmp_shuffle = list(zip(folder_path, labels_name, labels_class)) 63 | np.random.seed(random_seed) 64 | np.random.shuffle(tmp_shuffle) 65 | folder_path, labels_name, labels_class = zip(*tmp_shuffle) 66 | # for the colorization version return only the image paths 67 | return folder_path, labels_name, labels_class 68 | 69 | def split_dataset(self, thresh_test: float, thresh_val: float) -> Union[Tuple[list, list, list], Tuple[list, list]]: 70 | """ 71 | :param thresh_test: threshold for splitting the dataset in training and test set 72 | :param thresh_val: threshold for splitting the dataset in training, test and val set 73 | :return: the two split (or three, if I add the validation set) 74 | """ 75 | indices = list(range(self.data_len)) 76 | split_test = int(np.floor(thresh_test * self.data_len)) 77 | if thresh_val is not None: 78 | split_val = int(np.floor(thresh_val * self.data_len)) 79 | return indices[split_val:], indices[split_test:split_val], indices[:split_test] 80 | else: 81 | return indices[split_test:], indices[:split_test] 82 | 83 | @staticmethod 84 | def split_bands(spectral_img: torch.tensor) -> Tuple[torch.tensor, torch.tensor]: 85 | """ 86 | function to split the bands between rgb and other bands 87 | :param spectral_img: complete image 88 | :param bands_indices: indices of the spectral bands to keep 89 | :return: splitted image 90 | """ 91 | indices = torch.tensor([0, 4, 5, 6, 7, 8, 9, 10, 11]) 92 | indices_rgb = torch.tensor([3, 2, 1]) 93 | rgb = torch.index_select(input=spectral_img, dim=0, index=indices_rgb) 94 | spectral_bands = torch.index_select(input=spectral_img, dim=0, index=indices) 95 | return spectral_bands, rgb 96 | 97 | @staticmethod 98 | def augmentation_fn(images: torch.Tensor) -> torch.Tensor: 99 | """ 100 | function that applies data augmentation to torch image 101 | :param images: current image 102 | :return: augmented image 103 | """ 104 | rnd = np.random.random_sample() 105 | images = torch.unsqueeze(images, dim=1) 106 | angle = np.random.randint(-15, 15) 107 | 108 | for id, image in enumerate(images): 109 | if rnd < 0.25: 110 | image = TF.to_pil_image(image) 111 | image = TF.rotate(image, angle) 112 | elif 0.25 <= rnd <= 0.50: 113 | image = TF.to_pil_image(image) 114 | image = TF.vflip(image) 115 | elif 0.50 < rnd <= 0.75: 116 | image = TF.to_pil_image(image) 117 | image = TF.hflip(image) 118 | else: 119 | images = torch.squeeze(images) 120 | return images 121 | images[id] = TF.to_tensor(image) 122 | return torch.squeeze(images) 123 | 124 | @staticmethod 125 | def album_aug(images: torch.Tensor) -> torch.Tensor: 126 | """ 127 | function that applies data augmentation using the albumentations library to speed up 128 | :param images: current image 129 | :return: augmented image 130 | """ 131 | angle = np.random.randint(-15, 15) 132 | transform = A.Compose([ 133 | A.OneOf([ 134 | A.Rotate(limit=angle, always_apply=False, p=0.33), 135 | A.VerticalFlip(p=0.33), 136 | A.HorizontalFlip(p=0.33) 137 | ], 138 | p=0.75)] 139 | ) 140 | image_aug = np.transpose(images.numpy(), (1, 2, 0)) 141 | image_aug = transform(image=image_aug)['image'] 142 | return torch.squeeze(torch.from_numpy(image_aug)) 143 | 144 | def __getitem__(self, index: int) -> Tuple[torch.tensor, torch.tensor, torch.tensor]: 145 | # obtain the right folder 146 | imgs_file = self.folder_path[index] 147 | # load torch image 148 | spectral_img = torch.load(imgs_file + '/all_bands_chroma.pt') 149 | # resize the image as specified in the params dsize 150 | spectral_img = torch.squeeze( 151 | nn.functional.interpolate(input=torch.unsqueeze(spectral_img, dim=0), size=self.img_size)) 152 | # take only the bands specified in the init 153 | spectral_img = spectral_img[self.bands_indices] 154 | # eventually apply augmentation 155 | if self.augmentation: 156 | spectral_img = self.augmentation_fn(spectral_img) 157 | # split the bands and convert to CieLab space 158 | spectral_bands, rgb = self.split_bands(spectral_img) 159 | # convert tensor to numpy 160 | rgb = np.transpose(rgb.numpy(), (1, 2, 0)) 161 | L, ab = Color.rgb2lab(rgb) 162 | return spectral_bands, self.to_tensor(L), self.to_tensor(ab) 163 | 164 | def __len__(self) -> int: 165 | return self.data_len 166 | 167 | 168 | if __name__ == "__main__": 169 | argparser = argparse.ArgumentParser(description='BigEarthNet dataset tiff version') 170 | argparser.add_argument('--csv_filename', type=str, 171 | default='BigEarthNet_all_refactored_no_clouds_and_snow_v2_new_path.csv', 172 | required=True, help='csv containing dataset paths') 173 | argparser.add_argument('--n_samples', type=int, default=3000, help='Number of samples to create the csv file') 174 | argparser.add_argument('--augmentation', type=int, default=1, choices=[0, 1], help='set to 1 for use augmenation') 175 | 176 | args = argparser.parse_args() 177 | # Dataset definition 178 | big_earth = BigEarthDatasetTorch(csv_path=args.csv_filename, random_seed=19, 179 | bands_indices=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0], img_size=128, 180 | augmentation=1, n_samples=args.n_samples) 181 | # dataset split 182 | train_idx, val_idx, test_idx = big_earth.split_dataset(0.2, 0.4) 183 | # dataset sampler 184 | train_sampler = SequentialSampler(train_idx) 185 | val_sampler = SubsetRandomSampler(val_idx) 186 | test_sampler = SubsetRandomSampler(test_idx) 187 | # dataset loader 188 | train_loader = torch.utils.data.DataLoader(big_earth, batch_size=16, 189 | sampler=train_sampler, num_workers=4) 190 | test_loader = torch.utils.data.DataLoader(big_earth, batch_size=1, 191 | sampler=test_sampler, num_workers=4) 192 | start_time = time.time() 193 | 194 | runs = 5 195 | for i in range(runs): 196 | for idx, (spectral_img, L, ab) in enumerate(train_loader): 197 | print(idx) 198 | 199 | print("Mean Time over 5 runs: ", (time.time() - start_time) / runs) 200 | -------------------------------------------------------------------------------- /Colorization/dataset/quantiles_3000.json: -------------------------------------------------------------------------------- 1 | { 2 | "B01": { 3 | "min_q": 1.0, 4 | "max_q": 1097.0 5 | }, 6 | "B02": { 7 | "min_q": 59.0, 8 | "max_q": 1346.0 9 | }, 10 | "B03": { 11 | "min_q": 71.0, 12 | "max_q": 1709.0 13 | }, 14 | "B04": { 15 | "min_q": 32.0, 16 | "max_q": 2149.0 17 | }, 18 | "B05": { 19 | "min_q": 21.0, 20 | "max_q": 2466.0 21 | }, 22 | "B06": { 23 | "min_q": 7.0, 24 | "max_q": 3898.0 25 | }, 26 | "B07": { 27 | "min_q": 10.0, 28 | "max_q": 4707.0 29 | }, 30 | "B08": { 31 | "min_q": 4.0, 32 | "max_q": 4994.0 33 | }, 34 | "B8A": { 35 | "min_q": 2.0, 36 | "max_q": 4997.0 37 | }, 38 | "B09": { 39 | "min_q": 1.0, 40 | "max_q": 4798.0 41 | }, 42 | "B11": { 43 | "min_q": 6.0, 44 | "max_q": 4144.0 45 | }, 46 | "B12": { 47 | "min_q": 7.0, 48 | "max_q": 3157.0 49 | } 50 | } -------------------------------------------------------------------------------- /Colorization/job_config.py: -------------------------------------------------------------------------------- 1 | def set_params(params, id_optim): 2 | if id_optim is None: 3 | pass 4 | else: 5 | if id_optim == 0: 6 | params.dataset_nsamples = 5000 7 | params.epochs = 50 8 | params.seed = 42 9 | params.batch_size = 16 10 | params.test_split = 0.2 11 | params.val_split = 0.4 12 | params.backbone = 50 13 | params.dataset = "dataset/BigEarthNet_all_refactored_no_clouds_and_snow_v2_new_path.csv" 14 | params.path_nas = "/nas/softechict-nas-2/svincenzi/colorization_resnet/experiments_resnet50_AE/" 15 | params.log_dir = "Debug/" 16 | params.augmentation = 1 17 | params.decoder_version = 18 18 | params.pretrained = 0 19 | params.input_channels = 9 20 | params.out_channels = 2 21 | params.lr = 0.01 22 | params.optim = "SGD" 23 | params.img_size = 128 24 | params.weight_rec_loss = 100. 25 | params.grad_loss = 0 26 | params.weight_grad_loss = 0.1 27 | params.loss = "L1" 28 | params.scheduler = 1 29 | params.sched_step = 40 30 | params.sched_type = "step" 31 | params.path_model_dict = "" 32 | params.load_checkpoint = 0 33 | params.dropout = 0.3 34 | params.num_workers = 0 35 | 36 | params.log_dir = params.log_dir + "_batch_" + str(params.batch_size) 37 | 38 | return params 39 | -------------------------------------------------------------------------------- /Colorization/losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevinc/TheColorOutOfSpace/f8451a6b4d27ac98937eb21da6188e41f6cd254a/Colorization/losses/__init__.py -------------------------------------------------------------------------------- /Colorization/losses/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Loss(nn.Module): 7 | def __init__(self, mode="L1"): 8 | super().__init__() 9 | self.mode = mode 10 | if self.mode == 'L1': 11 | self.loss = nn.L1Loss(reduction='mean') 12 | else: 13 | self.loss = nn.MSELoss(reduction='mean') 14 | 15 | def forward(self, pred: torch.Tensor, gt: torch.Tensor): 16 | return self.loss(pred, gt) 17 | 18 | @staticmethod 19 | def grad_loss_fn(pred: torch.Tensor, gt: torch.Tensor) -> torch.Tensor: 20 | B, C, H, W = gt.shape 21 | # Horinzontal Sobel filter 22 | Sx = torch.Tensor(([-1, 0, 1], 23 | [-2, 0, 2], 24 | [-1, 0, 1])).cuda() 25 | # reshape the filter and compute the conv 26 | Sx = Sx.expand(1, C, 3, 3) 27 | Gt_x = F.conv2d(gt, Sx, padding=1) 28 | pred_x = F.conv2d(pred, Sx, padding=1) 29 | # Vertical Sobel filter 30 | Sy = torch.Tensor([[1, 2, 1], 31 | [0, 0, 0], 32 | [-1, -2, -1]]).cuda() 33 | # reshape the filter and compute the conv 34 | Sy = Sy.expand(1, C, 3, 3) 35 | Gt_y = F.conv2d(gt, Sy, padding=1) 36 | pred_y = F.conv2d(pred, Sy, padding=1) 37 | 38 | loss_grad_x = torch.pow((Gt_x - pred_x), 2).mean() 39 | loss_grad_y = torch.pow((Gt_y - pred_y), 2).mean() 40 | loss_grad = loss_grad_x + loss_grad_y 41 | return loss_grad 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /Colorization/main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import logging 5 | import os 6 | import warnings 7 | 8 | import numpy as np 9 | import torch.utils.data 10 | from torch import optim 11 | from torch.utils.data import SubsetRandomSampler 12 | from torch.utils.tensorboard import SummaryWriter 13 | 14 | from Colorization import utils 15 | from Colorization.dataset.dataset_big_earth_torch import BigEarthDatasetTorch 16 | from Colorization.job_config import set_params 17 | from Colorization.losses.loss import Loss 18 | from Colorization.models.Resnet18 import ResNet18 19 | from Colorization.models.Resnet50 import ResNet50 20 | from Colorization.test import test 21 | from Colorization.train import train 22 | 23 | warnings.filterwarnings("ignore") 24 | 25 | os.environ["OMP_NUM_THREADS"] = "1" 26 | torch.backends.cudnn.benchmark = True 27 | torch.backends.cudnn.enabled = True 28 | 29 | 30 | def init_worker(id_worker): 31 | np.random.seed(42 + id_worker) 32 | 33 | 34 | def main(args): 35 | # enable cuda if available 36 | args.cuda = args.cuda and torch.cuda.is_available() 37 | device = torch.device("cuda" if args.cuda else "cpu") 38 | 39 | # READ JSON CONFIG FILE 40 | assert os.path.isfile(args.json_config_file), "No json configuration file found at {}".format(args.json_config_file) 41 | params = utils.Params(args.json_config_file) 42 | 43 | # for change params related to job-id 44 | params = set_params(params, args.id_optim) 45 | 46 | # set the torch seed 47 | torch.manual_seed(params.seed) 48 | 49 | # initialize summary writer; every folder is saved inside runs 50 | writer = SummaryWriter(params.path_nas + params.log_dir + '/runs/' + params.log_dir) 51 | 52 | # create dir for log file 53 | if not os.path.exists(params.path_nas + params.log_dir): 54 | os.makedirs(params.path_nas + params.log_dir) 55 | # save the json config file of the model 56 | params.save(os.path.join(params.path_nas + params.log_dir, "params.json")) 57 | 58 | # Set the logger 59 | utils.set_logger(os.path.join(params.path_nas + params.log_dir, "log")) 60 | 61 | # DATASET 62 | # Torch version 63 | big_earth = BigEarthDatasetTorch(csv_path=params.dataset, random_seed=params.seed, bands_indices=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0], 64 | img_size=params.img_size, augmentation=params.augmentation, n_samples=params.dataset_nsamples) 65 | 66 | train_idx, val_idx, test_idx = big_earth.split_dataset(params.test_split, params.val_split) 67 | # define the sampler 68 | train_sampler = SubsetRandomSampler(train_idx) 69 | val_sampler = SubsetRandomSampler(val_idx) 70 | test_sampler = SubsetRandomSampler(test_idx) 71 | # define the loader 72 | train_loader = torch.utils.data.DataLoader(big_earth, batch_size=params.batch_size, 73 | sampler=train_sampler, num_workers=params.num_workers, worker_init_fn=init_worker) 74 | val_loader = torch.utils.data.DataLoader(big_earth, batch_size=params.batch_size, 75 | sampler=val_sampler, num_workers=params.num_workers, worker_init_fn=init_worker) 76 | test_loader = torch.utils.data.DataLoader(big_earth, batch_size=params.batch_size, 77 | sampler=test_sampler, num_workers=params.num_workers, worker_init_fn=init_worker) 78 | # MODEL definition 79 | if params.backbone == 50: 80 | model = ResNet50(in_channels=params.input_channels, out_channels=params.out_channels, pretrained=params.pretrained, 81 | dropout=params.dropout, decoder_version=params.decoder_version) 82 | else: 83 | model = ResNet18(in_channels=params.input_channels, out_channels=params.out_channels, pretrained=params.pretrained, 84 | dropout=params.dropout) 85 | 86 | # eventually load checkpoint 87 | if params.load_checkpoint == 1: 88 | checkpoint = torch.load(params.path_model_dict) 89 | model.load_state_dict(checkpoint['state_dict'], strict=False) 90 | start_epoch = checkpoint['epoch'] + 1 91 | else: 92 | start_epoch = 0 93 | 94 | # CUDA 95 | model.to(device) 96 | 97 | # LOSS ON RECONSTRUCTION 98 | loss_fn = Loss(mode=params.loss) 99 | 100 | # OPTIMIZER 101 | if params.optim == "Adam": 102 | optimizer = optim.Adam(model.parameters(), lr=params.lr, weight_decay=params.weight_decay) 103 | else: 104 | optimizer = optim.SGD(model.parameters(), lr=params.lr, momentum=params.momentum) 105 | 106 | # SCHEDULER 107 | if params.sched_type == "step": 108 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=params.sched_step, gamma=0.1) 109 | else: 110 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=params.sched_milestones, gamma=0.1) 111 | 112 | if params.load_checkpoint: 113 | optimizer.load_state_dict(checkpoint['optim_dict']) 114 | scheduler.load_state_dict(checkpoint['scheduler_dict']) 115 | 116 | for epoch in range(params.epochs-start_epoch): 117 | # Training 118 | if params.load_checkpoint: 119 | epoch += start_epoch 120 | 121 | logging.info("Epoch {}/{}".format(epoch, params.epochs)) 122 | 123 | train(model=model, train_loader=train_loader, loss_fn=loss_fn, optimizer=optimizer, 124 | device=device, params=params, epoch=epoch, writer=writer) 125 | # validation 126 | if epoch % params.val_step == 0: 127 | logging.info("Starting test for {} epoch(s)".format(params.epochs)) 128 | test(model=model, test_loader=val_loader, loss_fn=loss_fn, 129 | device=device, params=params, epoch=epoch, writer=writer) 130 | # scheduler step 131 | if params.scheduler: 132 | scheduler.step() 133 | # Save checkpoint 134 | if epoch % params.save_checkpoint == 0: 135 | if params.scheduler: 136 | state = {'epoch': epoch, 137 | 'state_dict': model.state_dict(), 138 | 'optim_dict': optimizer.state_dict(), 139 | 'scheduler_dict': scheduler.state_dict()} 140 | else: 141 | state = {'epoch': epoch, 142 | 'state_dict': model.state_dict(), 143 | 'optim_dict': optimizer.state_dict()} 144 | path_to_save_chk = params.path_nas + params.log_dir 145 | utils.save_checkpoint(state, 146 | is_best=False, 147 | checkpoint=path_to_save_chk) 148 | 149 | logging.info("Starting final test...") 150 | test(model=model, test_loader=test_loader, loss_fn=loss_fn, 151 | device=device, params=params, epoch=1, writer=writer) 152 | 153 | # CLOSE THE WRITER 154 | writer.close() 155 | 156 | 157 | if __name__ == '__main__': 158 | # command line arguments 159 | parser = argparse.ArgumentParser(description='Colorization') 160 | parser.add_argument('--cuda', action='store_true', default=True, help='enables CUDA training') 161 | parser.add_argument('--json_config_file', default='../Colorization/config/configuration.json', help='name of the json config file') 162 | parser.add_argument('--id_optim', default=0, type=int, help='id_optim parameter') 163 | # read the args 164 | args = parser.parse_args() 165 | main(args) 166 | 167 | 168 | 169 | 170 | -------------------------------------------------------------------------------- /Colorization/models/Decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from Colorization.models.Decoder_utils import conv1x1, Bottleneck, BasicBlock 5 | 6 | 7 | class ResNet(nn.Module): 8 | 9 | def __init__(self, block, layers, zero_init_residual=False, 10 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 11 | norm_layer=None, stride=2, out_channels=2, inplanes = 512): 12 | super(ResNet, self).__init__() 13 | if norm_layer is None: 14 | norm_layer = nn.BatchNorm2d 15 | self._norm_layer = norm_layer 16 | 17 | self.inplanes = inplanes 18 | self.dilation = 1 19 | if replace_stride_with_dilation is None: 20 | # each element in the tuple indicates if we should replace 21 | # the 2x2 stride with a dilated convolution instead 22 | replace_stride_with_dilation = [False, False, False] 23 | if len(replace_stride_with_dilation) != 3: 24 | raise ValueError("replace_stride_with_dilation should be None " 25 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 26 | self.groups = groups 27 | self.base_width = width_per_group 28 | self.stride = stride 29 | self.out_channels = out_channels 30 | 31 | self.avgpool = nn.AdaptiveAvgPool2d((4, 4)) 32 | 33 | self.layer4_dec = self._make_layer(block, 512, layers[3], stride=self.stride, 34 | dilate=replace_stride_with_dilation[2]) 35 | 36 | self.layer3_dec = self._make_layer(block, 256, layers[2], stride=self.stride, 37 | dilate=replace_stride_with_dilation[1]) 38 | 39 | self.layer2_dec = self._make_layer(block, 128, layers[1], stride=self.stride, 40 | dilate=replace_stride_with_dilation[0]) 41 | 42 | self.layer1_dec = self._make_layer(block, 64, layers[0]) 43 | 44 | self.upsample = nn.Upsample(scale_factor=2, mode='bicubic', align_corners=True) 45 | 46 | self.bn1 = norm_layer(int(inplanes / 8)) 47 | self.relu = nn.ReLU(inplace=True) 48 | self.conv_final = nn.ConvTranspose2d(int(inplanes / 8), self.out_channels, kernel_size=(3, 3), stride=self.stride, 49 | padding=1, groups=groups, bias=False, output_padding=1) 50 | self.sigmoid = nn.Sigmoid() 51 | 52 | for m in self.modules(): 53 | if isinstance(m, nn.Conv2d): 54 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 55 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 56 | nn.init.constant_(m.weight, 1) 57 | nn.init.constant_(m.bias, 0) 58 | 59 | # Zero-initialize the last BN in each residual branch, 60 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 61 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 62 | if zero_init_residual: 63 | for m in self.modules(): 64 | if isinstance(m, Bottleneck): 65 | nn.init.constant_(m.bn3.weight, 0) 66 | elif isinstance(m, BasicBlock): 67 | nn.init.constant_(m.bn2.weight, 0) 68 | 69 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 70 | norm_layer = self._norm_layer 71 | downsample = None 72 | previous_dilation = self.dilation 73 | if dilate: 74 | self.dilation *= stride 75 | stride = 1 76 | if stride != 1 or self.inplanes != planes * block.expansion: 77 | downsample = nn.Sequential( 78 | conv1x1(self.inplanes, planes * block.expansion, stride), 79 | norm_layer(planes * block.expansion), 80 | ) 81 | 82 | layers = [] 83 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 84 | self.base_width, previous_dilation, norm_layer)) 85 | self.inplanes = planes * block.expansion 86 | for _ in range(1, blocks): 87 | layers.append(block(self.inplanes, planes, groups=self.groups, 88 | base_width=self.base_width, dilation=self.dilation, 89 | norm_layer=norm_layer)) 90 | 91 | return nn.Sequential(*layers) 92 | 93 | def _forward_impl(self, x): 94 | x = self.avgpool(x) 95 | x = self.layer4_dec(x) 96 | x = self.layer3_dec(x) 97 | x = self.layer2_dec(x) 98 | x = self.layer1_dec(x) 99 | x = self.upsample(x) 100 | x = self.bn1(x) 101 | x = self.conv_final(x) 102 | x = self.sigmoid(x) 103 | return x 104 | 105 | def forward(self, x): 106 | return self._forward_impl(x) 107 | 108 | 109 | def _resnet(block, layers, stride, out_channels, inplanes, **kwargs): 110 | kwargs['stride'] = stride 111 | kwargs['out_channels'] = out_channels 112 | kwargs['inplanes'] = inplanes 113 | model = ResNet(block, layers, **kwargs) 114 | return model 115 | 116 | 117 | def resnet18_decoder(stride, out_channels, **kwargs): 118 | r"""ResNet-18 model from 119 | `"Deep Residual Learning for Image Recognition" `_ 120 | 121 | Args: 122 | pretrained (bool): If True, returns a model pre-trained on ImageNet 123 | progress (bool): If True, displays a progress bar of the download to stderr 124 | """ 125 | inplanes = 512 126 | return _resnet(BasicBlock, [2, 2, 2, 2], stride, out_channels, inplanes, **kwargs) 127 | 128 | 129 | def resnet50_decoder(stride, out_channels, **kwargs): 130 | r"""ResNet-50 model from 131 | `"Deep Residual Learning for Image Recognition" `_ 132 | 133 | Args: 134 | pretrained (bool): If True, returns a model pre-trained on ImageNet 135 | progress (bool): If True, displays a progress bar of the download to stderr 136 | """ 137 | inplanes = 2048 138 | return _resnet(Bottleneck, [3, 4, 6, 3], stride, out_channels, inplanes, **kwargs) 139 | 140 | 141 | if __name__ == '__main__': 142 | x = torch.ones((16, 512, 1, 1)) 143 | model = resnet18_decoder(stride=2, out_channels=2) 144 | out = model(x) 145 | print("End!") 146 | 147 | 148 | 149 | 150 | 151 | 152 | -------------------------------------------------------------------------------- /Colorization/models/Decoder_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 6 | """3x3 convolution with padding""" 7 | if stride == 1: 8 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 9 | padding=dilation, groups=groups, bias=False, dilation=dilation) 10 | else: 11 | return nn.ConvTranspose2d(in_planes, out_planes, kernel_size=(3, 3), stride=stride, 12 | padding=dilation, groups=groups, bias=False, output_padding=1) 13 | 14 | 15 | def conv1x1(in_planes, out_planes, stride=1): 16 | """1x1 convolution""" 17 | if stride == 1: 18 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 19 | else: 20 | return nn.ConvTranspose2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False, output_padding=1) 21 | 22 | 23 | class BasicBlock(nn.Module): 24 | expansion = 1 25 | __constants__ = ['downsample'] 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 28 | base_width=64, dilation=1, norm_layer=None): 29 | super(BasicBlock, self).__init__() 30 | if norm_layer is None: 31 | norm_layer = nn.BatchNorm2d 32 | if groups != 1 or base_width != 64: 33 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 34 | if dilation > 1: 35 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 36 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 37 | self.conv1 = conv3x3(inplanes, planes, stride) 38 | self.bn1 = norm_layer(planes) 39 | self.relu = nn.ReLU(inplace=True) 40 | self.conv2 = conv3x3(planes, planes) 41 | self.bn2 = norm_layer(planes) 42 | self.downsample = downsample 43 | self.stride = stride 44 | 45 | def forward(self, x): 46 | identity = x 47 | 48 | out = self.conv1(x) 49 | out = self.bn1(out) 50 | out = self.relu(out) 51 | 52 | out = self.conv2(out) 53 | out = self.bn2(out) 54 | 55 | if self.downsample is not None: 56 | identity = self.downsample(x) 57 | 58 | out += identity 59 | out = self.relu(out) 60 | 61 | return out 62 | 63 | 64 | class Bottleneck(nn.Module): 65 | expansion = 4 66 | __constants__ = ['downsample'] 67 | 68 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 69 | base_width=64, dilation=1, norm_layer=None): 70 | super(Bottleneck, self).__init__() 71 | if norm_layer is None: 72 | norm_layer = nn.BatchNorm2d 73 | width = int(planes * (base_width / 64.)) * groups 74 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 75 | self.conv1 = conv1x1(inplanes, width) 76 | self.bn1 = norm_layer(width) 77 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 78 | self.bn2 = norm_layer(width) 79 | self.conv3 = conv1x1(width, planes * self.expansion) 80 | self.bn3 = norm_layer(planes * self.expansion) 81 | self.relu = nn.ReLU(inplace=True) 82 | self.downsample = downsample 83 | self.stride = stride 84 | 85 | def forward(self, x): 86 | identity = x 87 | 88 | out = self.conv1(x) 89 | out = self.bn1(out) 90 | out = self.relu(out) 91 | 92 | out = self.conv2(out) 93 | out = self.bn2(out) 94 | out = self.relu(out) 95 | 96 | out = self.conv3(out) 97 | out = self.bn3(out) 98 | 99 | if self.downsample is not None: 100 | identity = self.downsample(x) 101 | 102 | out += identity 103 | out = self.relu(out) 104 | 105 | return out -------------------------------------------------------------------------------- /Colorization/models/Resnet18.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models 4 | 5 | from Colorization.models.Decoder import resnet18_decoder 6 | 7 | 8 | class ResNet18(nn.Module): 9 | def __init__(self, in_channels=1, out_channels=2, pretrained=0, dropout=0.3): 10 | super(ResNet18, self).__init__() 11 | if pretrained: 12 | self.model = models.resnet18(pretrained=True) 13 | else: 14 | self.model = models.resnet18(pretrained=False) 15 | # set the number of input channels 16 | if in_channels != 3: 17 | self.conv_1 = nn.Conv2d(in_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) 18 | self.model.conv1 = self.conv_1 19 | # feature extractor definition 20 | self.feature_extractor = nn.Sequential(*list(self.model.children())[:-1]) 21 | # decoder definition 22 | self.decoder = resnet18_decoder(stride=2, out_channels=out_channels) 23 | self.dropout = nn.Dropout(p=dropout) 24 | 25 | def forward(self, spectral: torch.Tensor) -> torch.Tensor: 26 | features = self.feature_extractor(spectral) 27 | features = self.dropout(features) 28 | recon = self.decoder(features) 29 | return recon 30 | 31 | 32 | if __name__ == '__main__': 33 | x = torch.ones((16, 12, 128, 128)).cuda() 34 | ab = torch.ones((16, 12, 128, 128)).cuda() 35 | net = ResNet18(in_channels=9, out_channels=2, pretrained=0).cuda() 36 | recon = net(x) 37 | print("that's all folks!") -------------------------------------------------------------------------------- /Colorization/models/Resnet50.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models 4 | 5 | from Colorization.models.Decoder import resnet50_decoder 6 | from Colorization.models.Decoder import resnet18_decoder 7 | 8 | 9 | class ResNet50(nn.Module): 10 | def __init__(self, in_channels=1, out_channels=2, pretrained=0, dropout=0.3, decoder_version=50): 11 | super(ResNet50, self).__init__() 12 | if pretrained: 13 | self.model = models.resnet50(pretrained=True) 14 | else: 15 | self.model = models.resnet50(pretrained=False) 16 | # set the number of input channels 17 | if in_channels != 3: 18 | self.conv_1 = nn.Conv2d(in_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) 19 | self.model.conv1 = self.conv_1 20 | # feature extractor definition 21 | self.feature_extractor = nn.Sequential(*list(self.model.children())[:-1]) 22 | # set the decoder version 23 | self.decoder_version = decoder_version 24 | # decoder definition 25 | if decoder_version == 50: 26 | self.decoder = resnet50_decoder(stride=2, out_channels=out_channels) 27 | else: 28 | self.decoder = resnet18_decoder(stride=2, out_channels=out_channels) 29 | # conv to reduce dimension in order to pass from resnet50 encoder to resnet18 decoder 30 | self.conv1x1 = nn.Conv2d(2048, 512, kernel_size=1, stride=1, bias=False) 31 | self.batch_norm = nn.BatchNorm1d(512) 32 | self.dropout = nn.Dropout(p=dropout) 33 | 34 | def forward(self, spectral: torch.Tensor) -> torch.Tensor: 35 | features = self.feature_extractor(spectral) 36 | if self.decoder_version == 18: 37 | features = self.conv1x1(features) 38 | # features = self.batch_norm(features) 39 | recon = self.decoder(features) 40 | return recon 41 | 42 | 43 | if __name__ == '__main__': 44 | x = torch.ones((16, 9, 128, 128)).cuda() 45 | ab = torch.ones((16, 2, 128, 128)).cuda() 46 | net = ResNet50(in_channels=9, out_channels=2, pretrained=0, decoder_version=18).cuda() 47 | recon = net(x) 48 | print("that's all folks!") -------------------------------------------------------------------------------- /Colorization/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevinc/TheColorOutOfSpace/f8451a6b4d27ac98937eb21da6188e41f6cd254a/Colorization/models/__init__.py -------------------------------------------------------------------------------- /Colorization/test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | 4 | import torch 5 | from Colorization.dataset.dataset_big_earth import Color 6 | from tqdm import tqdm 7 | 8 | 9 | def test(model, test_loader, loss_fn, device, params, epoch, writer): 10 | start_time = time.time() 11 | # SET THE MODEL TO EVALUATION MODE 12 | model.eval() 13 | 14 | test_loss = 0. 15 | 16 | with torch.no_grad(): 17 | with tqdm(total=len(test_loader)) as t: 18 | for batch_idx, (spectral, L, ab) in enumerate(test_loader): 19 | # move input data to GPU 20 | spectral = spectral.to(device) 21 | ab = ab.to(device) 22 | 23 | # FORWARD PASS 24 | out = model(spectral=spectral) 25 | loss = loss_fn(out, ab) * params.weight_rec_loss 26 | 27 | # GRAD LOSS 28 | if params.grad_loss: 29 | grad_loss = loss_fn.grad_loss_fn(out, ab)*params.weight_grad_loss 30 | loss += grad_loss 31 | 32 | test_loss += loss.item() 33 | 34 | # write loss 35 | t.set_postfix(loss='{:05.3f}'.format(loss.item())) 36 | t.update() 37 | 38 | # log on tensorboard 39 | if batch_idx % params.log_interval == 0: 40 | # LOSS LOG 41 | writer.add_scalar('Tot_Loss/test', loss.item(), epoch * len(test_loader) + batch_idx) 42 | if params.grad_loss: 43 | writer.add_scalar('Grad_Loss/test', grad_loss.item(), epoch * len(test_loader) + batch_idx) 44 | # IMAGE LOG 45 | n = min(out.size(0), 8) 46 | recon_rgb_n = Color.lab2rgb(L[:n], out[:n]) 47 | rgb_n = Color.lab2rgb(L[:n], ab[:n]) 48 | comparison = torch.cat([rgb_n[:n], recon_rgb_n[:n]]) 49 | writer.add_images('comparison_original_recon/test', comparison, 50 | epoch * len(test_loader) + batch_idx) 51 | 52 | time_elapsed = time.time() - start_time 53 | logging.info('Test complete in {:.0f}m {:.0f}s. Avg test loss: {:05.3f}'.format( 54 | time_elapsed // 60, time_elapsed % 60, test_loss / len(test_loader))) 55 | 56 | -------------------------------------------------------------------------------- /Colorization/train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | 4 | import torch 5 | from Colorization.dataset.dataset_big_earth import Color 6 | from tqdm import tqdm 7 | 8 | 9 | def train(model, train_loader, loss_fn, optimizer, device, params, epoch, writer): 10 | start_time = time.time() 11 | # SET THE MODEL TO TRAIN MODE 12 | model.train() 13 | 14 | train_loss = 0. 15 | with tqdm(total=len(train_loader)) as t: 16 | for batch_idx, (spectral, L, ab) in enumerate(train_loader): 17 | # move input data to GPU 18 | spectral = spectral.to(device) 19 | ab = ab.to(device) 20 | # set the gradient to zero 21 | optimizer.zero_grad() 22 | # FORWARD PASS 23 | out = model(spectral=spectral) 24 | # L2 | L1 LOSS ON RECONSTRUCTION 25 | loss = loss_fn(out, ab) * params.weight_rec_loss 26 | # GRAD LOSS 27 | if params.grad_loss: 28 | grad_loss = loss_fn.grad_loss_fn(out, ab) * params.weight_grad_loss 29 | loss += grad_loss 30 | # BACKWARD PASS 31 | loss.backward() 32 | train_loss += loss.item() 33 | # write loss 34 | t.set_postfix(loss='{:05.3f}'.format(loss.item())) 35 | t.update() 36 | # update the params of the model 37 | optimizer.step() 38 | # log on tensorboard 39 | if batch_idx % params.log_interval == 0: 40 | # LOSS LOG 41 | writer.add_scalar('Total_Loss/train', loss.item(), epoch * len(train_loader) + batch_idx) 42 | if params.grad_loss: 43 | writer.add_scalar('Grad_Loss/train', grad_loss.item(), epoch * len(train_loader) + batch_idx) 44 | # IMAGE LOG 45 | n = min(out.size(0), 8) 46 | recon_rgb_n = Color.lab2rgb(L[:n], out[:n]) 47 | rgb_n = Color.lab2rgb(L[:n], ab[:n]) 48 | comparison = torch.cat([rgb_n[:n], recon_rgb_n[:n]]) 49 | writer.add_images('comparison_original_recon/train', comparison, epoch * len(train_loader) + batch_idx) 50 | 51 | time_elapsed = time.time() - start_time 52 | logging.info('Epoch complete in {:.0f}m {:.0f}s. Avg training loss: {:05.3f}'.format( 53 | time_elapsed // 60, time_elapsed % 60, train_loss / len(train_loader))) 54 | -------------------------------------------------------------------------------- /Colorization/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import shutil 5 | 6 | import torch 7 | 8 | 9 | class Params: 10 | """Class that loads hyperparameters from a json file. 11 | 12 | Example: 13 | ``` 14 | params = Params(json_path) 15 | print(params.learning_rate) 16 | params.learning_rate = 0.5 # change the value of learning_rate in params 17 | ``` 18 | """ 19 | 20 | def __init__(self, json_path): 21 | with open(json_path) as f: 22 | params = json.load(f) 23 | self.__dict__.update(params) 24 | 25 | def save(self, json_path): 26 | with open(json_path, 'w') as f: 27 | json.dump(self.__dict__, f, indent=4) 28 | 29 | def update(self, json_path): 30 | """Loads parameters from json file""" 31 | with open(json_path) as f: 32 | params = json.load(f) 33 | self.__dict__.update(params) 34 | 35 | @property 36 | def dict(self): 37 | """Gives dict-like access to Params instance by `params.dict['learning_rate']""" 38 | return self.__dict__ 39 | 40 | 41 | class RunningAverage(): 42 | """A simple class that maintains the running average of a quantity 43 | 44 | Example: 45 | ``` 46 | loss_avg = RunningAverage() 47 | loss_avg.update(2) 48 | loss_avg.update(4) 49 | loss_avg() = 3 50 | ``` 51 | """ 52 | 53 | def __init__(self): 54 | self.steps = 0 55 | self.total = 0 56 | 57 | def update(self, val): 58 | self.total += val 59 | self.steps += 1 60 | 61 | def __call__(self): 62 | return self.total / float(self.steps) 63 | 64 | 65 | def set_logger(log_path): 66 | """Set the logger to log info in terminal and file `log_path`. 67 | 68 | In general, it is useful to have a logger so that every output to the terminal is saved 69 | in a permanent file. Here we save it to `model_dir/train.log`. 70 | 71 | Example: 72 | ``` 73 | logging.info("Starting training...") 74 | ``` 75 | 76 | Args: 77 | log_path: (string) where to log 78 | """ 79 | logger = logging.getLogger() 80 | logger.setLevel(logging.INFO) 81 | 82 | if not logger.handlers: 83 | # Logging to a file 84 | file_handler = logging.FileHandler(log_path) 85 | file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s')) 86 | logger.addHandler(file_handler) 87 | 88 | # Logging to console 89 | stream_handler = logging.StreamHandler() 90 | stream_handler.setFormatter(logging.Formatter('%(message)s')) 91 | logger.addHandler(stream_handler) 92 | 93 | 94 | def save_dict_to_json(d, json_path): 95 | """Saves dict of floats in json file 96 | 97 | Args: 98 | d: (dict) of float-castable values (np.float, int, float, etc.) 99 | json_path: (string) path to json file 100 | """ 101 | with open(json_path, 'w') as f: 102 | # We need to convert the values to float for json (it doesn't accept np.array, np.float, ) 103 | d = {k: float(v) for k, v in d.items()} 104 | json.dump(d, f, indent=4) 105 | 106 | 107 | def save_dict_to_json_append(d, json_path): 108 | """Saves dict of floats in json file 109 | 110 | Args: 111 | d: (dict) of float-castable values (np.float, int, float, etc.) 112 | json_path: (string) path to json file 113 | """ 114 | with open(json_path, 'a') as f: 115 | # We need to convert the values to float for json (it doesn't accept np.array, np.float, ) 116 | d = {k: float(v) for k, v in d.items()} 117 | json.dump(d, f, indent=4) 118 | 119 | 120 | def save_checkpoint(state, is_best, checkpoint): 121 | """Saves model and training parameters at checkpoint + 'last.pth.tar'. If is_best==True, also saves 122 | checkpoint + 'best.pth.tar' 123 | 124 | Args: 125 | state: (dict) contains model's state_dict, may contain other keys such as epoch, optimizer state_dict 126 | is_best: (bool) True if it is the best model seen till now 127 | checkpoint: (string) folder where parameters are to be saved 128 | """ 129 | filepath = os.path.join(checkpoint, 'last.pth.tar') 130 | if not os.path.exists(checkpoint): 131 | print("Checkpoint Directory does not exist! Making directory {}".format(checkpoint)) 132 | os.mkdir(checkpoint) 133 | else: 134 | print("Checkpoint Directory exists! ") 135 | torch.save(state, filepath) 136 | if is_best: 137 | shutil.copyfile(filepath, os.path.join(checkpoint, 'best.pth.tar')) 138 | 139 | 140 | def load_checkpoint(checkpoint, model, optimizer=None): 141 | """Loads model parameters (state_dict) from file_path. If optimizer is provided, loads state_dict of 142 | optimizer assuming it is present in checkpoint. 143 | 144 | Args: 145 | checkpoint: (string) filename which needs to be loaded 146 | model: (torch.nn.Module) model for which the parameters are loaded 147 | optimizer: (torch.optim) optional: resume optimizer from checkpoint 148 | """ 149 | if not os.path.exists(checkpoint): 150 | raise ("File doesn't exist {}".format(checkpoint)) 151 | checkpoint = torch.load(checkpoint) 152 | model.load_state_dict(checkpoint['state_dict']) 153 | 154 | if optimizer: 155 | optimizer.load_state_dict(checkpoint['optim_dict']) 156 | 157 | return checkpoint 158 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 stefano vincenzi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Multi_label_classification/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevinc/TheColorOutOfSpace/f8451a6b4d27ac98937eb21da6188e41f6cd254a/Multi_label_classification/__init__.py -------------------------------------------------------------------------------- /Multi_label_classification/config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevinc/TheColorOutOfSpace/f8451a6b4d27ac98937eb21da6188e41f6cd254a/Multi_label_classification/config/__init__.py -------------------------------------------------------------------------------- /Multi_label_classification/config/configuration.json: -------------------------------------------------------------------------------- 1 | { 2 | "log_interval": 10, 3 | "path_nas": "/nas/softechict-nas-2/svincenzi/multi_label_baseline/experiments/", 4 | "log_dir": "9_bands", 5 | "val_step": 10, 6 | 7 | "test_split": 0.2, 8 | "val_split": 0.4, 9 | "dataset": "dataset/big_earth_without_shitty_imgs_torch.csv", 10 | "dataset_nsamples": 1000, 11 | "bands": [1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0], 12 | "quantiles": "../Colorization/dataset/quantiles_3000.json", 13 | "num_workers": 4, 14 | 15 | "pretrained": 0, 16 | "resnet_version": 18, 17 | "input_channels": 9, 18 | "img_size": 128, 19 | "out_cls": 19, 20 | "batch_size": 32, 21 | "epochs": 50, 22 | "seed": 19, 23 | "change_first_conv": 0, 24 | 25 | 26 | "optim": "SGD", 27 | "lr": 0.01, 28 | "weight_decay": 1e-5, 29 | 30 | "scheduler": 0, 31 | "sched_step": 40, 32 | "sched_type": "step", 33 | "sched_milestones": [10, 60, 80], 34 | 35 | "save_checkpoint": 2, 36 | "load_checkpoint": 0, 37 | "path_model_dict": "/nas/softechict-nas-2/svincenzi/multi_label_baseline/experiments/", 38 | "load_checkpoint_tr": 0, 39 | "path_model_dict_tr": "/nas/softechict-nas-2/svincenzi/" 40 | } -------------------------------------------------------------------------------- /Multi_label_classification/dataset/README.md: -------------------------------------------------------------------------------- 1 | ## Change Labels 2 | Some of the original labels introduced in the original version of the BigEarthNet dataset are tough to predict. 3 | To tackle this problem, the paper "BigEarthNet Deep Learning Models with A New Class-Nomenclature for Remote Sensing Image Understanding" proposed a new nomenclature with 19 labels. 4 | By running the code below you can change the csv file created in the colorization phase with this new set of labels. 5 | ``` 6 | python change_labels.py --csv_filename [name of the file] --new_path_csv [name of the new csv file] 7 | ``` 8 | 9 | ## Credits 10 | ```bibtex 11 | @article{sumbul2020bigearthnet, 12 | title={BigEarthNet Deep Learning Models with A New Class-Nomenclature for Remote Sensing Image Understanding}, 13 | author={Sumbul, Gencer and Kang, Jian and Kreuziger, Tristan and Marcelino, Filipe and Costa, Hugo and Benevides, Pedro and Caetano, Mario and Demir, Beg{\"u}m}, 14 | journal={arXiv preprint arXiv:2001.06372}, 15 | year={2020} 16 | } 17 | ``` 18 | -------------------------------------------------------------------------------- /Multi_label_classification/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevinc/TheColorOutOfSpace/f8451a6b4d27ac98937eb21da6188e41f6cd254a/Multi_label_classification/dataset/__init__.py -------------------------------------------------------------------------------- /Multi_label_classification/dataset/change_labels.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pandas as pd 3 | 4 | # OLD LABELS TO DELETE 5 | labels_to_delete_index = [28, 40, 38, 26, 39, 35, 30, 24, 29, 42, 37] 6 | 7 | labels_to_delete_name = ["Road and rail networks and associated land", 8 | "Port areas", 9 | "Airports", 10 | "Mineral extraction sites", 11 | "Dump sites", 12 | "Construction sites", 13 | "Green urban areas", 14 | "Sport and leisure facilities", 15 | "Bare rock", 16 | "Burnt areas", 17 | "Intertidal flats"] 18 | 19 | # New labels assignments between old and new labels 20 | new_labels_assignments = { 21 | 0: 10, 22 | 1: 9, 23 | 2: 2, 24 | 3: 13, 25 | 4: 8, 26 | 5: 6, 27 | 6: 5, 28 | 7: 4, 29 | 8: 17, 30 | 9: 18, 31 | 10: 0, 32 | 11: 7, 33 | 12: 15, 34 | 13: 2, 35 | 14: 1, 36 | 15: 11, 37 | 16: 3, 38 | 17: 12, 39 | 18: 0, 40 | 19: 17, 41 | 20: 3, 42 | 21: 3, 43 | 22: 15, 44 | 23: 12, 45 | 25: 3, 46 | 27: 2, 47 | 31: 14, 48 | 32: 11, 49 | 33: 16, 50 | 34: 18, 51 | 36: 18, 52 | 41: 16 53 | } 54 | 55 | # New Labels 56 | BigEarthNet19_labels = { 57 | 0: "Urban fabric", 58 | 1: "Industrial or commercial units", 59 | 2: "Arable land", 60 | 3: "Permanent crops", 61 | 4: "Pastures", 62 | 5: "Complex cultivation patterns", 63 | 6: "Land principally occupied by agriculture, with significant areas of natural vegetation", 64 | 7: "Agro-forestry areas", 65 | 8: "Broad-leaved forest", 66 | 9: "Coniferous forest", 67 | 10: "Mixed forest", 68 | 11: "Natural grassland and sparsely vegetated areas", 69 | 12: "Moors, heathland and sclerophyllous vegetation", 70 | 13: "Transitional woodland, shrub", 71 | 14: "Beaches, dunes, sands", 72 | 15: "Inland wetlands", 73 | 16: "Coastal wetlands", 74 | 17: "Inland waters", 75 | 18: "Marine waters" 76 | } 77 | 78 | 79 | class Labels: 80 | def __init__(self): 81 | pass 82 | 83 | @staticmethod 84 | def delete_labels_in_csv(csv_filename: str, new_csv_file: str) -> True: 85 | """ 86 | Function that deletes the labels no more present in the new nomenclature 87 | :param csv_filename: old csv file 88 | :param new_csv_file: new csv file 89 | :return: True 90 | """ 91 | csv_file = pd.read_csv(csv_filename, header=None) 92 | csv_tmp = csv_file.copy() 93 | for idx, row in enumerate(csv_file.itertuples(index=True, name='Pandas')): 94 | lab_idx = [l for l in eval(row[3]) if l not in labels_to_delete_index] 95 | lab_name = [l for l in eval(row[2]) if l not in labels_to_delete_name] 96 | csv_tmp.at[idx, 1] = lab_name 97 | csv_tmp.at[idx, 2] = lab_idx 98 | print(idx) 99 | print("write csv file without the labels that are no more useful....") 100 | csv_tmp.to_csv(new_csv_file, header=False, index=False) 101 | return True 102 | 103 | @staticmethod 104 | def change_labels_in_csv(csv_filename: str, new_csv_file: str) -> True: 105 | """ 106 | Function that changes the existing labels to the new ones 107 | :param csv_filename: intermediate csv file 108 | :param new_csv_file: new csv file 109 | :return: True 110 | """ 111 | csv_file = pd.read_csv(csv_filename, header=None) 112 | csv_tmp = csv_file.copy() 113 | for idx, row in enumerate(csv_file.itertuples(index=True, name='Pandas')): 114 | lab_idx = [new_labels_assignments[l] for l in eval(row[3])] 115 | lab_name = [BigEarthNet19_labels[l] for l in lab_idx] 116 | csv_tmp.at[idx, 1] = lab_name 117 | csv_tmp.at[idx, 2] = lab_idx 118 | print(idx) 119 | print("write csv file without the labels that are no more useful....") 120 | csv_tmp.to_csv(new_csv_file, header=False, index=False) 121 | 122 | @staticmethod 123 | def search_empty_labels(csv_filename: str, new_csv_file: str) -> True: 124 | """ 125 | Function that searches for eventual satellitary images with no labels and erases it 126 | :param csv_filename: intermediate csv file 127 | :param new_csv_file: new csv file 128 | :return: True 129 | """ 130 | csv_file = pd.read_csv(csv_filename, header=None) 131 | csv_tmp = csv_file.copy() 132 | idx_to_delete = [] 133 | for idx, row in enumerate(csv_file.itertuples(index=True, name='Pandas')): 134 | lab_idx = [new_labels_assignments[l] for l in eval(row[3])] 135 | if not lab_idx: 136 | idx_to_delete.append(idx) 137 | print(idx) 138 | print("Number of entries to delete: ", len(idx_to_delete)) 139 | csv_tmp = csv_tmp.drop(idx_to_delete) 140 | print("write csv file without the empty labels entries...") 141 | csv_tmp.to_csv(new_csv_file, header=False, index=False) 142 | 143 | 144 | if __name__ == '__main__': 145 | argparser = argparse.ArgumentParser(description='BigEarthNet change labels') 146 | argparser.add_argument('--csv_filename', type=str, default='BigEarth.csv', help='Name of the csv dataset file') 147 | argparser.add_argument('--new_path_csv', type=str, default=None, help='indicate the new name of the csv') 148 | args = argparser.parse_args() 149 | Labels.delete_labels_in_csv(args.csv_filename, args.new_path_csv) 150 | Labels.change_labels_in_csv(args.new_path_csv, args.new_path_csv) 151 | Labels.search_empty_labels(args.new_path_csv, args.new_path_csv) 152 | 153 | 154 | 155 | -------------------------------------------------------------------------------- /Multi_label_classification/dataset/dataset_big_earth_mlc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import time 4 | from typing import Tuple 5 | 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import SubsetRandomSampler 9 | 10 | from Colorization.dataset.dataset_big_earth import BigEarthDataset 11 | 12 | 13 | class BigEarthDatasetMLC(BigEarthDataset): 14 | def __init__(self, csv_path: str, quantiles: str, random_seed: int, bands: list, 15 | create_torch_dataset=0, n_samples=100000): 16 | BigEarthDataset.__init__(self, csv_path, quantiles, random_seed, bands, 17 | create_torch_dataset, n_samples) 18 | 19 | def __getitem__(self, index: int) -> Tuple[torch.tensor, torch.tensor]: 20 | # obtain the right folder 21 | imgs_file = self.folder_path[index] 22 | imgs_bands = [] 23 | # load image 24 | for b in self.bands: 25 | for filename in glob.iglob(imgs_file + "/*" + b + ".tif"): 26 | band = self.custom_loader(filename, b, self.quantiles) 27 | imgs_bands.append(band) 28 | if len(self.bands) == 3: 29 | spectral_img = np.concatenate(imgs_bands[::-1], axis=2) # inverse order for rgb 30 | else: 31 | spectral_img = np.concatenate(imgs_bands, axis=2) 32 | # create multi-hot labels vector 33 | labels_index = list(map(int, self.labels_class[index][1:-1].split(','))) 34 | labels_class = np.zeros(19) 35 | labels_class[labels_index] = 1 36 | return self.to_tensor(spectral_img), torch.tensor(labels_class) 37 | 38 | 39 | if __name__ == "__main__": 40 | argparser = argparse.ArgumentParser(description='BigEarthNetMLC dataset tiff version') 41 | argparser.add_argument('--csv_filename', type=str, 42 | default='BigEarthNet_all_refactored_no_clouds_and_snow_server.csv', 43 | required=True, help='csv containing dataset paths') 44 | argparser.add_argument('--n_samples', type=int, default=3000, help='Number of samples to create the csv file') 45 | 46 | args = argparser.parse_args() 47 | # Dataset definition 48 | big_earth = BigEarthDatasetMLC(csv_path=args.csv_filename, quantiles='../../Colorization/dataset/quantiles_3000.json', 49 | random_seed=19, 50 | bands=["B01", "B02", "B03", "B04", "B05", "B06", "B07", "B08", "B8A", "B09", "B11", 51 | "B12"], 52 | n_samples=args.n_samples) 53 | # dataset split 54 | train_idx, val_idx, test_idx = big_earth.split_dataset(0.2, 0.4) 55 | # dataset sampler 56 | train_sampler = SubsetRandomSampler(train_idx) 57 | val_sampler = SubsetRandomSampler(val_idx) 58 | test_sampler = SubsetRandomSampler(test_idx) 59 | # dataset loader 60 | train_loader = torch.utils.data.DataLoader(big_earth, batch_size=16, 61 | sampler=train_sampler, num_workers=4) 62 | test_loader = torch.utils.data.DataLoader(big_earth, batch_size=1, 63 | sampler=test_sampler, num_workers=0) 64 | start_time = time.time() 65 | 66 | for idx, (spectral_img, labels) in enumerate(train_loader): 67 | print(idx) 68 | 69 | print("time: ", time.time() - start_time) 70 | -------------------------------------------------------------------------------- /Multi_label_classification/dataset/dataset_big_earth_torch_mlc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import time 4 | from typing import Tuple 5 | 6 | import numpy as np 7 | import torch 8 | from torch import nn 9 | from torch.utils.data import SubsetRandomSampler 10 | 11 | from Colorization.dataset.dataset_big_earth_torch import BigEarthDatasetTorch 12 | 13 | 14 | class BigEarthDatasetTorchMLC(BigEarthDatasetTorch): 15 | def __init__(self, csv_path: str, random_seed: int, bands_indices: list, img_size: int, n_samples=100000): 16 | BigEarthDatasetTorch.__init__(self, csv_path=csv_path, random_seed=random_seed, bands_indices=bands_indices, 17 | img_size=img_size, augmentation=0, n_samples=n_samples) 18 | 19 | def __getitem__(self, index: int) -> Tuple[torch.tensor, torch.tensor]: 20 | # obtain the right folder 21 | imgs_file = self.folder_path[index] 22 | # load torch image 23 | spectral_img = torch.load(imgs_file + '/all_bands_chroma.pt') 24 | # resize the image as specified in the params dsize 25 | spectral_img = torch.squeeze( 26 | nn.functional.interpolate(input=torch.unsqueeze(spectral_img, dim=0), size=self.img_size)) 27 | # take only the bands specified in the init 28 | spectral_img = spectral_img[self.bands_indices] 29 | # if RGB: invert the indices as it is saved as BGR 30 | if sum(self.bands_indices) == 3: 31 | spectral_img = torch.flip(spectral_img, [0]) 32 | # create multi-hot labels vector 33 | labels_index = list(map(int, self.labels_class[index][1:-1].split(','))) 34 | labels_class = np.zeros(19) 35 | labels_class[labels_index] = 1 36 | return spectral_img, torch.tensor(labels_class) 37 | 38 | 39 | if __name__ == "__main__": 40 | argparser = argparse.ArgumentParser(description='BigEarthNetTorchMLC dataset version') 41 | argparser.add_argument('--csv_filename', type=str, 42 | default='BigEarthNet_all_refactored_no_clouds_and_snow_server.csv', 43 | required=True, help='csv containing dataset paths') 44 | argparser.add_argument('--n_samples', type=int, default=3000, help='Number of samples to create the csv file') 45 | 46 | args = argparser.parse_args() 47 | # Dataset definition 48 | big_earth = BigEarthDatasetTorchMLC(csv_path=args.csv_filename, random_seed=19, bands_indices=[0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], 49 | img_size=128, n_samples=args.n_samples) 50 | # dataset split 51 | train_idx, val_idx, test_idx = big_earth.split_dataset(0.2, 0.4) 52 | # dataset sampler 53 | train_sampler = SubsetRandomSampler(train_idx) 54 | val_sampler = SubsetRandomSampler(val_idx) 55 | test_sampler = SubsetRandomSampler(test_idx) 56 | # dataset loader 57 | train_loader = torch.utils.data.DataLoader(big_earth, batch_size=16, 58 | sampler=train_sampler, num_workers=4) 59 | test_loader = torch.utils.data.DataLoader(big_earth, batch_size=1, 60 | sampler=test_sampler, num_workers=0) 61 | start_time = time.time() 62 | 63 | for idx, (spectral_img, labels) in enumerate(train_loader): 64 | print(idx) 65 | 66 | print("time: ", time.time() - start_time) 67 | -------------------------------------------------------------------------------- /Multi_label_classification/job_config.py: -------------------------------------------------------------------------------- 1 | def set_params(params, id_optim): 2 | if id_optim is None: 3 | pass 4 | else: 5 | if id_optim == 0: 6 | params.dataset_nsamples = 5000 7 | params.seed = 19 8 | params.epochs = 30 9 | params.batch_size = 32 10 | params.resnet_version = 18 11 | params.pretrained = 0 12 | params.out_cls = 19 13 | params.dataset = "dataset/BigEarthNet_all_refactored_no_clouds_and_snow_v2_new_path_new_labels.csv" 14 | params.log_dir = "/9_bands/19Labels/5k_R18/exp_1" 15 | params.bands = [1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0] 16 | params.input_channels = 9 17 | params.lr = 0.1 18 | params.optim = "SGD" 19 | params.img_size = 128 20 | params.scheduler = 1 21 | params.sched_type = "multi" 22 | params.sched_milestones = [10, 40] 23 | params.load_checkpoint = 0 24 | params.path_model_dict = "/nas/softechict-nas-2/svincenzi/colorization_resnet/experiments_resnet18_AE/500k_augmentation_scratch_continue_training/_batch_16/last.pth.tar" 25 | params.load_checkpoint_tr = 0 26 | params.path_model_dict_tr = "" 27 | params.num_workers = 4 28 | elif id_optim == 1: 29 | params.dataset_nsamples = 5000 30 | params.seed = 19 31 | params.epochs = 30 32 | params.batch_size = 32 33 | params.resnet_version = 18 34 | params.pretrained = 0 35 | params.out_cls = 19 36 | params.dataset = "dataset/BigEarthNet_all_refactored_no_clouds_and_snow_v2_new_path_new_labels.csv" 37 | params.log_dir = "/3_bands/19Labels/5k_R18/exp_1" 38 | params.bands = [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0] 39 | params.input_channels = 3 40 | params.lr = 0.1 41 | params.optim = "SGD" 42 | params.img_size = 128 43 | params.scheduler = 1 44 | params.sched_type = "multi" 45 | params.sched_milestones = [10, 40] 46 | params.load_checkpoint = 0 47 | params.path_model_dict = "/nas/softechict-nas-2/svincenzi/colorization_resnet/experiments_resnet18_AE/500k_augmentation_scratch_continue_training/_batch_16/last.pth.tar" 48 | params.load_checkpoint_tr = 0 49 | params.path_model_dict_tr = "" 50 | params.num_workers = 4 51 | 52 | params.log_dir = params.log_dir + "_batch_" + str(params.batch_size) 53 | 54 | return params 55 | 56 | 57 | -------------------------------------------------------------------------------- /Multi_label_classification/main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import copy 5 | import logging 6 | import os 7 | import warnings 8 | 9 | import torch.nn as nn 10 | import torch.utils.data 11 | from torch import optim 12 | from torch.utils.data import SubsetRandomSampler 13 | from torch.utils.tensorboard import SummaryWriter 14 | 15 | from Colorization import utils 16 | from Multi_label_classification.dataset.dataset_big_earth_torch_mlc import BigEarthDatasetTorchMLC 17 | from Multi_label_classification.job_config import set_params 18 | from Multi_label_classification.metrics.metric import metrics_def 19 | from Multi_label_classification.models.ResnetMLC import ResNetMLC 20 | from Multi_label_classification.test import test 21 | from Multi_label_classification.train import train 22 | 23 | warnings.filterwarnings("ignore") 24 | 25 | os.environ["OMP_NUM_THREADS"] = "1" 26 | torch.backends.cudnn.benchmark = True 27 | torch.backends.cudnn.enabled = True 28 | 29 | 30 | def main(args): 31 | # enable cuda if available 32 | args.cuda = args.cuda and torch.cuda.is_available() 33 | device = torch.device("cuda" if args.cuda else "cpu") 34 | 35 | # READ JSON CONFIG FILE 36 | assert os.path.isfile(args.json_config_file), "No json configuration file found at {}".format(args.json_config_file) 37 | params = utils.Params(args.json_config_file) 38 | 39 | # for change params related to job-id 40 | params = set_params(params, args.id_optim) 41 | 42 | # set the torch seed 43 | torch.manual_seed(params.seed) 44 | 45 | # initialize summary writer; every folder is saved inside runs 46 | writer = SummaryWriter(params.path_nas + params.log_dir + '/runs/') 47 | 48 | # create dir for log file 49 | if not os.path.exists(params.path_nas + params.log_dir): 50 | os.makedirs(params.path_nas + params.log_dir) 51 | 52 | # save the json config file of the model 53 | params.save(os.path.join(params.path_nas + params.log_dir, "params.json")) 54 | 55 | # Set the logger 56 | utils.set_logger(os.path.join(params.path_nas + params.log_dir, "log")) 57 | 58 | # DATASET 59 | # Torch version 60 | big_earth = BigEarthDatasetTorchMLC(csv_path=params.dataset, random_seed=params.seed, bands_indices=params.bands, 61 | img_size=params.img_size, n_samples=params.dataset_nsamples) 62 | # Split 63 | train_idx, val_idx, test_idx = big_earth.split_dataset(params.test_split, params.val_split) 64 | 65 | # define the sampler 66 | train_sampler = SubsetRandomSampler(train_idx) 67 | val_sampler = SubsetRandomSampler(val_idx) 68 | test_sampler = SubsetRandomSampler(test_idx) 69 | # define the loader 70 | train_loader = torch.utils.data.DataLoader(big_earth, batch_size=params.batch_size, 71 | sampler=train_sampler, num_workers=params.num_workers) 72 | val_loader = torch.utils.data.DataLoader(big_earth, batch_size=params.batch_size, 73 | sampler=val_sampler, num_workers=params.num_workers) 74 | test_loader = torch.utils.data.DataLoader(big_earth, batch_size=params.batch_size, 75 | sampler=test_sampler, num_workers=params.num_workers) 76 | 77 | # MODEL definition 78 | model = ResNetMLC(in_channels=params.input_channels, out_cls=params.out_cls, resnet_version=params.resnet_version, 79 | pretrained=params.pretrained, colorization=params.load_checkpoint) 80 | # Colorization checkpoint 81 | if params.load_checkpoint: 82 | checkpoint = torch.load(params.path_model_dict) 83 | model.load_state_dict(checkpoint['state_dict'], strict=False) 84 | # reset first layer when you want to apply colorization on all bands or RGB 85 | if params.change_first_conv: 86 | model.set_weights_conv1() 87 | # Checkpoint of the multi-label model 88 | if params.load_checkpoint_tr == 1: 89 | checkpoint = torch.load(params.path_model_dict_tr) 90 | model.load_state_dict(checkpoint['state_dict'], strict=False) 91 | 92 | # CUDA 93 | model.to(device) 94 | 95 | # loss for multilabel classification 96 | loss_fn = nn.MultiLabelSoftMarginLoss() 97 | 98 | # OPTIMIZER 99 | if params.optim == "Adam": 100 | optimizer = optim.Adam(model.parameters(), lr=params.lr, weight_decay=params.weight_decay) 101 | else: 102 | optimizer = optim.SGD(model.parameters(), lr=params.lr, momentum=0.9) 103 | 104 | # SCHEDULER 105 | if params.sched_type == "step": 106 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=params.sched_step, gamma=0.1) 107 | else: 108 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=params.sched_milestones, gamma=0.1) 109 | 110 | if params.load_checkpoint_tr: 111 | optimizer.load_state_dict(checkpoint['optim_dict']) 112 | scheduler.load_state_dict(checkpoint['scheduler_dict']) 113 | start_epoch = scheduler.last_epoch # start_epoch = checkpoint['epoch'] + 1 114 | else: 115 | start_epoch = 0 116 | 117 | # METRICS 118 | metrics = metrics_def 119 | 120 | # save the best model 121 | best_avg_prec_micro = 0.0 122 | best_model = copy.deepcopy(model.state_dict()) 123 | for epoch in range(params.epochs - start_epoch): 124 | # Training 125 | if params.load_checkpoint_tr: 126 | epoch += start_epoch 127 | logging.info("Starting training for {} epoch(s)".format(params.epochs)) 128 | logging.info("Epoch {}/{}".format(epoch, params.epochs)) 129 | train(model=model, train_loader=train_loader, loss_fn=loss_fn, optimizer=optimizer, device=device, metrics=metrics) 130 | # validation 131 | if epoch % params.val_step == 0: 132 | logging.info("Starting test for {} epoch(s)".format(params.epochs)) 133 | avg_pr_micro = test(model=model, test_loader=val_loader, loss_fn=loss_fn, 134 | device=device, metrics=metrics) 135 | # save best model params based on avg_pr_micro score on validation set 136 | if avg_pr_micro > best_avg_prec_micro: 137 | best_avg_prec_micro = avg_pr_micro 138 | best_model = copy.deepcopy(model.state_dict()) 139 | state = {'epoch': epoch, 140 | 'state_dict': model.state_dict(), 141 | 'optim_dict': optimizer.state_dict(), 142 | 'scheduler_dict': scheduler.state_dict()} 143 | path_to_save_chk = params.path_nas + params.log_dir 144 | utils.save_checkpoint(state, 145 | is_best=True, # True if this is the model with best metrics 146 | checkpoint=path_to_save_chk) # path to folder 147 | # scheduler step 148 | if params.scheduler: 149 | scheduler.step() 150 | logging.info("lr: {}".format(scheduler.get_lr()[0])) 151 | # Save checkpoint 152 | if epoch % params.save_checkpoint == 0: 153 | # as I don't have a good metric to check I save the final state of the model.. 154 | state = {'epoch': epoch, 155 | 'state_dict': model.state_dict(), 156 | 'optim_dict': optimizer.state_dict(), 157 | 'scheduler_dict': scheduler.state_dict()} 158 | path_to_save_chk = params.path_nas + params.log_dir 159 | utils.save_checkpoint(state, 160 | is_best=False, # True if this is the model with best metrics 161 | checkpoint=path_to_save_chk) # path to folder 162 | 163 | logging.info("Starting final test...") 164 | test(model=model, test_loader=test_loader, loss_fn=loss_fn, device=device, metrics=metrics) 165 | 166 | logging.info("Starting final test with best model...") 167 | model.load_state_dict(best_model) 168 | test(model=model, test_loader=test_loader, loss_fn=loss_fn, device=device, metrics=metrics) 169 | 170 | # CLOSE THE WRITER 171 | writer.close() 172 | 173 | 174 | if __name__ == '__main__': 175 | # command line arguments 176 | parser = argparse.ArgumentParser(description='multi_label_classification') 177 | parser.add_argument('--cuda', action='store_true', default=True, help='enables CUDA training') 178 | parser.add_argument('--json_config_file', default='config/configuration.json', help='name of the json config file') 179 | parser.add_argument('--id_optim', default=1, type=int, help='id_optim parameter') 180 | # read the args 181 | args = parser.parse_args() 182 | main(args) 183 | -------------------------------------------------------------------------------- /Multi_label_classification/main_ensemble.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import logging 5 | import os 6 | import warnings 7 | 8 | import torch.nn as nn 9 | import torch.utils.data 10 | from torch.utils.data import SubsetRandomSampler 11 | from torch.utils.tensorboard import SummaryWriter 12 | 13 | from Colorization import utils 14 | from Multi_label_classification.dataset.dataset_big_earth_torch_mlc import BigEarthDatasetTorchMLC 15 | from Multi_label_classification.job_config import set_params 16 | from Multi_label_classification.metrics.metric import metrics_def 17 | from Multi_label_classification.models.Ensemble import EnsembleModel 18 | from Multi_label_classification.models.ResnetMLC import ResNetMLC 19 | from Multi_label_classification.test import test 20 | 21 | warnings.filterwarnings("ignore") 22 | 23 | os.environ["OMP_NUM_THREADS"] = "1" 24 | torch.backends.cudnn.benchmark = True 25 | torch.backends.cudnn.enabled = True 26 | 27 | 28 | def main(args): 29 | # enable cuda if available 30 | args.cuda = args.cuda and torch.cuda.is_available() 31 | device = torch.device("cuda" if args.cuda else "cpu") 32 | 33 | # READ JSON CONFIG FILE 34 | assert os.path.isfile(args.json_config_file), "No json configuration file found at {}".format(args.json_config_file) 35 | params = utils.Params(args.json_config_file) 36 | 37 | # for change params related to job-id 38 | params = set_params(params, args.id_optim) 39 | 40 | # set the torch seed 41 | torch.manual_seed(params.seed) 42 | 43 | # initialize summary writer; every folder is saved inside runs 44 | writer = SummaryWriter(params.path_nas + params.log_dir + '/runs/') 45 | 46 | # create dir for log file 47 | if not os.path.exists(params.path_nas + params.log_dir): 48 | os.makedirs(params.path_nas + params.log_dir) 49 | 50 | # save the json config file of the model 51 | params.save(os.path.join(params.path_nas + params.log_dir, "params.json")) 52 | 53 | # Set the logger 54 | utils.set_logger(os.path.join(params.path_nas + params.log_dir, "log")) 55 | 56 | # DATASET 57 | # Torch version 58 | big_earth = BigEarthDatasetTorchMLC(csv_path=params.dataset, random_seed=params.seed, bands_indices=params.bands, 59 | img_size=params.img_size, n_samples=params.dataset_nsamples) 60 | # Split 61 | train_idx, val_idx, test_idx = big_earth.split_dataset(params.test_split, params.val_split) 62 | 63 | test_sampler = SubsetRandomSampler(test_idx) 64 | # define the loader 65 | test_loader = torch.utils.data.DataLoader(big_earth, batch_size=params.batch_size, 66 | sampler=test_sampler, num_workers=params.num_workers) 67 | # MODELS definition for Ensemble 68 | model_rgb = ResNetMLC(in_channels=3, out_cls=params.out_cls, resnet_version=params.resnet_version, 69 | pretrained=0, colorization=0) 70 | model_colorization = ResNetMLC(in_channels=9, out_cls=params.out_cls, resnet_version=params.resnet_version, 71 | pretrained=0, colorization=1) 72 | 73 | checkpoint = torch.load(args.rgb_checkpoint) 74 | model_rgb.load_state_dict(checkpoint['state_dict'], strict=False) 75 | 76 | checkpoint = torch.load(args.spectral_checkpoint) 77 | model_colorization.load_state_dict(checkpoint['state_dict'], strict=False) 78 | 79 | model = EnsembleModel(model_rgb=model_rgb, model_colorization=model_colorization, device=device) 80 | 81 | # CUDA 82 | model.to(device) 83 | 84 | # loss for multilabel classification 85 | loss_fn = nn.MultiLabelSoftMarginLoss() 86 | 87 | # METRICS 88 | metrics = metrics_def 89 | 90 | logging.info("Starting final test with ensemble model...") 91 | test(model=model, test_loader=test_loader, loss_fn=loss_fn, 92 | device=device, metrics=metrics) 93 | 94 | # CLOSE THE WRITER 95 | writer.close() 96 | 97 | 98 | if __name__ == '__main__': 99 | # command line arguments 100 | parser = argparse.ArgumentParser(description='multi_label_classification') 101 | parser.add_argument('--cuda', action='store_true', default=True, help='enables CUDA training') 102 | parser.add_argument('--json_config_file', default='Multi_label_classification/config/configuration.json', help='name of the json config file') 103 | parser.add_argument('--id_optim', default=0, type=int, help='id_optim parameter') 104 | parser.add_argument('--rgb_checkpoint', type=str, default=None, help='specify the rgb checkpoint path', required=True) 105 | parser.add_argument('--spectral_checkpoint', type=str, default=None, help='specify the spectral checkpoint path', required=True) 106 | # read the args 107 | args = parser.parse_args() 108 | main(args) 109 | -------------------------------------------------------------------------------- /Multi_label_classification/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevinc/TheColorOutOfSpace/f8451a6b4d27ac98937eb21da6188e41f6cd254a/Multi_label_classification/metrics/__init__.py -------------------------------------------------------------------------------- /Multi_label_classification/metrics/metric.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import average_precision_score, precision_score, recall_score, f1_score, fbeta_score, hamming_loss 2 | 3 | 4 | def precision(labels, pred, average): 5 | return precision_score(labels, pred, average=average) 6 | 7 | 8 | def recall(labels, pred, average): 9 | return recall_score(labels, pred, average=average) 10 | 11 | 12 | def hamming_loss(labels, pred): 13 | return hamming_loss(labels, pred) 14 | 15 | 16 | def f1(labels, pred, average): 17 | return f1_score(labels, pred, average=average) 18 | 19 | 20 | def f2(labels, pred, average): 21 | return fbeta_score(labels, pred, average=average, beta=2) 22 | 23 | 24 | def average_precision(labels, score, average): 25 | return average_precision_score(labels, score, average=average) 26 | 27 | 28 | metrics_def = { 29 | 'precision': precision, 30 | 'recall': recall, 31 | 'f1': f1, 32 | 'f2': f2 33 | } 34 | 35 | -------------------------------------------------------------------------------- /Multi_label_classification/models/Ensemble.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Tuple 4 | 5 | 6 | class EnsembleModel(nn.Module): 7 | def __init__(self, model_rgb, model_colorization, device=0): 8 | super(EnsembleModel, self).__init__() 9 | self.device = device 10 | self.features_rgb = model_rgb.feature_extractor 11 | self.rgb_classifier = model_rgb.classifier 12 | self.features_colorization = model_colorization.feature_extractor 13 | self.colorization_classifier = model_colorization.classifier 14 | self.sigmoid = nn.Sigmoid() 15 | 16 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 17 | B = x.shape[0] 18 | # split the input tensor 19 | indices = torch.tensor([3, 2, 1]) 20 | indices_spectral = torch.tensor([0, 4, 5, 6, 7, 8, 9, 10, 11]) 21 | rgb = torch.index_select(input=x, dim=1, index=indices.to(self.device)) 22 | spectral = torch.index_select(input=x, dim=1, index=indices_spectral.to(self.device)) 23 | # FEATURES EXTRACTION 24 | features_rgb = self.features_rgb(rgb) 25 | out_rgb = self.rgb_classifier(features_rgb.view(B, 512)) 26 | features_colorization = self.features_colorization(spectral) 27 | out_colorization = self.colorization_classifier(features_colorization.view(B, 512)) 28 | # output concat 29 | out = torch.stack((out_rgb, out_colorization), 2) 30 | out = out.mean(2) 31 | # sigmoid concat 32 | out_sigmoid_rgb = self.sigmoid(out_rgb) 33 | out_sigmoid_colorization = self.sigmoid(out_colorization) 34 | out_sigmoid = torch.stack((out_sigmoid_rgb, out_sigmoid_colorization), 2) 35 | out_sigmoid = out_sigmoid.mean(2) 36 | return out, out_sigmoid.detach() 37 | -------------------------------------------------------------------------------- /Multi_label_classification/models/ResnetMLC.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models 4 | from typing import Tuple 5 | 6 | 7 | class ResNetMLC(nn.Module): 8 | def __init__(self, in_channels=1, out_cls=2, resnet_version=18, pretrained=0, colorization=0): 9 | super(ResNetMLC, self).__init__() 10 | 11 | if resnet_version == 18: 12 | if pretrained: 13 | self.model = models.resnet18(pretrained=True) 14 | else: 15 | self.model = models.resnet18(pretrained=False) 16 | else: 17 | if pretrained: 18 | self.model = models.resnet50(pretrained=True) 19 | else: 20 | self.model = models.resnet50(pretrained=False) 21 | 22 | # SET THE CORRECT NUMBER OF INPUT CHANNELS 23 | self.colorization = colorization 24 | self.in_channels = in_channels 25 | if self.colorization: 26 | self.conv_1 = nn.Conv2d(9, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) 27 | else: 28 | self.conv_1 = nn.Conv2d(in_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) 29 | if in_channels != 3: 30 | self.model.conv1 = self.conv_1 31 | 32 | # DEFINE THE FEATURE EXTRACTOR 33 | self.feature_extractor = nn.Sequential(*list(self.model.children())[:-1]) 34 | 35 | # CLASSIFIER 36 | num_ftrs = self.model.fc.in_features 37 | self.classifier = nn.Linear(in_features=num_ftrs, out_features=out_cls) 38 | self.sigmoid = nn.Sigmoid() 39 | 40 | def set_weights_conv1(self): 41 | if self.colorization: 42 | self.feature_extractor[0] = nn.Conv2d(self.in_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) 43 | return 44 | 45 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 46 | features = self.feature_extractor(x) 47 | # FEATURES-SPACE 48 | B, F = features.shape[:2] 49 | # MULTI-LABEL CLASSIFICATION 50 | out = self.classifier(features.view(B, F)) 51 | return out, self.sigmoid(out).detach() 52 | -------------------------------------------------------------------------------- /Multi_label_classification/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevinc/TheColorOutOfSpace/f8451a6b4d27ac98937eb21da6188e41f6cd254a/Multi_label_classification/models/__init__.py -------------------------------------------------------------------------------- /Multi_label_classification/test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | 4 | import torch 5 | from tqdm import tqdm 6 | 7 | from Multi_label_classification.metrics.metric import average_precision 8 | 9 | 10 | def test(model, test_loader, loss_fn, device, metrics): 11 | start_time = time.time() 12 | # SET THE MODEL TO EVALUATION MODE 13 | model.eval() 14 | 15 | test_loss = 0 16 | 17 | # global average precision score on test set 18 | out_sigmoid_epoch = torch.Tensor() 19 | labels_class_epoch = torch.Tensor() 20 | 21 | with torch.no_grad(): 22 | with tqdm(total=len(test_loader)) as t: 23 | for batch_idx, (in_bands, labels_class) in enumerate(test_loader): 24 | # move input data to GPU 25 | in_bands = in_bands.to(device) 26 | labels_class = labels_class.to(device) 27 | 28 | # FORWARD PASS 29 | out, out_sigmoid = model(in_bands) 30 | 31 | # Multi-label classification loss 32 | loss = loss_fn(out, labels_class) 33 | 34 | test_loss += loss.item() 35 | 36 | # write loss 37 | t.set_postfix(loss='{:05.3f}'.format(loss.item())) 38 | t.update() 39 | 40 | # concat tensor in order to calculate the overall avg precision score for the entire test set 41 | labels_class_epoch = torch.cat((labels_class_epoch, labels_class.type(torch.FloatTensor).cpu()), 0) 42 | out_sigmoid_epoch = torch.cat((out_sigmoid_epoch, out_sigmoid.cpu()), 0) 43 | 44 | # final metrics 45 | # again threshold on sigmoid outputs 46 | out_thresholded_epoch = (out_sigmoid_epoch > 0.5).float() 47 | 48 | # overall metrics for the current epoch, log on file 49 | metrics_calc = {metric: metrics[metric](labels_class_epoch.cpu(), out_thresholded_epoch.cpu(), 'micro') for metric in 50 | metrics} 51 | metrics_string = " ; ".join("{}: {:05.3f}".format(k, v) for k, v in metrics_calc.items()) 52 | logging.info("- Test metrics micro: " + metrics_string) 53 | 54 | metrics_calc = {metric: metrics[metric](labels_class_epoch.cpu(), out_thresholded_epoch.cpu(), 'weighted') for 55 | metric in 56 | metrics} 57 | metrics_string = " ; ".join("{}: {:05.3f}".format(k, v) for k, v in metrics_calc.items()) 58 | logging.info("- Test metrics weighted: " + metrics_string) 59 | 60 | # AVERAGE PRECISION SCORE 61 | avg_pr_micro = average_precision(labels_class_epoch.cpu(), out_sigmoid_epoch.cpu(), average='micro') 62 | avg_pr_macro = average_precision(labels_class_epoch.cpu(), out_sigmoid_epoch.cpu(), average='macro') 63 | avg_pr_weighted = average_precision(labels_class_epoch.cpu(), out_sigmoid_epoch.cpu(), average='weighted') 64 | avg_pr_none = average_precision(labels_class_epoch.cpu(), out_sigmoid_epoch.cpu(), average=None) 65 | dict_avg_pr_none = dict(zip(list(range(19)), avg_pr_none)) 66 | 67 | # logging 68 | logging.info("\n- Test MAP : micro: {:05.3f} ; macro: {:05.3f}; weighted: {:05.3f}".format(avg_pr_micro, avg_pr_macro, avg_pr_weighted)) 69 | avg_pr_none_str = "\n".join("cls: {} --> val: {:05.3f} ".format(cls, val) for cls, val in dict_avg_pr_none.items()) 70 | logging.info("Test MAP: None\n" + avg_pr_none_str) 71 | # print("Weigth_sig: " + weight_sig) 72 | time_elapsed = time.time() - start_time 73 | logging.info('Test complete in {:.0f}m {:.0f}s. Avg test loss: {:05.3f}'.format( 74 | time_elapsed // 60, time_elapsed % 60, test_loss / len(test_loader))) 75 | 76 | return avg_pr_micro 77 | -------------------------------------------------------------------------------- /Multi_label_classification/train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | 4 | import torch 5 | from tqdm import tqdm 6 | 7 | from Multi_label_classification.metrics.metric import average_precision 8 | 9 | 10 | def train(model, train_loader, loss_fn, optimizer, device, metrics): 11 | start_time = time.time() 12 | # SET THE MODEL TO TRAIN MODE 13 | model.train() 14 | 15 | train_loss = 0 16 | 17 | # EPOCH METRICS 18 | out_sigmoid_epoch = torch.Tensor() 19 | labels_class_epoch = torch.Tensor() 20 | 21 | with tqdm(total=len(train_loader)) as t: 22 | for batch_idx, (in_bands, labels_class) in enumerate(train_loader): 23 | # move input data to GPU 24 | in_bands = in_bands.to(device) 25 | labels_class = labels_class.to(device) 26 | 27 | # set the gradient to zero 28 | optimizer.zero_grad() 29 | 30 | # FORWARD PASS 31 | out, out_sigmoid = model(in_bands) 32 | # multi-label classification loss 33 | loss = loss_fn(out, labels_class) 34 | 35 | # BACKWARD PASS 36 | loss.backward() 37 | 38 | train_loss += loss.item() 39 | 40 | # write loss 41 | t.set_postfix(loss='{:05.3f}'.format(loss.item())) 42 | t.update() 43 | 44 | # update the params of the model 45 | optimizer.step() 46 | 47 | # concat tensor in order to calculate the overall metrics for the entire epoch 48 | labels_class_epoch = torch.cat((labels_class_epoch, labels_class.type(torch.FloatTensor).cpu()), 0) 49 | out_sigmoid_epoch = torch.cat((out_sigmoid_epoch, out_sigmoid.cpu()), 0) 50 | 51 | # Threshold of 0.5 on the sigmoid output for the entire epoch 52 | out_thresholded_epoch = (out_sigmoid_epoch > 0.5).float() 53 | 54 | # overall metrics for the current epoch, log on file 55 | metrics_calc = {metric: metrics[metric](labels_class_epoch.cpu(), out_thresholded_epoch.cpu(), 'micro') for metric in metrics} 56 | metrics_string = " ; ".join("{}: {:05.3f}".format(k, v) for k, v in metrics_calc.items()) 57 | logging.info("- Train metrics : " + metrics_string) 58 | 59 | # MAP 60 | avg_pr_micro = average_precision(labels_class_epoch.cpu(), out_sigmoid_epoch.cpu(), average='micro') 61 | avg_pr_macro = average_precision(labels_class_epoch.cpu(), out_sigmoid_epoch.cpu(), average='macro') 62 | avg_pr_weighted = average_precision(labels_class_epoch.cpu(), out_sigmoid_epoch.cpu(), average='weighted') 63 | avg_pr_none = average_precision(labels_class_epoch.cpu(), out_sigmoid_epoch.cpu(), average=None) 64 | dict_avg_pr_none = dict(zip(list(range(19)), avg_pr_none)) 65 | 66 | logging.info("\n- Train MAP : micro: {:05.3f} ; macro: {:05.3f}; weighted: {:05.3f}".format(avg_pr_micro, avg_pr_macro, avg_pr_weighted)) 67 | avg_pr_none_str = "\n".join("cls: {} --> val: {:05.3f} ".format(cls, val) for cls, val in dict_avg_pr_none.items()) 68 | logging.info("Train MAP: None\n" + avg_pr_none_str) 69 | 70 | time_elapsed = time.time() - start_time 71 | logging.info('Epoch complete in {:.0f}m {:.0f}s. Avg training loss: {:05.3f}'.format( 72 | time_elapsed // 60, time_elapsed % 60, train_loss / len(train_loader))) 73 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # The color out of space: learning self-supervised representations for Earth Observation imagery 2 | This repository contains the PyTorch code for the paper: 3 | 4 | **The color out of space: learning self-supervised representations for Earth Observation imagery** 5 | 6 | ## Model architecture 7 | ![Colorization & Multi-label classification - overview](colorization_framework-1.png) 8 | 9 | ## Prerequisites 10 | * Python >= 3.7 11 | * PyTorch >= 1.5 12 | * CUDA 10.0 13 | 14 | ## Dataset 15 | We adopt the BigEarthNet Dataset. Refer to the README in the ``Colorization\dataset`` and ``Multi_label_classification\dataset`` folders for further information. 16 | 17 | ## Models 18 | ### Colorization 19 | Different Encoder-Decoder combinations are available 20 | - *Encoder ResNet18 - Decoder ResNet18* 21 | - *Encoder ResNet50 - Decoder ResNet50* 22 | - *Encoder ResNet50 - Decoder ResNet18* 23 | ### Multi Label Classification 24 | The same encoders were employed in the colorization phase and an Ensemble model, composed of two equal encoders trained respectively on RGB and all other bands. 25 | 26 | ## Training 27 | Before running the files ``main.py`` contained in both the ``Colorization`` and ``Multi_label_classification`` folders you can set the desired parameters in the file ``job_config.py``, which modify the ones contained in ``config/configuration.json``. 28 | 29 | ## Cite 30 | If you have any questions, please contact [stefano.vincenzi@unimore.it](mailto:stefano.vincenzi@unimore.it), or open an issue on this repo. 31 | 32 | If you find this repository useful for your research, please cite the following paper: 33 | ```bibtex 34 | @inproceedings{vincenzi2020color, 35 | title={The color out of space: learning self-supervised representations for Earth Observation imagery}, 36 | author={Vincenzi, Stefano and Porrello, Angelo and Buzzega, Pietro and Cipriano, Marco and Pietro, Fronte and Roberto, Cuccu and 37 | Carla, Ippoliti and Annamaria, Conte and Calderara, Simone}, 38 | booktitle={25th International Conference on Pattern Recognition}, 39 | year={2020} 40 | } 41 | ``` 42 | -------------------------------------------------------------------------------- /colorization_framework-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevinc/TheColorOutOfSpace/f8451a6b4d27ac98937eb21da6188e41f6cd254a/colorization_framework-1.png --------------------------------------------------------------------------------