├── images ├── architecture.png └── show_example.png ├── LICENSE ├── README.md ├── RS_Dataset.py ├── cloud_generation.py ├── model.py ├── baseline.py └── train.py /images/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMS97/GLNET/HEAD/images/architecture.png -------------------------------------------------------------------------------- /images/show_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMS97/GLNET/HEAD/images/show_example.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 wuchangsheng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Convolutional Neural Networks Based Remote Sensing Scene Classification under Clear and Cloudy Environments 2 | Accepted by ICCVW 2021 3 | 4 | 5 | 6 | ### Remote Sensing scene image classification under clear and cloudy environments. 7 | ![show example](images/show_example.png) 8 | 9 | ### Overview architecture of the proposed GLNet for the RS scene classification under clear and cloudy environments. 10 | ![archicture](images/architecture.png) 11 | 12 | 13 | 14 | ## Required libraries 15 | python 3.6 16 | 17 | pytorch 1.0+ 18 | 19 | numpy 20 | 21 | PIl 22 | 23 | torchvision 24 | 25 | 26 | ## Usage 27 | 1. clone this repo 28 | ``` 29 | git clone https://github.com/wuchangsheng951/GLNET.git 30 | ``` 31 | 32 | 2. download the dataset from [google drive](https://drive.google.com/file/d/1F_68mh40vNLOwila32GBYNHVEZI1HiTT/view?usp=sharing) 33 | 34 | 3. train the baseline model 35 | ``` 36 | python baseline.py 37 | ``` 38 | 4. load the model dir you trained in model.py 39 | 40 | 5. run the training by command 41 | ``` 42 | python train.py 43 | ``` 44 | 45 | ## Citation 46 | {Huiming Sun, Yuewei Lin, Qin Zou, Shaoyue Song, Jianwu Fang, Hongkai Yu. Convolutional Neural Networks Based Remote Sensing Scene Classification under Clear and Cloudy Environments. IEEE International Conference on Computer Vision Workshop (ICCVW), 2021.} 47 | -------------------------------------------------------------------------------- /RS_Dataset.py: -------------------------------------------------------------------------------- 1 | from path import Path 2 | from torch.utils.data import TensorDataset, DataLoader, Dataset,SubsetRandomSampler 3 | from torchvision.datasets import ImageFolder 4 | import cv2 5 | import numpy as np 6 | from PIL import Image 7 | from random import randint,sample 8 | 9 | class RS_Dataset(ImageFolder): 10 | # split image into 5 parts each part's partion is 0.6 11 | def __init__(self, root, transform=None, partion = 0.6, size = 224 ,): 12 | super(RS_Dataset, self).__init__(root, transform) 13 | self.indices = range(len(self)) 14 | self.transform = transform 15 | self.partion = partion 16 | self.size = size 17 | self.width,self.length = int(self.size*self.partion),int(self.size*self.partion) 18 | 19 | 20 | 21 | 22 | def __getitem__(self, index): 23 | 24 | img = np.array(Image.open(self.imgs[index][0])) 25 | upper_left = cv2.resize(img[:self.width,:self.length],(self.size,self.size)) 26 | upper_right = cv2.resize(img[:self.width,-self.length:],(self.size,self.size)) 27 | bottom_right = cv2.resize(img[-self.width:,-self.length:],(self.size,self.size)) 28 | bottom_left = cv2.resize(img[-self.width:,:self.length],(self.size,self.size)) 29 | mid = img[int((1-self.partion)/2*(img.shape[0])):-int((1-self.partion)/2*(img.shape[0])),int((1-self.partion)/2*(img.shape[1])):-int((1-self.partion)/2*(img.shape[1]))] 30 | cluster_data = upper_left,upper_right,bottom_right,bottom_left,mid 31 | img = self.transform(img) 32 | cluster_data = [self.transform(i) for i in cluster_data] 33 | label = self.imgs[index][1] 34 | 35 | return img, cluster_data, label -------------------------------------------------------------------------------- /cloud_generation.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import matplotlib.image as mpimg 3 | import matplotlib 4 | import numpy as np 5 | import skimage.transform 6 | import cv2 7 | 8 | def generate_cloud(im_size_h, im_size_w,k = 2,): 9 | # Initialize the white noise pattern 10 | base_pattern = np.random.uniform(0,255, (im_size_h//2, im_size_w//2)) 11 | 12 | # Initialize the output pattern 13 | turbulence_pattern = np.zeros((im_size_h, im_size_w)) 14 | 15 | # Create cloud pattern 16 | power_range = [k**i for i in range(2, int(np.log2(min(im_size_h, im_size_w))))] 17 | 18 | for p in power_range: 19 | quadrant = base_pattern[:p, :p] 20 | upsampled_pattern = skimage.transform.resize(quadrant, (im_size_h, im_size_w), mode='reflect') 21 | turbulence_pattern += upsampled_pattern / float(p) 22 | 23 | turbulence_pattern /= sum([1 / float(p) for p in power_range]) 24 | return turbulence_pattern 25 | 26 | 27 | def add_cloud(file_name,k): 28 | #file_name = '20_00019.png' 29 | # img = mpimg.imread(file_name) * 255 30 | img = cv2.imread(file_name) 31 | im_size_h, im_size_w = np.shape(img)[:2] 32 | 33 | # Generate cloud map 34 | cloud_map = generate_cloud(im_size_h, im_size_w,k) 35 | fourground_map = (255 - cloud_map) / 255 36 | #plt.imsave('cloud.png',cloud_map) 37 | # add cloud to original image 38 | res = np.zeros((np.shape(img))) 39 | print( img[:,:,0]) 40 | res[:,:,0] = img[:,:,0] * fourground_map + cloud_map 41 | res[:,:,1] = img[:,:,1] * fourground_map + cloud_map 42 | res[:,:,2] = img[:,:,2] * fourground_map + cloud_map 43 | 44 | #print(np.max(res)) 45 | #print(np.min(res)) 46 | #plt.imsave(file_name.replace('.tif', '_cloud.png'), res) 47 | 48 | return cloud_map, res.astype(np.uint8),fourground_map 49 | 50 | 51 | generate_cloud(256, 256,2) 52 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import nn 4 | import torchvision.models as models 5 | from torchvision.models import vgg16,alexnet,resnet50 6 | import numpy as np 7 | 8 | 9 | 10 | torch.nn.Module.dump_patches = True 11 | class SiameseNetwork(nn.Module): 12 | def __init__(self,base_model ='vgg16',num_classes = 5 , fixed = False,out_features_dim = 128): 13 | super(SiameseNetwork, self).__init__() 14 | self.base_model = base_model 15 | self.num_classes = num_classes 16 | self.fixed = fixed 17 | self.out_features_dim = out_features_dim 18 | 19 | self.lower_model = self.make_back_bone(self.base_model) 20 | self.upper_backbone = self.make_back_bone(self.base_model) 21 | # for param in self.parameters(): 22 | # param.requires_grad = False 23 | self.fc1 = nn.Linear(in_features=6*8*4, out_features=self.num_classes, bias=True) 24 | self.prelu = nn.PReLU() 25 | self.avgpool = nn.AvgPool2d(2) 26 | 27 | def make_back_bone(self,base_model): 28 | if base_model == 'vgg16': 29 | # model = torch.load('new_saved_models/2020-07-07_vgg16_93.79_baseline.pth') 30 | # model = torch.load('saved_models/2020-07-07_vgg16_cloud_baseline.pth') 31 | model = vgg16() 32 | for param in model.parameters(): 33 | if self.fixed: 34 | param.requires_grad = False 35 | model.classifier[-1] =nn.Linear(in_features=4096,out_features=self.out_features_dim) 36 | return model 37 | 38 | if base_model == 'alexnet': 39 | # model = torch.load('new_saved_models/2020-07-07_alexnet_91.57_baseline.pth') 40 | model = torch.load('new_saved_models/2020-08-20_alexnet_88.85_cloud_baseline.pth') 41 | for param in model.parameters(): 42 | if self.fixed: 43 | param.requires_grad = False 44 | model.classifier[-1] = nn.Linear(in_features=4096,out_features=self.out_features_dim) 45 | return model 46 | 47 | if base_model == 'resnet50': 48 | # model = torch.load('new_saved_models/2020-07-07_resnet50_94.21_baseline.pth') 49 | model = torch.load('new_saved_models/2020-08-21_resnet50_91.57_cloud_baseline.pth') 50 | for param in model.parameters(): 51 | if self.fixed: 52 | param.requires_grad = False 53 | model.fc = nn.Linear(in_features=2048,out_features=self.out_features_dim) 54 | 55 | return model 56 | 57 | def forward(self, img, cluster_data): 58 | output_list = [] 59 | for index,image in enumerate(cluster_data): 60 | output_list.append(self.upper_backbone(image)) 61 | 62 | output_list[index] = torch.unsqueeze(output_list[index],1) 63 | 64 | x_upper = torch.cat((output_list),1) 65 | x_upper = x_upper.view(x_upper.shape[0],5,32,4) 66 | 67 | x_lower = self.lower_model(img) 68 | x_lower = torch.unsqueeze(x_lower,1) 69 | x_lower = x_lower.view(x_lower.shape[0],1,32,4) 70 | 71 | x = torch.cat((x_upper,x_lower), dim = 1) 72 | x = self.avgpool(x) 73 | 74 | x = x.view(x.size(0), -1) 75 | # print(x.shape) 76 | x = self.fc1(x) 77 | 78 | 79 | 80 | return x 81 | 82 | 83 | if __name__ == '__main__': 84 | model = SiameseNetwork(7) 85 | -------------------------------------------------------------------------------- /baseline.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torchvision import datasets, transforms 8 | from torch.utils.data import TensorDataset, DataLoader, Dataset,SubsetRandomSampler 9 | from torchvision import models 10 | import time 11 | from RS_Dataset import RS_Dataset 12 | from tqdm import tqdm 13 | import os 14 | import shutil 15 | from datetime import date 16 | from torchvision.models import resnet50,alexnet,vgg16 17 | 18 | 19 | def train(PARAMS, model, criterion, device, train_loader, optimizer, epoch): 20 | t0 = time.time() 21 | model.train() 22 | correct = 0 23 | 24 | for batch_idx, (img, target) in enumerate(tqdm(train_loader)): 25 | img, target = img.to(device), target.to(device) 26 | optimizer.zero_grad() 27 | output = model(img) 28 | 29 | loss = criterion(output, target ) 30 | loss.backward() 31 | optimizer.step() 32 | pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability 33 | correct += pred.eq(target.view_as(pred)).sum().item() 34 | 35 | 36 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} , {:.2f} seconds'.format( 37 | epoch, batch_idx * len(img), len(train_loader.dataset), 38 | 100. * batch_idx / len(train_loader), loss.item(),time.time() - t0)) 39 | 40 | 41 | def test(PARAMS, model,criterion, device, test_loader,optimizer,epoch,best_acc): 42 | model.eval() 43 | test_loss = 0 44 | correct = 0 45 | 46 | example_images = [] 47 | with torch.no_grad(): 48 | for batch_idx, (img, target) in enumerate(tqdm(test_loader)): 49 | img, target = img.to(device), target.to(device) 50 | output = model(img) 51 | 52 | test_loss += criterion(output, target).item() # sum up batch loss 53 | pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability 54 | correct += pred.eq(target.view_as(pred)).sum().item() 55 | # Save the first input tensor in each test batch as an example image 56 | 57 | test_loss /= len(test_loader.dataset) 58 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format( 59 | test_loss, correct, len(test_loader.dataset), 60 | 100. * correct / len(test_loader.dataset))) 61 | 62 | 63 | current_acc = 100. * correct / len(test_loader.dataset) 64 | return current_acc 65 | 66 | def main(): 67 | 68 | 69 | parser = argparse.ArgumentParser(description='manual to this script') 70 | parser.add_argument('--model', type=str, default = 'vgg16') 71 | parser.add_argument('--batch_size', type=int, default=32) 72 | parser.add_argument('--evaluate_model', type=str) 73 | parser.add_argument('--dataset', type=str, default='rsscn7') 74 | 75 | args = parser.parse_args() 76 | 77 | PARAMS = {'DEVICE': torch.device("cuda" if torch.cuda.is_available() else "cpu"), 78 | 'bs': 8, 79 | 'epochs':50, 80 | 'lr': 0.0006, 81 | 'momentum': 0.5, 82 | 'log_interval':10, 83 | 'criterion':'cross_entropy', 84 | 'model_name': args.model, 85 | 'dataset': args.dataset, 86 | } 87 | 88 | 89 | # Training settings 90 | train_transform = transforms.Compose( 91 | [ 92 | transforms.RandomHorizontalFlip(), 93 | transforms.ColorJitter(0.4, 0.4, 0.4), 94 | transforms.Resize((256,256)), 95 | transforms.ToTensor(), 96 | transforms.Normalize([0.4850, 0.4560, 0.4060], [0.2290, 0.2240, 0.2250])]) 97 | test_transform = transforms.Compose( 98 | [ 99 | transforms.Resize((256,256)), 100 | transforms.ToTensor(), 101 | transforms.Normalize([0.4850, 0.4560, 0.4060], [0.2290, 0.2240, 0.2250])]) 102 | 103 | 104 | if args.dataset == 'rsscn7': 105 | 106 | # train_dataset = datasets.ImageFolder(root='data/thick_removal',transform = train_transform) 107 | # test_dataset = datasets.ImageFolder(root='data/thick_removal',transform = test_transform) 108 | 109 | train_dataset = datasets.ImageFolder(root='data/rsscn7/train_dataset/',transform = train_transform) 110 | test_dataset = datasets.ImageFolder(root='data/rsscn7/test_dataset/',transform = test_transform) 111 | elif args.dataset == 'ucm': 112 | 113 | train_dataset = datasets.ImageFolder(root='data/ucm/train_dataset/',transform = train_transform) 114 | test_dataset = datasets.ImageFolder(root='data/ucm/test_dataset/',transform = test_transform) 115 | print(PARAMS) 116 | train_loader = DataLoader(train_dataset, batch_size=PARAMS['bs'], shuffle=True, num_workers=4, pin_memory = True ) 117 | test_loader = DataLoader(test_dataset, batch_size=PARAMS['bs'], shuffle=True, num_workers=4, pin_memory = True ) 118 | 119 | 120 | 121 | num_classes = len(train_dataset.classes) 122 | if PARAMS['model_name'] == 'vgg16': 123 | model = models.vgg16(pretrained=True) 124 | model.classifier[-1] = nn.Linear(in_features=4096, out_features=num_classes, bias=True) 125 | elif PARAMS['model_name'] == 'resnet50': 126 | model = models.resnet50(pretrained=True) 127 | model.fc = nn.Linear(in_features=2048, out_features=num_classes, bias=True) 128 | elif PARAMS['model_name'] == 'alexnet': 129 | model = models.alexnet(pretrained=True) 130 | model.classifier[-1] = nn.Linear(in_features=4096, out_features=num_classes, bias=True) 131 | 132 | 133 | 134 | model = model.to(PARAMS['DEVICE']) 135 | optimizer = optim.SGD(model.parameters(), lr=PARAMS['lr'], momentum=PARAMS['momentum']) 136 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 7, gamma = 0.9) 137 | criterion = F.cross_entropy 138 | acc = 0 139 | 140 | if not args.evaluate_model: 141 | for epoch in range(1, PARAMS['epochs'] + 1): 142 | train(PARAMS, model,criterion, PARAMS['DEVICE'], train_loader, optimizer, epoch) 143 | acc = test(PARAMS, model,criterion, PARAMS['DEVICE'], test_loader,optimizer,epoch,acc) 144 | scheduler.step() 145 | torch.save(model.state_dict(), 'saved_models/{}_{}_{}_{}_baseline.pth'.format(args.dataset, date.today(), PARAMS['model_name'], round(acc,2))) 146 | # torch.save(model, 'saved_models/{}_{}_{}_{}_baseline.pth'.format(args.dataset, date.today(), PARAMS['model_name'], round(acc,2))) 147 | else: 148 | model = torch.load(args.evaluate_model) 149 | acc = test(PARAMS, model,criterion, PARAMS['DEVICE'], test_loader, optimizer, 0, acc) 150 | print(f'the evalutaion acc is {acc}') 151 | 152 | 153 | if __name__ == '__main__': 154 | main() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torchvision import datasets, transforms 8 | from torch.utils.data import TensorDataset, DataLoader, Dataset,SubsetRandomSampler 9 | from torchvision import models 10 | import time 11 | from RS_Dataset import RS_Dataset 12 | from tqdm import tqdm 13 | import os 14 | import shutil 15 | from datetime import date 16 | import argparse 17 | from torchvision.models import resnet50,alexnet,vgg16 18 | from model import SiameseNetwork 19 | 20 | #offline 21 | 22 | 23 | 24 | def train(PARAMS, model, criterion, device, train_loader, optimizer, epoch): 25 | t0 = time.time() 26 | model.train() 27 | correct = 0 28 | 29 | for batch_idx, (img, cluster, target) in enumerate(tqdm(train_loader)): 30 | img, target = img.to(device), target.to(device) 31 | cluster = [item.to(device) for item in cluster ] 32 | optimizer.zero_grad() 33 | output = model(img,cluster) 34 | # output = model(img) 35 | 36 | loss = criterion(output, target ) 37 | loss.backward() 38 | optimizer.step() 39 | pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability 40 | correct += pred.eq(target.view_as(pred)).sum().item() 41 | 42 | # if batch_idx % config.log_interval == 0: 43 | 44 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} , {:.2f} seconds'.format( 45 | epoch, batch_idx * len(img), len(train_loader.dataset), 46 | 100. * batch_idx / len(train_loader), loss.item(),time.time() - t0)) 47 | 48 | print('train_loss', epoch, loss.data.cpu().numpy()) 49 | print('Train Accuracy', epoch ,100. * correct / len(train_loader.dataset)) 50 | return 100. * correct / len(train_loader.dataset) 51 | 52 | 53 | 54 | 55 | def test(PARAMS, model,criterion, device, test_loader,optimizer,epoch,best_acc): 56 | model.eval() 57 | test_loss = 0 58 | correct = 0 59 | 60 | example_images = [] 61 | with torch.no_grad(): 62 | for batch_idx, (img, cluster, target) in enumerate(tqdm(test_loader)): 63 | img, target = img.to(device), target.to(device) 64 | cluster = [item.to(device) for item in cluster ] 65 | output = model(img,cluster) 66 | # output = model(img) 67 | 68 | test_loss += criterion(output, target).item() # sum up batch loss 69 | pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability 70 | correct += pred.eq(target.view_as(pred)).sum().item() 71 | # Save the first input tensor in each test batch as an example image 72 | 73 | test_loss /= len(test_loader.dataset) 74 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format( 75 | test_loss, correct, len(test_loader.dataset), 76 | 100. * correct / len(test_loader.dataset))) 77 | print('Test Accuracy ', 100. * correct / len(test_loader.dataset)) 78 | print('Test Loss ', test_loss) 79 | 80 | current_acc = 100. * correct / len(test_loader.dataset) 81 | 82 | checkpoint = { 83 | 'best_acc': best_acc, 84 | 'epoch': epoch + 1, 85 | 'model': model.state_dict(), 86 | 'optimizer': optimizer.state_dict(), 87 | } 88 | 89 | return current_acc 90 | 91 | def boolean_string(s): 92 | if s not in {'False', 'True'}: 93 | raise ValueError('Not a valid boolean string') 94 | return s == 'True' 95 | 96 | def main(): 97 | parser = argparse.ArgumentParser(description='manual to this script') 98 | parser.add_argument('--model', type=str, default = 'vgg16') 99 | parser.add_argument('--partion', type=float, default=0.5) 100 | parser.add_argument('--bs', type=int, default=8) 101 | parser.add_argument('--fixed',type=boolean_string, default=False) 102 | parser.add_argument('--Augmentation',type=boolean_string, default=False) 103 | parser.add_argument('--debug',type=boolean_string, default=False) 104 | args = parser.parse_args() 105 | 106 | 107 | 108 | 109 | 110 | PARAMS = {'DEVICE': torch.device("cuda" if torch.cuda.is_available() else "cpu"), 111 | 'bs': args.bs, 112 | 'epochs':50, 113 | 'lr': 0.0006, 114 | 'momentum': 0.5, 115 | 'log_interval':10, 116 | 'criterion':F.cross_entropy, 117 | 'partion':args.partion, 118 | 'model_name': str(args.model) , 119 | 'fixed':args.fixed, 120 | 'Augmentation': args.Augmentation, 121 | } 122 | tags = PARAMS['model_name'] +'_'+ "fixed_" +str(PARAMS['fixed']) +'_'+ 'aug_' + str(PARAMS['Augmentation']) 123 | 124 | 125 | # Training settings 126 | 127 | if PARAMS['Augmentation']: 128 | train_transform = transforms.Compose( 129 | [ 130 | transforms.ToPILImage(), 131 | transforms.RandomHorizontalFlip(), 132 | transforms.ColorJitter(0.4, 0.4, 0.4), 133 | transforms.Resize((224,224)), 134 | transforms.ToTensor(), 135 | transforms.Normalize([0.4850, 0.4560, 0.4060], [0.2290, 0.2240, 0.2250])]) 136 | else: 137 | train_transform = transforms.Compose( 138 | [ 139 | transforms.ToPILImage(), 140 | transforms.Resize((224,224)), 141 | transforms.ToTensor(), 142 | transforms.Normalize([0.4850, 0.4560, 0.4060], [0.2290, 0.2240, 0.2250])]) 143 | test_transform = transforms.Compose( 144 | [ 145 | transforms.ToPILImage(), 146 | transforms.Resize((224,224)), 147 | transforms.ToTensor(), 148 | transforms.Normalize([0.4850, 0.4560, 0.4060], [0.2290, 0.2240, 0.2250])]) 149 | 150 | 151 | train_dataset = RS_Dataset( 152 | root='thin_cloud/train_img',transform = train_transform) 153 | test_dataset = RS_Dataset( 154 | root='thin_cloud/test_img',transform = test_transform) 155 | 156 | print(PARAMS) 157 | train_loader = DataLoader(train_dataset, batch_size=PARAMS['bs'], shuffle=True, num_workers=4, pin_memory = True ) 158 | test_loader = DataLoader(test_dataset, batch_size=PARAMS['bs'], shuffle=True, num_workers=4, pin_memory = True ) 159 | 160 | 161 | 162 | 163 | num_classes = len(train_dataset.classes) 164 | # model = SiameseNetwork(base_model = PARAMS['model_name'], num_classes = num_classes).to(PARAMS['DEVICE']) 165 | model = SiameseNetwork(base_model = PARAMS['model_name'], num_classes = num_classes, fixed = PARAMS['fixed']).to(PARAMS['DEVICE'] ) 166 | 167 | model = model.to(PARAMS['DEVICE']) 168 | 169 | optimizer = optim.SGD(model.parameters(), lr=PARAMS['lr'], momentum=PARAMS['momentum']) 170 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 7, gamma = 0.9) 171 | criterion = F.cross_entropy 172 | current_acc = 0 173 | 174 | for epoch in range(1, PARAMS['epochs'] + 1): 175 | train(PARAMS, model,criterion, PARAMS['DEVICE'], train_loader, optimizer, epoch) 176 | current_acc = test(PARAMS, model,criterion, PARAMS['DEVICE'], test_loader,optimizer,epoch,current_acc) 177 | scheduler.step() 178 | torch.save(model, 'new_saved_models/{}_{}_{}_proposed_nodiff.pth'.format(date.today(),PARAMS['model_name'],round(current_acc,2))) 179 | 180 | 181 | 182 | if __name__ == '__main__': 183 | main() --------------------------------------------------------------------------------