├── __init__.py ├── test_all ├── __init__.py └── test_classes.py ├── logger ├── __init__.py ├── logger.py ├── logger_config.json └── visualization.py ├── utils ├── __init__.py └── module_util.py ├── .requirements.txt.swp ├── Detection_Results ├── 1_GT.jpg ├── 1_LR.jpg ├── 1_SR.jpg ├── 2_GT.jpg ├── 2_LR.jpg ├── 2_SR.jpg ├── 1_GT_box.jpg ├── 2_GT_box.jpg ├── 2_LR_detect.jpg ├── 1_LR_detection.jpg ├── 1_SR_detection.jpg ├── 2_LR_detect_new.jpg ├── 2_SR_detection.jpg ├── 1_LR_detection_new.jpg └── overall_pipeline.PNG ├── base ├── __init__.py ├── base_model.py ├── base_data_loader.py └── base_trainer.py ├── model ├── .ESRGAN_EESN_Model.py.swp ├── metric.py ├── loss.py ├── lr_scheduler.py ├── gan_base_model.py ├── ESRGANModel.py └── ESRGAN_EESN_Model.py ├── scripts_for_datasets ├── __init__.py ├── cowc_FRCNN_dataset.py ├── COWC_dataset.py ├── COWC_GAN_dataset.py └── COWC_EESRGAN_FRCNN_dataset.py ├── trainer ├── __init__.py ├── trainer.py ├── cowc_trainer.py ├── cowc_GAN_FRCNN_trainer.py ├── cowc_GAN_trainer.py └── FRCNN_trainer.py ├── requirements.txt ├── config.json ├── detection ├── transforms.py ├── train.py ├── group_by_aspect_ratio.py ├── engine.py ├── utils.py ├── coco_utils.py └── coco_eval.py ├── .gitignore ├── test.py ├── train.py ├── config_GAN.json ├── README.md ├── parse_config.py └── data_loader └── data_loaders.py /__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /test_all/__init__.py: -------------------------------------------------------------------------------- 1 | from .test_classes import * 2 | -------------------------------------------------------------------------------- /logger/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import * 2 | from .visualization import * -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import * 2 | from .module_util import * 3 | -------------------------------------------------------------------------------- /.requirements.txt.swp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jakaria08/EESRGAN/HEAD/.requirements.txt.swp -------------------------------------------------------------------------------- /Detection_Results/1_GT.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jakaria08/EESRGAN/HEAD/Detection_Results/1_GT.jpg -------------------------------------------------------------------------------- /Detection_Results/1_LR.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jakaria08/EESRGAN/HEAD/Detection_Results/1_LR.jpg -------------------------------------------------------------------------------- /Detection_Results/1_SR.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jakaria08/EESRGAN/HEAD/Detection_Results/1_SR.jpg -------------------------------------------------------------------------------- /Detection_Results/2_GT.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jakaria08/EESRGAN/HEAD/Detection_Results/2_GT.jpg -------------------------------------------------------------------------------- /Detection_Results/2_LR.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jakaria08/EESRGAN/HEAD/Detection_Results/2_LR.jpg -------------------------------------------------------------------------------- /Detection_Results/2_SR.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jakaria08/EESRGAN/HEAD/Detection_Results/2_SR.jpg -------------------------------------------------------------------------------- /Detection_Results/1_GT_box.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jakaria08/EESRGAN/HEAD/Detection_Results/1_GT_box.jpg -------------------------------------------------------------------------------- /Detection_Results/2_GT_box.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jakaria08/EESRGAN/HEAD/Detection_Results/2_GT_box.jpg -------------------------------------------------------------------------------- /base/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_data_loader import * 2 | from .base_model import * 3 | from .base_trainer import * 4 | -------------------------------------------------------------------------------- /model/.ESRGAN_EESN_Model.py.swp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jakaria08/EESRGAN/HEAD/model/.ESRGAN_EESN_Model.py.swp -------------------------------------------------------------------------------- /Detection_Results/2_LR_detect.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jakaria08/EESRGAN/HEAD/Detection_Results/2_LR_detect.jpg -------------------------------------------------------------------------------- /Detection_Results/1_LR_detection.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jakaria08/EESRGAN/HEAD/Detection_Results/1_LR_detection.jpg -------------------------------------------------------------------------------- /Detection_Results/1_SR_detection.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jakaria08/EESRGAN/HEAD/Detection_Results/1_SR_detection.jpg -------------------------------------------------------------------------------- /Detection_Results/2_LR_detect_new.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jakaria08/EESRGAN/HEAD/Detection_Results/2_LR_detect_new.jpg -------------------------------------------------------------------------------- /Detection_Results/2_SR_detection.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jakaria08/EESRGAN/HEAD/Detection_Results/2_SR_detection.jpg -------------------------------------------------------------------------------- /Detection_Results/1_LR_detection_new.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jakaria08/EESRGAN/HEAD/Detection_Results/1_LR_detection_new.jpg -------------------------------------------------------------------------------- /Detection_Results/overall_pipeline.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jakaria08/EESRGAN/HEAD/Detection_Results/overall_pipeline.PNG -------------------------------------------------------------------------------- /scripts_for_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .COWC_dataset import * 2 | from .COWC_GAN_dataset import * 3 | from .cowc_FRCNN_dataset import * 4 | from .COWC_EESRGAN_FRCNN_dataset import * 5 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import * 2 | from .cowc_trainer import * 3 | from .cowc_GAN_trainer import * 4 | from .FRCNN_trainer import * 5 | from .cowc_GAN_FRCNN_trainer import * 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv_python_headless==4.1.1.26 2 | tqdm==4.36.1 3 | matplotlib==3.1.1 4 | albumentations==0.3.3 5 | torchvision==0.3.0 6 | pandas==0.25.1 7 | numpy==1.16.5 8 | pytest==5.2.0 9 | torch==1.1.0 10 | kornia==0.1.4.post2 11 | dataclasses==0.7 12 | Pillow==7.1.1 13 | pycocotools==2.0.0 14 | tensorboardX==2.0 15 | -------------------------------------------------------------------------------- /model/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def accuracy(output, target): 5 | with torch.no_grad(): 6 | pred = torch.argmax(output, dim=1) 7 | assert pred.shape[0] == len(target) 8 | correct = 0 9 | correct += torch.sum(pred == target).item() 10 | return correct / len(target) 11 | 12 | 13 | def top_k_acc(output, target, k=3): 14 | with torch.no_grad(): 15 | pred = torch.topk(output, k, dim=1)[1] 16 | assert pred.shape[0] == len(target) 17 | correct = 0 18 | for i in range(k): 19 | correct += torch.sum(pred[:, i] == target).item() 20 | return correct / len(target) 21 | -------------------------------------------------------------------------------- /base/base_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | from abc import abstractmethod 4 | 5 | 6 | class BaseModel(nn.Module): 7 | """ 8 | Base class for all models 9 | """ 10 | @abstractmethod 11 | def forward(self, *inputs): 12 | """ 13 | Forward pass logic 14 | 15 | :return: Model output 16 | """ 17 | raise NotImplementedError 18 | 19 | def __str__(self): 20 | """ 21 | Model prints with number of trainable parameters 22 | """ 23 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 24 | params = sum([np.prod(p.size()) for p in model_parameters]) 25 | return super().__str__() + '\nTrainable parameters: {}'.format(params) 26 | -------------------------------------------------------------------------------- /logger/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import logging.config 3 | from pathlib import Path 4 | from utils import read_json 5 | 6 | 7 | def setup_logging(save_dir, log_config='logger/logger_config.json', default_level=logging.INFO): 8 | """ 9 | Setup logging configuration 10 | """ 11 | log_config = Path(log_config) 12 | if log_config.is_file(): 13 | config = read_json(log_config) 14 | # modify logging paths based on run config 15 | for _, handler in config['handlers'].items(): 16 | if 'filename' in handler: 17 | handler['filename'] = str(save_dir / handler['filename']) 18 | 19 | logging.config.dictConfig(config) 20 | else: 21 | print("Warning: logging configuration file is not found in {}.".format(log_config)) 22 | logging.basicConfig(level=default_level) 23 | -------------------------------------------------------------------------------- /logger/logger_config.json: -------------------------------------------------------------------------------- 1 | 2 | { 3 | "version": 1, 4 | "disable_existing_loggers": false, 5 | "formatters": { 6 | "simple": {"format": "%(message)s"}, 7 | "datetime": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"} 8 | }, 9 | "handlers": { 10 | "console": { 11 | "class": "logging.StreamHandler", 12 | "level": "DEBUG", 13 | "formatter": "simple", 14 | "stream": "ext://sys.stdout" 15 | }, 16 | "info_file_handler": { 17 | "class": "logging.handlers.RotatingFileHandler", 18 | "level": "INFO", 19 | "formatter": "datetime", 20 | "filename": "info.log", 21 | "maxBytes": 10485760, 22 | "backupCount": 20, "encoding": "utf8" 23 | } 24 | }, 25 | "root": { 26 | "level": "INFO", 27 | "handlers": [ 28 | "console", 29 | "info_file_handler" 30 | ] 31 | } 32 | } -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "ImagePatchClassifier_TinyNet", 3 | "n_gpu": 1, 4 | 5 | "arch": { 6 | "type": "ImagePatchClassifier", 7 | "args": {} 8 | }, 9 | "data_loader": { 10 | "type": "COWCDataLoader", 11 | "args":{ 12 | "data_dir": "/home/jakaria/Super_Resolution/Datasets/COWC/DetectionPatches_256x256/Potsdam_ISPRS/Train-Mixed-class-2/", 13 | "batch_size": 1, 14 | "shuffle": true, 15 | "validation_split": 0.1, 16 | "num_workers": 2 17 | } 18 | }, 19 | "optimizer": { 20 | "type": "SGD", 21 | "args":{ 22 | "lr": 0.0001, 23 | "weight_decay": 0, 24 | "momentum": 0.9 25 | } 26 | }, 27 | "loss": "cross_entropy", 28 | "metrics": [ 29 | "accuracy" 30 | ], 31 | "lr_scheduler": { 32 | "type": "StepLR", 33 | "args": { 34 | "step_size": 50, 35 | "gamma": 0.1 36 | } 37 | }, 38 | "trainer": { 39 | "epochs": 100, 40 | 41 | "save_dir": "saved/", 42 | "save_period": 1, 43 | "verbosity": 2, 44 | 45 | "monitor": "min val_loss", 46 | "early_stop": 10, 47 | 48 | "tensorboard": true 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /detection/transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | from torchvision.transforms import functional as F 5 | 6 | 7 | def _flip_coco_person_keypoints(kps, width): 8 | flip_inds = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15] 9 | flipped_data = kps[:, flip_inds] 10 | flipped_data[..., 0] = width - flipped_data[..., 0] 11 | # Maintain COCO convention that if visibility == 0, then x, y = 0 12 | inds = flipped_data[..., 2] == 0 13 | flipped_data[inds] = 0 14 | return flipped_data 15 | 16 | 17 | class Compose(object): 18 | def __init__(self, transforms): 19 | self.transforms = transforms 20 | 21 | def __call__(self, image, target): 22 | for t in self.transforms: 23 | image, target = t(image, target) 24 | return image, target 25 | 26 | 27 | class RandomHorizontalFlip(object): 28 | def __init__(self, prob): 29 | self.prob = prob 30 | 31 | def __call__(self, image, target): 32 | if random.random() < self.prob: 33 | height, width = image.shape[-2:] 34 | image = image.flip(-1) 35 | bbox = target["boxes"] 36 | bbox[:, [0, 2]] = width - bbox[:, [2, 0]] 37 | target["boxes"] = bbox 38 | if "masks" in target: 39 | target["masks"] = target["masks"].flip(-1) 40 | if "keypoints" in target: 41 | keypoints = target["keypoints"] 42 | keypoints = _flip_coco_person_keypoints(keypoints, width) 43 | target["keypoints"] = keypoints 44 | return image, target 45 | 46 | 47 | class ToTensor(object): 48 | def __call__(self, image, target): 49 | image = F.to_tensor(image) 50 | return image, target 51 | -------------------------------------------------------------------------------- /test_all/test_classes.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import os 3 | import shutil 4 | import torch 5 | from parse_config import ConfigParser 6 | from utils import read_json, write_json 7 | from scripts_for_datasets import COWCDataset 8 | # run tests 9 | # python -m pytest test_all/ 10 | # python -m pytest test_all/ -s ==> to see print statements 11 | class TestCOWCDataset(): 12 | 13 | def test_image_annot_equality(self): 14 | # Test code for init method 15 | # Testing the dataset size and similarity 16 | config = read_json('config.json') 17 | config = ConfigParser(config) 18 | data_dir = config['data_loader']['args']['data_dir'] 19 | shutil.rmtree("./saved")#removing /saved directory, everytime created by ConfigParser 20 | a = COWCDataset(data_dir) 21 | for img, annot in zip(a.imgs, a.annotation): 22 | if os.path.splitext(img)[0] != os.path.splitext(annot)[0]: 23 | print("problem") 24 | print(len(a.annotation)) 25 | assert len(a.imgs) == len(a.annotation), "NOT equal" 26 | 27 | def test_zero_annotation(self): 28 | # Test for checking number of image without bounding box 29 | config = read_json('config.json') 30 | config = ConfigParser(config) 31 | data_dir = config['data_loader']['args']['data_dir'] 32 | shutil.rmtree("./saved")#removing /saved directory, everytime created by ConfigParser 33 | a = COWCDataset(data_dir) 34 | zero_annotation = 0 35 | for i in range(len(a.annotation)): 36 | zero_annotation_get = a[i] 37 | if zero_annotation_get['object'].item() == 0: 38 | zero_annotation += 1 39 | print("Number of image withot bbox: "+str(zero_annotation)) 40 | #use assert zero_annotation == 0 if all image contain bounding box 41 | assert zero_annotation != 0, "Image exists with bounding box" 42 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | # input data, saved log, checkpoints 104 | data/ 105 | input/ 106 | saved/ 107 | saved_ESRGAN/ 108 | saved_EEGAN_separate/ 109 | datasets/ 110 | 111 | # editor, os cache directory 112 | .vscode/ 113 | .idea/ 114 | __MACOSX/ 115 | -------------------------------------------------------------------------------- /base/base_data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import DataLoader 3 | from torch.utils.data.dataloader import default_collate 4 | from torch.utils.data.sampler import SubsetRandomSampler 5 | 6 | 7 | class BaseDataLoader(DataLoader): 8 | """ 9 | Base class for all data loaders 10 | """ 11 | def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate): 12 | self.validation_split = validation_split 13 | self.shuffle = shuffle 14 | 15 | self.batch_idx = 0 16 | self.n_samples = len(dataset) 17 | 18 | self.sampler, self.valid_sampler = self._split_sampler(self.validation_split) 19 | 20 | self.init_kwargs = { 21 | 'dataset': dataset, 22 | 'batch_size': batch_size, 23 | 'shuffle': self.shuffle, 24 | 'collate_fn': collate_fn, 25 | 'num_workers': num_workers 26 | } 27 | super().__init__(sampler=self.sampler, **self.init_kwargs) 28 | 29 | def _split_sampler(self, split): 30 | if split == 0.0: 31 | return None, None 32 | 33 | idx_full = np.arange(self.n_samples) 34 | 35 | np.random.seed(0) 36 | np.random.shuffle(idx_full) 37 | 38 | if isinstance(split, int): 39 | assert split > 0 40 | assert split < self.n_samples, "validation set size is configured to be larger than entire dataset." 41 | len_valid = split 42 | else: 43 | len_valid = int(self.n_samples * split) 44 | 45 | valid_idx = idx_full[0:len_valid] 46 | train_idx = np.delete(idx_full, np.arange(0, len_valid)) 47 | 48 | train_sampler = SubsetRandomSampler(train_idx) 49 | valid_sampler = SubsetRandomSampler(valid_idx) 50 | 51 | # turn off shuffle option which is mutually exclusive with sampler 52 | self.shuffle = False 53 | self.n_samples = len(train_idx) 54 | 55 | return train_sampler, valid_sampler 56 | 57 | def split_validation(self): 58 | if self.valid_sampler is None: 59 | return None 60 | else: 61 | return DataLoader(sampler=self.valid_sampler, **self.init_kwargs) 62 | -------------------------------------------------------------------------------- /utils/module_util.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Taken from https://github.com/xinntao/BasicSR/blob/master/codes/models/modules/module_util.py 3 | ''' 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.init as init 7 | import torch.nn.functional as F 8 | 9 | 10 | def initialize_weights(net_l, scale=1): 11 | if not isinstance(net_l, list): 12 | net_l = [net_l] 13 | for net in net_l: 14 | for m in net.modules(): 15 | if isinstance(m, nn.Conv2d): 16 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 17 | m.weight.data *= scale # for residual block 18 | if m.bias is not None: 19 | m.bias.data.zero_() 20 | elif isinstance(m, nn.Linear): 21 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 22 | m.weight.data *= scale 23 | if m.bias is not None: 24 | m.bias.data.zero_() 25 | elif isinstance(m, nn.BatchNorm2d): 26 | init.constant_(m.weight, 1) 27 | init.constant_(m.bias.data, 0.0) 28 | 29 | 30 | def make_layer(block, n_layers): 31 | layers = [] 32 | for _ in range(n_layers): 33 | layers.append(block()) 34 | return nn.Sequential(*layers) 35 | 36 | 37 | class ResidualBlock_noBN(nn.Module): 38 | '''Residual block w/o BN 39 | ---Conv-ReLU-Conv-+- 40 | |________________| 41 | ''' 42 | 43 | def __init__(self, nf=64): 44 | super(ResidualBlock_noBN, self).__init__() 45 | self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 46 | self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 47 | 48 | # initialization 49 | initialize_weights([self.conv1, self.conv2], 0.1) 50 | 51 | def forward(self, x): 52 | identity = x 53 | out = F.relu(self.conv1(x), inplace=True) 54 | out = self.conv2(out) 55 | return identity + out 56 | 57 | 58 | def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'): 59 | """Warp an image or feature map with optical flow 60 | Args: 61 | x (Tensor): size (N, C, H, W) 62 | flow (Tensor): size (N, H, W, 2), normal value 63 | interp_mode (str): 'nearest' or 'bilinear' 64 | padding_mode (str): 'zeros' or 'border' or 'reflection' 65 | Returns: 66 | Tensor: warped image or feature map 67 | """ 68 | assert x.size()[-2:] == flow.size()[1:3] 69 | B, C, H, W = x.size() 70 | # mesh grid 71 | grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) 72 | grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 73 | grid.requires_grad = False 74 | grid = grid.type_as(x) 75 | vgrid = grid + flow 76 | # scale grid to [-1,1] 77 | vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0 78 | vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0 79 | vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) 80 | output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode) 81 | return output 82 | -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | 5 | 6 | def nll_loss(output, target): 7 | return F.nll_loss(output, target) 8 | 9 | def cross_entropy(output, target): 10 | return F.cross_entropy(output, target) 11 | 12 | 13 | class CharbonnierLoss(nn.Module): 14 | """Charbonnier Loss (L1)""" 15 | 16 | def __init__(self, eps=1e-6): 17 | super(CharbonnierLoss, self).__init__() 18 | self.eps = eps 19 | 20 | def forward(self, x, y): 21 | diff = x - y 22 | loss = torch.mean(torch.sqrt(diff * diff + self.eps)) 23 | return loss 24 | 25 | 26 | # Define GAN loss: [vanilla | lsgan | wgan-gp] 27 | class GANLoss(nn.Module): 28 | def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0): 29 | super(GANLoss, self).__init__() 30 | self.gan_type = gan_type.lower() 31 | self.real_label_val = real_label_val 32 | self.fake_label_val = fake_label_val 33 | 34 | if self.gan_type == 'gan' or self.gan_type == 'ragan': 35 | self.loss = nn.BCEWithLogitsLoss() 36 | elif self.gan_type == 'lsgan': 37 | self.loss = nn.MSELoss() 38 | elif self.gan_type == 'wgan-gp': 39 | 40 | def wgan_loss(input, target): 41 | # target is boolean 42 | return -1 * input.mean() if target else input.mean() 43 | 44 | self.loss = wgan_loss 45 | else: 46 | raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type)) 47 | 48 | def get_target_label(self, input, target_is_real): 49 | if self.gan_type == 'wgan-gp': 50 | return target_is_real 51 | if target_is_real: 52 | return torch.empty_like(input).fill_(self.real_label_val) 53 | else: 54 | return torch.empty_like(input).fill_(self.fake_label_val) 55 | 56 | def forward(self, input, target_is_real): 57 | target_label = self.get_target_label(input, target_is_real) 58 | loss = self.loss(input, target_label) 59 | return loss 60 | 61 | 62 | class GradientPenaltyLoss(nn.Module): 63 | def __init__(self, device=torch.device('cpu')): 64 | super(GradientPenaltyLoss, self).__init__() 65 | self.register_buffer('grad_outputs', torch.Tensor()) 66 | self.grad_outputs = self.grad_outputs.to(device) 67 | 68 | def get_grad_outputs(self, input): 69 | if self.grad_outputs.size() != input.size(): 70 | self.grad_outputs.resize_(input.size()).fill_(1.0) 71 | return self.grad_outputs 72 | 73 | def forward(self, interp, interp_crit): 74 | grad_outputs = self.get_grad_outputs(interp_crit) 75 | grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp, 76 | grad_outputs=grad_outputs, create_graph=True, 77 | retain_graph=True, only_inputs=True)[0] 78 | grad_interp = grad_interp.view(grad_interp.size(0), -1) 79 | grad_interp_norm = grad_interp.norm(2, dim=1) 80 | 81 | loss = ((grad_interp_norm - 1)**2).mean() 82 | return loss 83 | -------------------------------------------------------------------------------- /logger/visualization.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from datetime import datetime 3 | 4 | 5 | class TensorboardWriter(): 6 | def __init__(self, log_dir, logger, enabled): 7 | self.writer = None 8 | self.selected_module = "" 9 | 10 | if enabled: 11 | log_dir = str(log_dir) 12 | 13 | # Retrieve vizualization writer. 14 | succeeded = False 15 | for module in ["torch.utils.tensorboard", "tensorboardX"]: 16 | try: 17 | self.writer = importlib.import_module(module).SummaryWriter(log_dir) 18 | succeeded = True 19 | break 20 | except ImportError: 21 | succeeded = False 22 | self.selected_module = module 23 | 24 | if not succeeded: 25 | message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \ 26 | "this machine. Please install either TensorboardX with 'pip install tensorboardx', upgrade " \ 27 | "PyTorch to version >= 1.1 for using 'torch.utils.tensorboard' or turn off the option in " \ 28 | "the 'config.json' file." 29 | logger.warning(message) 30 | 31 | self.step = 0 32 | self.mode = '' 33 | 34 | self.tb_writer_ftns = { 35 | 'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio', 36 | 'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding' 37 | } 38 | self.tag_mode_exceptions = {'add_histogram', 'add_embedding'} 39 | self.timer = datetime.now() 40 | 41 | def set_step(self, step, mode='train'): 42 | self.mode = mode 43 | self.step = step 44 | if step == 0: 45 | self.timer = datetime.now() 46 | else: 47 | duration = datetime.now() - self.timer 48 | self.add_scalar('steps_per_sec', 1 / duration.total_seconds()) 49 | self.timer = datetime.now() 50 | 51 | def __getattr__(self, name): 52 | """ 53 | If visualization is configured to use: 54 | return add_data() methods of tensorboard with additional information (step, tag) added. 55 | Otherwise: 56 | return a blank function handle that does nothing 57 | """ 58 | if name in self.tb_writer_ftns: 59 | add_data = getattr(self.writer, name, None) 60 | 61 | def wrapper(tag, data, *args, **kwargs): 62 | if add_data is not None: 63 | # add mode(train/valid) tag 64 | if name not in self.tag_mode_exceptions: 65 | tag = '{}/{}'.format(tag, self.mode) 66 | add_data(tag, data, self.step, *args, **kwargs) 67 | return wrapper 68 | else: 69 | # default action for returning methods defined in this class, set_step() for instance. 70 | try: 71 | attr = object.__getattr__(name) 72 | except AttributeError: 73 | raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name)) 74 | return attr 75 | -------------------------------------------------------------------------------- /scripts_for_datasets/cowc_FRCNN_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import torch 4 | import numpy as np 5 | import glob 6 | import cv2 7 | from torch.utils.data import Dataset, DataLoader 8 | 9 | class COWCFRCNNDataset(Dataset): 10 | def __init__(self, root=None, image_height=256, image_width=256, transforms = None): 11 | self.root = root 12 | #take all under same folder for train and test split. 13 | self.transforms = transforms 14 | self.image_height = image_height 15 | self.image_width = image_width 16 | #sort all images for indexing, filter out check.jpgs 17 | self.imgs = list(sorted(set(glob.glob(self.root+"*.jpg") + 18 | glob.glob(self.root+"*.png")) - 19 | set(glob.glob(self.root+"*check.jpg")))) 20 | self.annotation = list(sorted(glob.glob(self.root+"*.txt"))) 21 | 22 | def __getitem__(self, idx): 23 | #get the paths 24 | img_path = os.path.join(self.root, self.imgs[idx]) 25 | annotation_path = os.path.join(self.root, self.annotation[idx]) 26 | img = cv2.imread(img_path,1) #read color image height*width*channel=3 27 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 28 | #get the bounding box 29 | boxes = list() 30 | with open(annotation_path) as f: 31 | for line in f: 32 | values = (line.split()) 33 | if "\ufeff" in values[0]: 34 | values[0] = values[0][-1] 35 | #get coordinates withing height width range 36 | x = float(values[1])*self.image_width 37 | y = float(values[2])*self.image_height 38 | width = float(values[3])*self.image_width 39 | height = float(values[4])*self.image_height 40 | #creating bounding boxes that would not touch the image edges 41 | x_min = 1 if x - width/2 <= 0 else int(x - width/2) 42 | x_max = self.image_width-1 if x + width/2 >= self.image_width-1 else int(x + width/2) 43 | y_min = 1 if y - height/2 <= 0 else int(y - height/2) 44 | y_max = self.image_height-1 if y + height/2 >= self.image_height-1 else int(y + height/2) 45 | 46 | x_min = int(x_min) 47 | x_max = int(x_max) 48 | y_min = int(y_min) 49 | y_max = int(y_max) 50 | 51 | boxes.append([x_min, y_min, x_max, y_max]) 52 | #print(boxes) 53 | 54 | boxes = torch.as_tensor(boxes, dtype=torch.float32) 55 | # there is only one class 56 | labels = torch.ones((len(boxes),), dtype=torch.int64) 57 | image_id = torch.tensor([idx]) 58 | area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) 59 | # suppose all instances are not crowd 60 | iscrowd = torch.zeros((len(boxes),), dtype=torch.int64) 61 | #create dictionary to access the values 62 | target = {} 63 | target["boxes"] = boxes 64 | target["labels"] = labels 65 | target["image_id"] = image_id 66 | target["area"] = area 67 | target["iscrowd"] = iscrowd 68 | 69 | if self.transforms is not None: 70 | img, target = self.transforms(img, target) 71 | 72 | #return img, target 73 | return img, target, annotation_path 74 | 75 | def __len__(self): 76 | return len(self.imgs) 77 | -------------------------------------------------------------------------------- /model/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import Counter 3 | from collections import defaultdict 4 | import torch 5 | from torch.optim.lr_scheduler import _LRScheduler 6 | 7 | 8 | class MultiStepLR_Restart(_LRScheduler): 9 | def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1, 10 | clear_state=False, last_epoch=-1): 11 | self.milestones = Counter(milestones) 12 | self.gamma = gamma 13 | self.clear_state = clear_state 14 | self.restarts = restarts if restarts else [0] 15 | self.restart_weights = weights if weights else [1] 16 | assert len(self.restarts) == len( 17 | self.restart_weights), 'restarts and their weights do not match.' 18 | super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch) 19 | 20 | def get_lr(self): 21 | if self.last_epoch in self.restarts: 22 | if self.clear_state: 23 | self.optimizer.state = defaultdict(dict) 24 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 25 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups] 26 | if self.last_epoch not in self.milestones: 27 | return [group['lr'] for group in self.optimizer.param_groups] 28 | return [ 29 | group['lr'] * self.gamma**self.milestones[self.last_epoch] 30 | for group in self.optimizer.param_groups 31 | ] 32 | 33 | 34 | class CosineAnnealingLR_Restart(_LRScheduler): 35 | def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1): 36 | self.T_period = T_period 37 | self.T_max = self.T_period[0] # current T period 38 | self.eta_min = eta_min 39 | self.restarts = restarts if restarts else [0] 40 | self.restart_weights = weights if weights else [1] 41 | self.last_restart = 0 42 | assert len(self.restarts) == len( 43 | self.restart_weights), 'restarts and their weights do not match.' 44 | super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch) 45 | 46 | def get_lr(self): 47 | if self.last_epoch == 0: 48 | return self.base_lrs 49 | elif self.last_epoch in self.restarts: 50 | self.last_restart = self.last_epoch 51 | self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1] 52 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 53 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups] 54 | elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0: 55 | return [ 56 | group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 57 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 58 | ] 59 | return [(1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) / 60 | (1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) * 61 | (group['lr'] - self.eta_min) + self.eta_min 62 | for group in self.optimizer.param_groups] 63 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from tqdm import tqdm 4 | import data_loader.data_loaders as module_data 5 | import model.loss as module_loss 6 | import model.metric as module_metric 7 | import model.model as module_arch 8 | from parse_config import ConfigParser 9 | from trainer import COWCFRCNNTrainer, COWCGANTrainer, COWCGANFrcnnTrainer 10 | ''' 11 | python test.py -c config_GAN.json 12 | ''' 13 | 14 | def main(config): 15 | 16 | data_loader = module_data.COWCGANFrcnnDataLoader('/home/jakaria/Super_Resolution/Datasets/COWC/DetectionPatches_256x256/Potsdam_ISPRS/HR/x4/valid_img/', 17 | '/home/jakaria/Super_Resolution/Datasets/COWC/DetectionPatches_256x256/Potsdam_ISPRS/LR/x4/valid_img/', 1, training=False) 18 | tester = COWCGANFrcnnTrainer(config=config, data_loader=data_loader) 19 | tester.test() 20 | 21 | ''' 22 | tester = COWCFRCNNTrainer(config=config) 23 | tester.test() 24 | ''' 25 | ''' 26 | logger = config.get_logger('test') 27 | 28 | # setup data_loader instances 29 | data_loader = getattr(module_data, config['data_loader']['type'])( 30 | config['data_loader']['args']['data_dir'], 31 | batch_size=512, 32 | shuffle=False, 33 | validation_split=0.0, 34 | training=False, 35 | num_workers=2 36 | ) 37 | 38 | # build model architecture 39 | model = config.init_obj('arch', module_arch) 40 | logger.info(model) 41 | 42 | # get function handles of loss and metrics 43 | loss_fn = getattr(module_loss, config['loss']) 44 | metric_fns = [getattr(module_metric, met) for met in config['metrics']] 45 | 46 | logger.info('Loading checkpoint: {} ...'.format(config.resume)) 47 | checkpoint = torch.load(config.resume) 48 | state_dict = checkpoint['state_dict'] 49 | if config['n_gpu'] > 1: 50 | model = torch.nn.DataParallel(model) 51 | model.load_state_dict(state_dict) 52 | 53 | # prepare model for testing 54 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 55 | model = model.to(device) 56 | model.eval() 57 | 58 | total_loss = 0.0 59 | total_metrics = torch.zeros(len(metric_fns)) 60 | 61 | with torch.no_grad(): 62 | for i, (data, target) in enumerate(tqdm(data_loader)): 63 | data, target = data.to(device), target.to(device) 64 | output = model(data) 65 | 66 | # 67 | # save sample images, or do something with output here 68 | # 69 | 70 | # computing loss, metrics on test set 71 | loss = loss_fn(output, target) 72 | batch_size = data.shape[0] 73 | total_loss += loss.item() * batch_size 74 | for i, metric in enumerate(metric_fns): 75 | total_metrics[i] += metric(output, target) * batch_size 76 | 77 | n_samples = len(data_loader.sampler) 78 | log = {'loss': total_loss / n_samples} 79 | log.update({ 80 | met.__name__: total_metrics[i].item() / n_samples for i, met in enumerate(metric_fns) 81 | }) 82 | logger.info(log) 83 | ''' 84 | 85 | 86 | if __name__ == '__main__': 87 | args = argparse.ArgumentParser(description='PyTorch Template') 88 | args.add_argument('-c', '--config', default=None, type=str, 89 | help='config file path (default: None)') 90 | args.add_argument('-r', '--resume', default=None, type=str, 91 | help='path to latest checkpoint (default: None)') 92 | args.add_argument('-d', '--device', default=None, type=str, 93 | help='indices of GPUs to enable (default: all)') 94 | 95 | config = ConfigParser.from_args(args) 96 | main(config) 97 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import argparse 3 | import collections 4 | import torch 5 | import os 6 | import numpy as np 7 | import data_loader.data_loaders as module_data 8 | import model.loss as module_loss 9 | import model.metric as module_metric 10 | import model.model as module_arch 11 | from parse_config import ConfigParser 12 | from trainer import COWCTrainer 13 | from trainer import COWCGANTrainer 14 | from trainer import COWCFRCNNTrainer 15 | from trainer import COWCGANFrcnnTrainer 16 | from utils import setup_logger, dict2str 17 | ''' 18 | python train.py -c config_GAN.json 19 | ''' 20 | 21 | # fix random seeds for reproducibility 22 | SEED = 123 23 | torch.manual_seed(SEED) 24 | torch.backends.cudnn.deterministic = True 25 | torch.backends.cudnn.benchmark = False 26 | np.random.seed(SEED) 27 | 28 | def main(config): 29 | #logger = config.get_logger('train') 30 | # config loggers. Before it, the log will not work 31 | setup_logger('base', config['path']['log'], 'train_' + config['name'], level=logging.INFO, 32 | screen=True, tofile=True) 33 | setup_logger('val', config['path']['log'], 'val_' + config['name'], level=logging.INFO, 34 | screen=True, tofile=True) 35 | logger = logging.getLogger('base') 36 | #logger.info(dict2str(config)) 37 | 38 | 39 | # setup data_loader instances 40 | data_loader = config.init_obj('data_loader', module_data) 41 | #change later this valid_data_loader using init_obj 42 | valid_data_loader = module_data.COWCGANFrcnnDataLoader('/home/jakaria/Super_Resolution/Datasets/COWC/DetectionPatches_256x256/Potsdam_ISPRS/HR/x4/valid_img/', 43 | '/home/jakaria/Super_Resolution/Datasets/COWC/DetectionPatches_256x256/Potsdam_ISPRS/LR/x4/valid_img/', 1, training = False) 44 | 45 | # build model architecture, then print to console 46 | #model = config.init_obj('arch', module_arch) 47 | #logger.info(model) 48 | 49 | # get function handles of loss and metrics 50 | #criterion = getattr(module_loss, config['loss']) 51 | #metrics = [getattr(module_metric, met) for met in config['metrics']] 52 | 53 | # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler 54 | #trainable_params = filter(lambda p: p.requires_grad, model.parameters()) 55 | #optimizer = config.init_obj('optimizer', torch.optim, trainable_params) 56 | 57 | #lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler, optimizer) 58 | ''' 59 | trainer = COWCGANTrainer(model, criterion, metrics, optimizer, 60 | config=config, 61 | data_loader=data_loader, 62 | valid_data_loader=valid_data_loader, 63 | lr_scheduler=lr_scheduler) 64 | ''' 65 | ''' 66 | trainer = COWCGANTrainer(config=config,data_loader=data_loader, 67 | valid_data_loader=valid_data_loader 68 | ) 69 | ''' 70 | 71 | trainer = COWCGANFrcnnTrainer(config=config, data_loader=data_loader, 72 | valid_data_loader=valid_data_loader) 73 | trainer.train() 74 | ''' 75 | trainer = COWCFRCNNTrainer(config=config) 76 | trainer.train() 77 | ''' 78 | if __name__ == '__main__': 79 | args = argparse.ArgumentParser(description='PyTorch Template') 80 | args.add_argument('-c', '--config', default=None, type=str, 81 | help='config file path (default: None)') 82 | args.add_argument('-r', '--resume', default=None, type=str, 83 | help='path to latest checkpoint (default: None)') 84 | args.add_argument('-d', '--device', default=None, type=str, 85 | help='indices of GPUs to enable (default: all)') 86 | 87 | # custom cli options to modify configuration from default values given in json file. 88 | CustomArgs = collections.namedtuple('CustomArgs', 'flags type target') 89 | options = [ 90 | CustomArgs(['--lr', '--learning_rate'], type=float, target='optimizer;args;lr'), 91 | CustomArgs(['--bs', '--batch_size'], type=int, target='data_loader;args;batch_size') 92 | ] 93 | config = ConfigParser.from_args(args, options) 94 | main(config) 95 | -------------------------------------------------------------------------------- /scripts_for_datasets/COWC_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import torch 4 | import numpy as np 5 | import glob 6 | import cv2 7 | import matplotlib.pyplot as plt 8 | from torch.utils.data import Dataset, DataLoader 9 | from torchvision import transforms, utils 10 | 11 | # Ignore warnings 12 | import warnings 13 | warnings.filterwarnings("ignore") 14 | 15 | 16 | class COWCDataset(Dataset): 17 | def __init__(self, root, image_height=256, image_width=256, transform = None): 18 | self.root = root 19 | #take all under same folder for train and test split. 20 | self.transform = transform 21 | self.image_height = image_height 22 | self.image_width = image_width 23 | #sort all images for indexing, filter out check.jpgs 24 | self.imgs = list(sorted(set(glob.glob(self.root+"*.jpg")) - set(glob.glob(self.root+"*check.jpg")))) 25 | self.annotation = list(sorted(glob.glob(self.root+"*.txt"))) 26 | 27 | def __getitem__(self, idx): 28 | #get the paths 29 | img_path = os.path.join(self.root, self.imgs[idx]) 30 | annotation_path = os.path.join(self.root, self.annotation[idx]) 31 | img = cv2.imread(img_path,1) #read color image height*width*channel=3 32 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 33 | #get the bounding box 34 | boxes = list() 35 | label_car_type = list() 36 | with open(annotation_path) as f: 37 | for line in f: 38 | values = (line.split()) 39 | if "\ufeff" in values[0]: 40 | values[0] = values[0][-1] 41 | obj_class = int(values[0]) 42 | #image without bounding box - in txt file, line starts with 0 and only contains only 0 43 | if obj_class == 0: 44 | boxes.append([0, 0, 1, 1]) 45 | labels = np.ones(len(boxes)) # all are cars 46 | label_car_type.append(obj_class) 47 | #create dictionary to access the values 48 | target = {} 49 | target['object'] = 0 50 | target['image'] = img 51 | target['bboxes'] = boxes 52 | target['labels'] = labels 53 | target['label_car_type'] = label_car_type 54 | target['idx'] = idx 55 | break 56 | else: 57 | #get coordinates withing height width range 58 | x = float(values[1])*self.image_width 59 | y = float(values[2])*self.image_height 60 | width = float(values[3])*self.image_width 61 | height = float(values[4])*self.image_height 62 | #creating bounding boxes that would not touch the image edges 63 | x_min = 1 if x - width/2 <= 0 else int(x - width/2) 64 | x_max = 255 if x + width/2 >= 256 else int(x + width/2) 65 | y_min = 1 if y - height/2 <= 0 else int(y - height/2) 66 | y_max = 255 if y + height/2 >= 256 else int(y + height/2) 67 | 68 | boxes.append([x_min, y_min, x_max, y_max]) 69 | label_car_type.append(obj_class) 70 | 71 | if obj_class != 0: 72 | labels = np.ones(len(boxes)) # all are cars 73 | #create dictionary to access the values 74 | target = {} 75 | target['object'] = 1 76 | target['image'] = img 77 | target['bboxes'] = boxes 78 | target['labels'] = labels 79 | target['label_car_type'] = label_car_type 80 | target['idx'] = idx 81 | 82 | if self.transform is None: 83 | #convert to tensor 84 | target = self.convert_to_tensor(**target) 85 | return target 86 | #transform 87 | else: 88 | transformed = self.transform(**target) 89 | #print(transformed['image'], transformed['bboxes'], transformed['labels'], transformed['idx']) 90 | target = self.convert_to_tensor(**transformed) 91 | return target 92 | 93 | def __len__(self): 94 | return len(self.imgs) 95 | 96 | def convert_to_tensor(self, **target): 97 | #convert to tensor 98 | target['object'] = torch.tensor(target['object'], dtype=torch.int64) 99 | target['image'] = torch.from_numpy(target['image'].transpose((2, 0, 1))) 100 | target['bboxes'] = torch.as_tensor(target['bboxes'], dtype=torch.int64) 101 | target['labels'] = torch.ones(len(target['bboxes']), dtype=torch.int64) 102 | target['label_car_type'] = torch.as_tensor(target['label_car_type'], dtype=torch.int64) 103 | target['image_id'] = torch.tensor([target['idx']]) 104 | 105 | return target 106 | -------------------------------------------------------------------------------- /trainer/trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torchvision.utils import make_grid 4 | from base import BaseTrainer 5 | from utils import inf_loop, MetricTracker 6 | 7 | 8 | class Trainer(BaseTrainer): 9 | """ 10 | Trainer class 11 | """ 12 | def __init__(self, model, criterion, metric_ftns, optimizer, config, data_loader, 13 | valid_data_loader=None, lr_scheduler=None, len_epoch=None): 14 | super().__init__(model, criterion, metric_ftns, optimizer, config) 15 | self.config = config 16 | self.data_loader = data_loader 17 | if len_epoch is None: 18 | # epoch-based training 19 | self.len_epoch = len(self.data_loader) 20 | else: 21 | # iteration-based training 22 | self.data_loader = inf_loop(data_loader) 23 | self.len_epoch = len_epoch 24 | self.valid_data_loader = valid_data_loader 25 | self.do_validation = self.valid_data_loader is not None 26 | self.lr_scheduler = lr_scheduler 27 | self.log_step = int(np.sqrt(data_loader.batch_size)) 28 | 29 | self.train_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) 30 | self.valid_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) 31 | 32 | def _train_epoch(self, epoch): 33 | """ 34 | Training logic for an epoch 35 | 36 | :param epoch: Integer, current training epoch. 37 | :return: A log that contains average loss and metric in this epoch. 38 | """ 39 | self.model.train() 40 | self.train_metrics.reset() 41 | for batch_idx, (data, target) in enumerate(self.data_loader): 42 | data, target = data.to(self.device), target.to(self.device) 43 | 44 | self.optimizer.zero_grad() 45 | output = self.model(data) 46 | loss = self.criterion(output, target) 47 | loss.backward() 48 | self.optimizer.step() 49 | 50 | self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) 51 | self.train_metrics.update('loss', loss.item()) 52 | for met in self.metric_ftns: 53 | self.train_metrics.update(met.__name__, met(output, target)) 54 | 55 | if batch_idx % self.log_step == 0: 56 | self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format( 57 | epoch, 58 | self._progress(batch_idx), 59 | loss.item())) 60 | self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) 61 | 62 | if batch_idx == self.len_epoch: 63 | break 64 | log = self.train_metrics.result() 65 | 66 | if self.do_validation: 67 | val_log = self._valid_epoch(epoch) 68 | log.update(**{'val_'+k : v for k, v in val_log.items()}) 69 | 70 | if self.lr_scheduler is not None: 71 | self.lr_scheduler.step() 72 | return log 73 | 74 | def _valid_epoch(self, epoch): 75 | """ 76 | Validate after training an epoch 77 | 78 | :param epoch: Integer, current training epoch. 79 | :return: A log that contains information about validation 80 | """ 81 | self.model.eval() 82 | self.valid_metrics.reset() 83 | with torch.no_grad(): 84 | for batch_idx, (data, target) in enumerate(self.valid_data_loader): 85 | data, target = data.to(self.device), target.to(self.device) 86 | 87 | output = self.model(data) 88 | loss = self.criterion(output, target) 89 | 90 | self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') 91 | self.valid_metrics.update('loss', loss.item()) 92 | for met in self.metric_ftns: 93 | self.valid_metrics.update(met.__name__, met(output, target)) 94 | self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) 95 | 96 | # add histogram of model parameters to the tensorboard 97 | for name, p in self.model.named_parameters(): 98 | self.writer.add_histogram(name, p, bins='auto') 99 | return self.valid_metrics.result() 100 | 101 | def _progress(self, batch_idx): 102 | base = '[{}/{} ({:.0f}%)]' 103 | if hasattr(self.data_loader, 'n_samples'): 104 | current = batch_idx * self.data_loader.batch_size 105 | total = self.data_loader.n_samples 106 | else: 107 | current = batch_idx 108 | total = self.len_epoch 109 | return base.format(current, total, 100.0 * current / total) 110 | -------------------------------------------------------------------------------- /model/gan_base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.parallel import DistributedDataParallel 6 | 7 | 8 | class BaseModel(): 9 | def __init__(self, config, device): 10 | self.config = config 11 | self.device = device 12 | self.schedulers = [] 13 | self.optimizers = [] 14 | 15 | def feed_data(self, data): 16 | pass 17 | 18 | def optimize_parameters(self): 19 | pass 20 | 21 | def get_current_visuals(self): 22 | pass 23 | 24 | def get_current_losses(self): 25 | pass 26 | 27 | def print_network(self): 28 | pass 29 | 30 | def save(self, label): 31 | pass 32 | 33 | def load(self): 34 | pass 35 | 36 | def _set_lr(self, lr_groups_l): 37 | ''' set learning rate for warmup, 38 | lr_groups_l: list for lr_groups. each for a optimizer''' 39 | for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): 40 | for param_group, lr in zip(optimizer.param_groups, lr_groups): 41 | param_group['lr'] = lr 42 | 43 | def _get_init_lr(self): 44 | # get the initial lr, which is set by the scheduler 45 | init_lr_groups_l = [] 46 | for optimizer in self.optimizers: 47 | init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups]) 48 | return init_lr_groups_l 49 | 50 | def update_learning_rate(self, cur_iter, warmup_iter=-1): 51 | for scheduler in self.schedulers: 52 | scheduler.step() 53 | #### set up warm up learning rate 54 | if cur_iter < warmup_iter: 55 | # get initial lr for each group 56 | init_lr_g_l = self._get_init_lr() 57 | # modify warming-up learning rates 58 | warm_up_lr_l = [] 59 | for init_lr_g in init_lr_g_l: 60 | warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g]) 61 | # set learning rate 62 | self._set_lr(warm_up_lr_l) 63 | 64 | def get_current_learning_rate(self): 65 | # return self.schedulers[0].get_lr()[0] 66 | return self.optimizers[0].param_groups[0]['lr'] 67 | 68 | def get_network_description(self, network): 69 | '''Get the string and total parameters of the network''' 70 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 71 | network = network.module 72 | s = str(network) 73 | n = sum(map(lambda x: x.numel(), network.parameters())) 74 | return s, n 75 | 76 | def save_network(self, network, network_label, iter_label): 77 | save_filename = '{}_{}.pth'.format(iter_label, network_label) 78 | save_path = os.path.join(self.config['path']['models'], save_filename) 79 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 80 | network = network.module 81 | state_dict = network.state_dict() 82 | for key, param in state_dict.items(): 83 | state_dict[key] = param.cpu() 84 | torch.save(state_dict, save_path) 85 | 86 | def load_network(self, load_path, network, strict=True): 87 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 88 | network = network.module 89 | load_net = torch.load(load_path) 90 | load_net_clean = OrderedDict() # remove unnecessary 'module.' 91 | for k, v in load_net.items(): 92 | if k.startswith('module.'): 93 | load_net_clean[k[7:]] = v 94 | else: 95 | load_net_clean[k] = v 96 | network.load_state_dict(load_net_clean, strict=strict) 97 | 98 | def save_training_state(self, epoch, iter_step): 99 | '''Saves training state during training, which will be used for resuming''' 100 | state = {'epoch': epoch, 'iter': iter_step, 'schedulers': [], 'optimizers': []} 101 | for s in self.schedulers: 102 | state['schedulers'].append(s.state_dict()) 103 | for o in self.optimizers: 104 | state['optimizers'].append(o.state_dict()) 105 | save_filename = '{}.state'.format(iter_step) 106 | save_path = os.path.join(self.config['path']['training_state'], save_filename) 107 | torch.save(state, save_path) 108 | 109 | def resume_training(self, resume_state): 110 | '''Resume the optimizers and schedulers for training''' 111 | resume_optimizers = resume_state['optimizers'] 112 | resume_schedulers = resume_state['schedulers'] 113 | assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers' 114 | assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers' 115 | for i, o in enumerate(resume_optimizers): 116 | self.optimizers[i].load_state_dict(o) 117 | for i, s in enumerate(resume_schedulers): 118 | self.schedulers[i].load_state_dict(s) 119 | -------------------------------------------------------------------------------- /config_GAN.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "RRDB_ESRGANx4", 3 | "n_gpu": 1, 4 | "model": "srgan", 5 | "distortion": "sr", 6 | "scale": 4, 7 | "use_tb_logger": true, 8 | 9 | "network_G": { 10 | "which_model_G": "RRDBNet", 11 | "in_nc": 3, 12 | "out_nc": 3, 13 | 14 | "nf": 64, 15 | "nb": 23, 16 | "args": {} 17 | }, 18 | "network_D": { 19 | "which_model_G": "discriminator_vgg_128", 20 | "in_nc": 3, 21 | "nf": 64, 22 | "args": {} 23 | }, 24 | "data_loader": { 25 | "type": "COWCGANFrcnnDataLoader", 26 | "args":{ 27 | "data_dir_GT": "/home/jakaria/Super_Resolution/Datasets/COWC/DetectionPatches_256x256/Potsdam_ISPRS/HR/x4/", 28 | "data_dir_LQ": "/home/jakaria/Super_Resolution/Datasets/COWC/DetectionPatches_256x256/Potsdam_ISPRS/LR/x4/", 29 | "batch_size": 2, 30 | "shuffle": true, 31 | "validation_split": 0.0, 32 | "num_workers": 2 33 | } 34 | }, 35 | "optimizer": { 36 | "type": "SGD", 37 | "args":{ 38 | "lr_G": 0.0001, 39 | "weight_decay_G": 0, 40 | "beta1_G": 0.9, 41 | "beta2_G": 0.99, 42 | 43 | "lr_D": 0.0001, 44 | "weight_decay_D": 0, 45 | "beta1_D": 0.9, 46 | "beta2_D": 0.99 47 | } 48 | }, 49 | "loss": "cross_entropy", 50 | "metrics": [ 51 | "accuracy" 52 | ], 53 | "lr_scheduler": { 54 | "type": "MultiStepLR", 55 | "args": { 56 | "lr_steps": [50000, 100000, 200000, 300000], 57 | "lr_gamma": 0.5, 58 | "T_period": [250000, 250000, 250000, 250000], 59 | "restarts": [250000, 500000, 750000], 60 | "restart_weights": [1, 1, 1], 61 | "eta_min": 0.0000001 62 | } 63 | }, 64 | "train": { 65 | "niter": 400000, 66 | "warmup_iter": -1, 67 | "pixel_criterion": "l1", 68 | "pixel_weight": 0.01, 69 | "feature_criterion": "l1", 70 | "feature_weight": 1, 71 | 72 | "gan_type": "ragan", 73 | "gan_weight": 0.001, 74 | "D_update_ratio": 1, 75 | "D_init_iters": 0, 76 | "manual_seed": 10, 77 | "val_freq": 1000, 78 | 79 | "save_dir": "saved/", 80 | "save_period": 1, 81 | "verbosity": 2, 82 | 83 | "monitor": "min val_loss", 84 | "early_stop": 10, 85 | 86 | "tensorboard": true 87 | }, 88 | "path": { 89 | "models": "saved/pretrained_models_EESRGAN_FRCNN", 90 | "FRCNN_model": "/home/jakaria/Super_Resolution/Filter_Enhance_Detect/saved/FRCNN_model_LR_LR_cowc/", 91 | "pretrain_model_G": "/home/jakaria/Super_Resolution/Filter_Enhance_Detect/saved/pretrained_models_EESRGAN_FRCNN/170000_G.pth", 92 | "pretrain_model_D": "/home/jakaria/Super_Resolution/Filter_Enhance_Detect/saved/pretrained_models_EESRGAN_FRCNN/170000_D.pth", 93 | "pretrain_model_FRCNN": "/home/jakaria/Super_Resolution/Filter_Enhance_Detect/saved/pretrained_models_EESRGAN_FRCNN/170000_FRCNN.pth", 94 | "pretrain_model_FRCNN_LR_LR": "/home/jakaria/Super_Resolution/Filter_Enhance_Detect/saved/FRCNN_model_LR_LR_cowc/0_FRCNN_LR_LR.pth", 95 | "training_state": "saved/training_state", 96 | "strict_load": true, 97 | "resume_state": "~", 98 | "val_images": "saved/val_images", 99 | "output_images": "saved/val_images_cars_new", 100 | "log": "saved/logs", 101 | "data_dir_Valid": "/home/jakaria/Super_Resolution/Datasets/COWC/DetectionPatches_256x256/Potsdam_ISPRS/LR/x4/valid_img/", 102 | "data_dir_F_SR": "/home/jakaria/Super_Resolution/Filter_Enhance_Detect/saved/Final_SR_images_test/", 103 | "data_dir_SR": "/home/jakaria/Super_Resolution/Filter_Enhance_Detect/saved/SR_images_test/", 104 | "data_dir_SR_combined": "/home/jakaria/Super_Resolution/Filter_Enhance_Detect/saved/combined_SR_images_216000/", 105 | "data_dir_E_SR_1": "/home/jakaria/Super_Resolution/Filter_Enhance_Detect/saved/enhanced_SR_images_1/", 106 | "data_dir_E_SR_2": "/home/jakaria/Super_Resolution/Filter_Enhance_Detect/saved/enhanced_SR_images_2/", 107 | "data_dir_E_SR_3": "/home/jakaria/Super_Resolution/Filter_Enhance_Detect/saved/enhanced_SR_images_3/", 108 | "data_dir_Bic": "/home/jakaria/Super_Resolution/Datasets/COWC/DetectionPatches_256x256/Potsdam_ISPRS/HR/x4/valid_img/", 109 | "data_dir_LR_train": "/home/jakaria/Super_Resolution/Datasets/COWC/DetectionPatches_256x256/Potsdam_ISPRS/LR/x4/", 110 | "data_dir_Bic_valid": "/home/jakaria/Super_Resolution/Datasets/COWC/DetectionPatches_256x256/Potsdam_ISPRS/Bic/x4/valid_img/", 111 | "Test_Result_LR_LR_COWC": "/home/jakaria/Super_Resolution/Filter_Enhance_Detect/saved/Test_Result_LR_LR_COWC/", 112 | "Test_Result_SR": "/home/jakaria/Super_Resolution/Filter_Enhance_Detect/saved/Test_Result_SR/" 113 | }, 114 | "logger": { 115 | "print_freq": 100, 116 | "save_checkpoint_freq": 1000 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /scripts_for_datasets/COWC_GAN_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import torch 4 | import numpy as np 5 | import glob 6 | import cv2 7 | import matplotlib.pyplot as plt 8 | from torch.utils.data import Dataset, DataLoader 9 | from torchvision import transforms, utils 10 | 11 | # Ignore warnings 12 | import warnings 13 | warnings.filterwarnings("ignore") 14 | 15 | 16 | class COWCGANDataset(Dataset): 17 | def __init__(self, data_dir_gt, data_dir_lq, image_height=256, image_width=256, transform = None): 18 | self.data_dir_gt = data_dir_gt 19 | self.data_dir_lq = data_dir_lq 20 | #take all under same folder for train and test split. 21 | self.transform = transform 22 | self.image_height = image_height 23 | self.image_width = image_width 24 | #sort all images for indexing, filter out check.jpgs 25 | self.imgs_gt = list(sorted(glob.glob(self.data_dir_gt+"*.jpg"))) 26 | self.imgs_lq = list(sorted(glob.glob(self.data_dir_lq+"*.jpg"))) 27 | self.annotation = list(sorted(glob.glob(self.data_dir_lq+"*.txt"))) 28 | 29 | def __getitem__(self, idx): 30 | #get the paths 31 | img_path_gt = os.path.join(self.data_dir_gt, self.imgs_gt[idx]) 32 | img_path_lq = os.path.join(self.data_dir_lq, self.imgs_lq[idx]) 33 | annotation_path = os.path.join(self.data_dir_lq, self.annotation[idx]) 34 | img_gt = cv2.imread(img_path_gt,1) #read color image height*width*channel=3 35 | img_lq = cv2.imread(img_path_lq,1) #read color image height*width*channel=3 36 | img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2RGB) 37 | img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2RGB) 38 | #get the bounding box 39 | boxes = list() 40 | label_car_type = list() 41 | with open(annotation_path) as f: 42 | for line in f: 43 | values = (line.split()) 44 | if "\ufeff" in values[0]: 45 | values[0] = values[0][-1] 46 | obj_class = int(values[0]) 47 | #image without bounding box - in txt file, line starts with 0 and only contains only 0 48 | if obj_class == 0: 49 | boxes.append([0, 0, 1, 1]) 50 | labels = np.ones(len(boxes)) # all are cars 51 | label_car_type.append(obj_class) 52 | #create dictionary to access the values 53 | target = {} 54 | target['object'] = 0 55 | target['image_lq'] = img_lq 56 | target['image'] = img_gt 57 | target['bboxes'] = boxes 58 | target['labels'] = labels 59 | target['label_car_type'] = label_car_type 60 | target['idx'] = idx 61 | target['LQ_path'] = img_path_lq 62 | break 63 | else: 64 | #get coordinates withing height width range 65 | x = float(values[1])*self.image_width 66 | y = float(values[2])*self.image_height 67 | width = float(values[3])*self.image_width 68 | height = float(values[4])*self.image_height 69 | #creating bounding boxes that would not touch the image edges 70 | x_min = 1 if x - width/2 <= 0 else int(x - width/2) 71 | x_max = 255 if x + width/2 >= 256 else int(x + width/2) 72 | y_min = 1 if y - height/2 <= 0 else int(y - height/2) 73 | y_max = 255 if y + height/2 >= 256 else int(y + height/2) 74 | 75 | boxes.append([x_min, y_min, x_max, y_max]) 76 | label_car_type.append(obj_class) 77 | 78 | if obj_class != 0: 79 | labels = np.ones(len(boxes)) # all are cars 80 | #create dictionary to access the values 81 | target = {} 82 | target['object'] = 1 83 | target['image_lq'] = img_lq 84 | target['image'] = img_gt 85 | target['bboxes'] = boxes 86 | target['labels'] = labels 87 | target['label_car_type'] = label_car_type 88 | target['idx'] = idx 89 | target['LQ_path'] = img_path_lq 90 | 91 | if self.transform is None: 92 | #convert to tensor 93 | target = self.convert_to_tensor(**target) 94 | return target 95 | #transform 96 | else: 97 | transformed = self.transform(**target) 98 | #print(transformed['image'], transformed['bboxes'], transformed['labels'], transformed['idx']) 99 | target = self.convert_to_tensor(**transformed) 100 | return target 101 | 102 | def __len__(self): 103 | return len(self.imgs_lq) 104 | 105 | def convert_to_tensor(self, **target): 106 | #convert to tensor 107 | target['object'] = torch.tensor(target['object'], dtype=torch.int64) 108 | target['image_lq'] = torch.from_numpy(target['image_lq'].transpose((2, 0, 1))) 109 | target['image'] = torch.from_numpy(target['image'].transpose((2, 0, 1))) 110 | target['bboxes'] = torch.as_tensor(target['bboxes'], dtype=torch.int64) 111 | target['labels'] = torch.ones(len(target['bboxes']), dtype=torch.int64) 112 | target['label_car_type'] = torch.as_tensor(target['label_car_type'], dtype=torch.int64) 113 | target['image_id'] = torch.tensor([target['idx']]) 114 | 115 | return target 116 | -------------------------------------------------------------------------------- /trainer/cowc_trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torchvision.utils import make_grid 4 | from base import BaseTrainer 5 | from utils import inf_loop, MetricTracker, visualize_bbox, visualize 6 | 7 | 8 | class COWCTrainer(BaseTrainer): 9 | """ 10 | Trainer class 11 | """ 12 | def __init__(self, model, criterion, metric_ftns, optimizer, config, data_loader, 13 | valid_data_loader=None, lr_scheduler=None, len_epoch=None): 14 | super().__init__(model, criterion, metric_ftns, optimizer, config) 15 | self.config = config 16 | self.data_loader = data_loader 17 | 18 | if len_epoch is None: 19 | # epoch-based training 20 | self.len_epoch = len(self.data_loader) 21 | else: 22 | # iteration-based training 23 | self.data_loader = inf_loop(data_loader) 24 | 25 | self.len_epoch = len_epoch 26 | self.valid_data_loader = valid_data_loader 27 | self.do_validation = self.valid_data_loader is not None 28 | self.lr_scheduler = lr_scheduler 29 | self.log_step = int(np.sqrt(data_loader.batch_size)) 30 | 31 | self.train_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) 32 | self.valid_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) 33 | 34 | def _train_epoch(self, epoch): 35 | """ 36 | Training logic for an epoch 37 | 38 | for visualization use the following code (use batch size = 1): 39 | category_id_to_name = {1: 'car'} 40 | for batch_idx, dataset_dict in enumerate(self.data_loader): 41 | if dataset_dict['object'][0].item() == 0: 42 | print(dataset_dict) 43 | visualize(dataset_dict, category_id_to_name) --> see this method in util 44 | 45 | image size: torch.Size([10, 3, 256, 256]) if batch_size = 10 46 | 47 | :param epoch: Integer, current training epoch. 48 | :return: A log that contains average loss and metric in this epoch. 49 | """ 50 | 51 | self.model.train() 52 | self.train_metrics.reset() 53 | for batch_idx, dataset_dict in enumerate(self.data_loader): 54 | #print(dataset_dict['image'].size()) 55 | 56 | data, target = dataset_dict['image'].to(self.device), \ 57 | dataset_dict['object'].to(self.device) 58 | 59 | self.optimizer.zero_grad() 60 | output = self.model(data) 61 | loss = self.criterion(output, target) 62 | loss.backward() 63 | self.optimizer.step() 64 | 65 | self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) 66 | self.train_metrics.update('loss', loss.item()) 67 | for met in self.metric_ftns: 68 | self.train_metrics.update(met.__name__, met(output, target)) 69 | 70 | if batch_idx % self.log_step == 0: 71 | self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format( 72 | epoch, 73 | self._progress(batch_idx), 74 | loss.item())) 75 | self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) 76 | 77 | if batch_idx == self.len_epoch: 78 | break 79 | log = self.train_metrics.result() 80 | 81 | if self.do_validation: 82 | val_log = self._valid_epoch(epoch) 83 | log.update(**{'val_'+k : v for k, v in val_log.items()}) 84 | 85 | if self.lr_scheduler is not None: 86 | self.lr_scheduler.step() 87 | return log 88 | 89 | def _valid_epoch(self, epoch): 90 | """ 91 | Validate after training an epoch 92 | 93 | :param epoch: Integer, current training epoch. 94 | :return: A log that contains information about validation 95 | """ 96 | self.model.eval() 97 | self.valid_metrics.reset() 98 | with torch.no_grad(): 99 | for batch_idx, dataset_dict in enumerate(self.valid_data_loader): 100 | data, target = dataset_dict['image'].to(self.device), \ 101 | dataset_dict['object'].to(self.device) 102 | 103 | output = self.model(data) 104 | loss = self.criterion(output, target) 105 | 106 | self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') 107 | self.valid_metrics.update('loss', loss.item()) 108 | for met in self.metric_ftns: 109 | self.valid_metrics.update(met.__name__, met(output, target)) 110 | self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) 111 | 112 | # add histogram of model parameters to the tensorboard 113 | for name, p in self.model.named_parameters(): 114 | self.writer.add_histogram(name, p, bins='auto') 115 | return self.valid_metrics.result() 116 | 117 | def _progress(self, batch_idx): 118 | base = '[{}/{} ({:.0f}%)]' 119 | if hasattr(self.data_loader, 'n_samples'): 120 | current = batch_idx * self.data_loader.batch_size 121 | total = self.data_loader.n_samples 122 | else: 123 | current = batch_idx 124 | total = self.len_epoch 125 | return base.format(current, total, 100.0 * current / total) 126 | -------------------------------------------------------------------------------- /scripts_for_datasets/COWC_EESRGAN_FRCNN_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import torch 4 | import numpy as np 5 | import glob 6 | import cv2 7 | import matplotlib.pyplot as plt 8 | from torch.utils.data import Dataset, DataLoader 9 | from torchvision import transforms, utils 10 | 11 | # Ignore warnings 12 | import warnings 13 | warnings.filterwarnings("ignore") 14 | 15 | 16 | class COWCGANFrcnnDataset(Dataset): 17 | def __init__(self, data_dir_gt, data_dir_lq, image_height=256, image_width=256, transform = None): 18 | self.data_dir_gt = data_dir_gt 19 | self.data_dir_lq = data_dir_lq 20 | #take all under same folder for train and test split. 21 | self.transform = transform 22 | self.image_height = image_height 23 | self.image_width = image_width 24 | #sort all images for indexing, filter out check.jpgs 25 | self.imgs_gt = list(sorted(glob.glob(self.data_dir_gt+"*.jpg"))) 26 | self.imgs_lq = list(sorted(glob.glob(self.data_dir_lq+"*.jpg"))) 27 | self.annotation = list(sorted(glob.glob(self.data_dir_lq+"*.txt"))) 28 | 29 | def __getitem__(self, idx): 30 | #get the paths 31 | img_path_gt = os.path.join(self.data_dir_gt, self.imgs_gt[idx]) 32 | img_path_lq = os.path.join(self.data_dir_lq, self.imgs_lq[idx]) 33 | annotation_path = os.path.join(self.data_dir_lq, self.annotation[idx]) 34 | img_gt = cv2.imread(img_path_gt,1) #read color image height*width*channel=3 35 | img_lq = cv2.imread(img_path_lq,1) #read color image height*width*channel=3 36 | img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2RGB) 37 | img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2RGB) 38 | #get the bounding box 39 | boxes = list() 40 | label_car_type = list() 41 | with open(annotation_path) as f: 42 | for line in f: 43 | values = (line.split()) 44 | if "\ufeff" in values[0]: 45 | values[0] = values[0][-1] 46 | obj_class = int(values[0]) 47 | #image without bounding box - in txt file, line starts with 0 and only contains only 0 48 | if obj_class == 0: 49 | boxes.append([0, 0, 1, 1]) 50 | labels = np.ones(len(boxes)) # all are cars 51 | label_car_type.append(obj_class) 52 | #create dictionary to access the values 53 | target = {} 54 | target['object'] = 0 55 | target['image_lq'] = img_lq 56 | target['image'] = img_gt 57 | target['bboxes'] = boxes 58 | target['labels'] = labels 59 | target['label_car_type'] = label_car_type 60 | target['image_id'] = idx 61 | target['LQ_path'] = img_path_lq 62 | target["area"] = 0 63 | target["iscrowd"] = 0 64 | break 65 | else: 66 | #get coordinates withing height width range 67 | x = float(values[1])*self.image_width 68 | y = float(values[2])*self.image_height 69 | width = float(values[3])*self.image_width 70 | height = float(values[4])*self.image_height 71 | #creating bounding boxes that would not touch the image edges 72 | x_min = 1 if x - width/2 <= 0 else int(x - width/2) 73 | x_max = 255 if x + width/2 >= 256 else int(x + width/2) 74 | y_min = 1 if y - height/2 <= 0 else int(y - height/2) 75 | y_max = 255 if y + height/2 >= 256 else int(y + height/2) 76 | 77 | boxes.append([x_min, y_min, x_max, y_max]) 78 | label_car_type.append(obj_class) 79 | 80 | if obj_class != 0: 81 | labels = np.ones(len(boxes)) # all are cars 82 | boxes_for_calc = torch.as_tensor(boxes, dtype=torch.int64) 83 | area = (boxes_for_calc[:, 3] - boxes_for_calc[:, 1]) * (boxes_for_calc[:, 2] - boxes_for_calc[:, 0]) 84 | iscrowd = torch.zeros((len(boxes),), dtype=torch.int64) 85 | #create dictionary to access the values 86 | target = {} 87 | target['object'] = 1 88 | target['image_lq'] = img_lq 89 | target['image'] = img_gt 90 | target['bboxes'] = boxes 91 | target['labels'] = labels 92 | target['label_car_type'] = label_car_type 93 | target['image_id'] = idx 94 | target['LQ_path'] = img_path_lq 95 | target["area"] = area 96 | target["iscrowd"] = iscrowd 97 | 98 | if self.transform is None: 99 | #convert to tensor 100 | image, target = self.convert_to_tensor(**target) 101 | return image, target 102 | #transform 103 | else: 104 | transformed = self.transform(**target) 105 | #print(transformed['image'], transformed['bboxes'], transformed['labels'], transformed['idx']) 106 | image, target = self.convert_to_tensor(**transformed) 107 | return image, target 108 | 109 | def __len__(self): 110 | return len(self.imgs_lq) 111 | 112 | def convert_to_tensor(self, **target): 113 | #convert to tensor 114 | target['object'] = torch.tensor(target['object'], dtype=torch.int64) 115 | target['image_lq'] = torch.from_numpy(target['image_lq'].transpose((2, 0, 1))) 116 | target['image'] = torch.from_numpy(target['image'].transpose((2, 0, 1))) 117 | target['boxes'] = torch.tensor(target['bboxes'], dtype=torch.float32) 118 | target['labels'] = torch.ones(len(target['labels']), dtype=torch.int64) 119 | target['label_car_type'] = torch.tensor(target['label_car_type'], dtype=torch.int64) 120 | target['image_id'] = torch.tensor([target['image_id']]) 121 | target["area"] = torch.tensor(target['area']) 122 | target["iscrowd"] = torch.tensor(target['iscrowd']) 123 | 124 | image = {} 125 | image['object'] = target['object'] 126 | image['image_lq'] = target['image_lq'] 127 | image['image'] = target['image'] 128 | image['image'] = target['image'] 129 | image['LQ_path'] = target['LQ_path'] 130 | 131 | del target['object'] 132 | del target['image_lq'] 133 | del target['image'] 134 | del target['bboxes'] 135 | del target['LQ_path'] 136 | 137 | return image, target 138 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EESRGAN 2 | ## Model Architecture 3 | 4 | ## Enhancement and Detection 5 | |Low Resolution
Image & Detection|Super Resolved
Image & Detection|High Resolution Ground Truth
Image & Bounding Box| 6 | | --- | --- | --- | 7 | |||| 8 | |||| 9 | |||| 10 | |||| 11 | ## Dependencies and Installation 12 | - Python 3 (Recommend to use Anaconda) 13 | - PyTorch >= 1.0 14 | - NVIDIA GPU + CUDA 15 | - Python packages: `pip install -r path/to/requirement.txt` 16 | ## Training 17 | `python train.py -c config_GAN.json` 18 | ## Testing 19 | `python test.py -c config_GAN.json` 20 | ## Dataset 21 | Download dataset from [here.](https://gdo152.llnl.gov/cowc/download/cowc-m/datasets/) 22 | [Here](https://github.com/LLNL/cowc/tree/master/COWC-M) is a GitHub repo to create custom image patches. 23 | Download pre-made dataset from [here](https://gdo152.llnl.gov/cowc/download/cowc-m/datasets/DetectionPatches_256x256.tgz) and [this](https://github.com/Jakaria08/EESRGAN/blob/1f93130d8e99166e7bc4d1640329450feec9ff9c/scripts_for_datasets/scripts_GAN_HR-LR.py#L24) script can be used with pre-made dataset to create high/low-resolution and bicubic images. Make sure to copy annotation files (.txt) in the HR, LR and Bic folder. 24 | ## Edit the JSON File 25 | The directory of the following JSON file is needed to be changed according to the user directory. For details see [config_GAN.json](https://github.com/Jakaria08/EESRGAN/blob/master/config_GAN.json) and pretrained weights are uploaded in [google drive](https://drive.google.com/drive/folders/15xN_TKKTUpQ5EVdZWJ2aZUa4Y-u-Mt0f?usp=sharing) 26 | ```yaml 27 | { 28 | "data_loader": { 29 | "type": "COWCGANFrcnnDataLoader", 30 | "args":{ 31 | "data_dir_GT": "/Directory for High-Resolution Ground Truth images/", 32 | "data_dir_LQ": "/Directory for 4x downsampled Low-Resolution images from the above High-Resolution images/" 33 | } 34 | }, 35 | 36 | "path": { 37 | "models": "saved/save_your_model_in_this_directory/", 38 | "pretrain_model_G": "Pretrained_model_path_for_train_test/170000_G.pth", 39 | "pretrain_model_D": "Pretrained_model_path_for_train_test/170000_G.pth", 40 | "pretrain_model_FRCNN": "Pretrained_model_path_for_train_test/170000_G.pth", 41 | "data_dir_Valid": "/Low_resoluton_test_validation_image_directory/" 42 | "Test_Result_SR": "Directory_to_store_test_results/" 43 | } 44 | } 45 | 46 | ``` 47 | ## Paper 48 | Find the published version on [Remote Sensing](https://www.mdpi.com/2072-4292/12/9/1432). 49 | Find the preprints of the related paper on [preprints.org](https://www.preprints.org/manuscript/202003.0313/v1), [arxiv.org](https://arxiv.org/abs/2003.09085) and [researchgate.net](https://www.researchgate.net/publication/340095015_Small-Object_Detection_in_Remote_Sensing_Images_with_End-to-End_Edge-Enhanced_GAN_and_Object_Detector_Network). 50 | ### Abstract 51 | The detection performance of small objects in remote sensing images has not been satisfactory compared to large objects, especially in low-resolution and noisy images. A generative adversarial network (GAN)-based model called enhanced super-resolution GAN (ESRGAN) showed remarkable image enhancement performance, but reconstructed images usually miss high-frequency edge information. Therefore, object detection performance showed degradation for small objects on recovered noisy and low-resolution remote sensing images. Inspired by the success of edge enhanced GAN (EEGAN) and ESRGAN, we applied a new edge-enhanced super-resolution GAN (EESRGAN) to improve the quality of remote sensing images and used different detector networks in an end-to-end manner where detector loss was backpropagated into the EESRGAN to improve the detection performance. We proposed an architecture with three components: ESRGAN, EEN, and Detection network. We used residual-in-residual dense blocks (RRDB) for both the ESRGAN and EEN, and for the detector network, we used a faster region-based convolutional network (FRCNN) (two-stage detector) and a single-shot multibox detector (SSD) (one stage detector). Extensive experiments on a public (car overhead with context) dataset and another self-assembled (oil and gas storage tank) satellite dataset showed superior performance of our method compared to the standalone state-of-the-art object detectors. 52 | ### Keywords 53 | object detection; faster region-based convolutional neural network (FRCNN); single-shot multibox detector (SSD); super-resolution; remote sensing imagery; edge enhancement; satellites 54 | ## Related Repository 55 | Some code segments are based on [ESRGAN](https://github.com/xinntao/BasicSR) 56 | ## Citation 57 | ### BibTex 58 | `@article{rabbi2020small,`\ 59 | `title={Small-Object Detection in Remote Sensing Images with End-to-End Edge-Enhanced GAN and Object Detector Network},`\ 60 | `author={Rabbi, Jakaria and Ray, Nilanjan and Schubert, Matthias and Chowdhury, Subir and Chao, Dennis},`\ 61 | `journal={Remote Sensing},`\ 62 | `volume={12},`\ 63 | `number={9},`\ 64 | `pages={1432},`\ 65 | `year={2020}`\ 66 | `publisher={Multidisciplinary Digital Publishing Institute}`\ 67 | `}` 68 | ### Chicago 69 | `Rabbi, Jakaria; Ray, Nilanjan; Schubert, Matthias; Chowdhury, Subir; Chao, Dennis. 2020. "Small-Object Detection in Remote Sensing Images with End-to-End Edge-Enhanced GAN and Object Detector Network." Remote Sens. 12, no. 9: 1432.` 70 | ## To Do 71 | - Refactor and clean the code. 72 | - Add more command line option for training and testing to run different configuration. 73 | - Fix bug and write important tests. 74 | -------------------------------------------------------------------------------- /parse_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from pathlib import Path 4 | from functools import reduce, partial 5 | from operator import getitem 6 | from datetime import datetime 7 | from logger import setup_logging 8 | from utils import read_json, write_json 9 | 10 | 11 | class ConfigParser: 12 | def __init__(self, config, resume=None, modification=None, run_id=None): 13 | """ 14 | class to parse configuration json file. Handles hyperparameters for training, initializations of modules, checkpoint saving 15 | and logging module. 16 | :param config: Dict containing configurations, hyperparameters for training. contents of `config.json` file for example. 17 | :param resume: String, path to the checkpoint being loaded. 18 | :param modification: Dict keychain:value, specifying position values to be replaced from config dict. 19 | :param run_id: Unique Identifier for training processes. Used to save checkpoints and training log. Timestamp is being used as default 20 | """ 21 | # load config file and apply modification 22 | self._config = _update_config(config, modification) 23 | self.resume = resume 24 | ''' 25 | # set save_dir where trained model and log will be saved. 26 | save_dir = Path(self.config['train']['save_dir']) 27 | 28 | exper_name = self.config['name'] 29 | if run_id is None: # use timestamp as default run-id 30 | run_id = datetime.now().strftime(r'%m%d_%H%M%S') 31 | self._save_dir = save_dir / 'models' / exper_name / run_id 32 | self._log_dir = save_dir / 'log' / exper_name / run_id 33 | 34 | # make directory for saving checkpoints and log. 35 | exist_ok = run_id == '' 36 | self.save_dir.mkdir(parents=True, exist_ok=exist_ok) 37 | self.log_dir.mkdir(parents=True, exist_ok=exist_ok) 38 | 39 | # save updated config file to the checkpoint dir 40 | write_json(self.config, self.save_dir / 'config.json') 41 | ''' 42 | # configure logging module 43 | #setup_logging(self.log_dir) 44 | self.log_levels = { 45 | 0: logging.WARNING, 46 | 1: logging.INFO, 47 | 2: logging.DEBUG 48 | } 49 | 50 | @classmethod 51 | def from_args(cls, args, options=''): 52 | """ 53 | Initialize this class from some cli arguments. Used in train, test. 54 | """ 55 | for opt in options: 56 | args.add_argument(*opt.flags, default=None, type=opt.type) 57 | if not isinstance(args, tuple): 58 | args = args.parse_args() 59 | 60 | if args.device is not None: 61 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device 62 | if args.resume is not None: 63 | resume = Path(args.resume) 64 | cfg_fname = resume.parent / 'config.json' 65 | else: 66 | msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example." 67 | assert args.config is not None, msg_no_cfg 68 | resume = None 69 | cfg_fname = Path(args.config) 70 | 71 | config = read_json(cfg_fname) 72 | if args.config and resume: 73 | # update new config for fine-tuning 74 | config.update(read_json(args.config)) 75 | 76 | # parse custom cli options into dictionary 77 | modification = {opt.target : getattr(args, _get_opt_name(opt.flags)) for opt in options} 78 | return cls(config, resume, modification) 79 | 80 | def init_obj(self, name, module, *args, **kwargs): 81 | """ 82 | Finds a function handle with the name given as 'type' in config, and returns the 83 | instance initialized with corresponding arguments given. 84 | 85 | `object = config.init_obj('name', module, a, b=1)` 86 | is equivalent to 87 | `object = module.name(a, b=1)` 88 | """ 89 | module_name = self[name]['type'] 90 | module_args = dict(self[name]['args']) 91 | assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed' 92 | module_args.update(kwargs) 93 | return getattr(module, module_name)(*args, **module_args) 94 | 95 | def init_ftn(self, name, module, *args, **kwargs): 96 | """ 97 | Finds a function handle with the name given as 'type' in config, and returns the 98 | function with given arguments fixed with functools.partial. 99 | 100 | `function = config.init_ftn('name', module, a, b=1)` 101 | is equivalent to 102 | `function = lambda *args, **kwargs: module.name(a, *args, b=1, **kwargs)`. 103 | """ 104 | module_name = self[name]['type'] 105 | module_args = dict(self[name]['args']) 106 | assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed' 107 | module_args.update(kwargs) 108 | return partial(getattr(module, module_name), *args, **module_args) 109 | 110 | def __getitem__(self, name): 111 | """Access items like ordinary dict.""" 112 | return self.config[name] 113 | 114 | def get_logger(self, name, verbosity=2): 115 | msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity, self.log_levels.keys()) 116 | assert verbosity in self.log_levels, msg_verbosity 117 | logger = logging.getLogger(name) 118 | logger.setLevel(self.log_levels[verbosity]) 119 | return logger 120 | 121 | # setting read-only attributes 122 | @property 123 | def config(self): 124 | return self._config 125 | 126 | @property 127 | def save_dir(self): 128 | return self._save_dir 129 | 130 | @property 131 | def log_dir(self): 132 | return self._log_dir 133 | 134 | # helper functions to update config dict with custom cli options 135 | def _update_config(config, modification): 136 | if modification is None: 137 | return config 138 | 139 | for k, v in modification.items(): 140 | if v is not None: 141 | _set_by_path(config, k, v) 142 | return config 143 | 144 | def _get_opt_name(flags): 145 | for flg in flags: 146 | if flg.startswith('--'): 147 | return flg.replace('--', '') 148 | return flags[0].replace('--', '') 149 | 150 | def _set_by_path(tree, keys, value): 151 | """Set a value in a nested object in tree by sequence of keys.""" 152 | keys = keys.split(';') 153 | _get_by_path(tree, keys[:-1])[keys[-1]] = value 154 | 155 | def _get_by_path(tree, keys): 156 | """Access a nested object in tree by sequence of keys.""" 157 | return reduce(getitem, keys, tree) 158 | -------------------------------------------------------------------------------- /data_loader/data_loaders.py: -------------------------------------------------------------------------------- 1 | from torchvision import datasets, transforms, utils 2 | from base import BaseDataLoader 3 | from scripts_for_datasets import COWCDataset, COWCGANDataset, COWCFRCNNDataset, COWCGANFrcnnDataset 4 | 5 | from albumentations import ( 6 | HorizontalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90, 7 | Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue, 8 | IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, 9 | IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose, 10 | BboxParams, RandomCrop, Normalize, Resize, VerticalFlip 11 | ) 12 | 13 | from albumentations.pytorch import ToTensor 14 | from utils import collate_fn 15 | #from detection.utils import collate_fn 16 | 17 | 18 | class MnistDataLoader(BaseDataLoader): 19 | """ 20 | MNIST data loading demo using BaseDataLoader 21 | """ 22 | def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_workers=1, training=True): 23 | trsfm = transforms.Compose([ 24 | transforms.ToTensor(), 25 | transforms.Normalize((0.1307,), (0.3081,)) 26 | ]) 27 | self.data_dir = data_dir 28 | self.dataset = datasets.MNIST(self.data_dir, train=training, download=True, transform=trsfm) 29 | super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers) 30 | 31 | class COWCDataLoader(BaseDataLoader): 32 | """ 33 | COWC data loading using BaseDataLoader 34 | """ 35 | def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_workers=1, training=True): 36 | #data transformation 37 | #According to this link: https://discuss.pytorch.org/t/normalization-of-input-image/34814/8 38 | #satellite image 0.5 is good otherwise calculate mean and std for the whole dataset. 39 | #calculted mean and std using method from util 40 | data_transforms = Compose([ 41 | Resize(256, 256), 42 | HorizontalFlip(), 43 | OneOf([ 44 | IAAAdditiveGaussianNoise(), 45 | GaussNoise(), 46 | ], p=0.2), 47 | OneOf([ 48 | CLAHE(clip_limit=2), 49 | IAASharpen(), 50 | IAAEmboss(), 51 | RandomBrightnessContrast(), 52 | ], p=0.3), 53 | HueSaturationValue(p=0.3), 54 | Normalize( #mean std for potsdam dataset from COWC [Calculate also for spot6] 55 | mean=[0.3442, 0.3708, 0.3476], 56 | std=[0.1232, 0.1230, 0.1284] 57 | ) 58 | ], 59 | bbox_params=BboxParams( 60 | format='pascal_voc', 61 | min_area=0, 62 | min_visibility=0, 63 | label_fields=['labels']) 64 | ) 65 | 66 | 67 | self.data_dir = data_dir 68 | self.dataset = COWCDataset(self.data_dir, transform=data_transforms) 69 | super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=collate_fn) 70 | 71 | class COWCGANDataLoader(BaseDataLoader): 72 | """ 73 | COWC data loading using BaseDataLoader 74 | """ 75 | def __init__(self, data_dir_GT, data_dir_LQ, batch_size, shuffle=True, validation_split=0.0, num_workers=1, training=True): 76 | #data transformation 77 | #According to this link: https://discuss.pytorch.org/t/normalization-of-input-image/34814/8 78 | #satellite image 0.5 is good otherwise calculate mean and std for the whole dataset. 79 | #calculted mean and std using method from util 80 | ''' 81 | Data transform for GAN training 82 | ''' 83 | data_transforms_train = Compose([ 84 | HorizontalFlip(), 85 | Normalize( #mean std for potsdam dataset from COWC [Calculate also for spot6] 86 | mean=[0.3442, 0.3708, 0.3476], 87 | std=[0.1232, 0.1230, 0.1284] 88 | ) 89 | ], 90 | additional_targets={ 91 | 'image_lq':'image' 92 | }, 93 | bbox_params=BboxParams( 94 | format='pascal_voc', 95 | min_area=0, 96 | min_visibility=0, 97 | label_fields=['labels']) 98 | ) 99 | 100 | data_transforms_test = Compose([ 101 | Normalize( #mean std for potsdam dataset from COWC [Calculate also for spot6] 102 | mean=[0.3442, 0.3708, 0.3476], 103 | std=[0.1232, 0.1230, 0.1284] 104 | )], 105 | additional_targets={ 106 | 'image_lq':'image' 107 | }) 108 | 109 | self.data_dir_gt = data_dir_GT 110 | self.data_dir_lq = data_dir_LQ 111 | 112 | if training == True: 113 | self.dataset = COWCGANDataset(self.data_dir_gt, self.data_dir_lq, transform=data_transforms_train) 114 | else: 115 | self.dataset = COWCGANDataset(self.data_dir_gt, self.data_dir_lq, transform=data_transforms_test) 116 | self.length = len(self.dataset) 117 | super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=collate_fn) 118 | 119 | class COWCGANFrcnnDataLoader(BaseDataLoader): 120 | """ 121 | COWC data loading using BaseDataLoader 122 | """ 123 | def __init__(self, data_dir_GT, data_dir_LQ, batch_size, shuffle=True, validation_split=0.0, num_workers=1, training=True): 124 | #data transformation 125 | #According to this link: https://discuss.pytorch.org/t/normalization-of-input-image/34814/8 126 | #satellite image 0.5 is good otherwise calculate mean and std for the whole dataset. 127 | #calculted mean and std using method from util 128 | ''' 129 | Data transform for GAN training 130 | ''' 131 | data_transforms_train = Compose([ 132 | HorizontalFlip(), 133 | Normalize( #mean std for potsdam dataset from COWC [Calculate also for spot6] 134 | mean=[0.3442, 0.3708, 0.3476], 135 | std=[0.1232, 0.1230, 0.1284] 136 | ) 137 | ], 138 | additional_targets={ 139 | 'image_lq':'image' 140 | }, 141 | bbox_params=BboxParams( 142 | format='pascal_voc', 143 | min_area=0, 144 | min_visibility=0, 145 | label_fields=['labels']) 146 | ) 147 | 148 | data_transforms_test = Compose([ 149 | Normalize( #mean std for potsdam dataset from COWC [Calculate also for spot6] 150 | mean=[0.3442, 0.3708, 0.3476], 151 | std=[0.1232, 0.1230, 0.1284] 152 | )], 153 | additional_targets={ 154 | 'image_lq':'image' 155 | }) 156 | 157 | self.data_dir_gt = data_dir_GT 158 | self.data_dir_lq = data_dir_LQ 159 | 160 | if training == True: 161 | self.dataset = COWCGANFrcnnDataset(self.data_dir_gt, self.data_dir_lq, transform=data_transforms_train) 162 | else: 163 | self.dataset = COWCGANFrcnnDataset(self.data_dir_gt, self.data_dir_lq, transform=data_transforms_test) 164 | self.length = len(self.dataset) 165 | super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=collate_fn) 166 | -------------------------------------------------------------------------------- /base/base_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from abc import abstractmethod 3 | from numpy import inf 4 | from logger import TensorboardWriter 5 | 6 | 7 | class BaseTrainer: 8 | """ 9 | Base class for all trainers 10 | """ 11 | def __init__(self, model, criterion, metric_ftns, optimizer, config): 12 | self.config = config 13 | self.logger = config.get_logger('trainer', config['trainer']['verbosity']) 14 | 15 | # setup GPU device if available, move model into configured device 16 | self.device, device_ids = self._prepare_device(config['n_gpu']) 17 | self.model = model.to(self.device) 18 | if len(device_ids) > 1: 19 | self.model = torch.nn.DataParallel(model, device_ids=device_ids) 20 | 21 | self.criterion = criterion 22 | self.metric_ftns = metric_ftns 23 | self.optimizer = optimizer 24 | 25 | cfg_trainer = config['trainer'] 26 | self.epochs = cfg_trainer['epochs'] 27 | self.save_period = cfg_trainer['save_period'] 28 | self.monitor = cfg_trainer.get('monitor', 'off') 29 | 30 | # configuration to monitor model performance and save best 31 | if self.monitor == 'off': 32 | self.mnt_mode = 'off' 33 | self.mnt_best = 0 34 | else: 35 | self.mnt_mode, self.mnt_metric = self.monitor.split() 36 | assert self.mnt_mode in ['min', 'max'] 37 | 38 | self.mnt_best = inf if self.mnt_mode == 'min' else -inf 39 | self.early_stop = cfg_trainer.get('early_stop', inf) 40 | 41 | self.start_epoch = 1 42 | 43 | self.checkpoint_dir = config.save_dir 44 | 45 | # setup visualization writer instance 46 | self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard']) 47 | 48 | if config.resume is not None: 49 | self._resume_checkpoint(config.resume) 50 | 51 | @abstractmethod 52 | def _train_epoch(self, epoch): 53 | """ 54 | Training logic for an epoch 55 | 56 | :param epoch: Current epoch number 57 | """ 58 | raise NotImplementedError 59 | 60 | def train(self): 61 | """ 62 | Full training logic 63 | """ 64 | not_improved_count = 0 65 | for epoch in range(self.start_epoch, self.epochs + 1): 66 | result = self._train_epoch(epoch) 67 | 68 | # save logged informations into log dict 69 | log = {'epoch': epoch} 70 | log.update(result) 71 | 72 | # print logged informations to the screen 73 | for key, value in log.items(): 74 | self.logger.info(' {:15s}: {}'.format(str(key), value)) 75 | 76 | # evaluate model performance according to configured metric, save best checkpoint as model_best 77 | best = False 78 | if self.mnt_mode != 'off': 79 | try: 80 | # check whether model performance improved or not, according to specified metric(mnt_metric) 81 | improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \ 82 | (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best) 83 | except KeyError: 84 | self.logger.warning("Warning: Metric '{}' is not found. " 85 | "Model performance monitoring is disabled.".format(self.mnt_metric)) 86 | self.mnt_mode = 'off' 87 | improved = False 88 | 89 | if improved: 90 | self.mnt_best = log[self.mnt_metric] 91 | not_improved_count = 0 92 | best = True 93 | else: 94 | not_improved_count += 1 95 | 96 | if not_improved_count > self.early_stop: 97 | self.logger.info("Validation performance didn\'t improve for {} epochs. " 98 | "Training stops.".format(self.early_stop)) 99 | break 100 | 101 | if epoch % self.save_period == 0: 102 | self._save_checkpoint(epoch, save_best=best) 103 | 104 | def _prepare_device(self, n_gpu_use): 105 | """ 106 | setup GPU device if available, move model into configured device 107 | """ 108 | n_gpu = torch.cuda.device_count() 109 | if n_gpu_use > 0 and n_gpu == 0: 110 | self.logger.warning("Warning: There\'s no GPU available on this machine," 111 | "training will be performed on CPU.") 112 | n_gpu_use = 0 113 | if n_gpu_use > n_gpu: 114 | self.logger.warning("Warning: The number of GPU\'s configured to use is {}, but only {} are available " 115 | "on this machine.".format(n_gpu_use, n_gpu)) 116 | n_gpu_use = n_gpu 117 | device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu') 118 | list_ids = list(range(n_gpu_use)) 119 | return device, list_ids 120 | 121 | def _save_checkpoint(self, epoch, save_best=False): 122 | """ 123 | Saving checkpoints 124 | 125 | :param epoch: current epoch number 126 | :param log: logging information of the epoch 127 | :param save_best: if True, rename the saved checkpoint to 'model_best.pth' 128 | """ 129 | arch = type(self.model).__name__ 130 | state = { 131 | 'arch': arch, 132 | 'epoch': epoch, 133 | 'state_dict': self.model.state_dict(), 134 | 'optimizer': self.optimizer.state_dict(), 135 | 'monitor_best': self.mnt_best, 136 | 'config': self.config 137 | } 138 | filename = str(self.checkpoint_dir / 'checkpoint-epoch{}.pth'.format(epoch)) 139 | torch.save(state, filename) 140 | self.logger.info("Saving checkpoint: {} ...".format(filename)) 141 | if save_best: 142 | best_path = str(self.checkpoint_dir / 'model_best.pth') 143 | torch.save(state, best_path) 144 | self.logger.info("Saving current best: model_best.pth ...") 145 | 146 | def _resume_checkpoint(self, resume_path): 147 | """ 148 | Resume from saved checkpoints 149 | 150 | :param resume_path: Checkpoint path to be resumed 151 | """ 152 | resume_path = str(resume_path) 153 | self.logger.info("Loading checkpoint: {} ...".format(resume_path)) 154 | checkpoint = torch.load(resume_path) 155 | self.start_epoch = checkpoint['epoch'] + 1 156 | self.mnt_best = checkpoint['monitor_best'] 157 | 158 | # load architecture params from checkpoint. 159 | if checkpoint['config']['arch'] != self.config['arch']: 160 | self.logger.warning("Warning: Architecture configuration given in config file is different from that of " 161 | "checkpoint. This may yield an exception while state_dict is being loaded.") 162 | self.model.load_state_dict(checkpoint['state_dict']) 163 | 164 | # load optimizer state from checkpoint only when optimizer type is not changed. 165 | if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']: 166 | self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. " 167 | "Optimizer parameters not being resumed.") 168 | else: 169 | self.optimizer.load_state_dict(checkpoint['optimizer']) 170 | 171 | self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch)) 172 | -------------------------------------------------------------------------------- /detection/train.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import time 4 | 5 | import torch 6 | import torch.utils.data 7 | from torch import nn 8 | import torchvision 9 | import torchvision.models.detection 10 | import torchvision.models.detection.mask_rcnn 11 | 12 | from torchvision import transforms 13 | 14 | from coco_utils import get_coco, get_coco_kp 15 | 16 | from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups 17 | from engine import train_one_epoch, evaluate 18 | 19 | import utils 20 | import transforms as T 21 | 22 | 23 | def get_dataset(name, image_set, transform): 24 | paths = { 25 | "coco": ('/datasets01/COCO/022719/', get_coco, 91), 26 | "coco_kp": ('/datasets01/COCO/022719/', get_coco_kp, 2) 27 | } 28 | p, ds_fn, num_classes = paths[name] 29 | 30 | ds = ds_fn(p, image_set=image_set, transforms=transform) 31 | return ds, num_classes 32 | 33 | 34 | def get_transform(train): 35 | transforms = [] 36 | transforms.append(T.ToTensor()) 37 | if train: 38 | transforms.append(T.RandomHorizontalFlip(0.5)) 39 | return T.Compose(transforms) 40 | 41 | 42 | def main(args): 43 | utils.init_distributed_mode(args) 44 | print(args) 45 | 46 | device = torch.device(args.device) 47 | 48 | # Data loading code 49 | print("Loading data") 50 | 51 | dataset, num_classes = get_dataset(args.dataset, "train", get_transform(train=True)) 52 | dataset_test, _ = get_dataset(args.dataset, "val", get_transform(train=False)) 53 | 54 | print("Creating data loaders") 55 | if args.distributed: 56 | train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) 57 | test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test) 58 | else: 59 | train_sampler = torch.utils.data.RandomSampler(dataset) 60 | test_sampler = torch.utils.data.SequentialSampler(dataset_test) 61 | 62 | if args.aspect_ratio_group_factor >= 0: 63 | group_ids = create_aspect_ratio_groups(dataset, k=args.aspect_ratio_group_factor) 64 | train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size) 65 | else: 66 | train_batch_sampler = torch.utils.data.BatchSampler( 67 | train_sampler, args.batch_size, drop_last=True) 68 | 69 | data_loader = torch.utils.data.DataLoader( 70 | dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, 71 | collate_fn=utils.collate_fn) 72 | 73 | data_loader_test = torch.utils.data.DataLoader( 74 | dataset_test, batch_size=1, 75 | sampler=test_sampler, num_workers=args.workers, 76 | collate_fn=utils.collate_fn) 77 | 78 | print("Creating model") 79 | model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes, 80 | pretrained=args.pretrained) 81 | model.to(device) 82 | 83 | model_without_ddp = model 84 | if args.distributed: 85 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 86 | model_without_ddp = model.module 87 | 88 | params = [p for p in model.parameters() if p.requires_grad] 89 | optimizer = torch.optim.SGD( 90 | params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 91 | 92 | # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) 93 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma) 94 | 95 | if args.resume: 96 | checkpoint = torch.load(args.resume, map_location='cpu') 97 | model_without_ddp.load_state_dict(checkpoint['model']) 98 | optimizer.load_state_dict(checkpoint['optimizer']) 99 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 100 | 101 | if args.test_only: 102 | evaluate(model, data_loader_test, device=device) 103 | return 104 | 105 | print("Start training") 106 | start_time = time.time() 107 | for epoch in range(args.epochs): 108 | if args.distributed: 109 | train_sampler.set_epoch(epoch) 110 | train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq) 111 | lr_scheduler.step() 112 | if args.output_dir: 113 | utils.save_on_master({ 114 | 'model': model_without_ddp.state_dict(), 115 | 'optimizer': optimizer.state_dict(), 116 | 'lr_scheduler': lr_scheduler.state_dict(), 117 | 'args': args}, 118 | os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) 119 | 120 | # evaluate after every epoch 121 | evaluate(model, data_loader_test, device=device) 122 | 123 | total_time = time.time() - start_time 124 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 125 | print('Training time {}'.format(total_time_str)) 126 | 127 | 128 | if __name__ == "__main__": 129 | import argparse 130 | parser = argparse.ArgumentParser(description='PyTorch Detection Training') 131 | 132 | parser.add_argument('--data-path', default='/datasets01/COCO/022719/', help='dataset') 133 | parser.add_argument('--dataset', default='coco', help='dataset') 134 | parser.add_argument('--model', default='maskrcnn_resnet50_fpn', help='model') 135 | parser.add_argument('--device', default='cuda', help='device') 136 | parser.add_argument('-b', '--batch-size', default=2, type=int) 137 | parser.add_argument('--epochs', default=13, type=int, metavar='N', 138 | help='number of total epochs to run') 139 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 140 | help='number of data loading workers (default: 16)') 141 | parser.add_argument('--lr', default=0.02, type=float, help='initial learning rate') 142 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 143 | help='momentum') 144 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 145 | metavar='W', help='weight decay (default: 1e-4)', 146 | dest='weight_decay') 147 | parser.add_argument('--lr-step-size', default=8, type=int, help='decrease lr every step-size epochs') 148 | parser.add_argument('--lr-steps', default=[8, 11], nargs='+', type=int, help='decrease lr every step-size epochs') 149 | parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') 150 | parser.add_argument('--print-freq', default=20, type=int, help='print frequency') 151 | parser.add_argument('--output-dir', default='.', help='path where to save') 152 | parser.add_argument('--resume', default='', help='resume from checkpoint') 153 | parser.add_argument('--aspect-ratio-group-factor', default=0, type=int) 154 | parser.add_argument( 155 | "--test-only", 156 | dest="test_only", 157 | help="Only test the model", 158 | action="store_true", 159 | ) 160 | parser.add_argument( 161 | "--pretrained", 162 | dest="pretrained", 163 | help="Use pre-trained models from the modelzoo", 164 | action="store_true", 165 | ) 166 | 167 | # distributed training parameters 168 | parser.add_argument('--world-size', default=1, type=int, 169 | help='number of distributed processes') 170 | parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training') 171 | 172 | args = parser.parse_args() 173 | 174 | if args.output_dir: 175 | utils.mkdir(args.output_dir) 176 | 177 | main(args) 178 | -------------------------------------------------------------------------------- /detection/group_by_aspect_ratio.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | from collections import defaultdict 3 | import copy 4 | import numpy as np 5 | 6 | import torch 7 | import torch.utils.data 8 | from torch.utils.data.sampler import BatchSampler, Sampler 9 | from torch.utils.model_zoo import tqdm 10 | import torchvision 11 | 12 | from PIL import Image 13 | 14 | 15 | class GroupedBatchSampler(BatchSampler): 16 | """ 17 | Wraps another sampler to yield a mini-batch of indices. 18 | It enforces that the batch only contain elements from the same group. 19 | It also tries to provide mini-batches which follows an ordering which is 20 | as close as possible to the ordering from the original sampler. 21 | Arguments: 22 | sampler (Sampler): Base sampler. 23 | group_ids (list[int]): If the sampler produces indices in range [0, N), 24 | `group_ids` must be a list of `N` ints which contains the group id of each sample. 25 | The group ids must be a continuous set of integers starting from 26 | 0, i.e. they must be in the range [0, num_groups). 27 | batch_size (int): Size of mini-batch. 28 | """ 29 | def __init__(self, sampler, group_ids, batch_size): 30 | if not isinstance(sampler, Sampler): 31 | raise ValueError( 32 | "sampler should be an instance of " 33 | "torch.utils.data.Sampler, but got sampler={}".format(sampler) 34 | ) 35 | self.sampler = sampler 36 | self.group_ids = group_ids 37 | self.batch_size = batch_size 38 | 39 | def __iter__(self): 40 | buffer_per_group = defaultdict(list) 41 | samples_per_group = defaultdict(list) 42 | 43 | num_batches = 0 44 | for idx in self.sampler: 45 | group_id = self.group_ids[idx] 46 | buffer_per_group[group_id].append(idx) 47 | samples_per_group[group_id].append(idx) 48 | if len(buffer_per_group[group_id]) == self.batch_size: 49 | yield buffer_per_group[group_id] 50 | num_batches += 1 51 | del buffer_per_group[group_id] 52 | assert len(buffer_per_group[group_id]) < self.batch_size 53 | 54 | # now we have run out of elements that satisfy 55 | # the group criteria, let's return the remaining 56 | # elements so that the size of the sampler is 57 | # deterministic 58 | expected_num_batches = len(self) 59 | num_remaining = expected_num_batches - num_batches 60 | if num_remaining > 0: 61 | # for the remaining batches, take first the buffers with largest number 62 | # of elements 63 | for group_id, _ in sorted(buffer_per_group.items(), 64 | key=lambda x: len(x[1]), reverse=True): 65 | remaining = self.batch_size - len(buffer_per_group[group_id]) 66 | buffer_per_group[group_id].extend( 67 | samples_per_group[group_id][:remaining]) 68 | assert len(buffer_per_group[group_id]) == self.batch_size 69 | yield buffer_per_group[group_id] 70 | num_remaining -= 1 71 | if num_remaining == 0: 72 | break 73 | assert num_remaining == 0 74 | 75 | def __len__(self): 76 | return len(self.sampler) // self.batch_size 77 | 78 | 79 | def _compute_aspect_ratios_slow(dataset, indices=None): 80 | print("Your dataset doesn't support the fast path for " 81 | "computing the aspect ratios, so will iterate over " 82 | "the full dataset and load every image instead. " 83 | "This might take some time...") 84 | if indices is None: 85 | indices = range(len(dataset)) 86 | 87 | class SubsetSampler(Sampler): 88 | def __init__(self, indices): 89 | self.indices = indices 90 | 91 | def __iter__(self): 92 | return iter(self.indices) 93 | 94 | def __len__(self): 95 | return len(self.indices) 96 | 97 | sampler = SubsetSampler(indices) 98 | data_loader = torch.utils.data.DataLoader( 99 | dataset, batch_size=1, sampler=sampler, 100 | num_workers=14, # you might want to increase it for faster processing 101 | collate_fn=lambda x: x[0]) 102 | aspect_ratios = [] 103 | with tqdm(total=len(dataset)) as pbar: 104 | for i, (img, _) in enumerate(data_loader): 105 | pbar.update(1) 106 | height, width = img.shape[-2:] 107 | aspect_ratio = float(height) / float(width) 108 | aspect_ratios.append(aspect_ratio) 109 | return aspect_ratios 110 | 111 | 112 | def _compute_aspect_ratios_custom_dataset(dataset, indices=None): 113 | if indices is None: 114 | indices = range(len(dataset)) 115 | aspect_ratios = [] 116 | for i in indices: 117 | height, width = dataset.get_height_and_width(i) 118 | aspect_ratio = float(height) / float(width) 119 | aspect_ratios.append(aspect_ratio) 120 | return aspect_ratios 121 | 122 | 123 | def _compute_aspect_ratios_coco_dataset(dataset, indices=None): 124 | if indices is None: 125 | indices = range(len(dataset)) 126 | aspect_ratios = [] 127 | for i in indices: 128 | img_info = dataset.coco.imgs[dataset.ids[i]] 129 | aspect_ratio = float(img_info["height"]) / float(img_info["width"]) 130 | aspect_ratios.append(aspect_ratio) 131 | return aspect_ratios 132 | 133 | 134 | def _compute_aspect_ratios_voc_dataset(dataset, indices=None): 135 | if indices is None: 136 | indices = range(len(dataset)) 137 | aspect_ratios = [] 138 | for i in indices: 139 | # this doesn't load the data into memory, because PIL loads it lazily 140 | width, height = Image.open(dataset.images[i]).size 141 | aspect_ratio = float(height) / float(width) 142 | aspect_ratios.append(aspect_ratio) 143 | return aspect_ratios 144 | 145 | 146 | def _compute_aspect_ratios_subset_dataset(dataset, indices=None): 147 | if indices is None: 148 | indices = range(len(dataset)) 149 | 150 | ds_indices = [dataset.indices[i] for i in indices] 151 | return compute_aspect_ratios(dataset.dataset, ds_indices) 152 | 153 | 154 | def compute_aspect_ratios(dataset, indices=None): 155 | if hasattr(dataset, "get_height_and_width"): 156 | return _compute_aspect_ratios_custom_dataset(dataset, indices) 157 | 158 | if isinstance(dataset, torchvision.datasets.CocoDetection): 159 | return _compute_aspect_ratios_coco_dataset(dataset, indices) 160 | 161 | if isinstance(dataset, torchvision.datasets.VOCDetection): 162 | return _compute_aspect_ratios_voc_dataset(dataset, indices) 163 | 164 | if isinstance(dataset, torch.utils.data.Subset): 165 | return _compute_aspect_ratios_subset_dataset(dataset, indices) 166 | 167 | # slow path 168 | return _compute_aspect_ratios_slow(dataset, indices) 169 | 170 | 171 | def _quantize(x, bins): 172 | bins = copy.deepcopy(bins) 173 | bins = sorted(bins) 174 | quantized = list(map(lambda y: bisect.bisect_right(bins, y), x)) 175 | return quantized 176 | 177 | 178 | def create_aspect_ratio_groups(dataset, k=0): 179 | aspect_ratios = compute_aspect_ratios(dataset) 180 | bins = (2 ** np.linspace(-1, 1, 2 * k + 1)).tolist() if k > 0 else [1.0] 181 | groups = _quantize(aspect_ratios, bins) 182 | # count number of elements per group 183 | counts = np.unique(groups, return_counts=True)[1] 184 | fbins = [0] + bins + [np.inf] 185 | print("Using {} as bins for aspect ratio quantization".format(fbins)) 186 | print("Count of instances per bin: {}".format(counts)) 187 | return groups 188 | -------------------------------------------------------------------------------- /trainer/cowc_GAN_FRCNN_trainer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import torch 4 | import math 5 | import os 6 | import model.ESRGANModel as ESRGAN 7 | import model.ESRGAN_EESN_FRCNN_Model as ESRGAN_EESN 8 | from scripts_for_datasets import COWCDataset, COWCGANDataset 9 | from torchvision.utils import make_grid 10 | from base import BaseTrainer 11 | from utils import inf_loop, MetricTracker, visualize_bbox, visualize, calculate_psnr, save_img, tensor2img, mkdir 12 | 13 | logger = logging.getLogger('base') 14 | ''' 15 | python train.py -c config_GAN.json 16 | modified from ESRGAN repo 17 | ''' 18 | 19 | class COWCGANFrcnnTrainer: 20 | """ 21 | Trainer class 22 | """ 23 | def __init__(self, config, data_loader, valid_data_loader=None): 24 | self.config = config 25 | self.data_loader = data_loader 26 | 27 | self.valid_data_loader = valid_data_loader 28 | self.do_validation = self.valid_data_loader is not None 29 | n_gpu = torch.cuda.device_count() 30 | self.device = torch.device('cuda:0' if n_gpu > 0 else 'cpu') 31 | self.train_size = int(math.ceil(self.data_loader.length / int(config['data_loader']['args']['batch_size']))) 32 | self.total_iters = int(config['train']['niter']) 33 | self.total_epochs = int(math.ceil(self.total_iters / self.train_size)) 34 | print(self.total_epochs) 35 | self.model = ESRGAN_EESN.ESRGAN_EESN_FRCNN_Model(config,self.device) 36 | 37 | def test(self): 38 | self.model.test(self.data_loader, train=False, testResult=True) 39 | 40 | def train(self): 41 | ''' 42 | Training logic for an epoch 43 | for visualization use the following code (use batch size = 1): 44 | 45 | category_id_to_name = {1: 'car'} 46 | for batch_idx, dataset_dict in enumerate(self.data_loader): 47 | if dataset_dict['idx'][0] == 10: 48 | print(dataset_dict) 49 | visualize(dataset_dict, category_id_to_name) #--> see this method in util 50 | 51 | #image size: torch.Size([10, 3, 256, 256]) if batch_size = 10 52 | ''' 53 | logger.info('Number of train images: {:,d}, iters: {:,d}'.format( 54 | self.data_loader.length, self.train_size)) 55 | logger.info('Total epochs needed: {:d} for iters {:,d}'.format( 56 | self.total_epochs, self.total_iters)) 57 | # tensorboard logger 58 | if self.config['use_tb_logger'] and 'debug' not in self.config['name']: 59 | version = float(torch.__version__[0:3]) 60 | if version >= 1.1: # PyTorch 1.1 61 | from torch.utils.tensorboard import SummaryWriter 62 | else: 63 | logger.info( 64 | 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version)) 65 | from tensorboardX import SummaryWriter 66 | tb_logger = SummaryWriter(log_dir='saved/tb_logger/' + self.config['name']) 67 | ## Todo : resume capability 68 | current_step = 0 69 | start_epoch = 0 70 | 71 | #### training 72 | logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step)) 73 | for epoch in range(start_epoch, self.total_epochs + 1): 74 | for _, (image, targets) in enumerate(self.data_loader): 75 | current_step += 1 76 | if current_step > self.total_iters: 77 | break 78 | #### update learning rate 79 | self.model.update_learning_rate(current_step, warmup_iter=self.config['train']['warmup_iter']) 80 | 81 | #### training 82 | self.model.feed_data(image, targets) 83 | self.model.optimize_parameters(current_step) 84 | 85 | #### log 86 | if current_step % self.config['logger']['print_freq'] == 0: 87 | logs = self.model.get_current_log() 88 | message = ' '.format( 89 | epoch, current_step, self.model.get_current_learning_rate()) 90 | for k, v in logs.items(): 91 | message += '{:s}: {:.4e} '.format(k, v) 92 | # tensorboard logger 93 | if self.config['use_tb_logger'] and 'debug' not in self.config['name']: 94 | tb_logger.add_scalar(k, v, current_step) 95 | 96 | logger.info(message) 97 | 98 | # validation 99 | if current_step % self.config['train']['val_freq'] == 0: 100 | self.model.test(self.valid_data_loader) 101 | 102 | #### save models and training states 103 | if current_step % self.config['logger']['save_checkpoint_freq'] == 0: 104 | logger.info('Saving models and training states.') 105 | self.model.save(current_step) 106 | self.model.save_training_state(epoch, current_step) 107 | 108 | #saving SR_images 109 | for _, (image, targets) in enumerate(self.valid_data_loader): 110 | #print(image) 111 | img_name = os.path.splitext(os.path.basename(image['LQ_path'][0]))[0] 112 | img_dir = os.path.join(self.config['path']['val_images'], img_name) 113 | mkdir(img_dir) 114 | 115 | self.model.feed_data(image, targets) 116 | self.model.test(self.valid_data_loader, train=False) 117 | 118 | visuals = self.model.get_current_visuals() 119 | sr_img = tensor2img(visuals['SR']) # uint8 120 | gt_img = tensor2img(visuals['GT']) # uint8 121 | lap_learned = tensor2img(visuals['lap_learned']) # uint8 122 | lap = tensor2img(visuals['lap']) # uint8 123 | lap_HR = tensor2img(visuals['lap_HR']) # uint8 124 | final_SR = tensor2img(visuals['final_SR']) # uint8 125 | 126 | # Save SR images for reference 127 | save_img_path = os.path.join(img_dir, 128 | '{:s}_{:d}_SR.png'.format(img_name, current_step)) 129 | save_img(sr_img, save_img_path) 130 | # Save GT images for reference 131 | save_img_path = os.path.join(img_dir, 132 | '{:s}_{:d}_GT.png'.format(img_name, current_step)) 133 | save_img(gt_img, save_img_path) 134 | # Save final_SR images for reference 135 | save_img_path = os.path.join(img_dir, 136 | '{:s}_{:d}_final_SR.png'.format(img_name, current_step)) 137 | save_img(final_SR, save_img_path) 138 | # Save lap_learned images for reference 139 | save_img_path = os.path.join(img_dir, 140 | '{:s}_{:d}_lap_learned.png'.format(img_name, current_step)) 141 | save_img(lap_learned, save_img_path) 142 | # Save lap images for reference 143 | save_img_path = os.path.join(img_dir, 144 | '{:s}_{:d}_lap.png'.format(img_name, current_step)) 145 | save_img(lap, save_img_path) 146 | # Save lap images for reference 147 | save_img_path = os.path.join(img_dir, 148 | '{:s}_{:d}_lap_HR.png'.format(img_name, current_step)) 149 | save_img(lap_HR, save_img_path) 150 | 151 | 152 | logger.info('Saving the final model.') 153 | self.model.save('latest') 154 | logger.info('End of training.') 155 | -------------------------------------------------------------------------------- /detection/engine.py: -------------------------------------------------------------------------------- 1 | import math 2 | import sys 3 | import time 4 | import torch 5 | import os 6 | import numpy as np 7 | import cv2 8 | 9 | import torchvision.models.detection.mask_rcnn 10 | 11 | from .coco_utils import get_coco_api_from_dataset, get_coco_api_from_dataset_base 12 | from .coco_eval import CocoEvaluator 13 | from .utils import MetricLogger, SmoothedValue, warmup_lr_scheduler, reduce_dict 14 | from utils import tensor2img 15 | 16 | 17 | def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq): 18 | model.train() 19 | metric_logger = MetricLogger(delimiter=" ") 20 | metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}')) 21 | header = 'Epoch: [{}]'.format(epoch) 22 | 23 | lr_scheduler = None 24 | if epoch == 0: 25 | warmup_factor = 1. / 1000 26 | warmup_iters = min(1000, len(data_loader) - 1) 27 | 28 | lr_scheduler = warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor) 29 | 30 | for images, targets in metric_logger.log_every(data_loader, print_freq, header): 31 | images = list(image.to(device) for image in images) 32 | targets = [{k: v.to(device) for k, v in t.items()} for t in targets] 33 | #print(images) 34 | loss_dict = model(images, targets) 35 | 36 | losses = sum(loss for loss in loss_dict.values()) 37 | 38 | # reduce losses over all GPUs for logging purposes 39 | loss_dict_reduced = reduce_dict(loss_dict) 40 | losses_reduced = sum(loss for loss in loss_dict_reduced.values()) 41 | 42 | loss_value = losses_reduced.item() 43 | 44 | if not math.isfinite(loss_value): 45 | print("Loss is {}, stopping training".format(loss_value)) 46 | print(loss_dict_reduced) 47 | sys.exit(1) 48 | 49 | optimizer.zero_grad() 50 | losses.backward() 51 | optimizer.step() 52 | 53 | if lr_scheduler is not None: 54 | lr_scheduler.step() 55 | 56 | metric_logger.update(loss=losses_reduced, **loss_dict_reduced) 57 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 58 | 59 | 60 | def _get_iou_types(model): 61 | model_without_ddp = model 62 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 63 | model_without_ddp = model.module 64 | iou_types = ["bbox"] 65 | if isinstance(model_without_ddp, torchvision.models.detection.MaskRCNN): 66 | iou_types.append("segm") 67 | if isinstance(model_without_ddp, torchvision.models.detection.KeypointRCNN): 68 | iou_types.append("keypoints") 69 | return iou_types 70 | 71 | ''' 72 | Draw boxes on the test images 73 | ''' 74 | def draw_detection_boxes(new_class_conf_box, config, file_name, image): 75 | source_image_path = os.path.join(config['path']['output_images'], file_name, file_name+'_112000_final_SR.png') 76 | dest_image_path = os.path.join(config['path']['Test_Result_SR'], file_name+'.png') 77 | image = cv2.imread(source_image_path, 1) 78 | #print(new_class_conf_box) 79 | #print(len(new_class_conf_box)) 80 | for i in range(len(new_class_conf_box)): 81 | clas,con,x1,y1,x2,y2 = new_class_conf_box[i] 82 | cv2.rectangle(image, (x1, y1), (x2, y2), (0,0,255), 4) 83 | font = cv2.FONT_HERSHEY_SIMPLEX 84 | cv2.putText(image, 'Car: '+ str((int(con*100))) + '%', (x1+5, y1+8), font, 0.2,(0,255,0),1,cv2.LINE_AA) 85 | 86 | cv2.imwrite(dest_image_path, image) 87 | 88 | ''' 89 | for generating test boxes 90 | ''' 91 | def get_prediction(outputs, file_path, config, file_name, image, threshold=0.5): 92 | new_class_conf_box = list() 93 | pred_class = [i for i in list(outputs[0]['labels'].detach().cpu().numpy())] # Get the Prediction Score 94 | text_boxes = [ [i[0], i[1], i[2], i[3] ] for i in list(outputs[0]['boxes'].detach().cpu().numpy())] # Bounding boxes 95 | pred_score = list(outputs[0]['scores'].detach().cpu().numpy()) 96 | #print(pred_score) 97 | for i in range(len(text_boxes)): 98 | new_class_conf_box.append([pred_class[i], pred_score[i], int(text_boxes[i][0]), int(text_boxes[i][1]), int(text_boxes[i][2]), int(text_boxes[i][3])]) 99 | draw_detection_boxes(new_class_conf_box, config, file_name, image) 100 | new_class_conf_box1 = np.matrix(new_class_conf_box) 101 | #print(new_class_conf_box) 102 | if(len(new_class_conf_box))>0: 103 | np.savetxt(file_path, new_class_conf_box1, fmt="%i %1.3f %i %i %i %i") 104 | 105 | 106 | @torch.no_grad() 107 | def evaluate_save(model_G, model_FRCNN, data_loader, device, config): 108 | i = 0 109 | print("Detection started........") 110 | for image, targets in data_loader: 111 | image['image_lq'] = image['image_lq'].to(device) 112 | 113 | _, img, _, _ = model_G(image['image_lq']) 114 | img_count = img.size()[0] 115 | images = [img[i] for i in range(img_count)] 116 | outputs = model_FRCNN(images) 117 | file_name = os.path.splitext(os.path.basename(image['LQ_path'][0]))[0] 118 | file_path = os.path.join(config['path']['Test_Result_SR'], file_name+'.txt') 119 | i=i+1 120 | print(i) 121 | img = img[0].detach()[0].float().cpu() 122 | img = tensor2img(img) 123 | get_prediction(outputs, file_path, config, file_name, img) 124 | print("successfully generated the results!") 125 | 126 | ''' 127 | This evaluate method is changed to pass the generator network and evalute 128 | the FRCNN with generated SR images 129 | ''' 130 | @torch.no_grad() 131 | def evaluate(model_G, model_FRCNN, data_loader, device): 132 | n_threads = torch.get_num_threads() 133 | # FIXME remove this and make paste_masks_in_image run on the GPU 134 | torch.set_num_threads(1) 135 | cpu_device = torch.device("cpu") 136 | #model.eval() 137 | metric_logger = MetricLogger(delimiter=" ") 138 | header = 'Test:' 139 | 140 | coco = get_coco_api_from_dataset(data_loader.dataset) 141 | iou_types = _get_iou_types(model_FRCNN) 142 | coco_evaluator = CocoEvaluator(coco, iou_types) 143 | 144 | for image, targets in metric_logger.log_every(data_loader, 100, header): 145 | image['image_lq'] = image['image_lq'].to(device) 146 | targets = [{k: v.to(device) for k, v in t.items()} for t in targets] 147 | 148 | torch.cuda.synchronize() 149 | model_time = time.time() 150 | _, image, _, _ = model_G(image['image_lq']) 151 | img_count = image.size()[0] 152 | image = [image[i] for i in range(img_count)] 153 | outputs = model_FRCNN(image) 154 | 155 | outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs] 156 | model_time = time.time() - model_time 157 | 158 | res = {target["image_id"].item(): output for target, output in zip(targets, outputs)} 159 | evaluator_time = time.time() 160 | coco_evaluator.update(res) 161 | evaluator_time = time.time() - evaluator_time 162 | metric_logger.update(model_time=model_time, evaluator_time=evaluator_time) 163 | 164 | # gather the stats from all processes 165 | metric_logger.synchronize_between_processes() 166 | print("Averaged stats:", metric_logger) 167 | coco_evaluator.synchronize_between_processes() 168 | 169 | # accumulate predictions from all images 170 | coco_evaluator.accumulate() 171 | coco_evaluator.summarize() 172 | torch.set_num_threads(n_threads) 173 | return coco_evaluator 174 | 175 | @torch.no_grad() 176 | def evaluate_base(model, data_loader, device): 177 | n_threads = torch.get_num_threads() 178 | # FIXME remove this and make paste_masks_in_image run on the GPU 179 | torch.set_num_threads(1) 180 | cpu_device = torch.device("cpu") 181 | model.eval() 182 | metric_logger = MetricLogger(delimiter=" ") 183 | header = 'Test:' 184 | 185 | coco = get_coco_api_from_dataset_base(data_loader.dataset) 186 | iou_types = _get_iou_types(model) 187 | coco_evaluator = CocoEvaluator(coco, iou_types) 188 | 189 | for image, targets in metric_logger.log_every(data_loader, 100, header): 190 | image = list(img.to(device) for img in image) 191 | targets = [{k: v.to(device) for k, v in t.items()} for t in targets] 192 | 193 | torch.cuda.synchronize() 194 | model_time = time.time() 195 | outputs = model(image) 196 | #print(outputs) 197 | 198 | outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs] 199 | model_time = time.time() - model_time 200 | 201 | res = {target["image_id"].item(): output for target, output in zip(targets, outputs)} 202 | evaluator_time = time.time() 203 | coco_evaluator.update(res) 204 | evaluator_time = time.time() - evaluator_time 205 | metric_logger.update(model_time=model_time, evaluator_time=evaluator_time) 206 | 207 | # gather the stats from all processes 208 | metric_logger.synchronize_between_processes() 209 | print("Averaged stats:", metric_logger) 210 | coco_evaluator.synchronize_between_processes() 211 | 212 | # accumulate predictions from all images 213 | coco_evaluator.accumulate() 214 | coco_evaluator.summarize() 215 | torch.set_num_threads(n_threads) 216 | return coco_evaluator 217 | -------------------------------------------------------------------------------- /trainer/cowc_GAN_trainer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import torch 4 | import math 5 | import os 6 | import model.ESRGANModel as ESRGAN 7 | import model.ESRGAN_EESN_Model as ESRGAN_EESN 8 | from scripts_for_datasets import COWCDataset, COWCGANDataset 9 | from torchvision.utils import make_grid 10 | from base import BaseTrainer 11 | from utils import inf_loop, MetricTracker, visualize_bbox, visualize, calculate_psnr, save_img, tensor2img, mkdir 12 | 13 | logger = logging.getLogger('base') 14 | ''' 15 | python train.py -c config_GAN.json 16 | modified from ESRGAN repo 17 | ''' 18 | 19 | class COWCGANTrainer: 20 | """ 21 | Trainer class 22 | """ 23 | def __init__(self, config, data_loader, valid_data_loader=None): 24 | self.config = config 25 | self.data_loader = data_loader 26 | 27 | self.valid_data_loader = valid_data_loader 28 | self.do_validation = self.valid_data_loader is not None 29 | n_gpu = torch.cuda.device_count() 30 | self.device = torch.device('cuda:1' if n_gpu > 0 else 'cpu') 31 | self.train_size = int(math.ceil(self.data_loader.length / int(config['data_loader']['args']['batch_size']))) 32 | self.total_iters = int(config['train']['niter']) 33 | self.total_epochs = int(math.ceil(self.total_iters / self.train_size)) 34 | print(self.total_epochs) 35 | self.model = ESRGAN_EESN.ESRGAN_EESN_Model(config,self.device) 36 | 37 | def test(self): 38 | for _, test_data in enumerate(self.data_loader): 39 | #print(val_data) 40 | img_name = os.path.splitext(os.path.basename(test_data['LQ_path'][0]))[0] 41 | img_dir = "/home/jakaria/Super_Resolution/Filter_Enhance_Detect/saved/" 42 | 43 | self.model.feed_data(test_data) 44 | self.model.test() 45 | 46 | visuals = self.model.get_current_visuals() 47 | sr_img = tensor2img(visuals['SR']) # uint8 48 | final_SR = tensor2img(visuals['final_SR']) # uint8 49 | 50 | # Save SR images for reference 51 | save_img_path = os.path.join(img_dir, 'combined_SR_images', img_name+'.png') 52 | save_img(sr_img, save_img_path) 53 | # Save final_SR images for reference 54 | save_img_path = os.path.join(img_dir, 'final_SR_images', img_name+'.png') 55 | save_img(final_SR, save_img_path) 56 | 57 | def train(self): 58 | ''' 59 | Training logic for an epoch 60 | for visualization use the following code (use batch size = 1): 61 | 62 | category_id_to_name = {1: 'car'} 63 | for batch_idx, dataset_dict in enumerate(self.data_loader): 64 | if dataset_dict['idx'][0] == 10: 65 | print(dataset_dict) 66 | visualize(dataset_dict, category_id_to_name) #--> see this method in util 67 | 68 | #image size: torch.Size([10, 3, 256, 256]) if batch_size = 10 69 | ''' 70 | logger.info('Number of train images: {:,d}, iters: {:,d}'.format( 71 | self.data_loader.length, self.train_size)) 72 | logger.info('Total epochs needed: {:d} for iters {:,d}'.format( 73 | self.total_epochs, self.total_iters)) 74 | # tensorboard logger 75 | if self.config['use_tb_logger'] and 'debug' not in self.config['name']: 76 | version = float(torch.__version__[0:3]) 77 | if version >= 1.1: # PyTorch 1.1 78 | from torch.utils.tensorboard import SummaryWriter 79 | else: 80 | logger.info( 81 | 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version)) 82 | from tensorboardX import SummaryWriter 83 | tb_logger = SummaryWriter(log_dir='saved/tb_logger/' + self.config['name']) 84 | 85 | current_step = 0 86 | start_epoch = 0 87 | 88 | #### training 89 | logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step)) 90 | for epoch in range(start_epoch, self.total_epochs + 1): 91 | for _, train_data in enumerate(self.data_loader): 92 | current_step += 1 93 | if current_step > self.total_iters: 94 | break 95 | #### update learning rate 96 | self.model.update_learning_rate(current_step, warmup_iter=self.config['train']['warmup_iter']) 97 | 98 | #### training 99 | self.model.feed_data(train_data) 100 | self.model.optimize_parameters(current_step) 101 | 102 | #### log 103 | if current_step % self.config['logger']['print_freq'] == 0: 104 | logs = self.model.get_current_log() 105 | message = ' '.format( 106 | epoch, current_step, self.model.get_current_learning_rate()) 107 | for k, v in logs.items(): 108 | message += '{:s}: {:.4e} '.format(k, v) 109 | # tensorboard logger 110 | if self.config['use_tb_logger'] and 'debug' not in self.config['name']: 111 | tb_logger.add_scalar(k, v, current_step) 112 | 113 | logger.info(message) 114 | 115 | # validation 116 | if current_step % self.config['train']['val_freq'] == 0: 117 | avg_psnr = 0.0 118 | idx = 0 119 | for val_data in self.valid_data_loader: 120 | idx += 1 121 | #print(val_data) 122 | img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0] 123 | img_dir = os.path.join(self.config['path']['val_images'], img_name) 124 | mkdir(img_dir) 125 | 126 | self.model.feed_data(val_data) 127 | self.model.test() 128 | 129 | visuals = self.model.get_current_visuals() 130 | sr_img = tensor2img(visuals['SR']) # uint8 131 | gt_img = tensor2img(visuals['GT']) # uint8 132 | lap_learned = tensor2img(visuals['lap_learned']) # uint8 133 | lap = tensor2img(visuals['lap']) # uint8 134 | lap_HR = tensor2img(visuals['lap_HR']) # uint8 135 | final_SR = tensor2img(visuals['final_SR']) # uint8 136 | 137 | # Save SR images for reference 138 | save_img_path = os.path.join(img_dir, 139 | '{:s}_{:d}_SR.png'.format(img_name, current_step)) 140 | save_img(sr_img, save_img_path) 141 | # Save GT images for reference 142 | save_img_path = os.path.join(img_dir, 143 | '{:s}_{:d}_GT.png'.format(img_name, current_step)) 144 | save_img(gt_img, save_img_path) 145 | # Save final_SR images for reference 146 | save_img_path = os.path.join(img_dir, 147 | '{:s}_{:d}_final_SR.png'.format(img_name, current_step)) 148 | save_img(final_SR, save_img_path) 149 | # Save lap_learned images for reference 150 | save_img_path = os.path.join(img_dir, 151 | '{:s}_{:d}_lap_learned.png'.format(img_name, current_step)) 152 | save_img(lap_learned, save_img_path) 153 | # Save lap images for reference 154 | save_img_path = os.path.join(img_dir, 155 | '{:s}_{:d}_lap.png'.format(img_name, current_step)) 156 | save_img(lap, save_img_path) 157 | # Save lap images for reference 158 | save_img_path = os.path.join(img_dir, 159 | '{:s}_{:d}_lap_HR.png'.format(img_name, current_step)) 160 | save_img(lap_HR, save_img_path) 161 | 162 | 163 | # calculate PSNR 164 | crop_size = self.config['scale'] 165 | gt_img = gt_img / 255. 166 | sr_img = sr_img / 255. 167 | cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :] 168 | cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :] 169 | avg_psnr += calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255) 170 | 171 | avg_psnr = avg_psnr / idx 172 | 173 | # log 174 | logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr)) 175 | logger_val = logging.getLogger('val') # validation logger 176 | logger_val.info(' psnr: {:.4e}'.format( 177 | epoch, current_step, avg_psnr)) 178 | # tensorboard logger 179 | if self.config['use_tb_logger'] and 'debug' not in self.config['name']: 180 | tb_logger.add_scalar('psnr', avg_psnr, current_step) 181 | 182 | #### save models and training states 183 | if current_step % self.config['logger']['save_checkpoint_freq'] == 0: 184 | logger.info('Saving models and training states.') 185 | self.model.save(current_step) 186 | self.model.save_training_state(epoch, current_step) 187 | 188 | 189 | logger.info('Saving the final model.') 190 | self.model.save('latest') 191 | logger.info('End of training.') 192 | -------------------------------------------------------------------------------- /detection/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | from collections import defaultdict, deque 4 | import datetime 5 | import pickle 6 | import time 7 | 8 | import torch 9 | import torch.distributed as dist 10 | 11 | import errno 12 | import os 13 | 14 | 15 | class SmoothedValue(object): 16 | """Track a series of values and provide access to smoothed values over a 17 | window or the global series average. 18 | """ 19 | 20 | def __init__(self, window_size=20, fmt=None): 21 | if fmt is None: 22 | fmt = "{median:.4f} ({global_avg:.4f})" 23 | self.deque = deque(maxlen=window_size) 24 | self.total = 0.0 25 | self.count = 0 26 | self.fmt = fmt 27 | 28 | def update(self, value, n=1): 29 | self.deque.append(value) 30 | self.count += n 31 | self.total += value * n 32 | 33 | def synchronize_between_processes(self): 34 | """ 35 | Warning: does not synchronize the deque! 36 | """ 37 | if not is_dist_avail_and_initialized(): 38 | return 39 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 40 | dist.barrier() 41 | dist.all_reduce(t) 42 | t = t.tolist() 43 | self.count = int(t[0]) 44 | self.total = t[1] 45 | 46 | @property 47 | def median(self): 48 | d = torch.tensor(list(self.deque)) 49 | return d.median().item() 50 | 51 | @property 52 | def avg(self): 53 | d = torch.tensor(list(self.deque), dtype=torch.float32) 54 | return d.mean().item() 55 | 56 | @property 57 | def global_avg(self): 58 | return self.total / self.count 59 | 60 | @property 61 | def max(self): 62 | return max(self.deque) 63 | 64 | @property 65 | def value(self): 66 | return self.deque[-1] 67 | 68 | def __str__(self): 69 | return self.fmt.format( 70 | median=self.median, 71 | avg=self.avg, 72 | global_avg=self.global_avg, 73 | max=self.max, 74 | value=self.value) 75 | 76 | 77 | def all_gather(data): 78 | """ 79 | Run all_gather on arbitrary picklable data (not necessarily tensors) 80 | Args: 81 | data: any picklable object 82 | Returns: 83 | list[data]: list of data gathered from each rank 84 | """ 85 | world_size = get_world_size() 86 | if world_size == 1: 87 | return [data] 88 | 89 | # serialized to a Tensor 90 | buffer = pickle.dumps(data) 91 | storage = torch.ByteStorage.from_buffer(buffer) 92 | tensor = torch.ByteTensor(storage).to("cuda") 93 | 94 | # obtain Tensor size of each rank 95 | local_size = torch.tensor([tensor.numel()], device="cuda") 96 | size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] 97 | dist.all_gather(size_list, local_size) 98 | size_list = [int(size.item()) for size in size_list] 99 | max_size = max(size_list) 100 | 101 | # receiving Tensor from all ranks 102 | # we pad the tensor because torch all_gather does not support 103 | # gathering tensors of different shapes 104 | tensor_list = [] 105 | for _ in size_list: 106 | tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) 107 | if local_size != max_size: 108 | padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") 109 | tensor = torch.cat((tensor, padding), dim=0) 110 | dist.all_gather(tensor_list, tensor) 111 | 112 | data_list = [] 113 | for size, tensor in zip(size_list, tensor_list): 114 | buffer = tensor.cpu().numpy().tobytes()[:size] 115 | data_list.append(pickle.loads(buffer)) 116 | 117 | return data_list 118 | 119 | 120 | def reduce_dict(input_dict, average=True): 121 | """ 122 | Args: 123 | input_dict (dict): all the values will be reduced 124 | average (bool): whether to do average or sum 125 | Reduce the values in the dictionary from all processes so that all processes 126 | have the averaged results. Returns a dict with the same fields as 127 | input_dict, after reduction. 128 | """ 129 | world_size = get_world_size() 130 | if world_size < 2: 131 | return input_dict 132 | with torch.no_grad(): 133 | names = [] 134 | values = [] 135 | # sort the keys so that they are consistent across processes 136 | for k in sorted(input_dict.keys()): 137 | names.append(k) 138 | values.append(input_dict[k]) 139 | values = torch.stack(values, dim=0) 140 | dist.all_reduce(values) 141 | if average: 142 | values /= world_size 143 | reduced_dict = {k: v for k, v in zip(names, values)} 144 | return reduced_dict 145 | 146 | 147 | class MetricLogger(object): 148 | def __init__(self, delimiter="\t"): 149 | self.meters = defaultdict(SmoothedValue) 150 | self.delimiter = delimiter 151 | 152 | def update(self, **kwargs): 153 | for k, v in kwargs.items(): 154 | if isinstance(v, torch.Tensor): 155 | v = v.item() 156 | assert isinstance(v, (float, int)) 157 | self.meters[k].update(v) 158 | 159 | def __getattr__(self, attr): 160 | if attr in self.meters: 161 | return self.meters[attr] 162 | if attr in self.__dict__: 163 | return self.__dict__[attr] 164 | raise AttributeError("'{}' object has no attribute '{}'".format( 165 | type(self).__name__, attr)) 166 | 167 | def __str__(self): 168 | loss_str = [] 169 | for name, meter in self.meters.items(): 170 | loss_str.append( 171 | "{}: {}".format(name, str(meter)) 172 | ) 173 | return self.delimiter.join(loss_str) 174 | 175 | def synchronize_between_processes(self): 176 | for meter in self.meters.values(): 177 | meter.synchronize_between_processes() 178 | 179 | def add_meter(self, name, meter): 180 | self.meters[name] = meter 181 | 182 | def log_every(self, iterable, print_freq, header=None): 183 | i = 0 184 | if not header: 185 | header = '' 186 | start_time = time.time() 187 | end = time.time() 188 | iter_time = SmoothedValue(fmt='{avg:.4f}') 189 | data_time = SmoothedValue(fmt='{avg:.4f}') 190 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 191 | log_msg = self.delimiter.join([ 192 | header, 193 | '[{0' + space_fmt + '}/{1}]', 194 | 'eta: {eta}', 195 | '{meters}', 196 | 'time: {time}', 197 | 'data: {data}', 198 | 'max mem: {memory:.0f}' 199 | ]) 200 | MB = 1024.0 * 1024.0 201 | for obj in iterable: 202 | data_time.update(time.time() - end) 203 | yield obj 204 | iter_time.update(time.time() - end) 205 | if i % print_freq == 0 or i == len(iterable) - 1: 206 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 207 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 208 | print(log_msg.format( 209 | i, len(iterable), eta=eta_string, 210 | meters=str(self), 211 | time=str(iter_time), data=str(data_time), 212 | memory=torch.cuda.max_memory_allocated() / MB)) 213 | i += 1 214 | end = time.time() 215 | total_time = time.time() - start_time 216 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 217 | print('{} Total time: {} ({:.4f} s / it)'.format( 218 | header, total_time_str, total_time / len(iterable))) 219 | 220 | 221 | def collate_fn(batch): 222 | return tuple(zip(*batch)) 223 | 224 | 225 | def warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor): 226 | 227 | def f(x): 228 | if x >= warmup_iters: 229 | return 1 230 | alpha = float(x) / warmup_iters 231 | return warmup_factor * (1 - alpha) + alpha 232 | 233 | return torch.optim.lr_scheduler.LambdaLR(optimizer, f) 234 | 235 | 236 | def mkdir(path): 237 | try: 238 | os.makedirs(path) 239 | except OSError as e: 240 | if e.errno != errno.EEXIST: 241 | raise 242 | 243 | 244 | def setup_for_distributed(is_master): 245 | """ 246 | This function disables printing when not in master process 247 | """ 248 | import builtins as __builtin__ 249 | builtin_print = __builtin__.print 250 | 251 | def print(*args, **kwargs): 252 | force = kwargs.pop('force', False) 253 | if is_master or force: 254 | builtin_print(*args, **kwargs) 255 | 256 | __builtin__.print = print 257 | 258 | 259 | def is_dist_avail_and_initialized(): 260 | if not dist.is_available(): 261 | return False 262 | if not dist.is_initialized(): 263 | return False 264 | return True 265 | 266 | 267 | def get_world_size(): 268 | if not is_dist_avail_and_initialized(): 269 | return 1 270 | return dist.get_world_size() 271 | 272 | 273 | def get_rank(): 274 | if not is_dist_avail_and_initialized(): 275 | return 0 276 | return dist.get_rank() 277 | 278 | 279 | def is_main_process(): 280 | return get_rank() == 0 281 | 282 | 283 | def save_on_master(*args, **kwargs): 284 | if is_main_process(): 285 | torch.save(*args, **kwargs) 286 | 287 | 288 | def init_distributed_mode(args): 289 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 290 | args.rank = int(os.environ["RANK"]) 291 | args.world_size = int(os.environ['WORLD_SIZE']) 292 | args.gpu = int(os.environ['LOCAL_RANK']) 293 | elif 'SLURM_PROCID' in os.environ: 294 | args.rank = int(os.environ['SLURM_PROCID']) 295 | args.gpu = args.rank % torch.cuda.device_count() 296 | else: 297 | print('Not using distributed mode') 298 | args.distributed = False 299 | return 300 | 301 | args.distributed = True 302 | 303 | torch.cuda.set_device(args.gpu) 304 | args.dist_backend = 'nccl' 305 | print('| distributed init (rank {}): {}'.format( 306 | args.rank, args.dist_url), flush=True) 307 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 308 | world_size=args.world_size, rank=args.rank) 309 | torch.distributed.barrier() 310 | setup_for_distributed(args.rank == 0) 311 | -------------------------------------------------------------------------------- /detection/coco_utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | from PIL import Image 4 | 5 | import torch 6 | import torch.utils.data 7 | import torchvision 8 | 9 | from pycocotools import mask as coco_mask 10 | from pycocotools.coco import COCO 11 | 12 | from torchvision import transforms as T 13 | 14 | 15 | class FilterAndRemapCocoCategories(object): 16 | def __init__(self, categories, remap=True): 17 | self.categories = categories 18 | self.remap = remap 19 | 20 | def __call__(self, image, target): 21 | anno = target["annotations"] 22 | anno = [obj for obj in anno if obj["category_id"] in self.categories] 23 | if not self.remap: 24 | target["annotations"] = anno 25 | return image, target 26 | anno = copy.deepcopy(anno) 27 | for obj in anno: 28 | obj["category_id"] = self.categories.index(obj["category_id"]) 29 | target["annotations"] = anno 30 | return image, target 31 | 32 | 33 | def convert_coco_poly_to_mask(segmentations, height, width): 34 | masks = [] 35 | for polygons in segmentations: 36 | rles = coco_mask.frPyObjects(polygons, height, width) 37 | mask = coco_mask.decode(rles) 38 | if len(mask.shape) < 3: 39 | mask = mask[..., None] 40 | mask = torch.as_tensor(mask, dtype=torch.uint8) 41 | mask = mask.any(dim=2) 42 | masks.append(mask) 43 | if masks: 44 | masks = torch.stack(masks, dim=0) 45 | else: 46 | masks = torch.zeros((0, height, width), dtype=torch.uint8) 47 | return masks 48 | 49 | 50 | class ConvertCocoPolysToMask(object): 51 | def __call__(self, image, target): 52 | w, h = image.size 53 | 54 | image_id = target["image_id"] 55 | image_id = torch.tensor([image_id]) 56 | 57 | anno = target["annotations"] 58 | 59 | anno = [obj for obj in anno if obj['iscrowd'] == 0] 60 | 61 | boxes = [obj["bbox"] for obj in anno] 62 | # guard against no boxes via resizing 63 | boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) 64 | boxes[:, 2:] += boxes[:, :2] 65 | boxes[:, 0::2].clamp_(min=0, max=w) 66 | boxes[:, 1::2].clamp_(min=0, max=h) 67 | 68 | classes = [obj["category_id"] for obj in anno] 69 | classes = torch.tensor(classes, dtype=torch.int64) 70 | 71 | segmentations = [obj["segmentation"] for obj in anno] 72 | masks = convert_coco_poly_to_mask(segmentations, h, w) 73 | 74 | keypoints = None 75 | if anno and "keypoints" in anno[0]: 76 | keypoints = [obj["keypoints"] for obj in anno] 77 | keypoints = torch.as_tensor(keypoints, dtype=torch.float32) 78 | num_keypoints = keypoints.shape[0] 79 | if num_keypoints: 80 | keypoints = keypoints.view(num_keypoints, -1, 3) 81 | 82 | keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) 83 | boxes = boxes[keep] 84 | classes = classes[keep] 85 | masks = masks[keep] 86 | if keypoints is not None: 87 | keypoints = keypoints[keep] 88 | 89 | target = {} 90 | target["boxes"] = boxes 91 | target["labels"] = classes 92 | target["masks"] = masks 93 | target["image_id"] = image_id 94 | if keypoints is not None: 95 | target["keypoints"] = keypoints 96 | 97 | # for conversion to coco api 98 | area = torch.tensor([obj["area"] for obj in anno]) 99 | iscrowd = torch.tensor([obj["iscrowd"] for obj in anno]) 100 | target["area"] = area 101 | target["iscrowd"] = iscrowd 102 | 103 | return image, target 104 | 105 | 106 | def _coco_remove_images_without_annotations(dataset, cat_list=None): 107 | def _has_only_empty_bbox(anno): 108 | return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno) 109 | 110 | def _count_visible_keypoints(anno): 111 | return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno) 112 | 113 | min_keypoints_per_image = 10 114 | 115 | def _has_valid_annotation(anno): 116 | # if it's empty, there is no annotation 117 | if len(anno) == 0: 118 | return False 119 | # if all boxes have close to zero area, there is no annotation 120 | if _has_only_empty_bbox(anno): 121 | return False 122 | # keypoints task have a slight different critera for considering 123 | # if an annotation is valid 124 | if "keypoints" not in anno[0]: 125 | return True 126 | # for keypoint detection tasks, only consider valid images those 127 | # containing at least min_keypoints_per_image 128 | if _count_visible_keypoints(anno) >= min_keypoints_per_image: 129 | return True 130 | return False 131 | 132 | assert isinstance(dataset, torchvision.datasets.CocoDetection) 133 | ids = [] 134 | for ds_idx, img_id in enumerate(dataset.ids): 135 | ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None) 136 | anno = dataset.coco.loadAnns(ann_ids) 137 | if cat_list: 138 | anno = [obj for obj in anno if obj["category_id"] in cat_list] 139 | if _has_valid_annotation(anno): 140 | ids.append(ds_idx) 141 | 142 | dataset = torch.utils.data.Subset(dataset, ids) 143 | return dataset 144 | 145 | 146 | def convert_to_coco_api(ds): 147 | coco_ds = COCO() 148 | ann_id = 0 149 | dataset = {'images': [], 'categories': [], 'annotations': []} 150 | categories = set() 151 | for img_idx in range(len(ds)): 152 | # find better way to get target 153 | # targets = ds.get_annotations(img_idx) 154 | img, targets = ds[img_idx] 155 | image_id = targets["image_id"].item() 156 | img_dict = {} 157 | img_dict['id'] = image_id 158 | img_dict['height'] = img['image_lq'].shape[-2] 159 | img_dict['width'] = img['image_lq'].shape[-1] 160 | dataset['images'].append(img_dict) 161 | bboxes = targets["boxes"] 162 | bboxes[:, 2:] -= bboxes[:, :2] 163 | bboxes = bboxes.tolist() 164 | labels = targets['labels'].tolist() 165 | areas = targets['area'].tolist() 166 | iscrowd = targets['iscrowd'].tolist() 167 | if 'masks' in targets: 168 | masks = targets['masks'] 169 | # make masks Fortran contiguous for coco_mask 170 | masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1) 171 | if 'keypoints' in targets: 172 | keypoints = targets['keypoints'] 173 | keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist() 174 | num_objs = len(bboxes) 175 | for i in range(num_objs): 176 | ann = {} 177 | ann['image_id'] = image_id 178 | ann['bbox'] = bboxes[i] 179 | ann['category_id'] = labels[i] 180 | categories.add(labels[i]) 181 | ann['area'] = areas[i] 182 | ann['iscrowd'] = iscrowd[i] 183 | ann['id'] = ann_id 184 | if 'masks' in targets: 185 | ann["segmentation"] = coco_mask.encode(masks[i].numpy()) 186 | if 'keypoints' in targets: 187 | ann['keypoints'] = keypoints[i] 188 | ann['num_keypoints'] = sum(k != 0 for k in keypoints[i][2::3]) 189 | dataset['annotations'].append(ann) 190 | ann_id += 1 191 | dataset['categories'] = [{'id': i} for i in sorted(categories)] 192 | coco_ds.dataset = dataset 193 | coco_ds.createIndex() 194 | return coco_ds 195 | 196 | def convert_to_coco_api_base(ds): 197 | coco_ds = COCO() 198 | ann_id = 0 199 | dataset = {'images': [], 'categories': [], 'annotations': []} 200 | categories = set() 201 | for img_idx in range(len(ds)): 202 | # find better way to get target 203 | # targets = ds.get_annotations(img_idx) 204 | img, targets = ds[img_idx] 205 | image_id = targets["image_id"].item() 206 | img_dict = {} 207 | img_dict['id'] = image_id 208 | img_dict['height'] = img.shape[-2] 209 | img_dict['width'] = img.shape[-1] 210 | dataset['images'].append(img_dict) 211 | bboxes = targets["boxes"] 212 | bboxes[:, 2:] -= bboxes[:, :2] 213 | bboxes = bboxes.tolist() 214 | labels = targets['labels'].tolist() 215 | areas = targets['area'].tolist() 216 | iscrowd = targets['iscrowd'].tolist() 217 | if 'masks' in targets: 218 | masks = targets['masks'] 219 | # make masks Fortran contiguous for coco_mask 220 | masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1) 221 | if 'keypoints' in targets: 222 | keypoints = targets['keypoints'] 223 | keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist() 224 | num_objs = len(bboxes) 225 | for i in range(num_objs): 226 | ann = {} 227 | ann['image_id'] = image_id 228 | ann['bbox'] = bboxes[i] 229 | ann['category_id'] = labels[i] 230 | categories.add(labels[i]) 231 | ann['area'] = areas[i] 232 | ann['iscrowd'] = iscrowd[i] 233 | ann['id'] = ann_id 234 | if 'masks' in targets: 235 | ann["segmentation"] = coco_mask.encode(masks[i].numpy()) 236 | if 'keypoints' in targets: 237 | ann['keypoints'] = keypoints[i] 238 | ann['num_keypoints'] = sum(k != 0 for k in keypoints[i][2::3]) 239 | dataset['annotations'].append(ann) 240 | ann_id += 1 241 | dataset['categories'] = [{'id': i} for i in sorted(categories)] 242 | coco_ds.dataset = dataset 243 | coco_ds.createIndex() 244 | return coco_ds 245 | 246 | 247 | def get_coco_api_from_dataset(dataset): 248 | for i in range(10): 249 | if isinstance(dataset, torchvision.datasets.CocoDetection): 250 | break 251 | if isinstance(dataset, torch.utils.data.Subset): 252 | dataset = dataset.dataset 253 | if isinstance(dataset, torchvision.datasets.CocoDetection): 254 | return dataset.coco 255 | return convert_to_coco_api(dataset) 256 | 257 | def get_coco_api_from_dataset_base(dataset): 258 | for i in range(10): 259 | if isinstance(dataset, torchvision.datasets.CocoDetection): 260 | break 261 | if isinstance(dataset, torch.utils.data.Subset): 262 | dataset = dataset.dataset 263 | if isinstance(dataset, torchvision.datasets.CocoDetection): 264 | return dataset.coco 265 | return convert_to_coco_api_base(dataset) 266 | 267 | 268 | class CocoDetection(torchvision.datasets.CocoDetection): 269 | def __init__(self, img_folder, ann_file, transforms): 270 | super(CocoDetection, self).__init__(img_folder, ann_file) 271 | self._transforms = transforms 272 | 273 | def __getitem__(self, idx): 274 | img, target = super(CocoDetection, self).__getitem__(idx) 275 | image_id = self.ids[idx] 276 | target = dict(image_id=image_id, annotations=target) 277 | if self._transforms is not None: 278 | img, target = self._transforms(img, target) 279 | return img, target 280 | 281 | 282 | def get_coco(root, image_set, transforms, mode='instances'): 283 | anno_file_template = "{}_{}2017.json" 284 | PATHS = { 285 | "train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))), 286 | "val": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val"))), 287 | # "train": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val"))) 288 | } 289 | 290 | t = [ConvertCocoPolysToMask()] 291 | 292 | if transforms is not None: 293 | t.append(transforms) 294 | transforms = T.Compose(t) 295 | 296 | img_folder, ann_file = PATHS[image_set] 297 | img_folder = os.path.join(root, img_folder) 298 | ann_file = os.path.join(root, ann_file) 299 | 300 | dataset = CocoDetection(img_folder, ann_file, transforms=transforms) 301 | 302 | if image_set == "train": 303 | dataset = _coco_remove_images_without_annotations(dataset) 304 | 305 | # dataset = torch.utils.data.Subset(dataset, [i for i in range(500)]) 306 | 307 | return dataset 308 | 309 | 310 | def get_coco_kp(root, image_set, transforms): 311 | return get_coco(root, image_set, transforms, mode="person_keypoints") 312 | -------------------------------------------------------------------------------- /model/ESRGANModel.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import OrderedDict 3 | import torch 4 | import torch.nn as nn 5 | import model.model as model 6 | import model.lr_scheduler as lr_scheduler 7 | from model.loss import GANLoss 8 | from .gan_base_model import BaseModel 9 | from torch.nn.parallel import DataParallel 10 | 11 | logger = logging.getLogger('base') 12 | # Taken from ESRGAN BASICSR repository and modified 13 | class ESRGANModel(BaseModel): 14 | def __init__(self, config, device): 15 | super(ESRGANModel, self).__init__(config, device) 16 | self.configG = config['network_G'] 17 | self.configD = config['network_D'] 18 | self.configT = config['train'] 19 | self.configO = config['optimizer']['args'] 20 | self.configS = config['lr_scheduler'] 21 | self.device = device 22 | #Generator 23 | self.netG = model.RRDBNet(in_nc=self.configG['in_nc'], out_nc=self.configG['out_nc'], 24 | nf=self.configG['nf'], nb=self.configG['nb']) 25 | self.netG = self.netG.to(self.device) 26 | self.netG = DataParallel(self.netG) 27 | #descriminator 28 | self.netD = model.Discriminator_VGG_128(in_nc=self.configD['in_nc'], nf=self.configD['nf']) 29 | self.netD = self.netD.to(self.device) 30 | self.netD = DataParallel(self.netD) 31 | 32 | self.netG.train() 33 | self.netD.train() 34 | #print(self.configT['pixel_weight']) 35 | # G pixel loss 36 | if self.configT['pixel_weight'] > 0.0: 37 | l_pix_type = self.configT['pixel_criterion'] 38 | if l_pix_type == 'l1': 39 | self.cri_pix = nn.L1Loss().to(self.device) 40 | elif l_pix_type == 'l2': 41 | self.cri_pix = nn.MSELoss().to(self.device) 42 | else: 43 | raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type)) 44 | self.l_pix_w = self.configT['pixel_weight'] 45 | else: 46 | self.cri_pix = None 47 | 48 | # G feature loss 49 | #print(self.configT['feature_weight']+1) 50 | if self.configT['feature_weight'] > 0: 51 | l_fea_type = self.configT['feature_criterion'] 52 | if l_fea_type == 'l1': 53 | self.cri_fea = nn.L1Loss().to(self.device) 54 | elif l_fea_type == 'l2': 55 | self.cri_fea = nn.MSELoss().to(self.device) 56 | else: 57 | raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type)) 58 | self.l_fea_w = self.configT['feature_weight'] 59 | else: 60 | self.cri_fea = None 61 | if self.cri_fea: # load VGG perceptual loss 62 | self.netF = model.VGGFeatureExtractor(feature_layer=34, 63 | use_input_norm=True, device=self.device) 64 | self.netF = self.netF.to(self.device) 65 | self.netF = DataParallel(self.netF) 66 | self.netF.eval() 67 | 68 | # GD gan loss 69 | self.cri_gan = GANLoss(self.configT['gan_type'], 1.0, 0.0).to(self.device) 70 | self.l_gan_w = self.configT['gan_weight'] 71 | # D_update_ratio and D_init_iters 72 | self.D_update_ratio = self.configT['D_update_ratio'] if self.configT['D_update_ratio'] else 1 73 | self.D_init_iters = self.configT['D_init_iters'] if self.configT['D_init_iters'] else 0 74 | 75 | 76 | # optimizers 77 | # G 78 | wd_G = self.configO['weight_decay_G'] if self.configO['weight_decay_G'] else 0 79 | optim_params = [] 80 | for k, v in self.netG.named_parameters(): # can optimize for a part of the model 81 | if v.requires_grad: 82 | optim_params.append(v) 83 | 84 | self.optimizer_G = torch.optim.Adam(optim_params, lr=self.configO['lr_G'], 85 | weight_decay=wd_G, 86 | betas=(self.configO['beta1_G'], self.configO['beta2_G'])) 87 | self.optimizers.append(self.optimizer_G) 88 | # D 89 | wd_D = self.configO['weight_decay_D'] if self.configO['weight_decay_D'] else 0 90 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=self.configO['lr_D'], 91 | weight_decay=wd_D, 92 | betas=(self.configO['beta1_D'], self.configO['beta2_D'])) 93 | self.optimizers.append(self.optimizer_D) 94 | 95 | # schedulers 96 | if self.configS['type'] == 'MultiStepLR': 97 | for optimizer in self.optimizers: 98 | self.schedulers.append( 99 | lr_scheduler.MultiStepLR_Restart(optimizer, self.configS['args']['lr_steps'], 100 | restarts=self.configS['args']['restarts'], 101 | weights=self.configS['args']['restart_weights'], 102 | gamma=self.configS['args']['lr_gamma'], 103 | clear_state=False)) 104 | elif self.configS['type'] == 'CosineAnnealingLR_Restart': 105 | for optimizer in self.optimizers: 106 | self.schedulers.append( 107 | lr_scheduler.CosineAnnealingLR_Restart( 108 | optimizer, self.configS['args']['T_period'], eta_min=self.configS['args']['eta_min'], 109 | restarts=self.configS['args']['restarts'], weights=self.configS['args']['restart_weights'])) 110 | else: 111 | raise NotImplementedError('MultiStepLR learning rate scheme is enough.') 112 | print(self.configS['args']['restarts']) 113 | self.log_dict = OrderedDict() 114 | 115 | self.print_network() # print network 116 | self.load() # load G and D if needed 117 | ''' 118 | The main repo did not use collate_fn and image read has different flags 119 | and also used np.ascontiguousarray() 120 | Might change my code if problem happens 121 | ''' 122 | def feed_data(self, data): 123 | self.var_L = data['image_lq'].to(self.device) 124 | self.var_H = data['image'].to(self.device) 125 | input_ref = data['ref'] if 'ref' in data else data['image'] 126 | self.var_ref = input_ref.to(self.device) 127 | 128 | def optimize_parameters(self, step): 129 | #Generator 130 | for p in self.netD.parameters(): 131 | p.requires_grad = False 132 | self.optimizer_G.zero_grad() 133 | self.fake_H = self.netG(self.var_L) 134 | 135 | l_g_total = 0 136 | if step % self.D_update_ratio == 0 and step > self.D_init_iters: 137 | if self.cri_pix: #pixel loss 138 | l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H) 139 | l_g_total += l_g_pix 140 | if self.cri_fea: # feature loss 141 | real_fea = self.netF(self.var_H).detach() #don't want to backpropagate this, need proper explanation 142 | fake_fea = self.netF(self.fake_H) #In netF normalize=False, check it 143 | l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) 144 | l_g_total += l_g_fea 145 | 146 | pred_g_fake = self.netD(self.fake_H) 147 | if self.configT['gan_type'] == 'gan': 148 | l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) 149 | elif self.configT['gan_type'] == 'ragan': 150 | pred_d_real = self.netD(self.var_ref).detach() 151 | l_g_gan = self.l_gan_w * ( 152 | self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + 153 | self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 154 | l_g_total += l_g_gan 155 | 156 | l_g_total.backward() 157 | self.optimizer_G.step() 158 | 159 | #descriminator 160 | for p in self.netD.parameters(): 161 | p.requires_grad = True 162 | 163 | self.optimizer_D.zero_grad() 164 | l_d_total = 0 165 | pred_d_real = self.netD(self.var_ref) 166 | pred_d_fake = self.netD(self.fake_H.detach()) #to avoid BP to Generator 167 | if self.configT['gan_type'] == 'gan': 168 | l_d_real = self.cri_gan(pred_d_real, True) 169 | l_d_fake = self.cri_gan(pred_d_fake, False) 170 | l_d_total = l_d_real + l_d_fake 171 | elif self.configT['gan_type'] == 'ragan': 172 | l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) 173 | l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) 174 | l_d_total = (l_d_real + l_d_fake) / 2 175 | 176 | l_d_total.backward() 177 | self.optimizer_D.step() 178 | 179 | # set log 180 | if step % self.D_update_ratio == 0 and step > self.D_init_iters: 181 | if self.cri_pix: 182 | self.log_dict['l_g_pix'] = l_g_pix.item() 183 | if self.cri_fea: 184 | self.log_dict['l_g_fea'] = l_g_fea.item() 185 | self.log_dict['l_g_gan'] = l_g_gan.item() 186 | 187 | self.log_dict['l_d_real'] = l_d_real.item() 188 | self.log_dict['l_d_fake'] = l_d_fake.item() 189 | self.log_dict['D_real'] = torch.mean(pred_d_real.detach()) 190 | self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach()) 191 | 192 | def test(self): 193 | self.netG.eval() 194 | with torch.no_grad(): 195 | self.fake_H = self.netG(self.var_L) 196 | self.netG.train() 197 | 198 | def get_current_log(self): 199 | return self.log_dict 200 | 201 | def get_current_visuals(self, need_GT=True): 202 | out_dict = OrderedDict() 203 | out_dict['LQ'] = self.var_L.detach()[0].float().cpu() 204 | out_dict['SR'] = self.fake_H.detach()[0].float().cpu() 205 | if need_GT: 206 | out_dict['GT'] = self.var_H.detach()[0].float().cpu() 207 | return out_dict 208 | 209 | def print_network(self): 210 | # Generator 211 | s, n = self.get_network_description(self.netG) 212 | if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): 213 | net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, 214 | self.netG.module.__class__.__name__) 215 | else: 216 | net_struc_str = '{}'.format(self.netG.__class__.__name__) 217 | 218 | logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) 219 | logger.info(s) 220 | 221 | # Discriminator 222 | s, n = self.get_network_description(self.netD) 223 | if isinstance(self.netD, nn.DataParallel) or isinstance(self.netD, 224 | DistributedDataParallel): 225 | net_struc_str = '{} - {}'.format(self.netD.__class__.__name__, 226 | self.netD.module.__class__.__name__) 227 | else: 228 | net_struc_str = '{}'.format(self.netD.__class__.__name__) 229 | 230 | logger.info('Network D structure: {}, with parameters: {:,d}'.format( 231 | net_struc_str, n)) 232 | logger.info(s) 233 | 234 | if self.cri_fea: # F, Perceptual Network 235 | s, n = self.get_network_description(self.netF) 236 | if isinstance(self.netF, nn.DataParallel) or isinstance( 237 | self.netF, DistributedDataParallel): 238 | net_struc_str = '{} - {}'.format(self.netF.__class__.__name__, 239 | self.netF.module.__class__.__name__) 240 | else: 241 | net_struc_str = '{}'.format(self.netF.__class__.__name__) 242 | 243 | logger.info('Network F structure: {}, with parameters: {:,d}'.format( 244 | net_struc_str, n)) 245 | logger.info(s) 246 | 247 | def load(self): 248 | load_path_G = self.config['path']['pretrain_model_G'] 249 | if load_path_G: 250 | logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) 251 | self.load_network(load_path_G, self.netG, self.config['path']['strict_load']) 252 | load_path_D = self.config['path']['pretrain_model_D'] 253 | if load_path_D: 254 | logger.info('Loading model for D [{:s}] ...'.format(load_path_D)) 255 | self.load_network(load_path_D, self.netD, self.config['path']['strict_load']) 256 | 257 | def save(self, iter_step): 258 | self.save_network(self.netG, 'G', iter_step) 259 | self.save_network(self.netD, 'D', iter_step) 260 | -------------------------------------------------------------------------------- /detection/coco_eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import tempfile 3 | 4 | import numpy as np 5 | import copy 6 | import time 7 | import torch 8 | import torch._six 9 | 10 | from pycocotools.cocoeval import COCOeval 11 | from pycocotools.coco import COCO 12 | import pycocotools.mask as mask_util 13 | 14 | from collections import defaultdict 15 | 16 | from .utils import all_gather 17 | 18 | 19 | class CocoEvaluator(object): 20 | def __init__(self, coco_gt, iou_types): 21 | assert isinstance(iou_types, (list, tuple)) 22 | coco_gt = copy.deepcopy(coco_gt) 23 | self.coco_gt = coco_gt 24 | 25 | self.iou_types = iou_types 26 | self.coco_eval = {} 27 | for iou_type in iou_types: 28 | self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type) 29 | 30 | self.img_ids = [] 31 | self.eval_imgs = {k: [] for k in iou_types} 32 | 33 | def update(self, predictions): 34 | img_ids = list(np.unique(list(predictions.keys()))) 35 | self.img_ids.extend(img_ids) 36 | 37 | for iou_type in self.iou_types: 38 | results = self.prepare(predictions, iou_type) 39 | coco_dt = loadRes(self.coco_gt, results) if results else COCO() 40 | coco_eval = self.coco_eval[iou_type] 41 | 42 | coco_eval.cocoDt = coco_dt 43 | coco_eval.params.imgIds = list(img_ids) 44 | img_ids, eval_imgs = evaluate(coco_eval) 45 | 46 | self.eval_imgs[iou_type].append(eval_imgs) 47 | 48 | def synchronize_between_processes(self): 49 | for iou_type in self.iou_types: 50 | self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2) 51 | create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type]) 52 | 53 | def accumulate(self): 54 | for coco_eval in self.coco_eval.values(): 55 | coco_eval.accumulate() 56 | 57 | def summarize(self): 58 | for iou_type, coco_eval in self.coco_eval.items(): 59 | print("IoU metric: {}".format(iou_type)) 60 | coco_eval.summarize() 61 | 62 | def prepare(self, predictions, iou_type): 63 | if iou_type == "bbox": 64 | return self.prepare_for_coco_detection(predictions) 65 | elif iou_type == "segm": 66 | return self.prepare_for_coco_segmentation(predictions) 67 | elif iou_type == "keypoints": 68 | return self.prepare_for_coco_keypoint(predictions) 69 | else: 70 | raise ValueError("Unknown iou type {}".format(iou_type)) 71 | 72 | def prepare_for_coco_detection(self, predictions): 73 | coco_results = [] 74 | for original_id, prediction in predictions.items(): 75 | if len(prediction) == 0: 76 | continue 77 | 78 | boxes = prediction["boxes"] 79 | boxes = convert_to_xywh(boxes).tolist() 80 | scores = prediction["scores"].tolist() 81 | labels = prediction["labels"].tolist() 82 | 83 | coco_results.extend( 84 | [ 85 | { 86 | "image_id": original_id, 87 | "category_id": labels[k], 88 | "bbox": box, 89 | "score": scores[k], 90 | } 91 | for k, box in enumerate(boxes) 92 | ] 93 | ) 94 | return coco_results 95 | 96 | def prepare_for_coco_segmentation(self, predictions): 97 | coco_results = [] 98 | for original_id, prediction in predictions.items(): 99 | if len(prediction) == 0: 100 | continue 101 | 102 | scores = prediction["scores"] 103 | labels = prediction["labels"] 104 | masks = prediction["masks"] 105 | 106 | masks = masks > 0.5 107 | 108 | scores = prediction["scores"].tolist() 109 | labels = prediction["labels"].tolist() 110 | 111 | rles = [ 112 | mask_util.encode(np.array(mask[0, :, :, np.newaxis], order="F"))[0] 113 | for mask in masks 114 | ] 115 | for rle in rles: 116 | rle["counts"] = rle["counts"].decode("utf-8") 117 | 118 | coco_results.extend( 119 | [ 120 | { 121 | "image_id": original_id, 122 | "category_id": labels[k], 123 | "segmentation": rle, 124 | "score": scores[k], 125 | } 126 | for k, rle in enumerate(rles) 127 | ] 128 | ) 129 | return coco_results 130 | 131 | def prepare_for_coco_keypoint(self, predictions): 132 | coco_results = [] 133 | for original_id, prediction in predictions.items(): 134 | if len(prediction) == 0: 135 | continue 136 | 137 | boxes = prediction["boxes"] 138 | boxes = convert_to_xywh(boxes).tolist() 139 | scores = prediction["scores"].tolist() 140 | labels = prediction["labels"].tolist() 141 | keypoints = prediction["keypoints"] 142 | keypoints = keypoints.flatten(start_dim=1).tolist() 143 | 144 | coco_results.extend( 145 | [ 146 | { 147 | "image_id": original_id, 148 | "category_id": labels[k], 149 | 'keypoints': keypoint, 150 | "score": scores[k], 151 | } 152 | for k, keypoint in enumerate(keypoints) 153 | ] 154 | ) 155 | return coco_results 156 | 157 | 158 | def convert_to_xywh(boxes): 159 | xmin, ymin, xmax, ymax = boxes.unbind(1) 160 | return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1) 161 | 162 | 163 | def merge(img_ids, eval_imgs): 164 | all_img_ids = all_gather(img_ids) 165 | all_eval_imgs = all_gather(eval_imgs) 166 | 167 | merged_img_ids = [] 168 | for p in all_img_ids: 169 | merged_img_ids.extend(p) 170 | 171 | merged_eval_imgs = [] 172 | for p in all_eval_imgs: 173 | merged_eval_imgs.append(p) 174 | 175 | merged_img_ids = np.array(merged_img_ids) 176 | merged_eval_imgs = np.concatenate(merged_eval_imgs, 2) 177 | 178 | # keep only unique (and in sorted order) images 179 | merged_img_ids, idx = np.unique(merged_img_ids, return_index=True) 180 | merged_eval_imgs = merged_eval_imgs[..., idx] 181 | 182 | return merged_img_ids, merged_eval_imgs 183 | 184 | 185 | def create_common_coco_eval(coco_eval, img_ids, eval_imgs): 186 | img_ids, eval_imgs = merge(img_ids, eval_imgs) 187 | img_ids = list(img_ids) 188 | eval_imgs = list(eval_imgs.flatten()) 189 | 190 | coco_eval.evalImgs = eval_imgs 191 | coco_eval.params.imgIds = img_ids 192 | coco_eval._paramsEval = copy.deepcopy(coco_eval.params) 193 | 194 | 195 | ################################################################# 196 | # From pycocotools, just removed the prints and fixed 197 | # a Python3 bug about unicode not defined 198 | ################################################################# 199 | 200 | # Ideally, pycocotools wouldn't have hard-coded prints 201 | # so that we could avoid copy-pasting those two functions 202 | 203 | def createIndex(self): 204 | # create index 205 | # print('creating index...') 206 | anns, cats, imgs = {}, {}, {} 207 | imgToAnns, catToImgs = defaultdict(list), defaultdict(list) 208 | if 'annotations' in self.dataset: 209 | for ann in self.dataset['annotations']: 210 | imgToAnns[ann['image_id']].append(ann) 211 | anns[ann['id']] = ann 212 | 213 | if 'images' in self.dataset: 214 | for img in self.dataset['images']: 215 | imgs[img['id']] = img 216 | 217 | if 'categories' in self.dataset: 218 | for cat in self.dataset['categories']: 219 | cats[cat['id']] = cat 220 | 221 | if 'annotations' in self.dataset and 'categories' in self.dataset: 222 | for ann in self.dataset['annotations']: 223 | catToImgs[ann['category_id']].append(ann['image_id']) 224 | 225 | # print('index created!') 226 | 227 | # create class members 228 | self.anns = anns 229 | self.imgToAnns = imgToAnns 230 | self.catToImgs = catToImgs 231 | self.imgs = imgs 232 | self.cats = cats 233 | 234 | 235 | maskUtils = mask_util 236 | 237 | 238 | def loadRes(self, resFile): 239 | """ 240 | Load result file and return a result api object. 241 | :param resFile (str) : file name of result file 242 | :return: res (obj) : result api object 243 | """ 244 | res = COCO() 245 | res.dataset['images'] = [img for img in self.dataset['images']] 246 | 247 | # print('Loading and preparing results...') 248 | # tic = time.time() 249 | if isinstance(resFile, torch._six.string_classes): 250 | anns = json.load(open(resFile)) 251 | elif type(resFile) == np.ndarray: 252 | anns = self.loadNumpyAnnotations(resFile) 253 | else: 254 | anns = resFile 255 | assert type(anns) == list, 'results in not an array of objects' 256 | annsImgIds = [ann['image_id'] for ann in anns] 257 | assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), \ 258 | 'Results do not correspond to current coco set' 259 | if 'caption' in anns[0]: 260 | imgIds = set([img['id'] for img in res.dataset['images']]) & set([ann['image_id'] for ann in anns]) 261 | res.dataset['images'] = [img for img in res.dataset['images'] if img['id'] in imgIds] 262 | for id, ann in enumerate(anns): 263 | ann['id'] = id + 1 264 | elif 'bbox' in anns[0] and not anns[0]['bbox'] == []: 265 | res.dataset['categories'] = copy.deepcopy(self.dataset['categories']) 266 | for id, ann in enumerate(anns): 267 | bb = ann['bbox'] 268 | x1, x2, y1, y2 = [bb[0], bb[0] + bb[2], bb[1], bb[1] + bb[3]] 269 | if 'segmentation' not in ann: 270 | ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]] 271 | ann['area'] = bb[2] * bb[3] 272 | ann['id'] = id + 1 273 | ann['iscrowd'] = 0 274 | elif 'segmentation' in anns[0]: 275 | res.dataset['categories'] = copy.deepcopy(self.dataset['categories']) 276 | for id, ann in enumerate(anns): 277 | # now only support compressed RLE format as segmentation results 278 | ann['area'] = maskUtils.area(ann['segmentation']) 279 | if 'bbox' not in ann: 280 | ann['bbox'] = maskUtils.toBbox(ann['segmentation']) 281 | ann['id'] = id + 1 282 | ann['iscrowd'] = 0 283 | elif 'keypoints' in anns[0]: 284 | res.dataset['categories'] = copy.deepcopy(self.dataset['categories']) 285 | for id, ann in enumerate(anns): 286 | s = ann['keypoints'] 287 | x = s[0::3] 288 | y = s[1::3] 289 | x0, x1, y0, y1 = np.min(x), np.max(x), np.min(y), np.max(y) 290 | ann['area'] = (x1 - x0) * (y1 - y0) 291 | ann['id'] = id + 1 292 | ann['bbox'] = [x0, y0, x1 - x0, y1 - y0] 293 | # print('DONE (t={:0.2f}s)'.format(time.time()- tic)) 294 | 295 | res.dataset['annotations'] = anns 296 | createIndex(res) 297 | return res 298 | 299 | 300 | def evaluate(self): 301 | ''' 302 | Run per image evaluation on given images and store results (a list of dict) in self.evalImgs 303 | :return: None 304 | ''' 305 | # tic = time.time() 306 | # print('Running per image evaluation...') 307 | p = self.params 308 | # add backward compatibility if useSegm is specified in params 309 | if p.useSegm is not None: 310 | p.iouType = 'segm' if p.useSegm == 1 else 'bbox' 311 | print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType)) 312 | # print('Evaluate annotation type *{}*'.format(p.iouType)) 313 | p.imgIds = list(np.unique(p.imgIds)) 314 | if p.useCats: 315 | p.catIds = list(np.unique(p.catIds)) 316 | p.maxDets = sorted(p.maxDets) 317 | self.params = p 318 | 319 | self._prepare() 320 | # loop through images, area range, max detection number 321 | catIds = p.catIds if p.useCats else [-1] 322 | 323 | if p.iouType == 'segm' or p.iouType == 'bbox': 324 | computeIoU = self.computeIoU 325 | elif p.iouType == 'keypoints': 326 | computeIoU = self.computeOks 327 | self.ious = { 328 | (imgId, catId): computeIoU(imgId, catId) 329 | for imgId in p.imgIds 330 | for catId in catIds} 331 | 332 | evaluateImg = self.evaluateImg 333 | maxDet = p.maxDets[-1] 334 | evalImgs = [ 335 | evaluateImg(imgId, catId, areaRng, maxDet) 336 | for catId in catIds 337 | for areaRng in p.areaRng 338 | for imgId in p.imgIds 339 | ] 340 | # this is NOT in the pycocotools code, but could be done outside 341 | evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds)) 342 | self._paramsEval = copy.deepcopy(self.params) 343 | # toc = time.time() 344 | # print('DONE (t={:0.2f}s).'.format(toc-tic)) 345 | return p.imgIds, evalImgs 346 | 347 | ################################################################# 348 | # end of straight copy from pycocotools, just removing the prints 349 | ################################################################# 350 | -------------------------------------------------------------------------------- /trainer/FRCNN_trainer.py: -------------------------------------------------------------------------------- 1 | ''' 2 | quick and dirty test, need to change later 3 | ''' 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | import torchvision 8 | import os 9 | import cv2 10 | from collections import OrderedDict 11 | from torch.nn.parallel import DataParallel, DistributedDataParallel 12 | from torchvision.models.detection.faster_rcnn import FastRCNNPredictor 13 | from torch.utils.data import Dataset, DataLoader 14 | from torchvision import utils 15 | from detection.engine import train_one_epoch, evaluate_base 16 | from detection.utils import collate_fn 17 | from scripts_for_datasets import COWCFRCNNDataset 18 | from detection.transforms import ToTensor, RandomHorizontalFlip, Compose 19 | from matplotlib import pyplot as plt 20 | from utils import tensor2img 21 | 22 | class COWCFRCNNTrainer: 23 | """ 24 | Trainer class 25 | """ 26 | def __init__(self, config): 27 | self.config = config 28 | 29 | n_gpu = torch.cuda.device_count() 30 | self.device = torch.device('cuda:0' if n_gpu > 0 else 'cpu') 31 | 32 | def get_transform(self, train): 33 | transforms = [] 34 | # converts the image, a PIL image, into a PyTorch Tensor 35 | transforms.append(ToTensor()) 36 | if train: 37 | # during training, randomly flip the training images 38 | # and ground-truth for data augmentation 39 | transforms.append(RandomHorizontalFlip(0.5)) 40 | return Compose(transforms) 41 | 42 | def data_loaders(self): 43 | # use our dataset and defined transformations 44 | dataset = COWCFRCNNDataset(root=self.config['path']['data_dir_LR_train'], 45 | image_height=64, image_width=64, transforms=self.get_transform(train=True)) 46 | dataset_test = COWCFRCNNDataset(root=self.config['path']['data_dir_Valid'], 47 | image_height=64, image_width=64, transforms=self.get_transform(train=False)) 48 | dataset_test_SR = COWCFRCNNDataset(root=self.config['path']['data_dir_SR'], 49 | transforms=self.get_transform(train=False)) 50 | dataset_test_SR_combined = COWCFRCNNDataset(root=self.config['path']['data_dir_SR_combined'], 51 | transforms=self.get_transform(train=False)) 52 | dataset_test_E_SR_1 = COWCFRCNNDataset(root=self.config['path']['data_dir_E_SR_1'], 53 | transforms=self.get_transform(train=False)) 54 | dataset_test_E_SR_2 = COWCFRCNNDataset(root=self.config['path']['data_dir_E_SR_2'], 55 | transforms=self.get_transform(train=False)) 56 | dataset_test_E_SR_3 = COWCFRCNNDataset(root=self.config['path']['data_dir_E_SR_3'], 57 | transforms=self.get_transform(train=False)) 58 | dataset_test_F_SR = COWCFRCNNDataset(root=self.config['path']['data_dir_F_SR'], 59 | transforms=self.get_transform(train=False)) 60 | dataset_test_Bic = COWCFRCNNDataset(root=self.config['path']['data_dir_Bic'], 61 | transforms=self.get_transform(train=False)) 62 | 63 | # define training and validation data loaders 64 | data_loader = torch.utils.data.DataLoader( 65 | dataset, batch_size=2, shuffle=True, num_workers=4, 66 | collate_fn=collate_fn) 67 | 68 | data_loader_test = torch.utils.data.DataLoader( 69 | dataset_test, batch_size=1, shuffle=False, num_workers=4, 70 | collate_fn=collate_fn) 71 | 72 | data_loader_test_SR = torch.utils.data.DataLoader( 73 | dataset_test_SR, batch_size=1, shuffle=False, num_workers=4, 74 | collate_fn=collate_fn) 75 | 76 | data_loader_test_SR_combined = torch.utils.data.DataLoader( 77 | dataset_test_SR_combined, batch_size=1, shuffle=False, num_workers=4, 78 | collate_fn=collate_fn) 79 | 80 | data_loader_test_E_SR_1 = torch.utils.data.DataLoader( 81 | dataset_test_E_SR_1, batch_size=1, shuffle=False, num_workers=4, 82 | collate_fn=collate_fn) 83 | 84 | data_loader_test_E_SR_2 = torch.utils.data.DataLoader( 85 | dataset_test_E_SR_2, batch_size=1, shuffle=False, num_workers=4, 86 | collate_fn=collate_fn) 87 | 88 | data_loader_test_E_SR_3 = torch.utils.data.DataLoader( 89 | dataset_test_E_SR_3, batch_size=1, shuffle=False, num_workers=4, 90 | collate_fn=collate_fn) 91 | 92 | data_loader_test_F_SR = torch.utils.data.DataLoader( 93 | dataset_test_F_SR, batch_size=1, shuffle=False, num_workers=4, 94 | collate_fn=collate_fn) 95 | 96 | data_loader_test_Bic = torch.utils.data.DataLoader( 97 | dataset_test_Bic, batch_size=1, shuffle=False, num_workers=4, 98 | collate_fn=collate_fn) 99 | 100 | return data_loader, data_loader_test, data_loader_test_SR, data_loader_test_SR_combined, \ 101 | data_loader_test_E_SR_1, data_loader_test_E_SR_2, data_loader_test_E_SR_3, \ 102 | data_loader_test_F_SR, data_loader_test_Bic 103 | 104 | def save_model(self, network, network_label, iter_label): 105 | save_filename = '{}_{}.pth'.format(iter_label, network_label) 106 | save_path = os.path.join(self.config['path']['FRCNN_model'], save_filename) 107 | 108 | state_dict = network.state_dict() 109 | for key, param in state_dict.items(): 110 | state_dict[key] = param.cpu() 111 | torch.save(state_dict, save_path) 112 | 113 | def load_model(self, load_path, network, strict=True): 114 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 115 | network = network.module 116 | load_net = torch.load(load_path) 117 | load_net_clean = OrderedDict() # remove unnecessary 'module.' 118 | for k, v in load_net.items(): 119 | if k.startswith('module.'): 120 | load_net_clean[k[7:]] = v 121 | else: 122 | load_net_clean[k] = v 123 | network.load_state_dict(load_net_clean, strict=strict) 124 | print("model_loaded") 125 | 126 | ''' 127 | Draw boxes on the test images 128 | ''' 129 | def draw_detection_boxes(self, new_class_conf_box, file_path): 130 | source_image_path = os.path.join(self.config['path']['data_dir_Bic_valid'], os.path.splitext(os.path.basename(file_path))[0]+'.jpg') 131 | dest_image_path = os.path.splitext(file_path)[0]+'.jpg' 132 | #print(dest_image_path) 133 | image = cv2.imread(source_image_path,1) 134 | #print(new_class_conf_box) 135 | #print(len(new_class_conf_box)) 136 | for i in range(len(new_class_conf_box)): 137 | clas,con,x1,y1,x2,y2 = new_class_conf_box[i] 138 | cv2.rectangle(image, (x1, y1), (x2, y2), (0,0,255), 4) 139 | font = cv2.FONT_HERSHEY_SIMPLEX 140 | cv2.putText(image, 'Car: '+ str((int(con*100))) + '%', ((x1)+5, (y1)+8), font, 0.2,(0,255,0),1,cv2.LINE_AA) 141 | 142 | cv2.imwrite(dest_image_path, image) 143 | 144 | ''' 145 | for generating test boxes 146 | ''' 147 | def get_prediction(self, model, images, annotation_path, threshold=0.5): 148 | new_class_conf_box = list() 149 | image = list(img.to(self.device) for img in images) 150 | outputs = model(image) 151 | file_path = os.path.join(self.config['path']['Test_Result_LR_LR_COWC'], os.path.basename(annotation_path)) 152 | #print(file_path) 153 | pred_class = [i for i in list(outputs[0]['labels'].detach().cpu().numpy())] # Get the Prediction Score 154 | text_boxes = [ [i[0], i[1], i[2], i[3] ] for i in list(outputs[0]['boxes'].detach().cpu().numpy())] # Bounding boxes 155 | pred_score = list(outputs[0]['scores'].detach().cpu().numpy()) 156 | #print(pred_score) 157 | for i in range(len(text_boxes)): 158 | if pred_score[i]<0.8: 159 | continue 160 | new_class_conf_box.append([pred_class[i], pred_score[i], int(text_boxes[i][0]*4), int(text_boxes[i][1]*4), int(text_boxes[i][2]*4), int(text_boxes[i][3]*4)]) 161 | self.draw_detection_boxes(new_class_conf_box, file_path) 162 | new_class_conf_box1 = np.matrix(new_class_conf_box) 163 | #print(new_class_conf_box) 164 | if(len(new_class_conf_box))>0: 165 | np.savetxt(file_path, new_class_conf_box1, fmt="%i %1.3f %i %i %i %i") 166 | 167 | #get test results 168 | def test(self): 169 | 170 | # load a model pre-trained pre-trained on COCO 171 | model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) 172 | 173 | # replace the classifier with a new one, that has 174 | # num_classes which is user-defined 175 | num_classes = 2 # 1 class (car) + background 176 | # get number of input features for the classifier 177 | in_features = model.roi_heads.box_predictor.cls_score.in_features 178 | # replace the pre-trained head with a new one 179 | model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) 180 | 181 | model.to(self.device) 182 | 183 | self.load_model(self.config['path']['pretrain_model_FRCNN_LR_LR'], model) 184 | 185 | _, data_loader_test, data_loader_test_SR, data_loader_test_SR_combined, \ 186 | data_loader_test_E_SR_1, data_loader_test_E_SR_2, data_loader_test_E_SR_3, \ 187 | data_loader_test_F_SR, data_loader_test_Bic = self.data_loaders() 188 | 189 | print("test lenghts of the data loaders.............") 190 | print(len(data_loader_test)) 191 | model.eval() 192 | i = 0 193 | print("Detection started........") 194 | for image, targets, annotation_path in data_loader_test: 195 | annotation_path = ''.join(annotation_path) 196 | self.get_prediction(model, image, annotation_path) 197 | #evaluate_base(model, data_loader_test_Bic, device=self.device) 198 | i=i+1 199 | print(i) 200 | print("successfully generated the results!") 201 | ''' 202 | print(len(data_loader_test_SR)) 203 | print(len(data_loader_test_SR_combined)) 204 | print(len(data_loader_test_E_SR_1)) 205 | print(len(data_loader_test_E_SR_2)) 206 | print(len(data_loader_test_E_SR_3)) 207 | print(len(data_loader_test_F_SR)) 208 | print(len(data_loader_test_Bic)) 209 | print("test HR images..............................") 210 | evaluate_base(model, data_loader_test, device=self.device) 211 | print("test SR images..............................") 212 | evaluate_base(model, data_loader_test_SR, device=self.device) 213 | print("test SR combined images..............................") 214 | evaluate_base(model, data_loader_test_SR_combined, device=self.device) 215 | print("test Enhanced SR 1 images.....................") 216 | evaluate_base(model, data_loader_test_E_SR_1, device=self.device) 217 | print("test Enhanced SR 2 images.....................") 218 | evaluate_base(model, data_loader_test_E_SR_2, device=self.device) 219 | print("test Enhanced SR 3 images.....................") 220 | evaluate_base(model, data_loader_test_E_SR_3, device=self.device) 221 | print("test Final SR images.........................") 222 | evaluate_base(model, data_loader_test_F_SR, device=self.device) 223 | print("test Bicubic images..........................") 224 | evaluate_base(model, data_loader_test_Bic, device=self.device) 225 | ''' 226 | def train(self): 227 | # load a model pre-trained pre-trained on COCO 228 | model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) 229 | 230 | # replace the classifier with a new one, that has 231 | # num_classes which is user-defined 232 | num_classes = 2 # 1 class (car) + background 233 | # get number of input features for the classifier 234 | in_features = model.roi_heads.box_predictor.cls_score.in_features 235 | # replace the pre-trained head with a new one 236 | model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) 237 | 238 | model.to(self.device) 239 | #self.load_model(self.config['path']['pretrain_model_FRCNN_LR_LR'], model) 240 | 241 | # construct an optimizer 242 | params = [p for p in model.parameters() if p.requires_grad] 243 | optimizer = torch.optim.SGD(params, lr=0.005, 244 | momentum=0.9, weight_decay=0.0005) 245 | 246 | # and a learning rate scheduler which decreases the learning rate by 247 | # 10x every 3 epochs 248 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 249 | step_size=3, 250 | gamma=0.1) 251 | 252 | data_loader, data_loader_test, _, _, _, _, _, _, _ = self.data_loaders() 253 | # let's train it for 10 epochs 254 | num_epochs = 1000 255 | 256 | for epoch in range(num_epochs): 257 | # train for one epoch, printing every 10 iterations 258 | train_one_epoch(model, optimizer, data_loader, self.device, epoch, print_freq=10) 259 | # update the learning rate 260 | lr_scheduler.step() 261 | # evaluate on the test dataset 262 | evaluate_base(model, data_loader_test, device=self.device) 263 | if epoch % 1 == 0: 264 | self.save_model(model, 'FRCNN_LR_LR', epoch) 265 | -------------------------------------------------------------------------------- /model/ESRGAN_EESN_Model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import OrderedDict 3 | import torch 4 | import torch.nn as nn 5 | import model.model as model 6 | import model.lr_scheduler as lr_scheduler 7 | from model.loss import GANLoss, CharbonnierLoss 8 | from .gan_base_model import BaseModel 9 | from torch.nn.parallel import DataParallel, DistributedDataParallel 10 | 11 | logger = logging.getLogger('base') 12 | # Taken from ESRGAN BASICSR repository and modified 13 | class ESRGAN_EESN_Model(BaseModel): 14 | def __init__(self, config, device): 15 | super(ESRGAN_EESN_Model, self).__init__(config, device) 16 | self.configG = config['network_G'] 17 | self.configD = config['network_D'] 18 | self.configT = config['train'] 19 | self.configO = config['optimizer']['args'] 20 | self.configS = config['lr_scheduler'] 21 | self.device = device 22 | #Generator 23 | self.netG = model.ESRGAN_EESN(in_nc=self.configG['in_nc'], out_nc=self.configG['out_nc'], 24 | nf=self.configG['nf'], nb=self.configG['nb']) 25 | self.netG = self.netG.to(self.device) 26 | self.netG = DataParallel(self.netG, device_ids=[1,0]) 27 | 28 | #descriminator 29 | self.netD = model.Discriminator_VGG_128(in_nc=self.configD['in_nc'], nf=self.configD['nf']) 30 | self.netD = self.netD.to(self.device) 31 | self.netD = DataParallel(self.netD, device_ids=[1,0]) 32 | 33 | self.netG.train() 34 | self.netD.train() 35 | #print(self.configT['pixel_weight']) 36 | # G CharbonnierLoss for final output SR and GT HR 37 | self.cri_charbonnier = CharbonnierLoss().to(device) 38 | # G pixel loss 39 | if self.configT['pixel_weight'] > 0.0: 40 | l_pix_type = self.configT['pixel_criterion'] 41 | if l_pix_type == 'l1': 42 | self.cri_pix = nn.L1Loss().to(self.device) 43 | elif l_pix_type == 'l2': 44 | self.cri_pix = nn.MSELoss().to(self.device) 45 | else: 46 | raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type)) 47 | self.l_pix_w = self.configT['pixel_weight'] 48 | else: 49 | self.cri_pix = None 50 | 51 | # G feature loss 52 | #print(self.configT['feature_weight']+1) 53 | if self.configT['feature_weight'] > 0: 54 | l_fea_type = self.configT['feature_criterion'] 55 | if l_fea_type == 'l1': 56 | self.cri_fea = nn.L1Loss().to(self.device) 57 | elif l_fea_type == 'l2': 58 | self.cri_fea = nn.MSELoss().to(self.device) 59 | else: 60 | raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type)) 61 | self.l_fea_w = self.configT['feature_weight'] 62 | else: 63 | self.cri_fea = None 64 | if self.cri_fea: # load VGG perceptual loss 65 | self.netF = model.VGGFeatureExtractor(feature_layer=34, 66 | use_input_norm=True, device=self.device) 67 | self.netF = self.netF.to(self.device) 68 | self.netF = DataParallel(self.netF, device_ids=[1,0]) 69 | self.netF.eval() 70 | 71 | # GD gan loss 72 | self.cri_gan = GANLoss(self.configT['gan_type'], 1.0, 0.0).to(self.device) 73 | self.l_gan_w = self.configT['gan_weight'] 74 | # D_update_ratio and D_init_iters 75 | self.D_update_ratio = self.configT['D_update_ratio'] if self.configT['D_update_ratio'] else 1 76 | self.D_init_iters = self.configT['D_init_iters'] if self.configT['D_init_iters'] else 0 77 | 78 | 79 | # optimizers 80 | # G 81 | wd_G = self.configO['weight_decay_G'] if self.configO['weight_decay_G'] else 0 82 | optim_params = [] 83 | for k, v in self.netG.named_parameters(): # can optimize for a part of the model 84 | if v.requires_grad: 85 | optim_params.append(v) 86 | 87 | self.optimizer_G = torch.optim.Adam(optim_params, lr=self.configO['lr_G'], 88 | weight_decay=wd_G, 89 | betas=(self.configO['beta1_G'], self.configO['beta2_G'])) 90 | self.optimizers.append(self.optimizer_G) 91 | 92 | # D 93 | wd_D = self.configO['weight_decay_D'] if self.configO['weight_decay_D'] else 0 94 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=self.configO['lr_D'], 95 | weight_decay=wd_D, 96 | betas=(self.configO['beta1_D'], self.configO['beta2_D'])) 97 | self.optimizers.append(self.optimizer_D) 98 | 99 | # schedulers 100 | if self.configS['type'] == 'MultiStepLR': 101 | for optimizer in self.optimizers: 102 | self.schedulers.append( 103 | lr_scheduler.MultiStepLR_Restart(optimizer, self.configS['args']['lr_steps'], 104 | restarts=self.configS['args']['restarts'], 105 | weights=self.configS['args']['restart_weights'], 106 | gamma=self.configS['args']['lr_gamma'], 107 | clear_state=False)) 108 | elif self.configS['type'] == 'CosineAnnealingLR_Restart': 109 | for optimizer in self.optimizers: 110 | self.schedulers.append( 111 | lr_scheduler.CosineAnnealingLR_Restart( 112 | optimizer, self.configS['args']['T_period'], eta_min=self.configS['args']['eta_min'], 113 | restarts=self.configS['args']['restarts'], weights=self.configS['args']['restart_weights'])) 114 | else: 115 | raise NotImplementedError('MultiStepLR learning rate scheme is enough.') 116 | print(self.configS['args']['restarts']) 117 | self.log_dict = OrderedDict() 118 | 119 | self.print_network() # print network 120 | self.load() # load G and D if needed 121 | ''' 122 | The main repo did not use collate_fn and image read has different flags 123 | and also used np.ascontiguousarray() 124 | Might change my code if problem happens 125 | ''' 126 | def feed_data(self, data): 127 | self.var_L = data['image_lq'].to(self.device) 128 | self.var_H = data['image'].to(self.device) 129 | input_ref = data['ref'] if 'ref' in data else data['image'] 130 | self.var_ref = input_ref.to(self.device) 131 | 132 | def optimize_parameters(self, step): 133 | #Generator 134 | for p in self.netD.parameters(): 135 | p.requires_grad = False 136 | self.optimizer_G.zero_grad() 137 | self.fake_H, self.final_SR, _, _ = self.netG(self.var_L) 138 | 139 | l_g_total = 0 140 | if step % self.D_update_ratio == 0 and step > self.D_init_iters: 141 | if self.cri_pix: #pixel loss 142 | l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H) 143 | l_g_total += l_g_pix 144 | if self.cri_fea: # feature loss 145 | real_fea = self.netF(self.var_H).detach() #don't want to backpropagate this, need proper explanation 146 | fake_fea = self.netF(self.fake_H) #In netF normalize=False, check it 147 | l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) 148 | l_g_total += l_g_fea 149 | 150 | pred_g_fake = self.netD(self.fake_H) 151 | if self.configT['gan_type'] == 'gan': 152 | l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) 153 | elif self.configT['gan_type'] == 'ragan': 154 | pred_d_real = self.netD(self.var_ref).detach() 155 | l_g_gan = self.l_gan_w * ( 156 | self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + 157 | self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 158 | l_g_total += l_g_gan 159 | #EESN calculate loss 160 | if self.cri_charbonnier: # charbonnier pixel loss HR and SR 161 | l_e_charbonnier = 5 * self.cri_charbonnier(self.final_SR, self.var_H) #change the weight to empirically 162 | l_g_total += l_e_charbonnier 163 | 164 | l_g_total.backward() 165 | self.optimizer_G.step() 166 | 167 | #descriminator 168 | for p in self.netD.parameters(): 169 | p.requires_grad = True 170 | 171 | self.optimizer_D.zero_grad() 172 | l_d_total = 0 173 | pred_d_real = self.netD(self.var_ref) 174 | pred_d_fake = self.netD(self.fake_H.detach()) #to avoid BP to Generator 175 | if self.configT['gan_type'] == 'gan': 176 | l_d_real = self.cri_gan(pred_d_real, True) 177 | l_d_fake = self.cri_gan(pred_d_fake, False) 178 | l_d_total = l_d_real + l_d_fake 179 | elif self.configT['gan_type'] == 'ragan': 180 | l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) 181 | l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) 182 | l_d_total = (l_d_real + l_d_fake) / 2 # thinking of adding final sr d loss 183 | 184 | l_d_total.backward() 185 | self.optimizer_D.step() 186 | 187 | # set log 188 | if step % self.D_update_ratio == 0 and step > self.D_init_iters: 189 | if self.cri_pix: 190 | self.log_dict['l_g_pix'] = l_g_pix.item() 191 | if self.cri_fea: 192 | self.log_dict['l_g_fea'] = l_g_fea.item() 193 | self.log_dict['l_g_gan'] = l_g_gan.item() 194 | self.log_dict['l_e_charbonnier'] = l_e_charbonnier.item() 195 | 196 | self.log_dict['l_d_real'] = l_d_real.item() 197 | self.log_dict['l_d_fake'] = l_d_fake.item() 198 | self.log_dict['D_real'] = torch.mean(pred_d_real.detach()) 199 | self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach()) 200 | 201 | def test(self): 202 | self.netG.eval() 203 | with torch.no_grad(): 204 | self.fake_H, self.final_SR, self.x_learned_lap_fake, self.x_lap = self.netG(self.var_L) 205 | _, _, _, self.x_lap_HR = self.netG(self.var_H) 206 | self.netG.train() 207 | 208 | def get_current_log(self): 209 | return self.log_dict 210 | 211 | def get_current_visuals(self, need_GT=True): 212 | out_dict = OrderedDict() 213 | out_dict['LQ'] = self.var_L.detach()[0].float().cpu() 214 | #out_dict['SR'] = self.fake_H.detach()[0].float().cpu() 215 | out_dict['SR'] = self.fake_H.detach()[0].float().cpu() 216 | out_dict['lap_learned'] = self.x_learned_lap_fake.detach()[0].float().cpu() 217 | out_dict['lap'] = self.x_lap.detach()[0].float().cpu() 218 | out_dict['lap_HR'] = self.x_lap_HR.detach()[0].float().cpu() 219 | out_dict['final_SR'] = self.final_SR.detach()[0].float().cpu() 220 | if need_GT: 221 | out_dict['GT'] = self.var_H.detach()[0].float().cpu() 222 | return out_dict 223 | 224 | def print_network(self): 225 | # Generator 226 | s, n = self.get_network_description(self.netG) 227 | if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): 228 | net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, 229 | self.netG.module.__class__.__name__) 230 | else: 231 | net_struc_str = '{}'.format(self.netG.__class__.__name__) 232 | 233 | logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) 234 | logger.info(s) 235 | 236 | # Discriminator 237 | s, n = self.get_network_description(self.netD) 238 | if isinstance(self.netD, nn.DataParallel) or isinstance(self.netD, 239 | DistributedDataParallel): 240 | net_struc_str = '{} - {}'.format(self.netD.__class__.__name__, 241 | self.netD.module.__class__.__name__) 242 | else: 243 | net_struc_str = '{}'.format(self.netD.__class__.__name__) 244 | 245 | logger.info('Network D structure: {}, with parameters: {:,d}'.format( 246 | net_struc_str, n)) 247 | logger.info(s) 248 | 249 | if self.cri_fea: # F, Perceptual Network 250 | s, n = self.get_network_description(self.netF) 251 | if isinstance(self.netF, nn.DataParallel) or isinstance( 252 | self.netF, DistributedDataParallel): 253 | net_struc_str = '{} - {}'.format(self.netF.__class__.__name__, 254 | self.netF.module.__class__.__name__) 255 | else: 256 | net_struc_str = '{}'.format(self.netF.__class__.__name__) 257 | 258 | logger.info('Network F structure: {}, with parameters: {:,d}'.format( 259 | net_struc_str, n)) 260 | logger.info(s) 261 | 262 | def load(self): 263 | load_path_G = self.config['path']['pretrain_model_G'] 264 | if load_path_G: 265 | logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) 266 | self.load_network(load_path_G, self.netG, self.config['path']['strict_load']) 267 | load_path_D = self.config['path']['pretrain_model_D'] 268 | if load_path_D: 269 | logger.info('Loading model for D [{:s}] ...'.format(load_path_D)) 270 | self.load_network(load_path_D, self.netD, self.config['path']['strict_load']) 271 | 272 | def save(self, iter_step): 273 | self.save_network(self.netG, 'G', iter_step) 274 | self.save_network(self.netD, 'D', iter_step) 275 | #self.save_network(self.netG.module.netE, 'E', iter_step) 276 | --------------------------------------------------------------------------------