├── README.md ├── image_dataset.py ├── inception_modified.py ├── test_classification_pytorch.py ├── test_segmentation_pytorch.py ├── train_classification_pytorch.py └── train_segmentation_pytorch.py /README.md: -------------------------------------------------------------------------------- 1 | # deepsolar_pytorch 2 | The PyTorch version of DeepSolar 3 | -------------------------------------------------------------------------------- /image_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from torch.utils.data import Dataset, DataLoader 7 | import torchvision 8 | from torchvision import datasets, models, transforms, utils 9 | import torchvision.transforms.functional as TF 10 | 11 | import numpy as np 12 | import json 13 | import pandas as pd 14 | import pickle 15 | import matplotlib.pyplot as plt 16 | import skimage 17 | import skimage.io 18 | import skimage.transform 19 | from PIL import Image 20 | import time 21 | import os 22 | from os.path import join, exists 23 | import copy 24 | import random 25 | from collections import OrderedDict 26 | 27 | 28 | class ImageFolderModified(Dataset): 29 | def __init__(self, root_dir, transform): 30 | self.root_dir = root_dir 31 | self.transform = transform 32 | self.idx2dir = [] 33 | self.path_list = [] 34 | for subdir in sorted(os.listdir(self.root_dir)): 35 | if not os.path.isfile(subdir): 36 | self.idx2dir.append(subdir) 37 | for class_idx, subdir in enumerate(self.idx2dir): 38 | class_dir = os.path.join(self.root_dir, subdir) 39 | for f in os.listdir(class_dir): 40 | if f[-4:] in ['.png', '.jpg', 'JPEG', 'jpeg']: 41 | self.path_list.append([os.path.join(class_dir, f), class_idx]) 42 | 43 | def __len__(self): 44 | return len(self.path_list) 45 | 46 | def __getitem__(self, idx): 47 | img_path, class_idx = self.path_list[idx] 48 | image = Image.open(img_path) 49 | if not image.mode == 'RGB': 50 | image = image.convert('RGB') 51 | image = self.transform(image) 52 | sample = [image, class_idx, img_path] 53 | return sample 54 | -------------------------------------------------------------------------------- /inception_modified.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from torch.utils.data import Dataset, DataLoader 7 | import torchvision 8 | from torchvision import datasets, models, transforms, utils 9 | import torchvision.transforms.functional as TF 10 | 11 | from tqdm import tqdm 12 | import numpy as np 13 | import json 14 | import pandas as pd 15 | import pickle 16 | import matplotlib.pyplot as plt 17 | import skimage 18 | import skimage.io 19 | import skimage.transform 20 | from PIL import Image 21 | import time 22 | import os 23 | from os.path import join, exists 24 | import copy 25 | import random 26 | from collections import OrderedDict 27 | from sklearn.metrics import r2_score 28 | 29 | 30 | import torch.nn.functional as F 31 | from torchvision.models import Inception3 32 | from collections import namedtuple 33 | 34 | _InceptionOuputs = namedtuple('InceptionOuputs', ['logits', 'aux_logits']) 35 | 36 | 37 | class InceptionSegmentation(nn.Module): 38 | def __init__(self, num_outputs=2, level=1): 39 | super(InceptionSegmentation, self).__init__() 40 | assert level in [1, 2] 41 | self.level = level 42 | self.inception3 = Inception3_modified(num_classes=num_outputs, aux_logits=False, transform_input=False) 43 | self.convolution1 = nn.Conv2d(288, 512, bias=True, kernel_size=3, padding=1) 44 | if self.level == 1: 45 | self.linear1 = nn.Linear(512, num_outputs, bias=False) 46 | else: 47 | self.convolution2 = nn.Conv2d(512, 512, bias=True, kernel_size=3, padding=1) 48 | self.linear2 = nn.Linear(512, num_outputs, bias=False) 49 | 50 | def forward(self, x, testing=False): 51 | logits, intermediate = self.inception3(x) 52 | feature_map = self.convolution1(intermediate) # N x 512 x 35 x 35 53 | feature_map = F.relu(feature_map) # N x 512 x 35 x 35 54 | if self.level == 1: 55 | y = F.adaptive_avg_pool2d(feature_map, (1, 1)) 56 | y = y.view(y.size(0), -1) # N x 512 57 | y = self.linear1(y) # N x 2 58 | if testing: 59 | CAM = self.linear1.weight.data[1, :] * feature_map.permute(0, 2, 3, 1) 60 | CAM = CAM.sum(dim=3) 61 | else: 62 | feature_map = self.convolution2(feature_map) # N x 512 x 35 x 35 63 | feature_map = F.relu(feature_map) # N x 512 x 35 x 35 64 | y = F.adaptive_avg_pool2d(feature_map, (1, 1)) 65 | y = y.view(y.size(0), -1) # N x 512 66 | y = self.linear2(y) # N x 2 67 | if testing: 68 | CAM = self.linear2.weight.data[1, :] * feature_map.permute(0, 2, 3, 1) 69 | CAM = CAM.sum(dim=3) 70 | if testing: 71 | return y, logits, CAM 72 | else: 73 | return y 74 | 75 | def load_basic_params(self, model_path, device=torch.device('cpu')): 76 | """Only load the parameters from main branch.""" 77 | old_params = torch.load(model_path, map_location=device) 78 | if model_path[-4:] == '.tar': # The file is not a model state dict, but a checkpoint dict 79 | old_params = old_params['model_state_dict'] 80 | self.inception3.load_state_dict(old_params, strict=False) 81 | print('Loaded basic model parameters from: ' + model_path) 82 | 83 | def load_existing_params(self, model_path, device=torch.device('cpu')): 84 | """Load the parameters of main branch and parameters of level-1 layers (and perhaps level-2 layers.)""" 85 | old_params = torch.load(model_path, map_location=device) 86 | if model_path[-4:] == '.tar': # The file is not a model state dict, but a checkpoint dict 87 | old_params = old_params['model_state_dict'] 88 | self.load_state_dict(old_params, strict=False) 89 | print('Loaded existing model parameters from: ' + model_path) 90 | 91 | 92 | class Inception3_modified(Inception3): 93 | def forward(self, x): 94 | if self.transform_input: 95 | x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 96 | x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 97 | x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 98 | x = torch.cat((x_ch0, x_ch1, x_ch2), 1) 99 | # N x 3 x 299 x 299 100 | x = self.Conv2d_1a_3x3(x) 101 | # N x 32 x 149 x 149 102 | x = self.Conv2d_2a_3x3(x) 103 | # N x 32 x 147 x 147 104 | x = self.Conv2d_2b_3x3(x) 105 | # N x 64 x 147 x 147 106 | x = F.max_pool2d(x, kernel_size=3, stride=2) 107 | # N x 64 x 73 x 73 108 | x = self.Conv2d_3b_1x1(x) 109 | # N x 80 x 73 x 73 110 | x = self.Conv2d_4a_3x3(x) 111 | # N x 192 x 71 x 71 112 | x = F.max_pool2d(x, kernel_size=3, stride=2) 113 | # N x 192 x 35 x 35 114 | x = self.Mixed_5b(x) 115 | # N x 256 x 35 x 35 116 | x = self.Mixed_5c(x) 117 | # N x 288 x 35 x 35 118 | x = self.Mixed_5d(x) 119 | # N x 288 x 35 x 35 120 | intermediate = x.clone() 121 | x = self.Mixed_6a(x) 122 | # N x 768 x 17 x 17 123 | x = self.Mixed_6b(x) 124 | # N x 768 x 17 x 17 125 | x = self.Mixed_6c(x) 126 | # N x 768 x 17 x 17 127 | x = self.Mixed_6d(x) 128 | # N x 768 x 17 x 17 129 | x = self.Mixed_6e(x) 130 | # N x 768 x 17 x 17 131 | if self.training and self.aux_logits: 132 | aux = self.AuxLogits(x) 133 | # N x 768 x 17 x 17 134 | x = self.Mixed_7a(x) 135 | # N x 1280 x 8 x 8 136 | x = self.Mixed_7b(x) 137 | # N x 2048 x 8 x 8 138 | x = self.Mixed_7c(x) 139 | # N x 2048 x 8 x 8 140 | # Adaptive average pooling 141 | x = F.adaptive_avg_pool2d(x, (1, 1)) 142 | # N x 2048 x 1 x 1 143 | x = F.dropout(x, training=self.training) 144 | # N x 2048 x 1 x 1 145 | x = x.view(x.size(0), -1) 146 | # N x 2048 147 | x = self.fc(x) 148 | # N x 1000 (num_classes) 149 | if self.training and self.aux_logits: 150 | return _InceptionOuputs(x, aux) 151 | return x, intermediate 152 | -------------------------------------------------------------------------------- /test_classification_pytorch.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from torch.utils.data import Dataset, DataLoader 7 | import torchvision 8 | from torchvision import datasets, models, transforms, utils 9 | import torchvision.transforms.functional as TF 10 | 11 | from tqdm import tqdm 12 | import numpy as np 13 | import json 14 | import pandas as pd 15 | import pickle 16 | import matplotlib.pyplot as plt 17 | import skimage 18 | import skimage.io 19 | import skimage.transform 20 | from PIL import Image 21 | import time 22 | import os 23 | from os.path import join, exists 24 | import copy 25 | import random 26 | from collections import OrderedDict 27 | from sklearn.metrics import r2_score 28 | 29 | from torch.nn import functional as F 30 | from torchvision.models import Inception3 31 | 32 | # Configuration 33 | # directory for loading training/validation/test data 34 | data_dir = '/home/ubuntu/projects/deepsolar/deepsolar_dataset_toy/test' 35 | old_ckpt_path = '/home/ubuntu/projects/deepsolar/deepsolar_pytorch_pretrained/deepsolar_pretrained.pth' 36 | 37 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 38 | input_size = 299 39 | batch_size = 32 40 | threshold = 0.2 # threshold probability to identify am image as positive 41 | 42 | def metrics(stats): 43 | """stats: {'TP': TP, 'FP': FP, 'TN': TN, 'FN': FN} 44 | return: must be a single number """ 45 | precision = (stats['TP'] + 0.00001) * 1.0 / (stats['TP'] + stats['FP'] + 0.00001) 46 | recall = (stats['TP'] + 0.00001) * 1.0 / (stats['TP'] + stats['FN'] + 0.00001) 47 | return 0.5*(precision + recall) 48 | 49 | 50 | def test_model(model, dataloader, metrics, threshold): 51 | stats = {'TP': 0, 'FP': 0, 'TN': 0, 'FN': 0} 52 | model.eval() 53 | for inputs, labels in tqdm(dataloader): 54 | inputs = inputs.to(device) 55 | labels = labels.to(device) 56 | with torch.set_grad_enabled(False): 57 | outputs = model(inputs) 58 | prob = F.softmax(outputs, dim=1) 59 | preds = prob[:, 1] >= threshold 60 | 61 | stats['TP'] += torch.sum((preds == 1) * (labels == 1)).cpu().item() 62 | stats['TN'] += torch.sum((preds == 0) * (labels == 0)).cpu().item() 63 | stats['FP'] += torch.sum((preds == 1) * (labels == 0)).cpu().item() 64 | stats['FN'] += torch.sum((preds == 0) * (labels == 1)).cpu().item() 65 | 66 | metric_value = metrics(stats) 67 | return stats, metric_value 68 | 69 | transform_test = transforms.Compose([ 70 | transforms.Resize(input_size), 71 | transforms.ToTensor(), 72 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 73 | ]) 74 | 75 | if __name__ == '__main__': 76 | # data 77 | dataset_test = datasets.ImageFolder(data_dir, transform_test) 78 | dataloader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=True, num_workers=4) 79 | # model 80 | model = Inception3(num_classes=2, aux_logits=True, transform_input=False) 81 | model = model.to(device) 82 | # load old parameters 83 | checkpoint = torch.load(old_ckpt_path, map_location=device) 84 | if old_ckpt_path[-4:] == '.tar': # it is a checkpoint dictionary rather than just model parameters 85 | model.load_state_dict(checkpoint['model_state_dict']) 86 | else: 87 | model.load_state_dict(checkpoint) 88 | print('Old checkpoint loaded: ' + old_ckpt_path) 89 | stats, metric_value = test_model(model, dataloader_test, metrics, threshold=threshold) 90 | precision = (stats['TP'] + 0.00001) * 1.0 / (stats['TP'] + stats['FP'] + 0.00001) 91 | recall = (stats['TP'] + 0.00001) * 1.0 / (stats['TP'] + stats['FN'] + 0.00001) 92 | print('metric value: '+str(metric_value)) 93 | print('precision: ' + str(round(precision, 4))) 94 | print('recall: ' + str(round(recall, 4))) 95 | -------------------------------------------------------------------------------- /test_segmentation_pytorch.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from torch.utils.data import Dataset, DataLoader 7 | import torchvision 8 | from torchvision import datasets, models, transforms, utils 9 | import torchvision.transforms.functional as TF 10 | 11 | from tqdm import tqdm 12 | import numpy as np 13 | import json 14 | import pandas as pd 15 | import pickle 16 | import matplotlib.pyplot as plt 17 | import skimage 18 | import skimage.io 19 | import skimage.transform 20 | from PIL import Image 21 | import time 22 | import os 23 | from os.path import join, exists 24 | import copy 25 | import random 26 | from collections import OrderedDict 27 | from sklearn.metrics import r2_score 28 | 29 | from torch.nn import functional as F 30 | from torchvision.models import Inception3 31 | 32 | from inception_modified import InceptionSegmentation 33 | from image_dataset import ImageFolderModified 34 | 35 | # Configuration 36 | # directory for loading training/validation/test data 37 | data_dir = '/home/ubuntu/projects/deepsolar/deepsolar_dataset_toy/test' 38 | old_ckpt_path = '/home/ubuntu/projects/deepsolar/deepsolar_pytorch_pretrained/deepsolar_seg_pretrained.pth' 39 | 40 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 41 | input_size = 299 42 | batch_size = 1 # must be 1 for testing segmentation 43 | threshold = 0.5 # threshold probability to identify am image as positive 44 | level = 2 45 | 46 | def metrics(stats): 47 | """stats: {'TP': TP, 'FP': FP, 'TN': TN, 'FN': FN} 48 | return: must be a single number """ 49 | precision = (stats['TP'] + 0.00001) * 1.0 / (stats['TP'] + stats['FP'] + 0.00001) 50 | recall = (stats['TP'] + 0.00001) * 1.0 / (stats['TP'] + stats['FN'] + 0.00001) 51 | return 0.5*(precision + recall) 52 | 53 | 54 | def test_model(model, dataloader, metrics, threshold): 55 | stats = {'TP': 0, 'FP': 0, 'TN': 0, 'FN': 0} 56 | model.eval() 57 | CAM_list = [] 58 | for inputs, labels, paths in tqdm(dataloader): 59 | inputs = inputs.to(device) 60 | labels = labels.to(device) 61 | with torch.set_grad_enabled(False): 62 | _, outputs, CAM = model(inputs, testing=True) # CAM is a 1 x 35 x 35 activation map 63 | prob = F.softmax(outputs, dim=1) 64 | preds = prob[:, 1] >= threshold 65 | 66 | CAM = CAM.squeeze(0).cpu().numpy() # transform tensor into numpy array 67 | for i in range(preds.size(0)): 68 | predicted_label = preds[i] 69 | if predicted_label.cpu().item(): 70 | CAM_list.append((CAM, paths[i])) # only use the generated CAM if it is predicted to be 1 71 | else: 72 | CAM_list.append((np.zeros_like(CAM), paths[i])) # otherwise the CAM is a totally black one 73 | 74 | stats['TP'] += torch.sum((preds == 1) * (labels == 1)).cpu().item() 75 | stats['TN'] += torch.sum((preds == 0) * (labels == 0)).cpu().item() 76 | stats['FP'] += torch.sum((preds == 1) * (labels == 0)).cpu().item() 77 | stats['FN'] += torch.sum((preds == 0) * (labels == 1)).cpu().item() 78 | 79 | metric_value = metrics(stats) 80 | return stats, metric_value, CAM_list 81 | 82 | transform_test = transforms.Compose([ 83 | transforms.Resize(input_size), 84 | transforms.ToTensor(), 85 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 86 | ]) 87 | 88 | if __name__ == '__main__': 89 | # data 90 | dataset_test = ImageFolderModified(data_dir, transform_test) 91 | dataloader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=4) 92 | # model 93 | model = InceptionSegmentation(num_outputs=2, level=level) 94 | model.load_existing_params(old_ckpt_path) 95 | 96 | model = model.to(device) 97 | 98 | stats, metric_value, CAM_list = test_model(model, dataloader_test, metrics, threshold=threshold) 99 | precision = (stats['TP'] + 0.00001) * 1.0 / (stats['TP'] + stats['FP'] + 0.00001) 100 | recall = (stats['TP'] + 0.00001) * 1.0 / (stats['TP'] + stats['FN'] + 0.00001) 101 | print('metric value: '+str(metric_value)) 102 | print('precision: ' + str(round(precision, 4))) 103 | print('recall: ' + str(round(recall, 4))) 104 | 105 | with open('CAM_list.pickle', 'w') as f: 106 | pickle.dump(CAM_list, f) 107 | 108 | -------------------------------------------------------------------------------- /train_classification_pytorch.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from torch.utils.data import Dataset, DataLoader 7 | import torchvision 8 | from torchvision import datasets, models, transforms, utils 9 | import torchvision.transforms.functional as TF 10 | 11 | from tqdm import tqdm 12 | import numpy as np 13 | import json 14 | import pandas as pd 15 | import pickle 16 | import matplotlib.pyplot as plt 17 | import skimage 18 | import skimage.io 19 | import skimage.transform 20 | from PIL import Image 21 | import time 22 | import os 23 | from os.path import join, exists 24 | import copy 25 | import random 26 | from collections import OrderedDict 27 | from sklearn.metrics import r2_score 28 | 29 | from torch.nn import functional as F 30 | from torchvision.models import Inception3 31 | 32 | 33 | # Configuration 34 | # directory for loading training/validation/test data 35 | data_dir = '/home/ubuntu/projects/deepsolar/deepsolar_dataset_toy' 36 | # path to load old model/checkpoint, "None" if not loading. 37 | old_ckpt_path = '/home/ubuntu/projects/deepsolar/deepsolar_pytorch_pretrained/deepsolar_pretrained.pth' 38 | # directory for saving model/checkpoint 39 | ckpt_save_dir = 'checkpoint/deepsolar_toy' 40 | 41 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 42 | trainable_params = ['fc'] # layers or modules set to be trainable. "None" if training all layers 43 | model_name = 'deepsolar_toy' # the prefix of the filename for saving model/checkpoint 44 | return_best = True # whether to return the best model according to the validation metrics 45 | if_early_stop = True # whether to stop early after validation metrics doesn't improve for definite number of epochs 46 | input_size = 299 # image size fed into the mdoel 47 | imbalance_rate = 5 # weight given to the positive (rarer) samples in loss function 48 | learning_rate = 0.01 # learning rate 49 | weight_decay = 0.00 # l2 regularization coefficient 50 | batch_size = 64 51 | num_epochs = 10 # number of epochs to train 52 | lr_decay_rate = 0.7 # learning rate decay rate for each decay step 53 | lr_decay_epochs = 5 # number of epochs for one learning rate decay 54 | early_stop_epochs = 5 # after validation metrics doesn't improve for "early_stop_epochs" epochs, stop the training. 55 | save_epochs = 5 # save the model/checkpoint every "save_epochs" epochs 56 | threshold = 0.2 # threshold probability to identify am image as positive 57 | 58 | 59 | def RandomRotationNew(image): 60 | angle = random.choice([0, 90, 180, 270]) 61 | image = TF.rotate(image, angle) 62 | return image 63 | 64 | def only_train(model, trainable_params): 65 | """trainable_params: The list of parameters and modules that are set to be trainable. 66 | Set require_grad = False for all those parameters not in the trainable_params""" 67 | print('Only the following layers:') 68 | for name, p in model.named_parameters(): 69 | p.requires_grad = False 70 | for target in trainable_params: 71 | if target == name or target in name: 72 | p.requires_grad = True 73 | print(' ' + name) 74 | break 75 | 76 | def metrics(stats): 77 | """ 78 | Self-defined metrics function to evaluate and compare models 79 | stats: {'TP': TP, 'FP': FP, 'TN': TN, 'FN': FN} 80 | return: must be a single number """ 81 | precision = (stats['TP'] + 0.00001) * 1.0 / (stats['TP'] + stats['FP'] + 0.00001) 82 | recall = (stats['TP'] + 0.00001) * 1.0 / (stats['TP'] + stats['FN'] + 0.00001) 83 | return 0.5*(precision + recall) 84 | 85 | 86 | def train_model(model, model_name, dataloaders, criterion, optimizer, metrics, num_epochs, threshold=0.5, training_log=None, 87 | verbose=True, return_best=True, if_early_stop=True, early_stop_epochs=10, scheduler=None, 88 | save_dir=None, save_epochs=5): 89 | since = time.time() 90 | if not training_log: 91 | training_log = dict() 92 | training_log['train_loss_history'] = [] 93 | training_log['val_loss_history'] = [] 94 | training_log['val_metric_value_history'] = [] 95 | training_log['current_epoch'] = -1 96 | current_epoch = training_log['current_epoch'] + 1 97 | 98 | best_model_wts = copy.deepcopy(model.state_dict()) 99 | best_optimizer_wts = copy.deepcopy(optimizer.state_dict()) 100 | best_log = copy.deepcopy(training_log) 101 | 102 | best_metric_value = -np.inf 103 | nodecrease = 0 # to count the epochs that val loss doesn't decrease 104 | early_stop = False 105 | 106 | for epoch in range(current_epoch, current_epoch + num_epochs): 107 | if verbose: 108 | print('Epoch {}/{}'.format(epoch, num_epochs - 1)) 109 | print('-' * 10) 110 | 111 | # Each epoch has a training and validation phase 112 | for phase in ['train', 'val']: 113 | if phase == 'train': 114 | model.train() # Set model to training mode 115 | else: 116 | model.eval() # Set model to evaluate mode 117 | 118 | running_loss = 0.0 119 | stats = {'TP': 0, 'FP': 0, 'TN': 0, 'FN': 0} 120 | 121 | # Iterate over data. 122 | for inputs, labels in tqdm(dataloaders[phase]): 123 | inputs = inputs.to(device) 124 | labels = labels.to(device) 125 | 126 | # zero the parameter gradients 127 | optimizer.zero_grad() 128 | 129 | # forward 130 | # track history if only in train 131 | with torch.set_grad_enabled(phase == 'train'): 132 | # Get model outputs and calculate loss 133 | if phase == 'train': 134 | outputs, aux_outputs = model(inputs) 135 | loss1 = criterion(outputs, labels) 136 | loss2 = criterion(aux_outputs, labels) 137 | loss = loss1 + 0.4*loss2 138 | else: 139 | outputs = model(inputs) 140 | loss = criterion(outputs, labels) 141 | 142 | prob = F.softmax(outputs, dim=1) 143 | preds = prob[:, 1] >= threshold 144 | 145 | # backward + optimize only if in training phase 146 | if phase == 'train': 147 | loss.backward() 148 | optimizer.step() 149 | 150 | # statistics 151 | running_loss += loss.item() * inputs.size(0) 152 | stats['TP'] += torch.sum((preds == 1) * (labels == 1)).cpu().item() 153 | stats['TN'] += torch.sum((preds == 0) * (labels == 0)).cpu().item() 154 | stats['FP'] += torch.sum((preds == 1) * (labels == 0)).cpu().item() 155 | stats['FN'] += torch.sum((preds == 0) * (labels == 1)).cpu().item() 156 | 157 | epoch_loss = running_loss / len(dataloaders[phase].dataset) 158 | epoch_metric_value = metrics(stats) 159 | 160 | if verbose: 161 | print('{} Loss: {:.4f} Metrics: {:.4f}'.format(phase, epoch_loss, epoch_metric_value)) 162 | 163 | training_log['current_epoch'] = epoch 164 | if phase == 'val': 165 | training_log['val_metric_value_history'].append(epoch_metric_value) 166 | training_log['val_loss_history'].append(epoch_loss) 167 | # deep copy the model 168 | if epoch_metric_value > best_metric_value: 169 | best_metric_value = epoch_metric_value 170 | best_model_wts = copy.deepcopy(model.state_dict()) 171 | best_optimizer_wts = copy.deepcopy(optimizer.state_dict()) 172 | best_log = copy.deepcopy(training_log) 173 | nodecrease = 0 174 | else: 175 | nodecrease += 1 176 | else: # train phase 177 | training_log['train_loss_history'].append(epoch_loss) 178 | if scheduler != None: 179 | scheduler.step() 180 | 181 | if nodecrease >= early_stop_epochs: 182 | early_stop = True 183 | 184 | if save_dir and epoch % save_epochs == 0: 185 | checkpoint = { 186 | 'model_state_dict': model.state_dict(), 187 | 'optimizer_state_dict': optimizer.state_dict(), 188 | 'training_log': training_log 189 | } 190 | torch.save(checkpoint, 191 | os.path.join(save_dir, model_name + '_' + str(training_log['current_epoch']) + '.tar')) 192 | 193 | if if_early_stop and early_stop: 194 | print('Early stopped!') 195 | break 196 | 197 | time_elapsed = time.time() - since 198 | print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 199 | print('Best validation metric value: {:4f}'.format(best_metric_value)) 200 | 201 | # load best model weights 202 | if return_best: 203 | model.load_state_dict(best_model_wts) 204 | optimizer.load_state_dict(best_optimizer_wts) 205 | training_log = best_log 206 | 207 | checkpoint = { 208 | 'model_state_dict': model.state_dict(), 209 | 'optimizer_state_dict': optimizer.state_dict(), 210 | 'training_log': training_log 211 | } 212 | torch.save(checkpoint, 213 | os.path.join(save_dir, model_name + '_' + str(training_log['current_epoch']) + '_last.tar')) 214 | 215 | return model, training_log 216 | 217 | 218 | data_transforms = { 219 | 'train': transforms.Compose([ 220 | transforms.Resize(input_size), 221 | transforms.Lambda(RandomRotationNew), 222 | transforms.RandomHorizontalFlip(p=0.5), 223 | transforms.RandomVerticalFlip(p=0.5), 224 | transforms.ToTensor(), 225 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 226 | ]), 227 | 'val': transforms.Compose([ 228 | transforms.Resize(input_size), 229 | transforms.ToTensor(), 230 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 231 | ]) 232 | } 233 | 234 | 235 | if __name__ == '__main__': 236 | # data 237 | image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']} 238 | dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, 239 | shuffle=True, num_workers=4) for x in ['train', 'val']} 240 | # model 241 | model = Inception3(num_classes=2, aux_logits=True, transform_input=False) 242 | model = model.to(device) 243 | optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, 244 | weight_decay=weight_decay, amsgrad=True) 245 | class_weight = torch.tensor([1, imbalance_rate], dtype=torch.float).cuda() 246 | loss_fn = nn.CrossEntropyLoss(weight=class_weight) 247 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_epochs, gamma=lr_decay_rate) 248 | 249 | # load old parameters 250 | if old_ckpt_path: 251 | checkpoint = torch.load(old_ckpt_path, map_location=device) 252 | if old_ckpt_path[-4:] == '.tar': # it is a checkpoint dictionary rather than just model parameters 253 | model.load_state_dict(checkpoint['model_state_dict']) 254 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 255 | training_log = checkpoint['training_log'] 256 | else: 257 | model.load_state_dict(checkpoint) 258 | training_log = None 259 | print('Old checkpoint loaded: ' + old_ckpt_path) 260 | else: 261 | training_log = None 262 | 263 | # fix some layers and make others trainable 264 | if trainable_params: 265 | only_train(model, trainable_params) 266 | 267 | _, _ = train_model(model, model_name=model_name, dataloaders=dataloaders_dict, criterion=loss_fn, 268 | optimizer=optimizer, metrics=metrics, num_epochs=num_epochs, threshold=threshold, 269 | training_log=training_log, verbose=True, return_best=return_best, 270 | if_early_stop=if_early_stop, early_stop_epochs=early_stop_epochs, 271 | scheduler=scheduler, save_dir=ckpt_save_dir, save_epochs=save_epochs) 272 | 273 | -------------------------------------------------------------------------------- /train_segmentation_pytorch.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from torch.utils.data import Dataset, DataLoader 7 | import torchvision 8 | from torchvision import datasets, models, transforms, utils 9 | import torchvision.transforms.functional as TF 10 | 11 | from tqdm import tqdm 12 | import numpy as np 13 | import json 14 | import pandas as pd 15 | import pickle 16 | import matplotlib.pyplot as plt 17 | import skimage 18 | import skimage.io 19 | import skimage.transform 20 | from PIL import Image 21 | import time 22 | import os 23 | from os.path import join, exists 24 | import copy 25 | import random 26 | from collections import OrderedDict 27 | from sklearn.metrics import r2_score 28 | 29 | from torch.nn import functional as F 30 | from torchvision.models import Inception3 31 | 32 | from inception_modified import InceptionSegmentation 33 | 34 | 35 | # Configuration 36 | # directory for loading training/validation/test data 37 | data_dir = '/home/ubuntu/projects/deepsolar/deepsolar_dataset_toy' 38 | # path to load basic main branch model, "None" if not loading. 39 | basic_params_path = '/home/ubuntu/projects/deepsolar/deepsolar_pytorch_pretrained/deepsolar_pretrained.pth' 40 | # path to load old model parameters, "None" if not loading. 41 | old_ckpt_path = 'checkpoint/deepsolar_toy/deepsolar_seg_level1_5.tar' 42 | # directory for saving model/checkpoint 43 | ckpt_save_dir = 'checkpoint/deepsolar_toy' 44 | 45 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 46 | model_name = 'deepsolar_seg_level2' # the prefix of the filename for saving model/checkpoint 47 | return_best = True # whether to return the best model according to the validation metrics 48 | if_early_stop = True # whether to stop early after validation metrics doesn't improve for definite number of epochs 49 | level = 2 # train the first level or second level of segmentation branch 50 | input_size = 299 # image size fed into the mdoel 51 | imbalance_rate = 5 # weight given to the positive (rarer) samples in loss function 52 | learning_rate = 0.01 # learning rate 53 | weight_decay = 0.00 # l2 regularization coefficient 54 | batch_size = 64 55 | num_epochs = 10 # number of epochs to train 56 | lr_decay_rate = 0.7 # learning rate decay rate for each decay step 57 | lr_decay_epochs = 5 # number of epochs for one learning rate decay 58 | early_stop_epochs = 5 # after validation metrics doesn't improve for "early_stop_epochs" epochs, stop the training. 59 | save_epochs = 5 # save the model/checkpoint every "save_epochs" epochs 60 | threshold = 0.5 # threshold probability to identify am image as positive 61 | 62 | 63 | def RandomRotationNew(image): 64 | angle = random.choice([0, 90, 180, 270]) 65 | image = TF.rotate(image, angle) 66 | return image 67 | 68 | def only_train(model, trainable_params): 69 | """trainable_params: The list of parameters and modules that are set to be trainable. 70 | Set require_grad = False for all those parameters not in the trainable_params""" 71 | print('Only the following layers:') 72 | for name, p in model.named_parameters(): 73 | p.requires_grad = False 74 | for target in trainable_params: 75 | if target == name or target in name: 76 | p.requires_grad = True 77 | print(' ' + name) 78 | break 79 | 80 | def metrics(stats): 81 | """ 82 | Self-defined metrics function to evaluate and compare models 83 | stats: {'TP': TP, 'FP': FP, 'TN': TN, 'FN': FN} 84 | return: must be a single number """ 85 | precision = (stats['TP'] + 0.00001) * 1.0 / (stats['TP'] + stats['FP'] + 0.00001) 86 | recall = (stats['TP'] + 0.00001) * 1.0 / (stats['TP'] + stats['FN'] + 0.00001) 87 | return 0.5*(precision + recall) 88 | 89 | 90 | def train_model(model, model_name, dataloaders, criterion, optimizer, metrics, num_epochs, threshold=0.5, training_log=None, 91 | verbose=True, return_best=True, if_early_stop=True, early_stop_epochs=10, scheduler=None, 92 | save_dir=None, save_epochs=5): 93 | since = time.time() 94 | if not training_log: 95 | training_log = dict() 96 | training_log['train_loss_history'] = [] 97 | training_log['val_loss_history'] = [] 98 | training_log['val_metric_value_history'] = [] 99 | training_log['current_epoch'] = -1 100 | current_epoch = training_log['current_epoch'] + 1 101 | 102 | best_model_wts = copy.deepcopy(model.state_dict()) 103 | best_optimizer_wts = copy.deepcopy(optimizer.state_dict()) 104 | best_log = copy.deepcopy(training_log) 105 | 106 | best_metric_value = -np.inf 107 | nodecrease = 0 # to count the epochs that val loss doesn't decrease 108 | early_stop = False 109 | 110 | for epoch in range(current_epoch, current_epoch + num_epochs): 111 | if verbose: 112 | print('Epoch {}/{}'.format(epoch, num_epochs - 1)) 113 | print('-' * 10) 114 | 115 | # Each epoch has a training and validation phase 116 | for phase in ['train', 'val']: 117 | if phase == 'train': 118 | model.train() # Set model to training mode 119 | else: 120 | model.eval() # Set model to evaluate mode 121 | 122 | running_loss = 0.0 123 | stats = {'TP': 0, 'FP': 0, 'TN': 0, 'FN': 0} 124 | 125 | # Iterate over data. 126 | for inputs, labels in tqdm(dataloaders[phase]): 127 | inputs = inputs.to(device) 128 | labels = labels.to(device) 129 | 130 | # zero the parameter gradients 131 | optimizer.zero_grad() 132 | 133 | # forward 134 | # track history if only in train 135 | with torch.set_grad_enabled(phase == 'train'): 136 | # Get model outputs and calculate loss 137 | outputs = model(inputs, testing=False) 138 | loss = criterion(outputs, labels) 139 | 140 | prob = F.softmax(outputs, dim=1) 141 | preds = prob[:, 1] >= threshold 142 | 143 | # backward + optimize only if in training phase 144 | if phase == 'train': 145 | loss.backward() 146 | optimizer.step() 147 | 148 | # statistics 149 | running_loss += loss.item() * inputs.size(0) 150 | stats['TP'] += torch.sum((preds == 1) * (labels == 1)).cpu().item() 151 | stats['TN'] += torch.sum((preds == 0) * (labels == 0)).cpu().item() 152 | stats['FP'] += torch.sum((preds == 1) * (labels == 0)).cpu().item() 153 | stats['FN'] += torch.sum((preds == 0) * (labels == 1)).cpu().item() 154 | 155 | epoch_loss = running_loss / len(dataloaders[phase].dataset) 156 | epoch_metric_value = metrics(stats) 157 | 158 | if verbose: 159 | print('{} Loss: {:.4f} Metrics: {:.4f}'.format(phase, epoch_loss, epoch_metric_value)) 160 | 161 | training_log['current_epoch'] = epoch 162 | if phase == 'val': 163 | training_log['val_metric_value_history'].append(epoch_metric_value) 164 | training_log['val_loss_history'].append(epoch_loss) 165 | # deep copy the model 166 | if epoch_metric_value > best_metric_value: 167 | best_metric_value = epoch_metric_value 168 | best_model_wts = copy.deepcopy(model.state_dict()) 169 | best_optimizer_wts = copy.deepcopy(optimizer.state_dict()) 170 | best_log = copy.deepcopy(training_log) 171 | nodecrease = 0 172 | else: 173 | nodecrease += 1 174 | else: # train phase 175 | training_log['train_loss_history'].append(epoch_loss) 176 | if scheduler != None: 177 | scheduler.step() 178 | 179 | if nodecrease >= early_stop_epochs: 180 | early_stop = True 181 | 182 | if save_dir and epoch % save_epochs == 0: 183 | checkpoint = { 184 | 'model_state_dict': model.state_dict(), 185 | 'optimizer_state_dict': optimizer.state_dict(), 186 | 'training_log': training_log 187 | } 188 | torch.save(checkpoint, 189 | os.path.join(save_dir, model_name + '_' + str(training_log['current_epoch']) + '.tar')) 190 | 191 | if if_early_stop and early_stop: 192 | print('Early stopped!') 193 | break 194 | 195 | time_elapsed = time.time() - since 196 | print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 197 | print('Best validation metric value: {:4f}'.format(best_metric_value)) 198 | 199 | # load best model weights 200 | if return_best: 201 | model.load_state_dict(best_model_wts) 202 | optimizer.load_state_dict(best_optimizer_wts) 203 | training_log = best_log 204 | 205 | checkpoint = { 206 | 'model_state_dict': model.state_dict(), 207 | 'optimizer_state_dict': optimizer.state_dict(), 208 | 'training_log': training_log 209 | } 210 | torch.save(checkpoint, 211 | os.path.join(save_dir, model_name + '_' + str(training_log['current_epoch']) + '_last.tar')) 212 | 213 | return model, training_log 214 | 215 | 216 | data_transforms = { 217 | 'train': transforms.Compose([ 218 | transforms.Resize(input_size), 219 | transforms.Lambda(RandomRotationNew), 220 | transforms.RandomHorizontalFlip(p=0.5), 221 | transforms.RandomVerticalFlip(p=0.5), 222 | transforms.ToTensor(), 223 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 224 | ]), 225 | 'val': transforms.Compose([ 226 | transforms.Resize(input_size), 227 | transforms.ToTensor(), 228 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 229 | ]) 230 | } 231 | 232 | 233 | if __name__ == '__main__': 234 | assert level in [1, 2] 235 | # data 236 | image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']} 237 | dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, 238 | shuffle=True, num_workers=4) for x in ['train', 'val']} 239 | 240 | if not os.path.exists(ckpt_save_dir): 241 | os.makedirs(ckpt_save_dir) 242 | # model 243 | model = InceptionSegmentation(num_outputs=2, level=level) 244 | if level == 1 and basic_params_path: 245 | model.load_basic_params(basic_params_path) 246 | elif level == 2 and old_ckpt_path: 247 | model.load_existing_params(old_ckpt_path) 248 | 249 | if level == 1: 250 | trainable_params = ['convolution1', 'linear1'] 251 | else: 252 | trainable_params = ['convolution2', 'linear2'] 253 | only_train(model, trainable_params) 254 | 255 | model = model.to(device) 256 | optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, 257 | weight_decay=weight_decay, amsgrad=True) 258 | class_weight = torch.tensor([1, imbalance_rate], dtype=torch.float).cuda() 259 | loss_fn = nn.CrossEntropyLoss(weight=class_weight) 260 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_epochs, gamma=lr_decay_rate) 261 | 262 | training_log = None 263 | 264 | _, _ = train_model(model, model_name=model_name, dataloaders=dataloaders_dict, criterion=loss_fn, 265 | optimizer=optimizer, metrics=metrics, num_epochs=num_epochs, threshold=threshold, 266 | training_log=training_log, verbose=True, return_best=return_best, 267 | if_early_stop=if_early_stop, early_stop_epochs=early_stop_epochs, 268 | scheduler=scheduler, save_dir=ckpt_save_dir, save_epochs=save_epochs) 269 | 270 | --------------------------------------------------------------------------------