├── __init__.py
├── test_all
├── __init__.py
└── test_classes.py
├── logger
├── __init__.py
├── logger.py
├── logger_config.json
└── visualization.py
├── utils
├── __init__.py
└── module_util.py
├── .requirements.txt.swp
├── Detection_Results
├── 1_GT.jpg
├── 1_LR.jpg
├── 1_SR.jpg
├── 2_GT.jpg
├── 2_LR.jpg
├── 2_SR.jpg
├── 1_GT_box.jpg
├── 2_GT_box.jpg
├── 2_LR_detect.jpg
├── 1_LR_detection.jpg
├── 1_SR_detection.jpg
├── 2_LR_detect_new.jpg
├── 2_SR_detection.jpg
├── 1_LR_detection_new.jpg
└── overall_pipeline.PNG
├── base
├── __init__.py
├── base_model.py
├── base_data_loader.py
└── base_trainer.py
├── model
├── .ESRGAN_EESN_Model.py.swp
├── metric.py
├── loss.py
├── lr_scheduler.py
├── gan_base_model.py
├── ESRGANModel.py
└── ESRGAN_EESN_Model.py
├── scripts_for_datasets
├── __init__.py
├── cowc_FRCNN_dataset.py
├── COWC_dataset.py
├── COWC_GAN_dataset.py
└── COWC_EESRGAN_FRCNN_dataset.py
├── trainer
├── __init__.py
├── trainer.py
├── cowc_trainer.py
├── cowc_GAN_FRCNN_trainer.py
├── cowc_GAN_trainer.py
└── FRCNN_trainer.py
├── requirements.txt
├── config.json
├── detection
├── transforms.py
├── train.py
├── group_by_aspect_ratio.py
├── engine.py
├── utils.py
├── coco_utils.py
└── coco_eval.py
├── .gitignore
├── test.py
├── train.py
├── config_GAN.json
├── README.md
├── parse_config.py
└── data_loader
└── data_loaders.py
/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/test_all/__init__.py:
--------------------------------------------------------------------------------
1 | from .test_classes import *
2 |
--------------------------------------------------------------------------------
/logger/__init__.py:
--------------------------------------------------------------------------------
1 | from .logger import *
2 | from .visualization import *
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .util import *
2 | from .module_util import *
3 |
--------------------------------------------------------------------------------
/.requirements.txt.swp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jakaria08/EESRGAN/HEAD/.requirements.txt.swp
--------------------------------------------------------------------------------
/Detection_Results/1_GT.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jakaria08/EESRGAN/HEAD/Detection_Results/1_GT.jpg
--------------------------------------------------------------------------------
/Detection_Results/1_LR.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jakaria08/EESRGAN/HEAD/Detection_Results/1_LR.jpg
--------------------------------------------------------------------------------
/Detection_Results/1_SR.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jakaria08/EESRGAN/HEAD/Detection_Results/1_SR.jpg
--------------------------------------------------------------------------------
/Detection_Results/2_GT.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jakaria08/EESRGAN/HEAD/Detection_Results/2_GT.jpg
--------------------------------------------------------------------------------
/Detection_Results/2_LR.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jakaria08/EESRGAN/HEAD/Detection_Results/2_LR.jpg
--------------------------------------------------------------------------------
/Detection_Results/2_SR.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jakaria08/EESRGAN/HEAD/Detection_Results/2_SR.jpg
--------------------------------------------------------------------------------
/Detection_Results/1_GT_box.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jakaria08/EESRGAN/HEAD/Detection_Results/1_GT_box.jpg
--------------------------------------------------------------------------------
/Detection_Results/2_GT_box.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jakaria08/EESRGAN/HEAD/Detection_Results/2_GT_box.jpg
--------------------------------------------------------------------------------
/base/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_data_loader import *
2 | from .base_model import *
3 | from .base_trainer import *
4 |
--------------------------------------------------------------------------------
/model/.ESRGAN_EESN_Model.py.swp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jakaria08/EESRGAN/HEAD/model/.ESRGAN_EESN_Model.py.swp
--------------------------------------------------------------------------------
/Detection_Results/2_LR_detect.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jakaria08/EESRGAN/HEAD/Detection_Results/2_LR_detect.jpg
--------------------------------------------------------------------------------
/Detection_Results/1_LR_detection.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jakaria08/EESRGAN/HEAD/Detection_Results/1_LR_detection.jpg
--------------------------------------------------------------------------------
/Detection_Results/1_SR_detection.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jakaria08/EESRGAN/HEAD/Detection_Results/1_SR_detection.jpg
--------------------------------------------------------------------------------
/Detection_Results/2_LR_detect_new.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jakaria08/EESRGAN/HEAD/Detection_Results/2_LR_detect_new.jpg
--------------------------------------------------------------------------------
/Detection_Results/2_SR_detection.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jakaria08/EESRGAN/HEAD/Detection_Results/2_SR_detection.jpg
--------------------------------------------------------------------------------
/Detection_Results/1_LR_detection_new.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jakaria08/EESRGAN/HEAD/Detection_Results/1_LR_detection_new.jpg
--------------------------------------------------------------------------------
/Detection_Results/overall_pipeline.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jakaria08/EESRGAN/HEAD/Detection_Results/overall_pipeline.PNG
--------------------------------------------------------------------------------
/scripts_for_datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .COWC_dataset import *
2 | from .COWC_GAN_dataset import *
3 | from .cowc_FRCNN_dataset import *
4 | from .COWC_EESRGAN_FRCNN_dataset import *
5 |
--------------------------------------------------------------------------------
/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | from .trainer import *
2 | from .cowc_trainer import *
3 | from .cowc_GAN_trainer import *
4 | from .FRCNN_trainer import *
5 | from .cowc_GAN_FRCNN_trainer import *
6 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | opencv_python_headless==4.1.1.26
2 | tqdm==4.36.1
3 | matplotlib==3.1.1
4 | albumentations==0.3.3
5 | torchvision==0.3.0
6 | pandas==0.25.1
7 | numpy==1.16.5
8 | pytest==5.2.0
9 | torch==1.1.0
10 | kornia==0.1.4.post2
11 | dataclasses==0.7
12 | Pillow==7.1.1
13 | pycocotools==2.0.0
14 | tensorboardX==2.0
15 |
--------------------------------------------------------------------------------
/model/metric.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def accuracy(output, target):
5 | with torch.no_grad():
6 | pred = torch.argmax(output, dim=1)
7 | assert pred.shape[0] == len(target)
8 | correct = 0
9 | correct += torch.sum(pred == target).item()
10 | return correct / len(target)
11 |
12 |
13 | def top_k_acc(output, target, k=3):
14 | with torch.no_grad():
15 | pred = torch.topk(output, k, dim=1)[1]
16 | assert pred.shape[0] == len(target)
17 | correct = 0
18 | for i in range(k):
19 | correct += torch.sum(pred[:, i] == target).item()
20 | return correct / len(target)
21 |
--------------------------------------------------------------------------------
/base/base_model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import numpy as np
3 | from abc import abstractmethod
4 |
5 |
6 | class BaseModel(nn.Module):
7 | """
8 | Base class for all models
9 | """
10 | @abstractmethod
11 | def forward(self, *inputs):
12 | """
13 | Forward pass logic
14 |
15 | :return: Model output
16 | """
17 | raise NotImplementedError
18 |
19 | def __str__(self):
20 | """
21 | Model prints with number of trainable parameters
22 | """
23 | model_parameters = filter(lambda p: p.requires_grad, self.parameters())
24 | params = sum([np.prod(p.size()) for p in model_parameters])
25 | return super().__str__() + '\nTrainable parameters: {}'.format(params)
26 |
--------------------------------------------------------------------------------
/logger/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import logging.config
3 | from pathlib import Path
4 | from utils import read_json
5 |
6 |
7 | def setup_logging(save_dir, log_config='logger/logger_config.json', default_level=logging.INFO):
8 | """
9 | Setup logging configuration
10 | """
11 | log_config = Path(log_config)
12 | if log_config.is_file():
13 | config = read_json(log_config)
14 | # modify logging paths based on run config
15 | for _, handler in config['handlers'].items():
16 | if 'filename' in handler:
17 | handler['filename'] = str(save_dir / handler['filename'])
18 |
19 | logging.config.dictConfig(config)
20 | else:
21 | print("Warning: logging configuration file is not found in {}.".format(log_config))
22 | logging.basicConfig(level=default_level)
23 |
--------------------------------------------------------------------------------
/logger/logger_config.json:
--------------------------------------------------------------------------------
1 |
2 | {
3 | "version": 1,
4 | "disable_existing_loggers": false,
5 | "formatters": {
6 | "simple": {"format": "%(message)s"},
7 | "datetime": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"}
8 | },
9 | "handlers": {
10 | "console": {
11 | "class": "logging.StreamHandler",
12 | "level": "DEBUG",
13 | "formatter": "simple",
14 | "stream": "ext://sys.stdout"
15 | },
16 | "info_file_handler": {
17 | "class": "logging.handlers.RotatingFileHandler",
18 | "level": "INFO",
19 | "formatter": "datetime",
20 | "filename": "info.log",
21 | "maxBytes": 10485760,
22 | "backupCount": 20, "encoding": "utf8"
23 | }
24 | },
25 | "root": {
26 | "level": "INFO",
27 | "handlers": [
28 | "console",
29 | "info_file_handler"
30 | ]
31 | }
32 | }
--------------------------------------------------------------------------------
/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "ImagePatchClassifier_TinyNet",
3 | "n_gpu": 1,
4 |
5 | "arch": {
6 | "type": "ImagePatchClassifier",
7 | "args": {}
8 | },
9 | "data_loader": {
10 | "type": "COWCDataLoader",
11 | "args":{
12 | "data_dir": "/home/jakaria/Super_Resolution/Datasets/COWC/DetectionPatches_256x256/Potsdam_ISPRS/Train-Mixed-class-2/",
13 | "batch_size": 1,
14 | "shuffle": true,
15 | "validation_split": 0.1,
16 | "num_workers": 2
17 | }
18 | },
19 | "optimizer": {
20 | "type": "SGD",
21 | "args":{
22 | "lr": 0.0001,
23 | "weight_decay": 0,
24 | "momentum": 0.9
25 | }
26 | },
27 | "loss": "cross_entropy",
28 | "metrics": [
29 | "accuracy"
30 | ],
31 | "lr_scheduler": {
32 | "type": "StepLR",
33 | "args": {
34 | "step_size": 50,
35 | "gamma": 0.1
36 | }
37 | },
38 | "trainer": {
39 | "epochs": 100,
40 |
41 | "save_dir": "saved/",
42 | "save_period": 1,
43 | "verbosity": 2,
44 |
45 | "monitor": "min val_loss",
46 | "early_stop": 10,
47 |
48 | "tensorboard": true
49 | }
50 | }
51 |
--------------------------------------------------------------------------------
/detection/transforms.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 |
4 | from torchvision.transforms import functional as F
5 |
6 |
7 | def _flip_coco_person_keypoints(kps, width):
8 | flip_inds = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
9 | flipped_data = kps[:, flip_inds]
10 | flipped_data[..., 0] = width - flipped_data[..., 0]
11 | # Maintain COCO convention that if visibility == 0, then x, y = 0
12 | inds = flipped_data[..., 2] == 0
13 | flipped_data[inds] = 0
14 | return flipped_data
15 |
16 |
17 | class Compose(object):
18 | def __init__(self, transforms):
19 | self.transforms = transforms
20 |
21 | def __call__(self, image, target):
22 | for t in self.transforms:
23 | image, target = t(image, target)
24 | return image, target
25 |
26 |
27 | class RandomHorizontalFlip(object):
28 | def __init__(self, prob):
29 | self.prob = prob
30 |
31 | def __call__(self, image, target):
32 | if random.random() < self.prob:
33 | height, width = image.shape[-2:]
34 | image = image.flip(-1)
35 | bbox = target["boxes"]
36 | bbox[:, [0, 2]] = width - bbox[:, [2, 0]]
37 | target["boxes"] = bbox
38 | if "masks" in target:
39 | target["masks"] = target["masks"].flip(-1)
40 | if "keypoints" in target:
41 | keypoints = target["keypoints"]
42 | keypoints = _flip_coco_person_keypoints(keypoints, width)
43 | target["keypoints"] = keypoints
44 | return image, target
45 |
46 |
47 | class ToTensor(object):
48 | def __call__(self, image, target):
49 | image = F.to_tensor(image)
50 | return image, target
51 |
--------------------------------------------------------------------------------
/test_all/test_classes.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import os
3 | import shutil
4 | import torch
5 | from parse_config import ConfigParser
6 | from utils import read_json, write_json
7 | from scripts_for_datasets import COWCDataset
8 | # run tests
9 | # python -m pytest test_all/
10 | # python -m pytest test_all/ -s ==> to see print statements
11 | class TestCOWCDataset():
12 |
13 | def test_image_annot_equality(self):
14 | # Test code for init method
15 | # Testing the dataset size and similarity
16 | config = read_json('config.json')
17 | config = ConfigParser(config)
18 | data_dir = config['data_loader']['args']['data_dir']
19 | shutil.rmtree("./saved")#removing /saved directory, everytime created by ConfigParser
20 | a = COWCDataset(data_dir)
21 | for img, annot in zip(a.imgs, a.annotation):
22 | if os.path.splitext(img)[0] != os.path.splitext(annot)[0]:
23 | print("problem")
24 | print(len(a.annotation))
25 | assert len(a.imgs) == len(a.annotation), "NOT equal"
26 |
27 | def test_zero_annotation(self):
28 | # Test for checking number of image without bounding box
29 | config = read_json('config.json')
30 | config = ConfigParser(config)
31 | data_dir = config['data_loader']['args']['data_dir']
32 | shutil.rmtree("./saved")#removing /saved directory, everytime created by ConfigParser
33 | a = COWCDataset(data_dir)
34 | zero_annotation = 0
35 | for i in range(len(a.annotation)):
36 | zero_annotation_get = a[i]
37 | if zero_annotation_get['object'].item() == 0:
38 | zero_annotation += 1
39 | print("Number of image withot bbox: "+str(zero_annotation))
40 | #use assert zero_annotation == 0 if all image contain bounding box
41 | assert zero_annotation != 0, "Image exists with bounding box"
42 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 |
103 | # input data, saved log, checkpoints
104 | data/
105 | input/
106 | saved/
107 | saved_ESRGAN/
108 | saved_EEGAN_separate/
109 | datasets/
110 |
111 | # editor, os cache directory
112 | .vscode/
113 | .idea/
114 | __MACOSX/
115 |
--------------------------------------------------------------------------------
/base/base_data_loader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch.utils.data import DataLoader
3 | from torch.utils.data.dataloader import default_collate
4 | from torch.utils.data.sampler import SubsetRandomSampler
5 |
6 |
7 | class BaseDataLoader(DataLoader):
8 | """
9 | Base class for all data loaders
10 | """
11 | def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate):
12 | self.validation_split = validation_split
13 | self.shuffle = shuffle
14 |
15 | self.batch_idx = 0
16 | self.n_samples = len(dataset)
17 |
18 | self.sampler, self.valid_sampler = self._split_sampler(self.validation_split)
19 |
20 | self.init_kwargs = {
21 | 'dataset': dataset,
22 | 'batch_size': batch_size,
23 | 'shuffle': self.shuffle,
24 | 'collate_fn': collate_fn,
25 | 'num_workers': num_workers
26 | }
27 | super().__init__(sampler=self.sampler, **self.init_kwargs)
28 |
29 | def _split_sampler(self, split):
30 | if split == 0.0:
31 | return None, None
32 |
33 | idx_full = np.arange(self.n_samples)
34 |
35 | np.random.seed(0)
36 | np.random.shuffle(idx_full)
37 |
38 | if isinstance(split, int):
39 | assert split > 0
40 | assert split < self.n_samples, "validation set size is configured to be larger than entire dataset."
41 | len_valid = split
42 | else:
43 | len_valid = int(self.n_samples * split)
44 |
45 | valid_idx = idx_full[0:len_valid]
46 | train_idx = np.delete(idx_full, np.arange(0, len_valid))
47 |
48 | train_sampler = SubsetRandomSampler(train_idx)
49 | valid_sampler = SubsetRandomSampler(valid_idx)
50 |
51 | # turn off shuffle option which is mutually exclusive with sampler
52 | self.shuffle = False
53 | self.n_samples = len(train_idx)
54 |
55 | return train_sampler, valid_sampler
56 |
57 | def split_validation(self):
58 | if self.valid_sampler is None:
59 | return None
60 | else:
61 | return DataLoader(sampler=self.valid_sampler, **self.init_kwargs)
62 |
--------------------------------------------------------------------------------
/utils/module_util.py:
--------------------------------------------------------------------------------
1 | '''
2 | Taken from https://github.com/xinntao/BasicSR/blob/master/codes/models/modules/module_util.py
3 | '''
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.init as init
7 | import torch.nn.functional as F
8 |
9 |
10 | def initialize_weights(net_l, scale=1):
11 | if not isinstance(net_l, list):
12 | net_l = [net_l]
13 | for net in net_l:
14 | for m in net.modules():
15 | if isinstance(m, nn.Conv2d):
16 | init.kaiming_normal_(m.weight, a=0, mode='fan_in')
17 | m.weight.data *= scale # for residual block
18 | if m.bias is not None:
19 | m.bias.data.zero_()
20 | elif isinstance(m, nn.Linear):
21 | init.kaiming_normal_(m.weight, a=0, mode='fan_in')
22 | m.weight.data *= scale
23 | if m.bias is not None:
24 | m.bias.data.zero_()
25 | elif isinstance(m, nn.BatchNorm2d):
26 | init.constant_(m.weight, 1)
27 | init.constant_(m.bias.data, 0.0)
28 |
29 |
30 | def make_layer(block, n_layers):
31 | layers = []
32 | for _ in range(n_layers):
33 | layers.append(block())
34 | return nn.Sequential(*layers)
35 |
36 |
37 | class ResidualBlock_noBN(nn.Module):
38 | '''Residual block w/o BN
39 | ---Conv-ReLU-Conv-+-
40 | |________________|
41 | '''
42 |
43 | def __init__(self, nf=64):
44 | super(ResidualBlock_noBN, self).__init__()
45 | self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
46 | self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
47 |
48 | # initialization
49 | initialize_weights([self.conv1, self.conv2], 0.1)
50 |
51 | def forward(self, x):
52 | identity = x
53 | out = F.relu(self.conv1(x), inplace=True)
54 | out = self.conv2(out)
55 | return identity + out
56 |
57 |
58 | def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'):
59 | """Warp an image or feature map with optical flow
60 | Args:
61 | x (Tensor): size (N, C, H, W)
62 | flow (Tensor): size (N, H, W, 2), normal value
63 | interp_mode (str): 'nearest' or 'bilinear'
64 | padding_mode (str): 'zeros' or 'border' or 'reflection'
65 | Returns:
66 | Tensor: warped image or feature map
67 | """
68 | assert x.size()[-2:] == flow.size()[1:3]
69 | B, C, H, W = x.size()
70 | # mesh grid
71 | grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
72 | grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
73 | grid.requires_grad = False
74 | grid = grid.type_as(x)
75 | vgrid = grid + flow
76 | # scale grid to [-1,1]
77 | vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0
78 | vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0
79 | vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
80 | output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode)
81 | return output
82 |
--------------------------------------------------------------------------------
/model/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 |
5 |
6 | def nll_loss(output, target):
7 | return F.nll_loss(output, target)
8 |
9 | def cross_entropy(output, target):
10 | return F.cross_entropy(output, target)
11 |
12 |
13 | class CharbonnierLoss(nn.Module):
14 | """Charbonnier Loss (L1)"""
15 |
16 | def __init__(self, eps=1e-6):
17 | super(CharbonnierLoss, self).__init__()
18 | self.eps = eps
19 |
20 | def forward(self, x, y):
21 | diff = x - y
22 | loss = torch.mean(torch.sqrt(diff * diff + self.eps))
23 | return loss
24 |
25 |
26 | # Define GAN loss: [vanilla | lsgan | wgan-gp]
27 | class GANLoss(nn.Module):
28 | def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0):
29 | super(GANLoss, self).__init__()
30 | self.gan_type = gan_type.lower()
31 | self.real_label_val = real_label_val
32 | self.fake_label_val = fake_label_val
33 |
34 | if self.gan_type == 'gan' or self.gan_type == 'ragan':
35 | self.loss = nn.BCEWithLogitsLoss()
36 | elif self.gan_type == 'lsgan':
37 | self.loss = nn.MSELoss()
38 | elif self.gan_type == 'wgan-gp':
39 |
40 | def wgan_loss(input, target):
41 | # target is boolean
42 | return -1 * input.mean() if target else input.mean()
43 |
44 | self.loss = wgan_loss
45 | else:
46 | raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type))
47 |
48 | def get_target_label(self, input, target_is_real):
49 | if self.gan_type == 'wgan-gp':
50 | return target_is_real
51 | if target_is_real:
52 | return torch.empty_like(input).fill_(self.real_label_val)
53 | else:
54 | return torch.empty_like(input).fill_(self.fake_label_val)
55 |
56 | def forward(self, input, target_is_real):
57 | target_label = self.get_target_label(input, target_is_real)
58 | loss = self.loss(input, target_label)
59 | return loss
60 |
61 |
62 | class GradientPenaltyLoss(nn.Module):
63 | def __init__(self, device=torch.device('cpu')):
64 | super(GradientPenaltyLoss, self).__init__()
65 | self.register_buffer('grad_outputs', torch.Tensor())
66 | self.grad_outputs = self.grad_outputs.to(device)
67 |
68 | def get_grad_outputs(self, input):
69 | if self.grad_outputs.size() != input.size():
70 | self.grad_outputs.resize_(input.size()).fill_(1.0)
71 | return self.grad_outputs
72 |
73 | def forward(self, interp, interp_crit):
74 | grad_outputs = self.get_grad_outputs(interp_crit)
75 | grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp,
76 | grad_outputs=grad_outputs, create_graph=True,
77 | retain_graph=True, only_inputs=True)[0]
78 | grad_interp = grad_interp.view(grad_interp.size(0), -1)
79 | grad_interp_norm = grad_interp.norm(2, dim=1)
80 |
81 | loss = ((grad_interp_norm - 1)**2).mean()
82 | return loss
83 |
--------------------------------------------------------------------------------
/logger/visualization.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from datetime import datetime
3 |
4 |
5 | class TensorboardWriter():
6 | def __init__(self, log_dir, logger, enabled):
7 | self.writer = None
8 | self.selected_module = ""
9 |
10 | if enabled:
11 | log_dir = str(log_dir)
12 |
13 | # Retrieve vizualization writer.
14 | succeeded = False
15 | for module in ["torch.utils.tensorboard", "tensorboardX"]:
16 | try:
17 | self.writer = importlib.import_module(module).SummaryWriter(log_dir)
18 | succeeded = True
19 | break
20 | except ImportError:
21 | succeeded = False
22 | self.selected_module = module
23 |
24 | if not succeeded:
25 | message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \
26 | "this machine. Please install either TensorboardX with 'pip install tensorboardx', upgrade " \
27 | "PyTorch to version >= 1.1 for using 'torch.utils.tensorboard' or turn off the option in " \
28 | "the 'config.json' file."
29 | logger.warning(message)
30 |
31 | self.step = 0
32 | self.mode = ''
33 |
34 | self.tb_writer_ftns = {
35 | 'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio',
36 | 'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding'
37 | }
38 | self.tag_mode_exceptions = {'add_histogram', 'add_embedding'}
39 | self.timer = datetime.now()
40 |
41 | def set_step(self, step, mode='train'):
42 | self.mode = mode
43 | self.step = step
44 | if step == 0:
45 | self.timer = datetime.now()
46 | else:
47 | duration = datetime.now() - self.timer
48 | self.add_scalar('steps_per_sec', 1 / duration.total_seconds())
49 | self.timer = datetime.now()
50 |
51 | def __getattr__(self, name):
52 | """
53 | If visualization is configured to use:
54 | return add_data() methods of tensorboard with additional information (step, tag) added.
55 | Otherwise:
56 | return a blank function handle that does nothing
57 | """
58 | if name in self.tb_writer_ftns:
59 | add_data = getattr(self.writer, name, None)
60 |
61 | def wrapper(tag, data, *args, **kwargs):
62 | if add_data is not None:
63 | # add mode(train/valid) tag
64 | if name not in self.tag_mode_exceptions:
65 | tag = '{}/{}'.format(tag, self.mode)
66 | add_data(tag, data, self.step, *args, **kwargs)
67 | return wrapper
68 | else:
69 | # default action for returning methods defined in this class, set_step() for instance.
70 | try:
71 | attr = object.__getattr__(name)
72 | except AttributeError:
73 | raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name))
74 | return attr
75 |
--------------------------------------------------------------------------------
/scripts_for_datasets/cowc_FRCNN_dataset.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import os
3 | import torch
4 | import numpy as np
5 | import glob
6 | import cv2
7 | from torch.utils.data import Dataset, DataLoader
8 |
9 | class COWCFRCNNDataset(Dataset):
10 | def __init__(self, root=None, image_height=256, image_width=256, transforms = None):
11 | self.root = root
12 | #take all under same folder for train and test split.
13 | self.transforms = transforms
14 | self.image_height = image_height
15 | self.image_width = image_width
16 | #sort all images for indexing, filter out check.jpgs
17 | self.imgs = list(sorted(set(glob.glob(self.root+"*.jpg") +
18 | glob.glob(self.root+"*.png")) -
19 | set(glob.glob(self.root+"*check.jpg"))))
20 | self.annotation = list(sorted(glob.glob(self.root+"*.txt")))
21 |
22 | def __getitem__(self, idx):
23 | #get the paths
24 | img_path = os.path.join(self.root, self.imgs[idx])
25 | annotation_path = os.path.join(self.root, self.annotation[idx])
26 | img = cv2.imread(img_path,1) #read color image height*width*channel=3
27 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
28 | #get the bounding box
29 | boxes = list()
30 | with open(annotation_path) as f:
31 | for line in f:
32 | values = (line.split())
33 | if "\ufeff" in values[0]:
34 | values[0] = values[0][-1]
35 | #get coordinates withing height width range
36 | x = float(values[1])*self.image_width
37 | y = float(values[2])*self.image_height
38 | width = float(values[3])*self.image_width
39 | height = float(values[4])*self.image_height
40 | #creating bounding boxes that would not touch the image edges
41 | x_min = 1 if x - width/2 <= 0 else int(x - width/2)
42 | x_max = self.image_width-1 if x + width/2 >= self.image_width-1 else int(x + width/2)
43 | y_min = 1 if y - height/2 <= 0 else int(y - height/2)
44 | y_max = self.image_height-1 if y + height/2 >= self.image_height-1 else int(y + height/2)
45 |
46 | x_min = int(x_min)
47 | x_max = int(x_max)
48 | y_min = int(y_min)
49 | y_max = int(y_max)
50 |
51 | boxes.append([x_min, y_min, x_max, y_max])
52 | #print(boxes)
53 |
54 | boxes = torch.as_tensor(boxes, dtype=torch.float32)
55 | # there is only one class
56 | labels = torch.ones((len(boxes),), dtype=torch.int64)
57 | image_id = torch.tensor([idx])
58 | area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
59 | # suppose all instances are not crowd
60 | iscrowd = torch.zeros((len(boxes),), dtype=torch.int64)
61 | #create dictionary to access the values
62 | target = {}
63 | target["boxes"] = boxes
64 | target["labels"] = labels
65 | target["image_id"] = image_id
66 | target["area"] = area
67 | target["iscrowd"] = iscrowd
68 |
69 | if self.transforms is not None:
70 | img, target = self.transforms(img, target)
71 |
72 | #return img, target
73 | return img, target, annotation_path
74 |
75 | def __len__(self):
76 | return len(self.imgs)
77 |
--------------------------------------------------------------------------------
/model/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import math
2 | from collections import Counter
3 | from collections import defaultdict
4 | import torch
5 | from torch.optim.lr_scheduler import _LRScheduler
6 |
7 |
8 | class MultiStepLR_Restart(_LRScheduler):
9 | def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1,
10 | clear_state=False, last_epoch=-1):
11 | self.milestones = Counter(milestones)
12 | self.gamma = gamma
13 | self.clear_state = clear_state
14 | self.restarts = restarts if restarts else [0]
15 | self.restart_weights = weights if weights else [1]
16 | assert len(self.restarts) == len(
17 | self.restart_weights), 'restarts and their weights do not match.'
18 | super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch)
19 |
20 | def get_lr(self):
21 | if self.last_epoch in self.restarts:
22 | if self.clear_state:
23 | self.optimizer.state = defaultdict(dict)
24 | weight = self.restart_weights[self.restarts.index(self.last_epoch)]
25 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
26 | if self.last_epoch not in self.milestones:
27 | return [group['lr'] for group in self.optimizer.param_groups]
28 | return [
29 | group['lr'] * self.gamma**self.milestones[self.last_epoch]
30 | for group in self.optimizer.param_groups
31 | ]
32 |
33 |
34 | class CosineAnnealingLR_Restart(_LRScheduler):
35 | def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1):
36 | self.T_period = T_period
37 | self.T_max = self.T_period[0] # current T period
38 | self.eta_min = eta_min
39 | self.restarts = restarts if restarts else [0]
40 | self.restart_weights = weights if weights else [1]
41 | self.last_restart = 0
42 | assert len(self.restarts) == len(
43 | self.restart_weights), 'restarts and their weights do not match.'
44 | super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch)
45 |
46 | def get_lr(self):
47 | if self.last_epoch == 0:
48 | return self.base_lrs
49 | elif self.last_epoch in self.restarts:
50 | self.last_restart = self.last_epoch
51 | self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1]
52 | weight = self.restart_weights[self.restarts.index(self.last_epoch)]
53 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
54 | elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0:
55 | return [
56 | group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2
57 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
58 | ]
59 | return [(1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) /
60 | (1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) *
61 | (group['lr'] - self.eta_min) + self.eta_min
62 | for group in self.optimizer.param_groups]
63 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | from tqdm import tqdm
4 | import data_loader.data_loaders as module_data
5 | import model.loss as module_loss
6 | import model.metric as module_metric
7 | import model.model as module_arch
8 | from parse_config import ConfigParser
9 | from trainer import COWCFRCNNTrainer, COWCGANTrainer, COWCGANFrcnnTrainer
10 | '''
11 | python test.py -c config_GAN.json
12 | '''
13 |
14 | def main(config):
15 |
16 | data_loader = module_data.COWCGANFrcnnDataLoader('/home/jakaria/Super_Resolution/Datasets/COWC/DetectionPatches_256x256/Potsdam_ISPRS/HR/x4/valid_img/',
17 | '/home/jakaria/Super_Resolution/Datasets/COWC/DetectionPatches_256x256/Potsdam_ISPRS/LR/x4/valid_img/', 1, training=False)
18 | tester = COWCGANFrcnnTrainer(config=config, data_loader=data_loader)
19 | tester.test()
20 |
21 | '''
22 | tester = COWCFRCNNTrainer(config=config)
23 | tester.test()
24 | '''
25 | '''
26 | logger = config.get_logger('test')
27 |
28 | # setup data_loader instances
29 | data_loader = getattr(module_data, config['data_loader']['type'])(
30 | config['data_loader']['args']['data_dir'],
31 | batch_size=512,
32 | shuffle=False,
33 | validation_split=0.0,
34 | training=False,
35 | num_workers=2
36 | )
37 |
38 | # build model architecture
39 | model = config.init_obj('arch', module_arch)
40 | logger.info(model)
41 |
42 | # get function handles of loss and metrics
43 | loss_fn = getattr(module_loss, config['loss'])
44 | metric_fns = [getattr(module_metric, met) for met in config['metrics']]
45 |
46 | logger.info('Loading checkpoint: {} ...'.format(config.resume))
47 | checkpoint = torch.load(config.resume)
48 | state_dict = checkpoint['state_dict']
49 | if config['n_gpu'] > 1:
50 | model = torch.nn.DataParallel(model)
51 | model.load_state_dict(state_dict)
52 |
53 | # prepare model for testing
54 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
55 | model = model.to(device)
56 | model.eval()
57 |
58 | total_loss = 0.0
59 | total_metrics = torch.zeros(len(metric_fns))
60 |
61 | with torch.no_grad():
62 | for i, (data, target) in enumerate(tqdm(data_loader)):
63 | data, target = data.to(device), target.to(device)
64 | output = model(data)
65 |
66 | #
67 | # save sample images, or do something with output here
68 | #
69 |
70 | # computing loss, metrics on test set
71 | loss = loss_fn(output, target)
72 | batch_size = data.shape[0]
73 | total_loss += loss.item() * batch_size
74 | for i, metric in enumerate(metric_fns):
75 | total_metrics[i] += metric(output, target) * batch_size
76 |
77 | n_samples = len(data_loader.sampler)
78 | log = {'loss': total_loss / n_samples}
79 | log.update({
80 | met.__name__: total_metrics[i].item() / n_samples for i, met in enumerate(metric_fns)
81 | })
82 | logger.info(log)
83 | '''
84 |
85 |
86 | if __name__ == '__main__':
87 | args = argparse.ArgumentParser(description='PyTorch Template')
88 | args.add_argument('-c', '--config', default=None, type=str,
89 | help='config file path (default: None)')
90 | args.add_argument('-r', '--resume', default=None, type=str,
91 | help='path to latest checkpoint (default: None)')
92 | args.add_argument('-d', '--device', default=None, type=str,
93 | help='indices of GPUs to enable (default: all)')
94 |
95 | config = ConfigParser.from_args(args)
96 | main(config)
97 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import argparse
3 | import collections
4 | import torch
5 | import os
6 | import numpy as np
7 | import data_loader.data_loaders as module_data
8 | import model.loss as module_loss
9 | import model.metric as module_metric
10 | import model.model as module_arch
11 | from parse_config import ConfigParser
12 | from trainer import COWCTrainer
13 | from trainer import COWCGANTrainer
14 | from trainer import COWCFRCNNTrainer
15 | from trainer import COWCGANFrcnnTrainer
16 | from utils import setup_logger, dict2str
17 | '''
18 | python train.py -c config_GAN.json
19 | '''
20 |
21 | # fix random seeds for reproducibility
22 | SEED = 123
23 | torch.manual_seed(SEED)
24 | torch.backends.cudnn.deterministic = True
25 | torch.backends.cudnn.benchmark = False
26 | np.random.seed(SEED)
27 |
28 | def main(config):
29 | #logger = config.get_logger('train')
30 | # config loggers. Before it, the log will not work
31 | setup_logger('base', config['path']['log'], 'train_' + config['name'], level=logging.INFO,
32 | screen=True, tofile=True)
33 | setup_logger('val', config['path']['log'], 'val_' + config['name'], level=logging.INFO,
34 | screen=True, tofile=True)
35 | logger = logging.getLogger('base')
36 | #logger.info(dict2str(config))
37 |
38 |
39 | # setup data_loader instances
40 | data_loader = config.init_obj('data_loader', module_data)
41 | #change later this valid_data_loader using init_obj
42 | valid_data_loader = module_data.COWCGANFrcnnDataLoader('/home/jakaria/Super_Resolution/Datasets/COWC/DetectionPatches_256x256/Potsdam_ISPRS/HR/x4/valid_img/',
43 | '/home/jakaria/Super_Resolution/Datasets/COWC/DetectionPatches_256x256/Potsdam_ISPRS/LR/x4/valid_img/', 1, training = False)
44 |
45 | # build model architecture, then print to console
46 | #model = config.init_obj('arch', module_arch)
47 | #logger.info(model)
48 |
49 | # get function handles of loss and metrics
50 | #criterion = getattr(module_loss, config['loss'])
51 | #metrics = [getattr(module_metric, met) for met in config['metrics']]
52 |
53 | # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
54 | #trainable_params = filter(lambda p: p.requires_grad, model.parameters())
55 | #optimizer = config.init_obj('optimizer', torch.optim, trainable_params)
56 |
57 | #lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler, optimizer)
58 | '''
59 | trainer = COWCGANTrainer(model, criterion, metrics, optimizer,
60 | config=config,
61 | data_loader=data_loader,
62 | valid_data_loader=valid_data_loader,
63 | lr_scheduler=lr_scheduler)
64 | '''
65 | '''
66 | trainer = COWCGANTrainer(config=config,data_loader=data_loader,
67 | valid_data_loader=valid_data_loader
68 | )
69 | '''
70 |
71 | trainer = COWCGANFrcnnTrainer(config=config, data_loader=data_loader,
72 | valid_data_loader=valid_data_loader)
73 | trainer.train()
74 | '''
75 | trainer = COWCFRCNNTrainer(config=config)
76 | trainer.train()
77 | '''
78 | if __name__ == '__main__':
79 | args = argparse.ArgumentParser(description='PyTorch Template')
80 | args.add_argument('-c', '--config', default=None, type=str,
81 | help='config file path (default: None)')
82 | args.add_argument('-r', '--resume', default=None, type=str,
83 | help='path to latest checkpoint (default: None)')
84 | args.add_argument('-d', '--device', default=None, type=str,
85 | help='indices of GPUs to enable (default: all)')
86 |
87 | # custom cli options to modify configuration from default values given in json file.
88 | CustomArgs = collections.namedtuple('CustomArgs', 'flags type target')
89 | options = [
90 | CustomArgs(['--lr', '--learning_rate'], type=float, target='optimizer;args;lr'),
91 | CustomArgs(['--bs', '--batch_size'], type=int, target='data_loader;args;batch_size')
92 | ]
93 | config = ConfigParser.from_args(args, options)
94 | main(config)
95 |
--------------------------------------------------------------------------------
/scripts_for_datasets/COWC_dataset.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import os
3 | import torch
4 | import numpy as np
5 | import glob
6 | import cv2
7 | import matplotlib.pyplot as plt
8 | from torch.utils.data import Dataset, DataLoader
9 | from torchvision import transforms, utils
10 |
11 | # Ignore warnings
12 | import warnings
13 | warnings.filterwarnings("ignore")
14 |
15 |
16 | class COWCDataset(Dataset):
17 | def __init__(self, root, image_height=256, image_width=256, transform = None):
18 | self.root = root
19 | #take all under same folder for train and test split.
20 | self.transform = transform
21 | self.image_height = image_height
22 | self.image_width = image_width
23 | #sort all images for indexing, filter out check.jpgs
24 | self.imgs = list(sorted(set(glob.glob(self.root+"*.jpg")) - set(glob.glob(self.root+"*check.jpg"))))
25 | self.annotation = list(sorted(glob.glob(self.root+"*.txt")))
26 |
27 | def __getitem__(self, idx):
28 | #get the paths
29 | img_path = os.path.join(self.root, self.imgs[idx])
30 | annotation_path = os.path.join(self.root, self.annotation[idx])
31 | img = cv2.imread(img_path,1) #read color image height*width*channel=3
32 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
33 | #get the bounding box
34 | boxes = list()
35 | label_car_type = list()
36 | with open(annotation_path) as f:
37 | for line in f:
38 | values = (line.split())
39 | if "\ufeff" in values[0]:
40 | values[0] = values[0][-1]
41 | obj_class = int(values[0])
42 | #image without bounding box - in txt file, line starts with 0 and only contains only 0
43 | if obj_class == 0:
44 | boxes.append([0, 0, 1, 1])
45 | labels = np.ones(len(boxes)) # all are cars
46 | label_car_type.append(obj_class)
47 | #create dictionary to access the values
48 | target = {}
49 | target['object'] = 0
50 | target['image'] = img
51 | target['bboxes'] = boxes
52 | target['labels'] = labels
53 | target['label_car_type'] = label_car_type
54 | target['idx'] = idx
55 | break
56 | else:
57 | #get coordinates withing height width range
58 | x = float(values[1])*self.image_width
59 | y = float(values[2])*self.image_height
60 | width = float(values[3])*self.image_width
61 | height = float(values[4])*self.image_height
62 | #creating bounding boxes that would not touch the image edges
63 | x_min = 1 if x - width/2 <= 0 else int(x - width/2)
64 | x_max = 255 if x + width/2 >= 256 else int(x + width/2)
65 | y_min = 1 if y - height/2 <= 0 else int(y - height/2)
66 | y_max = 255 if y + height/2 >= 256 else int(y + height/2)
67 |
68 | boxes.append([x_min, y_min, x_max, y_max])
69 | label_car_type.append(obj_class)
70 |
71 | if obj_class != 0:
72 | labels = np.ones(len(boxes)) # all are cars
73 | #create dictionary to access the values
74 | target = {}
75 | target['object'] = 1
76 | target['image'] = img
77 | target['bboxes'] = boxes
78 | target['labels'] = labels
79 | target['label_car_type'] = label_car_type
80 | target['idx'] = idx
81 |
82 | if self.transform is None:
83 | #convert to tensor
84 | target = self.convert_to_tensor(**target)
85 | return target
86 | #transform
87 | else:
88 | transformed = self.transform(**target)
89 | #print(transformed['image'], transformed['bboxes'], transformed['labels'], transformed['idx'])
90 | target = self.convert_to_tensor(**transformed)
91 | return target
92 |
93 | def __len__(self):
94 | return len(self.imgs)
95 |
96 | def convert_to_tensor(self, **target):
97 | #convert to tensor
98 | target['object'] = torch.tensor(target['object'], dtype=torch.int64)
99 | target['image'] = torch.from_numpy(target['image'].transpose((2, 0, 1)))
100 | target['bboxes'] = torch.as_tensor(target['bboxes'], dtype=torch.int64)
101 | target['labels'] = torch.ones(len(target['bboxes']), dtype=torch.int64)
102 | target['label_car_type'] = torch.as_tensor(target['label_car_type'], dtype=torch.int64)
103 | target['image_id'] = torch.tensor([target['idx']])
104 |
105 | return target
106 |
--------------------------------------------------------------------------------
/trainer/trainer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torchvision.utils import make_grid
4 | from base import BaseTrainer
5 | from utils import inf_loop, MetricTracker
6 |
7 |
8 | class Trainer(BaseTrainer):
9 | """
10 | Trainer class
11 | """
12 | def __init__(self, model, criterion, metric_ftns, optimizer, config, data_loader,
13 | valid_data_loader=None, lr_scheduler=None, len_epoch=None):
14 | super().__init__(model, criterion, metric_ftns, optimizer, config)
15 | self.config = config
16 | self.data_loader = data_loader
17 | if len_epoch is None:
18 | # epoch-based training
19 | self.len_epoch = len(self.data_loader)
20 | else:
21 | # iteration-based training
22 | self.data_loader = inf_loop(data_loader)
23 | self.len_epoch = len_epoch
24 | self.valid_data_loader = valid_data_loader
25 | self.do_validation = self.valid_data_loader is not None
26 | self.lr_scheduler = lr_scheduler
27 | self.log_step = int(np.sqrt(data_loader.batch_size))
28 |
29 | self.train_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer)
30 | self.valid_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer)
31 |
32 | def _train_epoch(self, epoch):
33 | """
34 | Training logic for an epoch
35 |
36 | :param epoch: Integer, current training epoch.
37 | :return: A log that contains average loss and metric in this epoch.
38 | """
39 | self.model.train()
40 | self.train_metrics.reset()
41 | for batch_idx, (data, target) in enumerate(self.data_loader):
42 | data, target = data.to(self.device), target.to(self.device)
43 |
44 | self.optimizer.zero_grad()
45 | output = self.model(data)
46 | loss = self.criterion(output, target)
47 | loss.backward()
48 | self.optimizer.step()
49 |
50 | self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
51 | self.train_metrics.update('loss', loss.item())
52 | for met in self.metric_ftns:
53 | self.train_metrics.update(met.__name__, met(output, target))
54 |
55 | if batch_idx % self.log_step == 0:
56 | self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format(
57 | epoch,
58 | self._progress(batch_idx),
59 | loss.item()))
60 | self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
61 |
62 | if batch_idx == self.len_epoch:
63 | break
64 | log = self.train_metrics.result()
65 |
66 | if self.do_validation:
67 | val_log = self._valid_epoch(epoch)
68 | log.update(**{'val_'+k : v for k, v in val_log.items()})
69 |
70 | if self.lr_scheduler is not None:
71 | self.lr_scheduler.step()
72 | return log
73 |
74 | def _valid_epoch(self, epoch):
75 | """
76 | Validate after training an epoch
77 |
78 | :param epoch: Integer, current training epoch.
79 | :return: A log that contains information about validation
80 | """
81 | self.model.eval()
82 | self.valid_metrics.reset()
83 | with torch.no_grad():
84 | for batch_idx, (data, target) in enumerate(self.valid_data_loader):
85 | data, target = data.to(self.device), target.to(self.device)
86 |
87 | output = self.model(data)
88 | loss = self.criterion(output, target)
89 |
90 | self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid')
91 | self.valid_metrics.update('loss', loss.item())
92 | for met in self.metric_ftns:
93 | self.valid_metrics.update(met.__name__, met(output, target))
94 | self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
95 |
96 | # add histogram of model parameters to the tensorboard
97 | for name, p in self.model.named_parameters():
98 | self.writer.add_histogram(name, p, bins='auto')
99 | return self.valid_metrics.result()
100 |
101 | def _progress(self, batch_idx):
102 | base = '[{}/{} ({:.0f}%)]'
103 | if hasattr(self.data_loader, 'n_samples'):
104 | current = batch_idx * self.data_loader.batch_size
105 | total = self.data_loader.n_samples
106 | else:
107 | current = batch_idx
108 | total = self.len_epoch
109 | return base.format(current, total, 100.0 * current / total)
110 |
--------------------------------------------------------------------------------
/model/gan_base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections import OrderedDict
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn.parallel import DistributedDataParallel
6 |
7 |
8 | class BaseModel():
9 | def __init__(self, config, device):
10 | self.config = config
11 | self.device = device
12 | self.schedulers = []
13 | self.optimizers = []
14 |
15 | def feed_data(self, data):
16 | pass
17 |
18 | def optimize_parameters(self):
19 | pass
20 |
21 | def get_current_visuals(self):
22 | pass
23 |
24 | def get_current_losses(self):
25 | pass
26 |
27 | def print_network(self):
28 | pass
29 |
30 | def save(self, label):
31 | pass
32 |
33 | def load(self):
34 | pass
35 |
36 | def _set_lr(self, lr_groups_l):
37 | ''' set learning rate for warmup,
38 | lr_groups_l: list for lr_groups. each for a optimizer'''
39 | for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
40 | for param_group, lr in zip(optimizer.param_groups, lr_groups):
41 | param_group['lr'] = lr
42 |
43 | def _get_init_lr(self):
44 | # get the initial lr, which is set by the scheduler
45 | init_lr_groups_l = []
46 | for optimizer in self.optimizers:
47 | init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups])
48 | return init_lr_groups_l
49 |
50 | def update_learning_rate(self, cur_iter, warmup_iter=-1):
51 | for scheduler in self.schedulers:
52 | scheduler.step()
53 | #### set up warm up learning rate
54 | if cur_iter < warmup_iter:
55 | # get initial lr for each group
56 | init_lr_g_l = self._get_init_lr()
57 | # modify warming-up learning rates
58 | warm_up_lr_l = []
59 | for init_lr_g in init_lr_g_l:
60 | warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g])
61 | # set learning rate
62 | self._set_lr(warm_up_lr_l)
63 |
64 | def get_current_learning_rate(self):
65 | # return self.schedulers[0].get_lr()[0]
66 | return self.optimizers[0].param_groups[0]['lr']
67 |
68 | def get_network_description(self, network):
69 | '''Get the string and total parameters of the network'''
70 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
71 | network = network.module
72 | s = str(network)
73 | n = sum(map(lambda x: x.numel(), network.parameters()))
74 | return s, n
75 |
76 | def save_network(self, network, network_label, iter_label):
77 | save_filename = '{}_{}.pth'.format(iter_label, network_label)
78 | save_path = os.path.join(self.config['path']['models'], save_filename)
79 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
80 | network = network.module
81 | state_dict = network.state_dict()
82 | for key, param in state_dict.items():
83 | state_dict[key] = param.cpu()
84 | torch.save(state_dict, save_path)
85 |
86 | def load_network(self, load_path, network, strict=True):
87 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
88 | network = network.module
89 | load_net = torch.load(load_path)
90 | load_net_clean = OrderedDict() # remove unnecessary 'module.'
91 | for k, v in load_net.items():
92 | if k.startswith('module.'):
93 | load_net_clean[k[7:]] = v
94 | else:
95 | load_net_clean[k] = v
96 | network.load_state_dict(load_net_clean, strict=strict)
97 |
98 | def save_training_state(self, epoch, iter_step):
99 | '''Saves training state during training, which will be used for resuming'''
100 | state = {'epoch': epoch, 'iter': iter_step, 'schedulers': [], 'optimizers': []}
101 | for s in self.schedulers:
102 | state['schedulers'].append(s.state_dict())
103 | for o in self.optimizers:
104 | state['optimizers'].append(o.state_dict())
105 | save_filename = '{}.state'.format(iter_step)
106 | save_path = os.path.join(self.config['path']['training_state'], save_filename)
107 | torch.save(state, save_path)
108 |
109 | def resume_training(self, resume_state):
110 | '''Resume the optimizers and schedulers for training'''
111 | resume_optimizers = resume_state['optimizers']
112 | resume_schedulers = resume_state['schedulers']
113 | assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers'
114 | assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers'
115 | for i, o in enumerate(resume_optimizers):
116 | self.optimizers[i].load_state_dict(o)
117 | for i, s in enumerate(resume_schedulers):
118 | self.schedulers[i].load_state_dict(s)
119 |
--------------------------------------------------------------------------------
/config_GAN.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "RRDB_ESRGANx4",
3 | "n_gpu": 1,
4 | "model": "srgan",
5 | "distortion": "sr",
6 | "scale": 4,
7 | "use_tb_logger": true,
8 |
9 | "network_G": {
10 | "which_model_G": "RRDBNet",
11 | "in_nc": 3,
12 | "out_nc": 3,
13 |
14 | "nf": 64,
15 | "nb": 23,
16 | "args": {}
17 | },
18 | "network_D": {
19 | "which_model_G": "discriminator_vgg_128",
20 | "in_nc": 3,
21 | "nf": 64,
22 | "args": {}
23 | },
24 | "data_loader": {
25 | "type": "COWCGANFrcnnDataLoader",
26 | "args":{
27 | "data_dir_GT": "/home/jakaria/Super_Resolution/Datasets/COWC/DetectionPatches_256x256/Potsdam_ISPRS/HR/x4/",
28 | "data_dir_LQ": "/home/jakaria/Super_Resolution/Datasets/COWC/DetectionPatches_256x256/Potsdam_ISPRS/LR/x4/",
29 | "batch_size": 2,
30 | "shuffle": true,
31 | "validation_split": 0.0,
32 | "num_workers": 2
33 | }
34 | },
35 | "optimizer": {
36 | "type": "SGD",
37 | "args":{
38 | "lr_G": 0.0001,
39 | "weight_decay_G": 0,
40 | "beta1_G": 0.9,
41 | "beta2_G": 0.99,
42 |
43 | "lr_D": 0.0001,
44 | "weight_decay_D": 0,
45 | "beta1_D": 0.9,
46 | "beta2_D": 0.99
47 | }
48 | },
49 | "loss": "cross_entropy",
50 | "metrics": [
51 | "accuracy"
52 | ],
53 | "lr_scheduler": {
54 | "type": "MultiStepLR",
55 | "args": {
56 | "lr_steps": [50000, 100000, 200000, 300000],
57 | "lr_gamma": 0.5,
58 | "T_period": [250000, 250000, 250000, 250000],
59 | "restarts": [250000, 500000, 750000],
60 | "restart_weights": [1, 1, 1],
61 | "eta_min": 0.0000001
62 | }
63 | },
64 | "train": {
65 | "niter": 400000,
66 | "warmup_iter": -1,
67 | "pixel_criterion": "l1",
68 | "pixel_weight": 0.01,
69 | "feature_criterion": "l1",
70 | "feature_weight": 1,
71 |
72 | "gan_type": "ragan",
73 | "gan_weight": 0.001,
74 | "D_update_ratio": 1,
75 | "D_init_iters": 0,
76 | "manual_seed": 10,
77 | "val_freq": 1000,
78 |
79 | "save_dir": "saved/",
80 | "save_period": 1,
81 | "verbosity": 2,
82 |
83 | "monitor": "min val_loss",
84 | "early_stop": 10,
85 |
86 | "tensorboard": true
87 | },
88 | "path": {
89 | "models": "saved/pretrained_models_EESRGAN_FRCNN",
90 | "FRCNN_model": "/home/jakaria/Super_Resolution/Filter_Enhance_Detect/saved/FRCNN_model_LR_LR_cowc/",
91 | "pretrain_model_G": "/home/jakaria/Super_Resolution/Filter_Enhance_Detect/saved/pretrained_models_EESRGAN_FRCNN/170000_G.pth",
92 | "pretrain_model_D": "/home/jakaria/Super_Resolution/Filter_Enhance_Detect/saved/pretrained_models_EESRGAN_FRCNN/170000_D.pth",
93 | "pretrain_model_FRCNN": "/home/jakaria/Super_Resolution/Filter_Enhance_Detect/saved/pretrained_models_EESRGAN_FRCNN/170000_FRCNN.pth",
94 | "pretrain_model_FRCNN_LR_LR": "/home/jakaria/Super_Resolution/Filter_Enhance_Detect/saved/FRCNN_model_LR_LR_cowc/0_FRCNN_LR_LR.pth",
95 | "training_state": "saved/training_state",
96 | "strict_load": true,
97 | "resume_state": "~",
98 | "val_images": "saved/val_images",
99 | "output_images": "saved/val_images_cars_new",
100 | "log": "saved/logs",
101 | "data_dir_Valid": "/home/jakaria/Super_Resolution/Datasets/COWC/DetectionPatches_256x256/Potsdam_ISPRS/LR/x4/valid_img/",
102 | "data_dir_F_SR": "/home/jakaria/Super_Resolution/Filter_Enhance_Detect/saved/Final_SR_images_test/",
103 | "data_dir_SR": "/home/jakaria/Super_Resolution/Filter_Enhance_Detect/saved/SR_images_test/",
104 | "data_dir_SR_combined": "/home/jakaria/Super_Resolution/Filter_Enhance_Detect/saved/combined_SR_images_216000/",
105 | "data_dir_E_SR_1": "/home/jakaria/Super_Resolution/Filter_Enhance_Detect/saved/enhanced_SR_images_1/",
106 | "data_dir_E_SR_2": "/home/jakaria/Super_Resolution/Filter_Enhance_Detect/saved/enhanced_SR_images_2/",
107 | "data_dir_E_SR_3": "/home/jakaria/Super_Resolution/Filter_Enhance_Detect/saved/enhanced_SR_images_3/",
108 | "data_dir_Bic": "/home/jakaria/Super_Resolution/Datasets/COWC/DetectionPatches_256x256/Potsdam_ISPRS/HR/x4/valid_img/",
109 | "data_dir_LR_train": "/home/jakaria/Super_Resolution/Datasets/COWC/DetectionPatches_256x256/Potsdam_ISPRS/LR/x4/",
110 | "data_dir_Bic_valid": "/home/jakaria/Super_Resolution/Datasets/COWC/DetectionPatches_256x256/Potsdam_ISPRS/Bic/x4/valid_img/",
111 | "Test_Result_LR_LR_COWC": "/home/jakaria/Super_Resolution/Filter_Enhance_Detect/saved/Test_Result_LR_LR_COWC/",
112 | "Test_Result_SR": "/home/jakaria/Super_Resolution/Filter_Enhance_Detect/saved/Test_Result_SR/"
113 | },
114 | "logger": {
115 | "print_freq": 100,
116 | "save_checkpoint_freq": 1000
117 | }
118 | }
119 |
--------------------------------------------------------------------------------
/scripts_for_datasets/COWC_GAN_dataset.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import os
3 | import torch
4 | import numpy as np
5 | import glob
6 | import cv2
7 | import matplotlib.pyplot as plt
8 | from torch.utils.data import Dataset, DataLoader
9 | from torchvision import transforms, utils
10 |
11 | # Ignore warnings
12 | import warnings
13 | warnings.filterwarnings("ignore")
14 |
15 |
16 | class COWCGANDataset(Dataset):
17 | def __init__(self, data_dir_gt, data_dir_lq, image_height=256, image_width=256, transform = None):
18 | self.data_dir_gt = data_dir_gt
19 | self.data_dir_lq = data_dir_lq
20 | #take all under same folder for train and test split.
21 | self.transform = transform
22 | self.image_height = image_height
23 | self.image_width = image_width
24 | #sort all images for indexing, filter out check.jpgs
25 | self.imgs_gt = list(sorted(glob.glob(self.data_dir_gt+"*.jpg")))
26 | self.imgs_lq = list(sorted(glob.glob(self.data_dir_lq+"*.jpg")))
27 | self.annotation = list(sorted(glob.glob(self.data_dir_lq+"*.txt")))
28 |
29 | def __getitem__(self, idx):
30 | #get the paths
31 | img_path_gt = os.path.join(self.data_dir_gt, self.imgs_gt[idx])
32 | img_path_lq = os.path.join(self.data_dir_lq, self.imgs_lq[idx])
33 | annotation_path = os.path.join(self.data_dir_lq, self.annotation[idx])
34 | img_gt = cv2.imread(img_path_gt,1) #read color image height*width*channel=3
35 | img_lq = cv2.imread(img_path_lq,1) #read color image height*width*channel=3
36 | img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2RGB)
37 | img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2RGB)
38 | #get the bounding box
39 | boxes = list()
40 | label_car_type = list()
41 | with open(annotation_path) as f:
42 | for line in f:
43 | values = (line.split())
44 | if "\ufeff" in values[0]:
45 | values[0] = values[0][-1]
46 | obj_class = int(values[0])
47 | #image without bounding box - in txt file, line starts with 0 and only contains only 0
48 | if obj_class == 0:
49 | boxes.append([0, 0, 1, 1])
50 | labels = np.ones(len(boxes)) # all are cars
51 | label_car_type.append(obj_class)
52 | #create dictionary to access the values
53 | target = {}
54 | target['object'] = 0
55 | target['image_lq'] = img_lq
56 | target['image'] = img_gt
57 | target['bboxes'] = boxes
58 | target['labels'] = labels
59 | target['label_car_type'] = label_car_type
60 | target['idx'] = idx
61 | target['LQ_path'] = img_path_lq
62 | break
63 | else:
64 | #get coordinates withing height width range
65 | x = float(values[1])*self.image_width
66 | y = float(values[2])*self.image_height
67 | width = float(values[3])*self.image_width
68 | height = float(values[4])*self.image_height
69 | #creating bounding boxes that would not touch the image edges
70 | x_min = 1 if x - width/2 <= 0 else int(x - width/2)
71 | x_max = 255 if x + width/2 >= 256 else int(x + width/2)
72 | y_min = 1 if y - height/2 <= 0 else int(y - height/2)
73 | y_max = 255 if y + height/2 >= 256 else int(y + height/2)
74 |
75 | boxes.append([x_min, y_min, x_max, y_max])
76 | label_car_type.append(obj_class)
77 |
78 | if obj_class != 0:
79 | labels = np.ones(len(boxes)) # all are cars
80 | #create dictionary to access the values
81 | target = {}
82 | target['object'] = 1
83 | target['image_lq'] = img_lq
84 | target['image'] = img_gt
85 | target['bboxes'] = boxes
86 | target['labels'] = labels
87 | target['label_car_type'] = label_car_type
88 | target['idx'] = idx
89 | target['LQ_path'] = img_path_lq
90 |
91 | if self.transform is None:
92 | #convert to tensor
93 | target = self.convert_to_tensor(**target)
94 | return target
95 | #transform
96 | else:
97 | transformed = self.transform(**target)
98 | #print(transformed['image'], transformed['bboxes'], transformed['labels'], transformed['idx'])
99 | target = self.convert_to_tensor(**transformed)
100 | return target
101 |
102 | def __len__(self):
103 | return len(self.imgs_lq)
104 |
105 | def convert_to_tensor(self, **target):
106 | #convert to tensor
107 | target['object'] = torch.tensor(target['object'], dtype=torch.int64)
108 | target['image_lq'] = torch.from_numpy(target['image_lq'].transpose((2, 0, 1)))
109 | target['image'] = torch.from_numpy(target['image'].transpose((2, 0, 1)))
110 | target['bboxes'] = torch.as_tensor(target['bboxes'], dtype=torch.int64)
111 | target['labels'] = torch.ones(len(target['bboxes']), dtype=torch.int64)
112 | target['label_car_type'] = torch.as_tensor(target['label_car_type'], dtype=torch.int64)
113 | target['image_id'] = torch.tensor([target['idx']])
114 |
115 | return target
116 |
--------------------------------------------------------------------------------
/trainer/cowc_trainer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torchvision.utils import make_grid
4 | from base import BaseTrainer
5 | from utils import inf_loop, MetricTracker, visualize_bbox, visualize
6 |
7 |
8 | class COWCTrainer(BaseTrainer):
9 | """
10 | Trainer class
11 | """
12 | def __init__(self, model, criterion, metric_ftns, optimizer, config, data_loader,
13 | valid_data_loader=None, lr_scheduler=None, len_epoch=None):
14 | super().__init__(model, criterion, metric_ftns, optimizer, config)
15 | self.config = config
16 | self.data_loader = data_loader
17 |
18 | if len_epoch is None:
19 | # epoch-based training
20 | self.len_epoch = len(self.data_loader)
21 | else:
22 | # iteration-based training
23 | self.data_loader = inf_loop(data_loader)
24 |
25 | self.len_epoch = len_epoch
26 | self.valid_data_loader = valid_data_loader
27 | self.do_validation = self.valid_data_loader is not None
28 | self.lr_scheduler = lr_scheduler
29 | self.log_step = int(np.sqrt(data_loader.batch_size))
30 |
31 | self.train_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer)
32 | self.valid_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer)
33 |
34 | def _train_epoch(self, epoch):
35 | """
36 | Training logic for an epoch
37 |
38 | for visualization use the following code (use batch size = 1):
39 | category_id_to_name = {1: 'car'}
40 | for batch_idx, dataset_dict in enumerate(self.data_loader):
41 | if dataset_dict['object'][0].item() == 0:
42 | print(dataset_dict)
43 | visualize(dataset_dict, category_id_to_name) --> see this method in util
44 |
45 | image size: torch.Size([10, 3, 256, 256]) if batch_size = 10
46 |
47 | :param epoch: Integer, current training epoch.
48 | :return: A log that contains average loss and metric in this epoch.
49 | """
50 |
51 | self.model.train()
52 | self.train_metrics.reset()
53 | for batch_idx, dataset_dict in enumerate(self.data_loader):
54 | #print(dataset_dict['image'].size())
55 |
56 | data, target = dataset_dict['image'].to(self.device), \
57 | dataset_dict['object'].to(self.device)
58 |
59 | self.optimizer.zero_grad()
60 | output = self.model(data)
61 | loss = self.criterion(output, target)
62 | loss.backward()
63 | self.optimizer.step()
64 |
65 | self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
66 | self.train_metrics.update('loss', loss.item())
67 | for met in self.metric_ftns:
68 | self.train_metrics.update(met.__name__, met(output, target))
69 |
70 | if batch_idx % self.log_step == 0:
71 | self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format(
72 | epoch,
73 | self._progress(batch_idx),
74 | loss.item()))
75 | self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
76 |
77 | if batch_idx == self.len_epoch:
78 | break
79 | log = self.train_metrics.result()
80 |
81 | if self.do_validation:
82 | val_log = self._valid_epoch(epoch)
83 | log.update(**{'val_'+k : v for k, v in val_log.items()})
84 |
85 | if self.lr_scheduler is not None:
86 | self.lr_scheduler.step()
87 | return log
88 |
89 | def _valid_epoch(self, epoch):
90 | """
91 | Validate after training an epoch
92 |
93 | :param epoch: Integer, current training epoch.
94 | :return: A log that contains information about validation
95 | """
96 | self.model.eval()
97 | self.valid_metrics.reset()
98 | with torch.no_grad():
99 | for batch_idx, dataset_dict in enumerate(self.valid_data_loader):
100 | data, target = dataset_dict['image'].to(self.device), \
101 | dataset_dict['object'].to(self.device)
102 |
103 | output = self.model(data)
104 | loss = self.criterion(output, target)
105 |
106 | self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid')
107 | self.valid_metrics.update('loss', loss.item())
108 | for met in self.metric_ftns:
109 | self.valid_metrics.update(met.__name__, met(output, target))
110 | self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
111 |
112 | # add histogram of model parameters to the tensorboard
113 | for name, p in self.model.named_parameters():
114 | self.writer.add_histogram(name, p, bins='auto')
115 | return self.valid_metrics.result()
116 |
117 | def _progress(self, batch_idx):
118 | base = '[{}/{} ({:.0f}%)]'
119 | if hasattr(self.data_loader, 'n_samples'):
120 | current = batch_idx * self.data_loader.batch_size
121 | total = self.data_loader.n_samples
122 | else:
123 | current = batch_idx
124 | total = self.len_epoch
125 | return base.format(current, total, 100.0 * current / total)
126 |
--------------------------------------------------------------------------------
/scripts_for_datasets/COWC_EESRGAN_FRCNN_dataset.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import os
3 | import torch
4 | import numpy as np
5 | import glob
6 | import cv2
7 | import matplotlib.pyplot as plt
8 | from torch.utils.data import Dataset, DataLoader
9 | from torchvision import transforms, utils
10 |
11 | # Ignore warnings
12 | import warnings
13 | warnings.filterwarnings("ignore")
14 |
15 |
16 | class COWCGANFrcnnDataset(Dataset):
17 | def __init__(self, data_dir_gt, data_dir_lq, image_height=256, image_width=256, transform = None):
18 | self.data_dir_gt = data_dir_gt
19 | self.data_dir_lq = data_dir_lq
20 | #take all under same folder for train and test split.
21 | self.transform = transform
22 | self.image_height = image_height
23 | self.image_width = image_width
24 | #sort all images for indexing, filter out check.jpgs
25 | self.imgs_gt = list(sorted(glob.glob(self.data_dir_gt+"*.jpg")))
26 | self.imgs_lq = list(sorted(glob.glob(self.data_dir_lq+"*.jpg")))
27 | self.annotation = list(sorted(glob.glob(self.data_dir_lq+"*.txt")))
28 |
29 | def __getitem__(self, idx):
30 | #get the paths
31 | img_path_gt = os.path.join(self.data_dir_gt, self.imgs_gt[idx])
32 | img_path_lq = os.path.join(self.data_dir_lq, self.imgs_lq[idx])
33 | annotation_path = os.path.join(self.data_dir_lq, self.annotation[idx])
34 | img_gt = cv2.imread(img_path_gt,1) #read color image height*width*channel=3
35 | img_lq = cv2.imread(img_path_lq,1) #read color image height*width*channel=3
36 | img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2RGB)
37 | img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2RGB)
38 | #get the bounding box
39 | boxes = list()
40 | label_car_type = list()
41 | with open(annotation_path) as f:
42 | for line in f:
43 | values = (line.split())
44 | if "\ufeff" in values[0]:
45 | values[0] = values[0][-1]
46 | obj_class = int(values[0])
47 | #image without bounding box - in txt file, line starts with 0 and only contains only 0
48 | if obj_class == 0:
49 | boxes.append([0, 0, 1, 1])
50 | labels = np.ones(len(boxes)) # all are cars
51 | label_car_type.append(obj_class)
52 | #create dictionary to access the values
53 | target = {}
54 | target['object'] = 0
55 | target['image_lq'] = img_lq
56 | target['image'] = img_gt
57 | target['bboxes'] = boxes
58 | target['labels'] = labels
59 | target['label_car_type'] = label_car_type
60 | target['image_id'] = idx
61 | target['LQ_path'] = img_path_lq
62 | target["area"] = 0
63 | target["iscrowd"] = 0
64 | break
65 | else:
66 | #get coordinates withing height width range
67 | x = float(values[1])*self.image_width
68 | y = float(values[2])*self.image_height
69 | width = float(values[3])*self.image_width
70 | height = float(values[4])*self.image_height
71 | #creating bounding boxes that would not touch the image edges
72 | x_min = 1 if x - width/2 <= 0 else int(x - width/2)
73 | x_max = 255 if x + width/2 >= 256 else int(x + width/2)
74 | y_min = 1 if y - height/2 <= 0 else int(y - height/2)
75 | y_max = 255 if y + height/2 >= 256 else int(y + height/2)
76 |
77 | boxes.append([x_min, y_min, x_max, y_max])
78 | label_car_type.append(obj_class)
79 |
80 | if obj_class != 0:
81 | labels = np.ones(len(boxes)) # all are cars
82 | boxes_for_calc = torch.as_tensor(boxes, dtype=torch.int64)
83 | area = (boxes_for_calc[:, 3] - boxes_for_calc[:, 1]) * (boxes_for_calc[:, 2] - boxes_for_calc[:, 0])
84 | iscrowd = torch.zeros((len(boxes),), dtype=torch.int64)
85 | #create dictionary to access the values
86 | target = {}
87 | target['object'] = 1
88 | target['image_lq'] = img_lq
89 | target['image'] = img_gt
90 | target['bboxes'] = boxes
91 | target['labels'] = labels
92 | target['label_car_type'] = label_car_type
93 | target['image_id'] = idx
94 | target['LQ_path'] = img_path_lq
95 | target["area"] = area
96 | target["iscrowd"] = iscrowd
97 |
98 | if self.transform is None:
99 | #convert to tensor
100 | image, target = self.convert_to_tensor(**target)
101 | return image, target
102 | #transform
103 | else:
104 | transformed = self.transform(**target)
105 | #print(transformed['image'], transformed['bboxes'], transformed['labels'], transformed['idx'])
106 | image, target = self.convert_to_tensor(**transformed)
107 | return image, target
108 |
109 | def __len__(self):
110 | return len(self.imgs_lq)
111 |
112 | def convert_to_tensor(self, **target):
113 | #convert to tensor
114 | target['object'] = torch.tensor(target['object'], dtype=torch.int64)
115 | target['image_lq'] = torch.from_numpy(target['image_lq'].transpose((2, 0, 1)))
116 | target['image'] = torch.from_numpy(target['image'].transpose((2, 0, 1)))
117 | target['boxes'] = torch.tensor(target['bboxes'], dtype=torch.float32)
118 | target['labels'] = torch.ones(len(target['labels']), dtype=torch.int64)
119 | target['label_car_type'] = torch.tensor(target['label_car_type'], dtype=torch.int64)
120 | target['image_id'] = torch.tensor([target['image_id']])
121 | target["area"] = torch.tensor(target['area'])
122 | target["iscrowd"] = torch.tensor(target['iscrowd'])
123 |
124 | image = {}
125 | image['object'] = target['object']
126 | image['image_lq'] = target['image_lq']
127 | image['image'] = target['image']
128 | image['image'] = target['image']
129 | image['LQ_path'] = target['LQ_path']
130 |
131 | del target['object']
132 | del target['image_lq']
133 | del target['image']
134 | del target['bboxes']
135 | del target['LQ_path']
136 |
137 | return image, target
138 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # EESRGAN
2 | ## Model Architecture
3 |
4 | ## Enhancement and Detection
5 | |Low Resolution
Image & Detection|Super Resolved
Image & Detection|High Resolution Ground Truth
Image & Bounding Box|
6 | | --- | --- | --- |
7 | |
|
|
|
8 | |
|
|
|
9 | |
|
|
|
10 | |
|
|
|
11 | ## Dependencies and Installation
12 | - Python 3 (Recommend to use Anaconda)
13 | - PyTorch >= 1.0
14 | - NVIDIA GPU + CUDA
15 | - Python packages: `pip install -r path/to/requirement.txt`
16 | ## Training
17 | `python train.py -c config_GAN.json`
18 | ## Testing
19 | `python test.py -c config_GAN.json`
20 | ## Dataset
21 | Download dataset from [here.](https://gdo152.llnl.gov/cowc/download/cowc-m/datasets/)
22 | [Here](https://github.com/LLNL/cowc/tree/master/COWC-M) is a GitHub repo to create custom image patches.
23 | Download pre-made dataset from [here](https://gdo152.llnl.gov/cowc/download/cowc-m/datasets/DetectionPatches_256x256.tgz) and [this](https://github.com/Jakaria08/EESRGAN/blob/1f93130d8e99166e7bc4d1640329450feec9ff9c/scripts_for_datasets/scripts_GAN_HR-LR.py#L24) script can be used with pre-made dataset to create high/low-resolution and bicubic images. Make sure to copy annotation files (.txt) in the HR, LR and Bic folder.
24 | ## Edit the JSON File
25 | The directory of the following JSON file is needed to be changed according to the user directory. For details see [config_GAN.json](https://github.com/Jakaria08/EESRGAN/blob/master/config_GAN.json) and pretrained weights are uploaded in [google drive](https://drive.google.com/drive/folders/15xN_TKKTUpQ5EVdZWJ2aZUa4Y-u-Mt0f?usp=sharing)
26 | ```yaml
27 | {
28 | "data_loader": {
29 | "type": "COWCGANFrcnnDataLoader",
30 | "args":{
31 | "data_dir_GT": "/Directory for High-Resolution Ground Truth images/",
32 | "data_dir_LQ": "/Directory for 4x downsampled Low-Resolution images from the above High-Resolution images/"
33 | }
34 | },
35 |
36 | "path": {
37 | "models": "saved/save_your_model_in_this_directory/",
38 | "pretrain_model_G": "Pretrained_model_path_for_train_test/170000_G.pth",
39 | "pretrain_model_D": "Pretrained_model_path_for_train_test/170000_G.pth",
40 | "pretrain_model_FRCNN": "Pretrained_model_path_for_train_test/170000_G.pth",
41 | "data_dir_Valid": "/Low_resoluton_test_validation_image_directory/"
42 | "Test_Result_SR": "Directory_to_store_test_results/"
43 | }
44 | }
45 |
46 | ```
47 | ## Paper
48 | Find the published version on [Remote Sensing](https://www.mdpi.com/2072-4292/12/9/1432).
49 | Find the preprints of the related paper on [preprints.org](https://www.preprints.org/manuscript/202003.0313/v1), [arxiv.org](https://arxiv.org/abs/2003.09085) and [researchgate.net](https://www.researchgate.net/publication/340095015_Small-Object_Detection_in_Remote_Sensing_Images_with_End-to-End_Edge-Enhanced_GAN_and_Object_Detector_Network).
50 | ### Abstract
51 | The detection performance of small objects in remote sensing images has not been satisfactory compared to large objects, especially in low-resolution and noisy images. A generative adversarial network (GAN)-based model called enhanced super-resolution GAN (ESRGAN) showed remarkable image enhancement performance, but reconstructed images usually miss high-frequency edge information. Therefore, object detection performance showed degradation for small objects on recovered noisy and low-resolution remote sensing images. Inspired by the success of edge enhanced GAN (EEGAN) and ESRGAN, we applied a new edge-enhanced super-resolution GAN (EESRGAN) to improve the quality of remote sensing images and used different detector networks in an end-to-end manner where detector loss was backpropagated into the EESRGAN to improve the detection performance. We proposed an architecture with three components: ESRGAN, EEN, and Detection network. We used residual-in-residual dense blocks (RRDB) for both the ESRGAN and EEN, and for the detector network, we used a faster region-based convolutional network (FRCNN) (two-stage detector) and a single-shot multibox detector (SSD) (one stage detector). Extensive experiments on a public (car overhead with context) dataset and another self-assembled (oil and gas storage tank) satellite dataset showed superior performance of our method compared to the standalone state-of-the-art object detectors.
52 | ### Keywords
53 | object detection; faster region-based convolutional neural network (FRCNN); single-shot multibox detector (SSD); super-resolution; remote sensing imagery; edge enhancement; satellites
54 | ## Related Repository
55 | Some code segments are based on [ESRGAN](https://github.com/xinntao/BasicSR)
56 | ## Citation
57 | ### BibTex
58 | `@article{rabbi2020small,`\
59 | `title={Small-Object Detection in Remote Sensing Images with End-to-End Edge-Enhanced GAN and Object Detector Network},`\
60 | `author={Rabbi, Jakaria and Ray, Nilanjan and Schubert, Matthias and Chowdhury, Subir and Chao, Dennis},`\
61 | `journal={Remote Sensing},`\
62 | `volume={12},`\
63 | `number={9},`\
64 | `pages={1432},`\
65 | `year={2020}`\
66 | `publisher={Multidisciplinary Digital Publishing Institute}`\
67 | `}`
68 | ### Chicago
69 | `Rabbi, Jakaria; Ray, Nilanjan; Schubert, Matthias; Chowdhury, Subir; Chao, Dennis. 2020. "Small-Object Detection in Remote Sensing Images with End-to-End Edge-Enhanced GAN and Object Detector Network." Remote Sens. 12, no. 9: 1432.`
70 | ## To Do
71 | - Refactor and clean the code.
72 | - Add more command line option for training and testing to run different configuration.
73 | - Fix bug and write important tests.
74 |
--------------------------------------------------------------------------------
/parse_config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | from pathlib import Path
4 | from functools import reduce, partial
5 | from operator import getitem
6 | from datetime import datetime
7 | from logger import setup_logging
8 | from utils import read_json, write_json
9 |
10 |
11 | class ConfigParser:
12 | def __init__(self, config, resume=None, modification=None, run_id=None):
13 | """
14 | class to parse configuration json file. Handles hyperparameters for training, initializations of modules, checkpoint saving
15 | and logging module.
16 | :param config: Dict containing configurations, hyperparameters for training. contents of `config.json` file for example.
17 | :param resume: String, path to the checkpoint being loaded.
18 | :param modification: Dict keychain:value, specifying position values to be replaced from config dict.
19 | :param run_id: Unique Identifier for training processes. Used to save checkpoints and training log. Timestamp is being used as default
20 | """
21 | # load config file and apply modification
22 | self._config = _update_config(config, modification)
23 | self.resume = resume
24 | '''
25 | # set save_dir where trained model and log will be saved.
26 | save_dir = Path(self.config['train']['save_dir'])
27 |
28 | exper_name = self.config['name']
29 | if run_id is None: # use timestamp as default run-id
30 | run_id = datetime.now().strftime(r'%m%d_%H%M%S')
31 | self._save_dir = save_dir / 'models' / exper_name / run_id
32 | self._log_dir = save_dir / 'log' / exper_name / run_id
33 |
34 | # make directory for saving checkpoints and log.
35 | exist_ok = run_id == ''
36 | self.save_dir.mkdir(parents=True, exist_ok=exist_ok)
37 | self.log_dir.mkdir(parents=True, exist_ok=exist_ok)
38 |
39 | # save updated config file to the checkpoint dir
40 | write_json(self.config, self.save_dir / 'config.json')
41 | '''
42 | # configure logging module
43 | #setup_logging(self.log_dir)
44 | self.log_levels = {
45 | 0: logging.WARNING,
46 | 1: logging.INFO,
47 | 2: logging.DEBUG
48 | }
49 |
50 | @classmethod
51 | def from_args(cls, args, options=''):
52 | """
53 | Initialize this class from some cli arguments. Used in train, test.
54 | """
55 | for opt in options:
56 | args.add_argument(*opt.flags, default=None, type=opt.type)
57 | if not isinstance(args, tuple):
58 | args = args.parse_args()
59 |
60 | if args.device is not None:
61 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device
62 | if args.resume is not None:
63 | resume = Path(args.resume)
64 | cfg_fname = resume.parent / 'config.json'
65 | else:
66 | msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example."
67 | assert args.config is not None, msg_no_cfg
68 | resume = None
69 | cfg_fname = Path(args.config)
70 |
71 | config = read_json(cfg_fname)
72 | if args.config and resume:
73 | # update new config for fine-tuning
74 | config.update(read_json(args.config))
75 |
76 | # parse custom cli options into dictionary
77 | modification = {opt.target : getattr(args, _get_opt_name(opt.flags)) for opt in options}
78 | return cls(config, resume, modification)
79 |
80 | def init_obj(self, name, module, *args, **kwargs):
81 | """
82 | Finds a function handle with the name given as 'type' in config, and returns the
83 | instance initialized with corresponding arguments given.
84 |
85 | `object = config.init_obj('name', module, a, b=1)`
86 | is equivalent to
87 | `object = module.name(a, b=1)`
88 | """
89 | module_name = self[name]['type']
90 | module_args = dict(self[name]['args'])
91 | assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
92 | module_args.update(kwargs)
93 | return getattr(module, module_name)(*args, **module_args)
94 |
95 | def init_ftn(self, name, module, *args, **kwargs):
96 | """
97 | Finds a function handle with the name given as 'type' in config, and returns the
98 | function with given arguments fixed with functools.partial.
99 |
100 | `function = config.init_ftn('name', module, a, b=1)`
101 | is equivalent to
102 | `function = lambda *args, **kwargs: module.name(a, *args, b=1, **kwargs)`.
103 | """
104 | module_name = self[name]['type']
105 | module_args = dict(self[name]['args'])
106 | assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
107 | module_args.update(kwargs)
108 | return partial(getattr(module, module_name), *args, **module_args)
109 |
110 | def __getitem__(self, name):
111 | """Access items like ordinary dict."""
112 | return self.config[name]
113 |
114 | def get_logger(self, name, verbosity=2):
115 | msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity, self.log_levels.keys())
116 | assert verbosity in self.log_levels, msg_verbosity
117 | logger = logging.getLogger(name)
118 | logger.setLevel(self.log_levels[verbosity])
119 | return logger
120 |
121 | # setting read-only attributes
122 | @property
123 | def config(self):
124 | return self._config
125 |
126 | @property
127 | def save_dir(self):
128 | return self._save_dir
129 |
130 | @property
131 | def log_dir(self):
132 | return self._log_dir
133 |
134 | # helper functions to update config dict with custom cli options
135 | def _update_config(config, modification):
136 | if modification is None:
137 | return config
138 |
139 | for k, v in modification.items():
140 | if v is not None:
141 | _set_by_path(config, k, v)
142 | return config
143 |
144 | def _get_opt_name(flags):
145 | for flg in flags:
146 | if flg.startswith('--'):
147 | return flg.replace('--', '')
148 | return flags[0].replace('--', '')
149 |
150 | def _set_by_path(tree, keys, value):
151 | """Set a value in a nested object in tree by sequence of keys."""
152 | keys = keys.split(';')
153 | _get_by_path(tree, keys[:-1])[keys[-1]] = value
154 |
155 | def _get_by_path(tree, keys):
156 | """Access a nested object in tree by sequence of keys."""
157 | return reduce(getitem, keys, tree)
158 |
--------------------------------------------------------------------------------
/data_loader/data_loaders.py:
--------------------------------------------------------------------------------
1 | from torchvision import datasets, transforms, utils
2 | from base import BaseDataLoader
3 | from scripts_for_datasets import COWCDataset, COWCGANDataset, COWCFRCNNDataset, COWCGANFrcnnDataset
4 |
5 | from albumentations import (
6 | HorizontalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
7 | Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
8 | IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine,
9 | IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose,
10 | BboxParams, RandomCrop, Normalize, Resize, VerticalFlip
11 | )
12 |
13 | from albumentations.pytorch import ToTensor
14 | from utils import collate_fn
15 | #from detection.utils import collate_fn
16 |
17 |
18 | class MnistDataLoader(BaseDataLoader):
19 | """
20 | MNIST data loading demo using BaseDataLoader
21 | """
22 | def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_workers=1, training=True):
23 | trsfm = transforms.Compose([
24 | transforms.ToTensor(),
25 | transforms.Normalize((0.1307,), (0.3081,))
26 | ])
27 | self.data_dir = data_dir
28 | self.dataset = datasets.MNIST(self.data_dir, train=training, download=True, transform=trsfm)
29 | super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers)
30 |
31 | class COWCDataLoader(BaseDataLoader):
32 | """
33 | COWC data loading using BaseDataLoader
34 | """
35 | def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_workers=1, training=True):
36 | #data transformation
37 | #According to this link: https://discuss.pytorch.org/t/normalization-of-input-image/34814/8
38 | #satellite image 0.5 is good otherwise calculate mean and std for the whole dataset.
39 | #calculted mean and std using method from util
40 | data_transforms = Compose([
41 | Resize(256, 256),
42 | HorizontalFlip(),
43 | OneOf([
44 | IAAAdditiveGaussianNoise(),
45 | GaussNoise(),
46 | ], p=0.2),
47 | OneOf([
48 | CLAHE(clip_limit=2),
49 | IAASharpen(),
50 | IAAEmboss(),
51 | RandomBrightnessContrast(),
52 | ], p=0.3),
53 | HueSaturationValue(p=0.3),
54 | Normalize( #mean std for potsdam dataset from COWC [Calculate also for spot6]
55 | mean=[0.3442, 0.3708, 0.3476],
56 | std=[0.1232, 0.1230, 0.1284]
57 | )
58 | ],
59 | bbox_params=BboxParams(
60 | format='pascal_voc',
61 | min_area=0,
62 | min_visibility=0,
63 | label_fields=['labels'])
64 | )
65 |
66 |
67 | self.data_dir = data_dir
68 | self.dataset = COWCDataset(self.data_dir, transform=data_transforms)
69 | super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=collate_fn)
70 |
71 | class COWCGANDataLoader(BaseDataLoader):
72 | """
73 | COWC data loading using BaseDataLoader
74 | """
75 | def __init__(self, data_dir_GT, data_dir_LQ, batch_size, shuffle=True, validation_split=0.0, num_workers=1, training=True):
76 | #data transformation
77 | #According to this link: https://discuss.pytorch.org/t/normalization-of-input-image/34814/8
78 | #satellite image 0.5 is good otherwise calculate mean and std for the whole dataset.
79 | #calculted mean and std using method from util
80 | '''
81 | Data transform for GAN training
82 | '''
83 | data_transforms_train = Compose([
84 | HorizontalFlip(),
85 | Normalize( #mean std for potsdam dataset from COWC [Calculate also for spot6]
86 | mean=[0.3442, 0.3708, 0.3476],
87 | std=[0.1232, 0.1230, 0.1284]
88 | )
89 | ],
90 | additional_targets={
91 | 'image_lq':'image'
92 | },
93 | bbox_params=BboxParams(
94 | format='pascal_voc',
95 | min_area=0,
96 | min_visibility=0,
97 | label_fields=['labels'])
98 | )
99 |
100 | data_transforms_test = Compose([
101 | Normalize( #mean std for potsdam dataset from COWC [Calculate also for spot6]
102 | mean=[0.3442, 0.3708, 0.3476],
103 | std=[0.1232, 0.1230, 0.1284]
104 | )],
105 | additional_targets={
106 | 'image_lq':'image'
107 | })
108 |
109 | self.data_dir_gt = data_dir_GT
110 | self.data_dir_lq = data_dir_LQ
111 |
112 | if training == True:
113 | self.dataset = COWCGANDataset(self.data_dir_gt, self.data_dir_lq, transform=data_transforms_train)
114 | else:
115 | self.dataset = COWCGANDataset(self.data_dir_gt, self.data_dir_lq, transform=data_transforms_test)
116 | self.length = len(self.dataset)
117 | super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=collate_fn)
118 |
119 | class COWCGANFrcnnDataLoader(BaseDataLoader):
120 | """
121 | COWC data loading using BaseDataLoader
122 | """
123 | def __init__(self, data_dir_GT, data_dir_LQ, batch_size, shuffle=True, validation_split=0.0, num_workers=1, training=True):
124 | #data transformation
125 | #According to this link: https://discuss.pytorch.org/t/normalization-of-input-image/34814/8
126 | #satellite image 0.5 is good otherwise calculate mean and std for the whole dataset.
127 | #calculted mean and std using method from util
128 | '''
129 | Data transform for GAN training
130 | '''
131 | data_transforms_train = Compose([
132 | HorizontalFlip(),
133 | Normalize( #mean std for potsdam dataset from COWC [Calculate also for spot6]
134 | mean=[0.3442, 0.3708, 0.3476],
135 | std=[0.1232, 0.1230, 0.1284]
136 | )
137 | ],
138 | additional_targets={
139 | 'image_lq':'image'
140 | },
141 | bbox_params=BboxParams(
142 | format='pascal_voc',
143 | min_area=0,
144 | min_visibility=0,
145 | label_fields=['labels'])
146 | )
147 |
148 | data_transforms_test = Compose([
149 | Normalize( #mean std for potsdam dataset from COWC [Calculate also for spot6]
150 | mean=[0.3442, 0.3708, 0.3476],
151 | std=[0.1232, 0.1230, 0.1284]
152 | )],
153 | additional_targets={
154 | 'image_lq':'image'
155 | })
156 |
157 | self.data_dir_gt = data_dir_GT
158 | self.data_dir_lq = data_dir_LQ
159 |
160 | if training == True:
161 | self.dataset = COWCGANFrcnnDataset(self.data_dir_gt, self.data_dir_lq, transform=data_transforms_train)
162 | else:
163 | self.dataset = COWCGANFrcnnDataset(self.data_dir_gt, self.data_dir_lq, transform=data_transforms_test)
164 | self.length = len(self.dataset)
165 | super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=collate_fn)
166 |
--------------------------------------------------------------------------------
/base/base_trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from abc import abstractmethod
3 | from numpy import inf
4 | from logger import TensorboardWriter
5 |
6 |
7 | class BaseTrainer:
8 | """
9 | Base class for all trainers
10 | """
11 | def __init__(self, model, criterion, metric_ftns, optimizer, config):
12 | self.config = config
13 | self.logger = config.get_logger('trainer', config['trainer']['verbosity'])
14 |
15 | # setup GPU device if available, move model into configured device
16 | self.device, device_ids = self._prepare_device(config['n_gpu'])
17 | self.model = model.to(self.device)
18 | if len(device_ids) > 1:
19 | self.model = torch.nn.DataParallel(model, device_ids=device_ids)
20 |
21 | self.criterion = criterion
22 | self.metric_ftns = metric_ftns
23 | self.optimizer = optimizer
24 |
25 | cfg_trainer = config['trainer']
26 | self.epochs = cfg_trainer['epochs']
27 | self.save_period = cfg_trainer['save_period']
28 | self.monitor = cfg_trainer.get('monitor', 'off')
29 |
30 | # configuration to monitor model performance and save best
31 | if self.monitor == 'off':
32 | self.mnt_mode = 'off'
33 | self.mnt_best = 0
34 | else:
35 | self.mnt_mode, self.mnt_metric = self.monitor.split()
36 | assert self.mnt_mode in ['min', 'max']
37 |
38 | self.mnt_best = inf if self.mnt_mode == 'min' else -inf
39 | self.early_stop = cfg_trainer.get('early_stop', inf)
40 |
41 | self.start_epoch = 1
42 |
43 | self.checkpoint_dir = config.save_dir
44 |
45 | # setup visualization writer instance
46 | self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard'])
47 |
48 | if config.resume is not None:
49 | self._resume_checkpoint(config.resume)
50 |
51 | @abstractmethod
52 | def _train_epoch(self, epoch):
53 | """
54 | Training logic for an epoch
55 |
56 | :param epoch: Current epoch number
57 | """
58 | raise NotImplementedError
59 |
60 | def train(self):
61 | """
62 | Full training logic
63 | """
64 | not_improved_count = 0
65 | for epoch in range(self.start_epoch, self.epochs + 1):
66 | result = self._train_epoch(epoch)
67 |
68 | # save logged informations into log dict
69 | log = {'epoch': epoch}
70 | log.update(result)
71 |
72 | # print logged informations to the screen
73 | for key, value in log.items():
74 | self.logger.info(' {:15s}: {}'.format(str(key), value))
75 |
76 | # evaluate model performance according to configured metric, save best checkpoint as model_best
77 | best = False
78 | if self.mnt_mode != 'off':
79 | try:
80 | # check whether model performance improved or not, according to specified metric(mnt_metric)
81 | improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \
82 | (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best)
83 | except KeyError:
84 | self.logger.warning("Warning: Metric '{}' is not found. "
85 | "Model performance monitoring is disabled.".format(self.mnt_metric))
86 | self.mnt_mode = 'off'
87 | improved = False
88 |
89 | if improved:
90 | self.mnt_best = log[self.mnt_metric]
91 | not_improved_count = 0
92 | best = True
93 | else:
94 | not_improved_count += 1
95 |
96 | if not_improved_count > self.early_stop:
97 | self.logger.info("Validation performance didn\'t improve for {} epochs. "
98 | "Training stops.".format(self.early_stop))
99 | break
100 |
101 | if epoch % self.save_period == 0:
102 | self._save_checkpoint(epoch, save_best=best)
103 |
104 | def _prepare_device(self, n_gpu_use):
105 | """
106 | setup GPU device if available, move model into configured device
107 | """
108 | n_gpu = torch.cuda.device_count()
109 | if n_gpu_use > 0 and n_gpu == 0:
110 | self.logger.warning("Warning: There\'s no GPU available on this machine,"
111 | "training will be performed on CPU.")
112 | n_gpu_use = 0
113 | if n_gpu_use > n_gpu:
114 | self.logger.warning("Warning: The number of GPU\'s configured to use is {}, but only {} are available "
115 | "on this machine.".format(n_gpu_use, n_gpu))
116 | n_gpu_use = n_gpu
117 | device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
118 | list_ids = list(range(n_gpu_use))
119 | return device, list_ids
120 |
121 | def _save_checkpoint(self, epoch, save_best=False):
122 | """
123 | Saving checkpoints
124 |
125 | :param epoch: current epoch number
126 | :param log: logging information of the epoch
127 | :param save_best: if True, rename the saved checkpoint to 'model_best.pth'
128 | """
129 | arch = type(self.model).__name__
130 | state = {
131 | 'arch': arch,
132 | 'epoch': epoch,
133 | 'state_dict': self.model.state_dict(),
134 | 'optimizer': self.optimizer.state_dict(),
135 | 'monitor_best': self.mnt_best,
136 | 'config': self.config
137 | }
138 | filename = str(self.checkpoint_dir / 'checkpoint-epoch{}.pth'.format(epoch))
139 | torch.save(state, filename)
140 | self.logger.info("Saving checkpoint: {} ...".format(filename))
141 | if save_best:
142 | best_path = str(self.checkpoint_dir / 'model_best.pth')
143 | torch.save(state, best_path)
144 | self.logger.info("Saving current best: model_best.pth ...")
145 |
146 | def _resume_checkpoint(self, resume_path):
147 | """
148 | Resume from saved checkpoints
149 |
150 | :param resume_path: Checkpoint path to be resumed
151 | """
152 | resume_path = str(resume_path)
153 | self.logger.info("Loading checkpoint: {} ...".format(resume_path))
154 | checkpoint = torch.load(resume_path)
155 | self.start_epoch = checkpoint['epoch'] + 1
156 | self.mnt_best = checkpoint['monitor_best']
157 |
158 | # load architecture params from checkpoint.
159 | if checkpoint['config']['arch'] != self.config['arch']:
160 | self.logger.warning("Warning: Architecture configuration given in config file is different from that of "
161 | "checkpoint. This may yield an exception while state_dict is being loaded.")
162 | self.model.load_state_dict(checkpoint['state_dict'])
163 |
164 | # load optimizer state from checkpoint only when optimizer type is not changed.
165 | if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']:
166 | self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. "
167 | "Optimizer parameters not being resumed.")
168 | else:
169 | self.optimizer.load_state_dict(checkpoint['optimizer'])
170 |
171 | self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch))
172 |
--------------------------------------------------------------------------------
/detection/train.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import os
3 | import time
4 |
5 | import torch
6 | import torch.utils.data
7 | from torch import nn
8 | import torchvision
9 | import torchvision.models.detection
10 | import torchvision.models.detection.mask_rcnn
11 |
12 | from torchvision import transforms
13 |
14 | from coco_utils import get_coco, get_coco_kp
15 |
16 | from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups
17 | from engine import train_one_epoch, evaluate
18 |
19 | import utils
20 | import transforms as T
21 |
22 |
23 | def get_dataset(name, image_set, transform):
24 | paths = {
25 | "coco": ('/datasets01/COCO/022719/', get_coco, 91),
26 | "coco_kp": ('/datasets01/COCO/022719/', get_coco_kp, 2)
27 | }
28 | p, ds_fn, num_classes = paths[name]
29 |
30 | ds = ds_fn(p, image_set=image_set, transforms=transform)
31 | return ds, num_classes
32 |
33 |
34 | def get_transform(train):
35 | transforms = []
36 | transforms.append(T.ToTensor())
37 | if train:
38 | transforms.append(T.RandomHorizontalFlip(0.5))
39 | return T.Compose(transforms)
40 |
41 |
42 | def main(args):
43 | utils.init_distributed_mode(args)
44 | print(args)
45 |
46 | device = torch.device(args.device)
47 |
48 | # Data loading code
49 | print("Loading data")
50 |
51 | dataset, num_classes = get_dataset(args.dataset, "train", get_transform(train=True))
52 | dataset_test, _ = get_dataset(args.dataset, "val", get_transform(train=False))
53 |
54 | print("Creating data loaders")
55 | if args.distributed:
56 | train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
57 | test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
58 | else:
59 | train_sampler = torch.utils.data.RandomSampler(dataset)
60 | test_sampler = torch.utils.data.SequentialSampler(dataset_test)
61 |
62 | if args.aspect_ratio_group_factor >= 0:
63 | group_ids = create_aspect_ratio_groups(dataset, k=args.aspect_ratio_group_factor)
64 | train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size)
65 | else:
66 | train_batch_sampler = torch.utils.data.BatchSampler(
67 | train_sampler, args.batch_size, drop_last=True)
68 |
69 | data_loader = torch.utils.data.DataLoader(
70 | dataset, batch_sampler=train_batch_sampler, num_workers=args.workers,
71 | collate_fn=utils.collate_fn)
72 |
73 | data_loader_test = torch.utils.data.DataLoader(
74 | dataset_test, batch_size=1,
75 | sampler=test_sampler, num_workers=args.workers,
76 | collate_fn=utils.collate_fn)
77 |
78 | print("Creating model")
79 | model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes,
80 | pretrained=args.pretrained)
81 | model.to(device)
82 |
83 | model_without_ddp = model
84 | if args.distributed:
85 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
86 | model_without_ddp = model.module
87 |
88 | params = [p for p in model.parameters() if p.requires_grad]
89 | optimizer = torch.optim.SGD(
90 | params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
91 |
92 | # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
93 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
94 |
95 | if args.resume:
96 | checkpoint = torch.load(args.resume, map_location='cpu')
97 | model_without_ddp.load_state_dict(checkpoint['model'])
98 | optimizer.load_state_dict(checkpoint['optimizer'])
99 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
100 |
101 | if args.test_only:
102 | evaluate(model, data_loader_test, device=device)
103 | return
104 |
105 | print("Start training")
106 | start_time = time.time()
107 | for epoch in range(args.epochs):
108 | if args.distributed:
109 | train_sampler.set_epoch(epoch)
110 | train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq)
111 | lr_scheduler.step()
112 | if args.output_dir:
113 | utils.save_on_master({
114 | 'model': model_without_ddp.state_dict(),
115 | 'optimizer': optimizer.state_dict(),
116 | 'lr_scheduler': lr_scheduler.state_dict(),
117 | 'args': args},
118 | os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
119 |
120 | # evaluate after every epoch
121 | evaluate(model, data_loader_test, device=device)
122 |
123 | total_time = time.time() - start_time
124 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
125 | print('Training time {}'.format(total_time_str))
126 |
127 |
128 | if __name__ == "__main__":
129 | import argparse
130 | parser = argparse.ArgumentParser(description='PyTorch Detection Training')
131 |
132 | parser.add_argument('--data-path', default='/datasets01/COCO/022719/', help='dataset')
133 | parser.add_argument('--dataset', default='coco', help='dataset')
134 | parser.add_argument('--model', default='maskrcnn_resnet50_fpn', help='model')
135 | parser.add_argument('--device', default='cuda', help='device')
136 | parser.add_argument('-b', '--batch-size', default=2, type=int)
137 | parser.add_argument('--epochs', default=13, type=int, metavar='N',
138 | help='number of total epochs to run')
139 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
140 | help='number of data loading workers (default: 16)')
141 | parser.add_argument('--lr', default=0.02, type=float, help='initial learning rate')
142 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
143 | help='momentum')
144 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
145 | metavar='W', help='weight decay (default: 1e-4)',
146 | dest='weight_decay')
147 | parser.add_argument('--lr-step-size', default=8, type=int, help='decrease lr every step-size epochs')
148 | parser.add_argument('--lr-steps', default=[8, 11], nargs='+', type=int, help='decrease lr every step-size epochs')
149 | parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
150 | parser.add_argument('--print-freq', default=20, type=int, help='print frequency')
151 | parser.add_argument('--output-dir', default='.', help='path where to save')
152 | parser.add_argument('--resume', default='', help='resume from checkpoint')
153 | parser.add_argument('--aspect-ratio-group-factor', default=0, type=int)
154 | parser.add_argument(
155 | "--test-only",
156 | dest="test_only",
157 | help="Only test the model",
158 | action="store_true",
159 | )
160 | parser.add_argument(
161 | "--pretrained",
162 | dest="pretrained",
163 | help="Use pre-trained models from the modelzoo",
164 | action="store_true",
165 | )
166 |
167 | # distributed training parameters
168 | parser.add_argument('--world-size', default=1, type=int,
169 | help='number of distributed processes')
170 | parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
171 |
172 | args = parser.parse_args()
173 |
174 | if args.output_dir:
175 | utils.mkdir(args.output_dir)
176 |
177 | main(args)
178 |
--------------------------------------------------------------------------------
/detection/group_by_aspect_ratio.py:
--------------------------------------------------------------------------------
1 | import bisect
2 | from collections import defaultdict
3 | import copy
4 | import numpy as np
5 |
6 | import torch
7 | import torch.utils.data
8 | from torch.utils.data.sampler import BatchSampler, Sampler
9 | from torch.utils.model_zoo import tqdm
10 | import torchvision
11 |
12 | from PIL import Image
13 |
14 |
15 | class GroupedBatchSampler(BatchSampler):
16 | """
17 | Wraps another sampler to yield a mini-batch of indices.
18 | It enforces that the batch only contain elements from the same group.
19 | It also tries to provide mini-batches which follows an ordering which is
20 | as close as possible to the ordering from the original sampler.
21 | Arguments:
22 | sampler (Sampler): Base sampler.
23 | group_ids (list[int]): If the sampler produces indices in range [0, N),
24 | `group_ids` must be a list of `N` ints which contains the group id of each sample.
25 | The group ids must be a continuous set of integers starting from
26 | 0, i.e. they must be in the range [0, num_groups).
27 | batch_size (int): Size of mini-batch.
28 | """
29 | def __init__(self, sampler, group_ids, batch_size):
30 | if not isinstance(sampler, Sampler):
31 | raise ValueError(
32 | "sampler should be an instance of "
33 | "torch.utils.data.Sampler, but got sampler={}".format(sampler)
34 | )
35 | self.sampler = sampler
36 | self.group_ids = group_ids
37 | self.batch_size = batch_size
38 |
39 | def __iter__(self):
40 | buffer_per_group = defaultdict(list)
41 | samples_per_group = defaultdict(list)
42 |
43 | num_batches = 0
44 | for idx in self.sampler:
45 | group_id = self.group_ids[idx]
46 | buffer_per_group[group_id].append(idx)
47 | samples_per_group[group_id].append(idx)
48 | if len(buffer_per_group[group_id]) == self.batch_size:
49 | yield buffer_per_group[group_id]
50 | num_batches += 1
51 | del buffer_per_group[group_id]
52 | assert len(buffer_per_group[group_id]) < self.batch_size
53 |
54 | # now we have run out of elements that satisfy
55 | # the group criteria, let's return the remaining
56 | # elements so that the size of the sampler is
57 | # deterministic
58 | expected_num_batches = len(self)
59 | num_remaining = expected_num_batches - num_batches
60 | if num_remaining > 0:
61 | # for the remaining batches, take first the buffers with largest number
62 | # of elements
63 | for group_id, _ in sorted(buffer_per_group.items(),
64 | key=lambda x: len(x[1]), reverse=True):
65 | remaining = self.batch_size - len(buffer_per_group[group_id])
66 | buffer_per_group[group_id].extend(
67 | samples_per_group[group_id][:remaining])
68 | assert len(buffer_per_group[group_id]) == self.batch_size
69 | yield buffer_per_group[group_id]
70 | num_remaining -= 1
71 | if num_remaining == 0:
72 | break
73 | assert num_remaining == 0
74 |
75 | def __len__(self):
76 | return len(self.sampler) // self.batch_size
77 |
78 |
79 | def _compute_aspect_ratios_slow(dataset, indices=None):
80 | print("Your dataset doesn't support the fast path for "
81 | "computing the aspect ratios, so will iterate over "
82 | "the full dataset and load every image instead. "
83 | "This might take some time...")
84 | if indices is None:
85 | indices = range(len(dataset))
86 |
87 | class SubsetSampler(Sampler):
88 | def __init__(self, indices):
89 | self.indices = indices
90 |
91 | def __iter__(self):
92 | return iter(self.indices)
93 |
94 | def __len__(self):
95 | return len(self.indices)
96 |
97 | sampler = SubsetSampler(indices)
98 | data_loader = torch.utils.data.DataLoader(
99 | dataset, batch_size=1, sampler=sampler,
100 | num_workers=14, # you might want to increase it for faster processing
101 | collate_fn=lambda x: x[0])
102 | aspect_ratios = []
103 | with tqdm(total=len(dataset)) as pbar:
104 | for i, (img, _) in enumerate(data_loader):
105 | pbar.update(1)
106 | height, width = img.shape[-2:]
107 | aspect_ratio = float(height) / float(width)
108 | aspect_ratios.append(aspect_ratio)
109 | return aspect_ratios
110 |
111 |
112 | def _compute_aspect_ratios_custom_dataset(dataset, indices=None):
113 | if indices is None:
114 | indices = range(len(dataset))
115 | aspect_ratios = []
116 | for i in indices:
117 | height, width = dataset.get_height_and_width(i)
118 | aspect_ratio = float(height) / float(width)
119 | aspect_ratios.append(aspect_ratio)
120 | return aspect_ratios
121 |
122 |
123 | def _compute_aspect_ratios_coco_dataset(dataset, indices=None):
124 | if indices is None:
125 | indices = range(len(dataset))
126 | aspect_ratios = []
127 | for i in indices:
128 | img_info = dataset.coco.imgs[dataset.ids[i]]
129 | aspect_ratio = float(img_info["height"]) / float(img_info["width"])
130 | aspect_ratios.append(aspect_ratio)
131 | return aspect_ratios
132 |
133 |
134 | def _compute_aspect_ratios_voc_dataset(dataset, indices=None):
135 | if indices is None:
136 | indices = range(len(dataset))
137 | aspect_ratios = []
138 | for i in indices:
139 | # this doesn't load the data into memory, because PIL loads it lazily
140 | width, height = Image.open(dataset.images[i]).size
141 | aspect_ratio = float(height) / float(width)
142 | aspect_ratios.append(aspect_ratio)
143 | return aspect_ratios
144 |
145 |
146 | def _compute_aspect_ratios_subset_dataset(dataset, indices=None):
147 | if indices is None:
148 | indices = range(len(dataset))
149 |
150 | ds_indices = [dataset.indices[i] for i in indices]
151 | return compute_aspect_ratios(dataset.dataset, ds_indices)
152 |
153 |
154 | def compute_aspect_ratios(dataset, indices=None):
155 | if hasattr(dataset, "get_height_and_width"):
156 | return _compute_aspect_ratios_custom_dataset(dataset, indices)
157 |
158 | if isinstance(dataset, torchvision.datasets.CocoDetection):
159 | return _compute_aspect_ratios_coco_dataset(dataset, indices)
160 |
161 | if isinstance(dataset, torchvision.datasets.VOCDetection):
162 | return _compute_aspect_ratios_voc_dataset(dataset, indices)
163 |
164 | if isinstance(dataset, torch.utils.data.Subset):
165 | return _compute_aspect_ratios_subset_dataset(dataset, indices)
166 |
167 | # slow path
168 | return _compute_aspect_ratios_slow(dataset, indices)
169 |
170 |
171 | def _quantize(x, bins):
172 | bins = copy.deepcopy(bins)
173 | bins = sorted(bins)
174 | quantized = list(map(lambda y: bisect.bisect_right(bins, y), x))
175 | return quantized
176 |
177 |
178 | def create_aspect_ratio_groups(dataset, k=0):
179 | aspect_ratios = compute_aspect_ratios(dataset)
180 | bins = (2 ** np.linspace(-1, 1, 2 * k + 1)).tolist() if k > 0 else [1.0]
181 | groups = _quantize(aspect_ratios, bins)
182 | # count number of elements per group
183 | counts = np.unique(groups, return_counts=True)[1]
184 | fbins = [0] + bins + [np.inf]
185 | print("Using {} as bins for aspect ratio quantization".format(fbins))
186 | print("Count of instances per bin: {}".format(counts))
187 | return groups
188 |
--------------------------------------------------------------------------------
/trainer/cowc_GAN_FRCNN_trainer.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import numpy as np
3 | import torch
4 | import math
5 | import os
6 | import model.ESRGANModel as ESRGAN
7 | import model.ESRGAN_EESN_FRCNN_Model as ESRGAN_EESN
8 | from scripts_for_datasets import COWCDataset, COWCGANDataset
9 | from torchvision.utils import make_grid
10 | from base import BaseTrainer
11 | from utils import inf_loop, MetricTracker, visualize_bbox, visualize, calculate_psnr, save_img, tensor2img, mkdir
12 |
13 | logger = logging.getLogger('base')
14 | '''
15 | python train.py -c config_GAN.json
16 | modified from ESRGAN repo
17 | '''
18 |
19 | class COWCGANFrcnnTrainer:
20 | """
21 | Trainer class
22 | """
23 | def __init__(self, config, data_loader, valid_data_loader=None):
24 | self.config = config
25 | self.data_loader = data_loader
26 |
27 | self.valid_data_loader = valid_data_loader
28 | self.do_validation = self.valid_data_loader is not None
29 | n_gpu = torch.cuda.device_count()
30 | self.device = torch.device('cuda:0' if n_gpu > 0 else 'cpu')
31 | self.train_size = int(math.ceil(self.data_loader.length / int(config['data_loader']['args']['batch_size'])))
32 | self.total_iters = int(config['train']['niter'])
33 | self.total_epochs = int(math.ceil(self.total_iters / self.train_size))
34 | print(self.total_epochs)
35 | self.model = ESRGAN_EESN.ESRGAN_EESN_FRCNN_Model(config,self.device)
36 |
37 | def test(self):
38 | self.model.test(self.data_loader, train=False, testResult=True)
39 |
40 | def train(self):
41 | '''
42 | Training logic for an epoch
43 | for visualization use the following code (use batch size = 1):
44 |
45 | category_id_to_name = {1: 'car'}
46 | for batch_idx, dataset_dict in enumerate(self.data_loader):
47 | if dataset_dict['idx'][0] == 10:
48 | print(dataset_dict)
49 | visualize(dataset_dict, category_id_to_name) #--> see this method in util
50 |
51 | #image size: torch.Size([10, 3, 256, 256]) if batch_size = 10
52 | '''
53 | logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
54 | self.data_loader.length, self.train_size))
55 | logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
56 | self.total_epochs, self.total_iters))
57 | # tensorboard logger
58 | if self.config['use_tb_logger'] and 'debug' not in self.config['name']:
59 | version = float(torch.__version__[0:3])
60 | if version >= 1.1: # PyTorch 1.1
61 | from torch.utils.tensorboard import SummaryWriter
62 | else:
63 | logger.info(
64 | 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
65 | from tensorboardX import SummaryWriter
66 | tb_logger = SummaryWriter(log_dir='saved/tb_logger/' + self.config['name'])
67 | ## Todo : resume capability
68 | current_step = 0
69 | start_epoch = 0
70 |
71 | #### training
72 | logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
73 | for epoch in range(start_epoch, self.total_epochs + 1):
74 | for _, (image, targets) in enumerate(self.data_loader):
75 | current_step += 1
76 | if current_step > self.total_iters:
77 | break
78 | #### update learning rate
79 | self.model.update_learning_rate(current_step, warmup_iter=self.config['train']['warmup_iter'])
80 |
81 | #### training
82 | self.model.feed_data(image, targets)
83 | self.model.optimize_parameters(current_step)
84 |
85 | #### log
86 | if current_step % self.config['logger']['print_freq'] == 0:
87 | logs = self.model.get_current_log()
88 | message = ' '.format(
89 | epoch, current_step, self.model.get_current_learning_rate())
90 | for k, v in logs.items():
91 | message += '{:s}: {:.4e} '.format(k, v)
92 | # tensorboard logger
93 | if self.config['use_tb_logger'] and 'debug' not in self.config['name']:
94 | tb_logger.add_scalar(k, v, current_step)
95 |
96 | logger.info(message)
97 |
98 | # validation
99 | if current_step % self.config['train']['val_freq'] == 0:
100 | self.model.test(self.valid_data_loader)
101 |
102 | #### save models and training states
103 | if current_step % self.config['logger']['save_checkpoint_freq'] == 0:
104 | logger.info('Saving models and training states.')
105 | self.model.save(current_step)
106 | self.model.save_training_state(epoch, current_step)
107 |
108 | #saving SR_images
109 | for _, (image, targets) in enumerate(self.valid_data_loader):
110 | #print(image)
111 | img_name = os.path.splitext(os.path.basename(image['LQ_path'][0]))[0]
112 | img_dir = os.path.join(self.config['path']['val_images'], img_name)
113 | mkdir(img_dir)
114 |
115 | self.model.feed_data(image, targets)
116 | self.model.test(self.valid_data_loader, train=False)
117 |
118 | visuals = self.model.get_current_visuals()
119 | sr_img = tensor2img(visuals['SR']) # uint8
120 | gt_img = tensor2img(visuals['GT']) # uint8
121 | lap_learned = tensor2img(visuals['lap_learned']) # uint8
122 | lap = tensor2img(visuals['lap']) # uint8
123 | lap_HR = tensor2img(visuals['lap_HR']) # uint8
124 | final_SR = tensor2img(visuals['final_SR']) # uint8
125 |
126 | # Save SR images for reference
127 | save_img_path = os.path.join(img_dir,
128 | '{:s}_{:d}_SR.png'.format(img_name, current_step))
129 | save_img(sr_img, save_img_path)
130 | # Save GT images for reference
131 | save_img_path = os.path.join(img_dir,
132 | '{:s}_{:d}_GT.png'.format(img_name, current_step))
133 | save_img(gt_img, save_img_path)
134 | # Save final_SR images for reference
135 | save_img_path = os.path.join(img_dir,
136 | '{:s}_{:d}_final_SR.png'.format(img_name, current_step))
137 | save_img(final_SR, save_img_path)
138 | # Save lap_learned images for reference
139 | save_img_path = os.path.join(img_dir,
140 | '{:s}_{:d}_lap_learned.png'.format(img_name, current_step))
141 | save_img(lap_learned, save_img_path)
142 | # Save lap images for reference
143 | save_img_path = os.path.join(img_dir,
144 | '{:s}_{:d}_lap.png'.format(img_name, current_step))
145 | save_img(lap, save_img_path)
146 | # Save lap images for reference
147 | save_img_path = os.path.join(img_dir,
148 | '{:s}_{:d}_lap_HR.png'.format(img_name, current_step))
149 | save_img(lap_HR, save_img_path)
150 |
151 |
152 | logger.info('Saving the final model.')
153 | self.model.save('latest')
154 | logger.info('End of training.')
155 |
--------------------------------------------------------------------------------
/detection/engine.py:
--------------------------------------------------------------------------------
1 | import math
2 | import sys
3 | import time
4 | import torch
5 | import os
6 | import numpy as np
7 | import cv2
8 |
9 | import torchvision.models.detection.mask_rcnn
10 |
11 | from .coco_utils import get_coco_api_from_dataset, get_coco_api_from_dataset_base
12 | from .coco_eval import CocoEvaluator
13 | from .utils import MetricLogger, SmoothedValue, warmup_lr_scheduler, reduce_dict
14 | from utils import tensor2img
15 |
16 |
17 | def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
18 | model.train()
19 | metric_logger = MetricLogger(delimiter=" ")
20 | metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
21 | header = 'Epoch: [{}]'.format(epoch)
22 |
23 | lr_scheduler = None
24 | if epoch == 0:
25 | warmup_factor = 1. / 1000
26 | warmup_iters = min(1000, len(data_loader) - 1)
27 |
28 | lr_scheduler = warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor)
29 |
30 | for images, targets in metric_logger.log_every(data_loader, print_freq, header):
31 | images = list(image.to(device) for image in images)
32 | targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
33 | #print(images)
34 | loss_dict = model(images, targets)
35 |
36 | losses = sum(loss for loss in loss_dict.values())
37 |
38 | # reduce losses over all GPUs for logging purposes
39 | loss_dict_reduced = reduce_dict(loss_dict)
40 | losses_reduced = sum(loss for loss in loss_dict_reduced.values())
41 |
42 | loss_value = losses_reduced.item()
43 |
44 | if not math.isfinite(loss_value):
45 | print("Loss is {}, stopping training".format(loss_value))
46 | print(loss_dict_reduced)
47 | sys.exit(1)
48 |
49 | optimizer.zero_grad()
50 | losses.backward()
51 | optimizer.step()
52 |
53 | if lr_scheduler is not None:
54 | lr_scheduler.step()
55 |
56 | metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
57 | metric_logger.update(lr=optimizer.param_groups[0]["lr"])
58 |
59 |
60 | def _get_iou_types(model):
61 | model_without_ddp = model
62 | if isinstance(model, torch.nn.parallel.DistributedDataParallel):
63 | model_without_ddp = model.module
64 | iou_types = ["bbox"]
65 | if isinstance(model_without_ddp, torchvision.models.detection.MaskRCNN):
66 | iou_types.append("segm")
67 | if isinstance(model_without_ddp, torchvision.models.detection.KeypointRCNN):
68 | iou_types.append("keypoints")
69 | return iou_types
70 |
71 | '''
72 | Draw boxes on the test images
73 | '''
74 | def draw_detection_boxes(new_class_conf_box, config, file_name, image):
75 | source_image_path = os.path.join(config['path']['output_images'], file_name, file_name+'_112000_final_SR.png')
76 | dest_image_path = os.path.join(config['path']['Test_Result_SR'], file_name+'.png')
77 | image = cv2.imread(source_image_path, 1)
78 | #print(new_class_conf_box)
79 | #print(len(new_class_conf_box))
80 | for i in range(len(new_class_conf_box)):
81 | clas,con,x1,y1,x2,y2 = new_class_conf_box[i]
82 | cv2.rectangle(image, (x1, y1), (x2, y2), (0,0,255), 4)
83 | font = cv2.FONT_HERSHEY_SIMPLEX
84 | cv2.putText(image, 'Car: '+ str((int(con*100))) + '%', (x1+5, y1+8), font, 0.2,(0,255,0),1,cv2.LINE_AA)
85 |
86 | cv2.imwrite(dest_image_path, image)
87 |
88 | '''
89 | for generating test boxes
90 | '''
91 | def get_prediction(outputs, file_path, config, file_name, image, threshold=0.5):
92 | new_class_conf_box = list()
93 | pred_class = [i for i in list(outputs[0]['labels'].detach().cpu().numpy())] # Get the Prediction Score
94 | text_boxes = [ [i[0], i[1], i[2], i[3] ] for i in list(outputs[0]['boxes'].detach().cpu().numpy())] # Bounding boxes
95 | pred_score = list(outputs[0]['scores'].detach().cpu().numpy())
96 | #print(pred_score)
97 | for i in range(len(text_boxes)):
98 | new_class_conf_box.append([pred_class[i], pred_score[i], int(text_boxes[i][0]), int(text_boxes[i][1]), int(text_boxes[i][2]), int(text_boxes[i][3])])
99 | draw_detection_boxes(new_class_conf_box, config, file_name, image)
100 | new_class_conf_box1 = np.matrix(new_class_conf_box)
101 | #print(new_class_conf_box)
102 | if(len(new_class_conf_box))>0:
103 | np.savetxt(file_path, new_class_conf_box1, fmt="%i %1.3f %i %i %i %i")
104 |
105 |
106 | @torch.no_grad()
107 | def evaluate_save(model_G, model_FRCNN, data_loader, device, config):
108 | i = 0
109 | print("Detection started........")
110 | for image, targets in data_loader:
111 | image['image_lq'] = image['image_lq'].to(device)
112 |
113 | _, img, _, _ = model_G(image['image_lq'])
114 | img_count = img.size()[0]
115 | images = [img[i] for i in range(img_count)]
116 | outputs = model_FRCNN(images)
117 | file_name = os.path.splitext(os.path.basename(image['LQ_path'][0]))[0]
118 | file_path = os.path.join(config['path']['Test_Result_SR'], file_name+'.txt')
119 | i=i+1
120 | print(i)
121 | img = img[0].detach()[0].float().cpu()
122 | img = tensor2img(img)
123 | get_prediction(outputs, file_path, config, file_name, img)
124 | print("successfully generated the results!")
125 |
126 | '''
127 | This evaluate method is changed to pass the generator network and evalute
128 | the FRCNN with generated SR images
129 | '''
130 | @torch.no_grad()
131 | def evaluate(model_G, model_FRCNN, data_loader, device):
132 | n_threads = torch.get_num_threads()
133 | # FIXME remove this and make paste_masks_in_image run on the GPU
134 | torch.set_num_threads(1)
135 | cpu_device = torch.device("cpu")
136 | #model.eval()
137 | metric_logger = MetricLogger(delimiter=" ")
138 | header = 'Test:'
139 |
140 | coco = get_coco_api_from_dataset(data_loader.dataset)
141 | iou_types = _get_iou_types(model_FRCNN)
142 | coco_evaluator = CocoEvaluator(coco, iou_types)
143 |
144 | for image, targets in metric_logger.log_every(data_loader, 100, header):
145 | image['image_lq'] = image['image_lq'].to(device)
146 | targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
147 |
148 | torch.cuda.synchronize()
149 | model_time = time.time()
150 | _, image, _, _ = model_G(image['image_lq'])
151 | img_count = image.size()[0]
152 | image = [image[i] for i in range(img_count)]
153 | outputs = model_FRCNN(image)
154 |
155 | outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
156 | model_time = time.time() - model_time
157 |
158 | res = {target["image_id"].item(): output for target, output in zip(targets, outputs)}
159 | evaluator_time = time.time()
160 | coco_evaluator.update(res)
161 | evaluator_time = time.time() - evaluator_time
162 | metric_logger.update(model_time=model_time, evaluator_time=evaluator_time)
163 |
164 | # gather the stats from all processes
165 | metric_logger.synchronize_between_processes()
166 | print("Averaged stats:", metric_logger)
167 | coco_evaluator.synchronize_between_processes()
168 |
169 | # accumulate predictions from all images
170 | coco_evaluator.accumulate()
171 | coco_evaluator.summarize()
172 | torch.set_num_threads(n_threads)
173 | return coco_evaluator
174 |
175 | @torch.no_grad()
176 | def evaluate_base(model, data_loader, device):
177 | n_threads = torch.get_num_threads()
178 | # FIXME remove this and make paste_masks_in_image run on the GPU
179 | torch.set_num_threads(1)
180 | cpu_device = torch.device("cpu")
181 | model.eval()
182 | metric_logger = MetricLogger(delimiter=" ")
183 | header = 'Test:'
184 |
185 | coco = get_coco_api_from_dataset_base(data_loader.dataset)
186 | iou_types = _get_iou_types(model)
187 | coco_evaluator = CocoEvaluator(coco, iou_types)
188 |
189 | for image, targets in metric_logger.log_every(data_loader, 100, header):
190 | image = list(img.to(device) for img in image)
191 | targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
192 |
193 | torch.cuda.synchronize()
194 | model_time = time.time()
195 | outputs = model(image)
196 | #print(outputs)
197 |
198 | outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
199 | model_time = time.time() - model_time
200 |
201 | res = {target["image_id"].item(): output for target, output in zip(targets, outputs)}
202 | evaluator_time = time.time()
203 | coco_evaluator.update(res)
204 | evaluator_time = time.time() - evaluator_time
205 | metric_logger.update(model_time=model_time, evaluator_time=evaluator_time)
206 |
207 | # gather the stats from all processes
208 | metric_logger.synchronize_between_processes()
209 | print("Averaged stats:", metric_logger)
210 | coco_evaluator.synchronize_between_processes()
211 |
212 | # accumulate predictions from all images
213 | coco_evaluator.accumulate()
214 | coco_evaluator.summarize()
215 | torch.set_num_threads(n_threads)
216 | return coco_evaluator
217 |
--------------------------------------------------------------------------------
/trainer/cowc_GAN_trainer.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import numpy as np
3 | import torch
4 | import math
5 | import os
6 | import model.ESRGANModel as ESRGAN
7 | import model.ESRGAN_EESN_Model as ESRGAN_EESN
8 | from scripts_for_datasets import COWCDataset, COWCGANDataset
9 | from torchvision.utils import make_grid
10 | from base import BaseTrainer
11 | from utils import inf_loop, MetricTracker, visualize_bbox, visualize, calculate_psnr, save_img, tensor2img, mkdir
12 |
13 | logger = logging.getLogger('base')
14 | '''
15 | python train.py -c config_GAN.json
16 | modified from ESRGAN repo
17 | '''
18 |
19 | class COWCGANTrainer:
20 | """
21 | Trainer class
22 | """
23 | def __init__(self, config, data_loader, valid_data_loader=None):
24 | self.config = config
25 | self.data_loader = data_loader
26 |
27 | self.valid_data_loader = valid_data_loader
28 | self.do_validation = self.valid_data_loader is not None
29 | n_gpu = torch.cuda.device_count()
30 | self.device = torch.device('cuda:1' if n_gpu > 0 else 'cpu')
31 | self.train_size = int(math.ceil(self.data_loader.length / int(config['data_loader']['args']['batch_size'])))
32 | self.total_iters = int(config['train']['niter'])
33 | self.total_epochs = int(math.ceil(self.total_iters / self.train_size))
34 | print(self.total_epochs)
35 | self.model = ESRGAN_EESN.ESRGAN_EESN_Model(config,self.device)
36 |
37 | def test(self):
38 | for _, test_data in enumerate(self.data_loader):
39 | #print(val_data)
40 | img_name = os.path.splitext(os.path.basename(test_data['LQ_path'][0]))[0]
41 | img_dir = "/home/jakaria/Super_Resolution/Filter_Enhance_Detect/saved/"
42 |
43 | self.model.feed_data(test_data)
44 | self.model.test()
45 |
46 | visuals = self.model.get_current_visuals()
47 | sr_img = tensor2img(visuals['SR']) # uint8
48 | final_SR = tensor2img(visuals['final_SR']) # uint8
49 |
50 | # Save SR images for reference
51 | save_img_path = os.path.join(img_dir, 'combined_SR_images', img_name+'.png')
52 | save_img(sr_img, save_img_path)
53 | # Save final_SR images for reference
54 | save_img_path = os.path.join(img_dir, 'final_SR_images', img_name+'.png')
55 | save_img(final_SR, save_img_path)
56 |
57 | def train(self):
58 | '''
59 | Training logic for an epoch
60 | for visualization use the following code (use batch size = 1):
61 |
62 | category_id_to_name = {1: 'car'}
63 | for batch_idx, dataset_dict in enumerate(self.data_loader):
64 | if dataset_dict['idx'][0] == 10:
65 | print(dataset_dict)
66 | visualize(dataset_dict, category_id_to_name) #--> see this method in util
67 |
68 | #image size: torch.Size([10, 3, 256, 256]) if batch_size = 10
69 | '''
70 | logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
71 | self.data_loader.length, self.train_size))
72 | logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
73 | self.total_epochs, self.total_iters))
74 | # tensorboard logger
75 | if self.config['use_tb_logger'] and 'debug' not in self.config['name']:
76 | version = float(torch.__version__[0:3])
77 | if version >= 1.1: # PyTorch 1.1
78 | from torch.utils.tensorboard import SummaryWriter
79 | else:
80 | logger.info(
81 | 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
82 | from tensorboardX import SummaryWriter
83 | tb_logger = SummaryWriter(log_dir='saved/tb_logger/' + self.config['name'])
84 |
85 | current_step = 0
86 | start_epoch = 0
87 |
88 | #### training
89 | logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
90 | for epoch in range(start_epoch, self.total_epochs + 1):
91 | for _, train_data in enumerate(self.data_loader):
92 | current_step += 1
93 | if current_step > self.total_iters:
94 | break
95 | #### update learning rate
96 | self.model.update_learning_rate(current_step, warmup_iter=self.config['train']['warmup_iter'])
97 |
98 | #### training
99 | self.model.feed_data(train_data)
100 | self.model.optimize_parameters(current_step)
101 |
102 | #### log
103 | if current_step % self.config['logger']['print_freq'] == 0:
104 | logs = self.model.get_current_log()
105 | message = ' '.format(
106 | epoch, current_step, self.model.get_current_learning_rate())
107 | for k, v in logs.items():
108 | message += '{:s}: {:.4e} '.format(k, v)
109 | # tensorboard logger
110 | if self.config['use_tb_logger'] and 'debug' not in self.config['name']:
111 | tb_logger.add_scalar(k, v, current_step)
112 |
113 | logger.info(message)
114 |
115 | # validation
116 | if current_step % self.config['train']['val_freq'] == 0:
117 | avg_psnr = 0.0
118 | idx = 0
119 | for val_data in self.valid_data_loader:
120 | idx += 1
121 | #print(val_data)
122 | img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0]
123 | img_dir = os.path.join(self.config['path']['val_images'], img_name)
124 | mkdir(img_dir)
125 |
126 | self.model.feed_data(val_data)
127 | self.model.test()
128 |
129 | visuals = self.model.get_current_visuals()
130 | sr_img = tensor2img(visuals['SR']) # uint8
131 | gt_img = tensor2img(visuals['GT']) # uint8
132 | lap_learned = tensor2img(visuals['lap_learned']) # uint8
133 | lap = tensor2img(visuals['lap']) # uint8
134 | lap_HR = tensor2img(visuals['lap_HR']) # uint8
135 | final_SR = tensor2img(visuals['final_SR']) # uint8
136 |
137 | # Save SR images for reference
138 | save_img_path = os.path.join(img_dir,
139 | '{:s}_{:d}_SR.png'.format(img_name, current_step))
140 | save_img(sr_img, save_img_path)
141 | # Save GT images for reference
142 | save_img_path = os.path.join(img_dir,
143 | '{:s}_{:d}_GT.png'.format(img_name, current_step))
144 | save_img(gt_img, save_img_path)
145 | # Save final_SR images for reference
146 | save_img_path = os.path.join(img_dir,
147 | '{:s}_{:d}_final_SR.png'.format(img_name, current_step))
148 | save_img(final_SR, save_img_path)
149 | # Save lap_learned images for reference
150 | save_img_path = os.path.join(img_dir,
151 | '{:s}_{:d}_lap_learned.png'.format(img_name, current_step))
152 | save_img(lap_learned, save_img_path)
153 | # Save lap images for reference
154 | save_img_path = os.path.join(img_dir,
155 | '{:s}_{:d}_lap.png'.format(img_name, current_step))
156 | save_img(lap, save_img_path)
157 | # Save lap images for reference
158 | save_img_path = os.path.join(img_dir,
159 | '{:s}_{:d}_lap_HR.png'.format(img_name, current_step))
160 | save_img(lap_HR, save_img_path)
161 |
162 |
163 | # calculate PSNR
164 | crop_size = self.config['scale']
165 | gt_img = gt_img / 255.
166 | sr_img = sr_img / 255.
167 | cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :]
168 | cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :]
169 | avg_psnr += calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255)
170 |
171 | avg_psnr = avg_psnr / idx
172 |
173 | # log
174 | logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
175 | logger_val = logging.getLogger('val') # validation logger
176 | logger_val.info(' psnr: {:.4e}'.format(
177 | epoch, current_step, avg_psnr))
178 | # tensorboard logger
179 | if self.config['use_tb_logger'] and 'debug' not in self.config['name']:
180 | tb_logger.add_scalar('psnr', avg_psnr, current_step)
181 |
182 | #### save models and training states
183 | if current_step % self.config['logger']['save_checkpoint_freq'] == 0:
184 | logger.info('Saving models and training states.')
185 | self.model.save(current_step)
186 | self.model.save_training_state(epoch, current_step)
187 |
188 |
189 | logger.info('Saving the final model.')
190 | self.model.save('latest')
191 | logger.info('End of training.')
192 |
--------------------------------------------------------------------------------
/detection/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | from collections import defaultdict, deque
4 | import datetime
5 | import pickle
6 | import time
7 |
8 | import torch
9 | import torch.distributed as dist
10 |
11 | import errno
12 | import os
13 |
14 |
15 | class SmoothedValue(object):
16 | """Track a series of values and provide access to smoothed values over a
17 | window or the global series average.
18 | """
19 |
20 | def __init__(self, window_size=20, fmt=None):
21 | if fmt is None:
22 | fmt = "{median:.4f} ({global_avg:.4f})"
23 | self.deque = deque(maxlen=window_size)
24 | self.total = 0.0
25 | self.count = 0
26 | self.fmt = fmt
27 |
28 | def update(self, value, n=1):
29 | self.deque.append(value)
30 | self.count += n
31 | self.total += value * n
32 |
33 | def synchronize_between_processes(self):
34 | """
35 | Warning: does not synchronize the deque!
36 | """
37 | if not is_dist_avail_and_initialized():
38 | return
39 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
40 | dist.barrier()
41 | dist.all_reduce(t)
42 | t = t.tolist()
43 | self.count = int(t[0])
44 | self.total = t[1]
45 |
46 | @property
47 | def median(self):
48 | d = torch.tensor(list(self.deque))
49 | return d.median().item()
50 |
51 | @property
52 | def avg(self):
53 | d = torch.tensor(list(self.deque), dtype=torch.float32)
54 | return d.mean().item()
55 |
56 | @property
57 | def global_avg(self):
58 | return self.total / self.count
59 |
60 | @property
61 | def max(self):
62 | return max(self.deque)
63 |
64 | @property
65 | def value(self):
66 | return self.deque[-1]
67 |
68 | def __str__(self):
69 | return self.fmt.format(
70 | median=self.median,
71 | avg=self.avg,
72 | global_avg=self.global_avg,
73 | max=self.max,
74 | value=self.value)
75 |
76 |
77 | def all_gather(data):
78 | """
79 | Run all_gather on arbitrary picklable data (not necessarily tensors)
80 | Args:
81 | data: any picklable object
82 | Returns:
83 | list[data]: list of data gathered from each rank
84 | """
85 | world_size = get_world_size()
86 | if world_size == 1:
87 | return [data]
88 |
89 | # serialized to a Tensor
90 | buffer = pickle.dumps(data)
91 | storage = torch.ByteStorage.from_buffer(buffer)
92 | tensor = torch.ByteTensor(storage).to("cuda")
93 |
94 | # obtain Tensor size of each rank
95 | local_size = torch.tensor([tensor.numel()], device="cuda")
96 | size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
97 | dist.all_gather(size_list, local_size)
98 | size_list = [int(size.item()) for size in size_list]
99 | max_size = max(size_list)
100 |
101 | # receiving Tensor from all ranks
102 | # we pad the tensor because torch all_gather does not support
103 | # gathering tensors of different shapes
104 | tensor_list = []
105 | for _ in size_list:
106 | tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
107 | if local_size != max_size:
108 | padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
109 | tensor = torch.cat((tensor, padding), dim=0)
110 | dist.all_gather(tensor_list, tensor)
111 |
112 | data_list = []
113 | for size, tensor in zip(size_list, tensor_list):
114 | buffer = tensor.cpu().numpy().tobytes()[:size]
115 | data_list.append(pickle.loads(buffer))
116 |
117 | return data_list
118 |
119 |
120 | def reduce_dict(input_dict, average=True):
121 | """
122 | Args:
123 | input_dict (dict): all the values will be reduced
124 | average (bool): whether to do average or sum
125 | Reduce the values in the dictionary from all processes so that all processes
126 | have the averaged results. Returns a dict with the same fields as
127 | input_dict, after reduction.
128 | """
129 | world_size = get_world_size()
130 | if world_size < 2:
131 | return input_dict
132 | with torch.no_grad():
133 | names = []
134 | values = []
135 | # sort the keys so that they are consistent across processes
136 | for k in sorted(input_dict.keys()):
137 | names.append(k)
138 | values.append(input_dict[k])
139 | values = torch.stack(values, dim=0)
140 | dist.all_reduce(values)
141 | if average:
142 | values /= world_size
143 | reduced_dict = {k: v for k, v in zip(names, values)}
144 | return reduced_dict
145 |
146 |
147 | class MetricLogger(object):
148 | def __init__(self, delimiter="\t"):
149 | self.meters = defaultdict(SmoothedValue)
150 | self.delimiter = delimiter
151 |
152 | def update(self, **kwargs):
153 | for k, v in kwargs.items():
154 | if isinstance(v, torch.Tensor):
155 | v = v.item()
156 | assert isinstance(v, (float, int))
157 | self.meters[k].update(v)
158 |
159 | def __getattr__(self, attr):
160 | if attr in self.meters:
161 | return self.meters[attr]
162 | if attr in self.__dict__:
163 | return self.__dict__[attr]
164 | raise AttributeError("'{}' object has no attribute '{}'".format(
165 | type(self).__name__, attr))
166 |
167 | def __str__(self):
168 | loss_str = []
169 | for name, meter in self.meters.items():
170 | loss_str.append(
171 | "{}: {}".format(name, str(meter))
172 | )
173 | return self.delimiter.join(loss_str)
174 |
175 | def synchronize_between_processes(self):
176 | for meter in self.meters.values():
177 | meter.synchronize_between_processes()
178 |
179 | def add_meter(self, name, meter):
180 | self.meters[name] = meter
181 |
182 | def log_every(self, iterable, print_freq, header=None):
183 | i = 0
184 | if not header:
185 | header = ''
186 | start_time = time.time()
187 | end = time.time()
188 | iter_time = SmoothedValue(fmt='{avg:.4f}')
189 | data_time = SmoothedValue(fmt='{avg:.4f}')
190 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
191 | log_msg = self.delimiter.join([
192 | header,
193 | '[{0' + space_fmt + '}/{1}]',
194 | 'eta: {eta}',
195 | '{meters}',
196 | 'time: {time}',
197 | 'data: {data}',
198 | 'max mem: {memory:.0f}'
199 | ])
200 | MB = 1024.0 * 1024.0
201 | for obj in iterable:
202 | data_time.update(time.time() - end)
203 | yield obj
204 | iter_time.update(time.time() - end)
205 | if i % print_freq == 0 or i == len(iterable) - 1:
206 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
207 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
208 | print(log_msg.format(
209 | i, len(iterable), eta=eta_string,
210 | meters=str(self),
211 | time=str(iter_time), data=str(data_time),
212 | memory=torch.cuda.max_memory_allocated() / MB))
213 | i += 1
214 | end = time.time()
215 | total_time = time.time() - start_time
216 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
217 | print('{} Total time: {} ({:.4f} s / it)'.format(
218 | header, total_time_str, total_time / len(iterable)))
219 |
220 |
221 | def collate_fn(batch):
222 | return tuple(zip(*batch))
223 |
224 |
225 | def warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor):
226 |
227 | def f(x):
228 | if x >= warmup_iters:
229 | return 1
230 | alpha = float(x) / warmup_iters
231 | return warmup_factor * (1 - alpha) + alpha
232 |
233 | return torch.optim.lr_scheduler.LambdaLR(optimizer, f)
234 |
235 |
236 | def mkdir(path):
237 | try:
238 | os.makedirs(path)
239 | except OSError as e:
240 | if e.errno != errno.EEXIST:
241 | raise
242 |
243 |
244 | def setup_for_distributed(is_master):
245 | """
246 | This function disables printing when not in master process
247 | """
248 | import builtins as __builtin__
249 | builtin_print = __builtin__.print
250 |
251 | def print(*args, **kwargs):
252 | force = kwargs.pop('force', False)
253 | if is_master or force:
254 | builtin_print(*args, **kwargs)
255 |
256 | __builtin__.print = print
257 |
258 |
259 | def is_dist_avail_and_initialized():
260 | if not dist.is_available():
261 | return False
262 | if not dist.is_initialized():
263 | return False
264 | return True
265 |
266 |
267 | def get_world_size():
268 | if not is_dist_avail_and_initialized():
269 | return 1
270 | return dist.get_world_size()
271 |
272 |
273 | def get_rank():
274 | if not is_dist_avail_and_initialized():
275 | return 0
276 | return dist.get_rank()
277 |
278 |
279 | def is_main_process():
280 | return get_rank() == 0
281 |
282 |
283 | def save_on_master(*args, **kwargs):
284 | if is_main_process():
285 | torch.save(*args, **kwargs)
286 |
287 |
288 | def init_distributed_mode(args):
289 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
290 | args.rank = int(os.environ["RANK"])
291 | args.world_size = int(os.environ['WORLD_SIZE'])
292 | args.gpu = int(os.environ['LOCAL_RANK'])
293 | elif 'SLURM_PROCID' in os.environ:
294 | args.rank = int(os.environ['SLURM_PROCID'])
295 | args.gpu = args.rank % torch.cuda.device_count()
296 | else:
297 | print('Not using distributed mode')
298 | args.distributed = False
299 | return
300 |
301 | args.distributed = True
302 |
303 | torch.cuda.set_device(args.gpu)
304 | args.dist_backend = 'nccl'
305 | print('| distributed init (rank {}): {}'.format(
306 | args.rank, args.dist_url), flush=True)
307 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
308 | world_size=args.world_size, rank=args.rank)
309 | torch.distributed.barrier()
310 | setup_for_distributed(args.rank == 0)
311 |
--------------------------------------------------------------------------------
/detection/coco_utils.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import os
3 | from PIL import Image
4 |
5 | import torch
6 | import torch.utils.data
7 | import torchvision
8 |
9 | from pycocotools import mask as coco_mask
10 | from pycocotools.coco import COCO
11 |
12 | from torchvision import transforms as T
13 |
14 |
15 | class FilterAndRemapCocoCategories(object):
16 | def __init__(self, categories, remap=True):
17 | self.categories = categories
18 | self.remap = remap
19 |
20 | def __call__(self, image, target):
21 | anno = target["annotations"]
22 | anno = [obj for obj in anno if obj["category_id"] in self.categories]
23 | if not self.remap:
24 | target["annotations"] = anno
25 | return image, target
26 | anno = copy.deepcopy(anno)
27 | for obj in anno:
28 | obj["category_id"] = self.categories.index(obj["category_id"])
29 | target["annotations"] = anno
30 | return image, target
31 |
32 |
33 | def convert_coco_poly_to_mask(segmentations, height, width):
34 | masks = []
35 | for polygons in segmentations:
36 | rles = coco_mask.frPyObjects(polygons, height, width)
37 | mask = coco_mask.decode(rles)
38 | if len(mask.shape) < 3:
39 | mask = mask[..., None]
40 | mask = torch.as_tensor(mask, dtype=torch.uint8)
41 | mask = mask.any(dim=2)
42 | masks.append(mask)
43 | if masks:
44 | masks = torch.stack(masks, dim=0)
45 | else:
46 | masks = torch.zeros((0, height, width), dtype=torch.uint8)
47 | return masks
48 |
49 |
50 | class ConvertCocoPolysToMask(object):
51 | def __call__(self, image, target):
52 | w, h = image.size
53 |
54 | image_id = target["image_id"]
55 | image_id = torch.tensor([image_id])
56 |
57 | anno = target["annotations"]
58 |
59 | anno = [obj for obj in anno if obj['iscrowd'] == 0]
60 |
61 | boxes = [obj["bbox"] for obj in anno]
62 | # guard against no boxes via resizing
63 | boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
64 | boxes[:, 2:] += boxes[:, :2]
65 | boxes[:, 0::2].clamp_(min=0, max=w)
66 | boxes[:, 1::2].clamp_(min=0, max=h)
67 |
68 | classes = [obj["category_id"] for obj in anno]
69 | classes = torch.tensor(classes, dtype=torch.int64)
70 |
71 | segmentations = [obj["segmentation"] for obj in anno]
72 | masks = convert_coco_poly_to_mask(segmentations, h, w)
73 |
74 | keypoints = None
75 | if anno and "keypoints" in anno[0]:
76 | keypoints = [obj["keypoints"] for obj in anno]
77 | keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
78 | num_keypoints = keypoints.shape[0]
79 | if num_keypoints:
80 | keypoints = keypoints.view(num_keypoints, -1, 3)
81 |
82 | keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
83 | boxes = boxes[keep]
84 | classes = classes[keep]
85 | masks = masks[keep]
86 | if keypoints is not None:
87 | keypoints = keypoints[keep]
88 |
89 | target = {}
90 | target["boxes"] = boxes
91 | target["labels"] = classes
92 | target["masks"] = masks
93 | target["image_id"] = image_id
94 | if keypoints is not None:
95 | target["keypoints"] = keypoints
96 |
97 | # for conversion to coco api
98 | area = torch.tensor([obj["area"] for obj in anno])
99 | iscrowd = torch.tensor([obj["iscrowd"] for obj in anno])
100 | target["area"] = area
101 | target["iscrowd"] = iscrowd
102 |
103 | return image, target
104 |
105 |
106 | def _coco_remove_images_without_annotations(dataset, cat_list=None):
107 | def _has_only_empty_bbox(anno):
108 | return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)
109 |
110 | def _count_visible_keypoints(anno):
111 | return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)
112 |
113 | min_keypoints_per_image = 10
114 |
115 | def _has_valid_annotation(anno):
116 | # if it's empty, there is no annotation
117 | if len(anno) == 0:
118 | return False
119 | # if all boxes have close to zero area, there is no annotation
120 | if _has_only_empty_bbox(anno):
121 | return False
122 | # keypoints task have a slight different critera for considering
123 | # if an annotation is valid
124 | if "keypoints" not in anno[0]:
125 | return True
126 | # for keypoint detection tasks, only consider valid images those
127 | # containing at least min_keypoints_per_image
128 | if _count_visible_keypoints(anno) >= min_keypoints_per_image:
129 | return True
130 | return False
131 |
132 | assert isinstance(dataset, torchvision.datasets.CocoDetection)
133 | ids = []
134 | for ds_idx, img_id in enumerate(dataset.ids):
135 | ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
136 | anno = dataset.coco.loadAnns(ann_ids)
137 | if cat_list:
138 | anno = [obj for obj in anno if obj["category_id"] in cat_list]
139 | if _has_valid_annotation(anno):
140 | ids.append(ds_idx)
141 |
142 | dataset = torch.utils.data.Subset(dataset, ids)
143 | return dataset
144 |
145 |
146 | def convert_to_coco_api(ds):
147 | coco_ds = COCO()
148 | ann_id = 0
149 | dataset = {'images': [], 'categories': [], 'annotations': []}
150 | categories = set()
151 | for img_idx in range(len(ds)):
152 | # find better way to get target
153 | # targets = ds.get_annotations(img_idx)
154 | img, targets = ds[img_idx]
155 | image_id = targets["image_id"].item()
156 | img_dict = {}
157 | img_dict['id'] = image_id
158 | img_dict['height'] = img['image_lq'].shape[-2]
159 | img_dict['width'] = img['image_lq'].shape[-1]
160 | dataset['images'].append(img_dict)
161 | bboxes = targets["boxes"]
162 | bboxes[:, 2:] -= bboxes[:, :2]
163 | bboxes = bboxes.tolist()
164 | labels = targets['labels'].tolist()
165 | areas = targets['area'].tolist()
166 | iscrowd = targets['iscrowd'].tolist()
167 | if 'masks' in targets:
168 | masks = targets['masks']
169 | # make masks Fortran contiguous for coco_mask
170 | masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1)
171 | if 'keypoints' in targets:
172 | keypoints = targets['keypoints']
173 | keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist()
174 | num_objs = len(bboxes)
175 | for i in range(num_objs):
176 | ann = {}
177 | ann['image_id'] = image_id
178 | ann['bbox'] = bboxes[i]
179 | ann['category_id'] = labels[i]
180 | categories.add(labels[i])
181 | ann['area'] = areas[i]
182 | ann['iscrowd'] = iscrowd[i]
183 | ann['id'] = ann_id
184 | if 'masks' in targets:
185 | ann["segmentation"] = coco_mask.encode(masks[i].numpy())
186 | if 'keypoints' in targets:
187 | ann['keypoints'] = keypoints[i]
188 | ann['num_keypoints'] = sum(k != 0 for k in keypoints[i][2::3])
189 | dataset['annotations'].append(ann)
190 | ann_id += 1
191 | dataset['categories'] = [{'id': i} for i in sorted(categories)]
192 | coco_ds.dataset = dataset
193 | coco_ds.createIndex()
194 | return coco_ds
195 |
196 | def convert_to_coco_api_base(ds):
197 | coco_ds = COCO()
198 | ann_id = 0
199 | dataset = {'images': [], 'categories': [], 'annotations': []}
200 | categories = set()
201 | for img_idx in range(len(ds)):
202 | # find better way to get target
203 | # targets = ds.get_annotations(img_idx)
204 | img, targets = ds[img_idx]
205 | image_id = targets["image_id"].item()
206 | img_dict = {}
207 | img_dict['id'] = image_id
208 | img_dict['height'] = img.shape[-2]
209 | img_dict['width'] = img.shape[-1]
210 | dataset['images'].append(img_dict)
211 | bboxes = targets["boxes"]
212 | bboxes[:, 2:] -= bboxes[:, :2]
213 | bboxes = bboxes.tolist()
214 | labels = targets['labels'].tolist()
215 | areas = targets['area'].tolist()
216 | iscrowd = targets['iscrowd'].tolist()
217 | if 'masks' in targets:
218 | masks = targets['masks']
219 | # make masks Fortran contiguous for coco_mask
220 | masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1)
221 | if 'keypoints' in targets:
222 | keypoints = targets['keypoints']
223 | keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist()
224 | num_objs = len(bboxes)
225 | for i in range(num_objs):
226 | ann = {}
227 | ann['image_id'] = image_id
228 | ann['bbox'] = bboxes[i]
229 | ann['category_id'] = labels[i]
230 | categories.add(labels[i])
231 | ann['area'] = areas[i]
232 | ann['iscrowd'] = iscrowd[i]
233 | ann['id'] = ann_id
234 | if 'masks' in targets:
235 | ann["segmentation"] = coco_mask.encode(masks[i].numpy())
236 | if 'keypoints' in targets:
237 | ann['keypoints'] = keypoints[i]
238 | ann['num_keypoints'] = sum(k != 0 for k in keypoints[i][2::3])
239 | dataset['annotations'].append(ann)
240 | ann_id += 1
241 | dataset['categories'] = [{'id': i} for i in sorted(categories)]
242 | coco_ds.dataset = dataset
243 | coco_ds.createIndex()
244 | return coco_ds
245 |
246 |
247 | def get_coco_api_from_dataset(dataset):
248 | for i in range(10):
249 | if isinstance(dataset, torchvision.datasets.CocoDetection):
250 | break
251 | if isinstance(dataset, torch.utils.data.Subset):
252 | dataset = dataset.dataset
253 | if isinstance(dataset, torchvision.datasets.CocoDetection):
254 | return dataset.coco
255 | return convert_to_coco_api(dataset)
256 |
257 | def get_coco_api_from_dataset_base(dataset):
258 | for i in range(10):
259 | if isinstance(dataset, torchvision.datasets.CocoDetection):
260 | break
261 | if isinstance(dataset, torch.utils.data.Subset):
262 | dataset = dataset.dataset
263 | if isinstance(dataset, torchvision.datasets.CocoDetection):
264 | return dataset.coco
265 | return convert_to_coco_api_base(dataset)
266 |
267 |
268 | class CocoDetection(torchvision.datasets.CocoDetection):
269 | def __init__(self, img_folder, ann_file, transforms):
270 | super(CocoDetection, self).__init__(img_folder, ann_file)
271 | self._transforms = transforms
272 |
273 | def __getitem__(self, idx):
274 | img, target = super(CocoDetection, self).__getitem__(idx)
275 | image_id = self.ids[idx]
276 | target = dict(image_id=image_id, annotations=target)
277 | if self._transforms is not None:
278 | img, target = self._transforms(img, target)
279 | return img, target
280 |
281 |
282 | def get_coco(root, image_set, transforms, mode='instances'):
283 | anno_file_template = "{}_{}2017.json"
284 | PATHS = {
285 | "train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))),
286 | "val": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val"))),
287 | # "train": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val")))
288 | }
289 |
290 | t = [ConvertCocoPolysToMask()]
291 |
292 | if transforms is not None:
293 | t.append(transforms)
294 | transforms = T.Compose(t)
295 |
296 | img_folder, ann_file = PATHS[image_set]
297 | img_folder = os.path.join(root, img_folder)
298 | ann_file = os.path.join(root, ann_file)
299 |
300 | dataset = CocoDetection(img_folder, ann_file, transforms=transforms)
301 |
302 | if image_set == "train":
303 | dataset = _coco_remove_images_without_annotations(dataset)
304 |
305 | # dataset = torch.utils.data.Subset(dataset, [i for i in range(500)])
306 |
307 | return dataset
308 |
309 |
310 | def get_coco_kp(root, image_set, transforms):
311 | return get_coco(root, image_set, transforms, mode="person_keypoints")
312 |
--------------------------------------------------------------------------------
/model/ESRGANModel.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from collections import OrderedDict
3 | import torch
4 | import torch.nn as nn
5 | import model.model as model
6 | import model.lr_scheduler as lr_scheduler
7 | from model.loss import GANLoss
8 | from .gan_base_model import BaseModel
9 | from torch.nn.parallel import DataParallel
10 |
11 | logger = logging.getLogger('base')
12 | # Taken from ESRGAN BASICSR repository and modified
13 | class ESRGANModel(BaseModel):
14 | def __init__(self, config, device):
15 | super(ESRGANModel, self).__init__(config, device)
16 | self.configG = config['network_G']
17 | self.configD = config['network_D']
18 | self.configT = config['train']
19 | self.configO = config['optimizer']['args']
20 | self.configS = config['lr_scheduler']
21 | self.device = device
22 | #Generator
23 | self.netG = model.RRDBNet(in_nc=self.configG['in_nc'], out_nc=self.configG['out_nc'],
24 | nf=self.configG['nf'], nb=self.configG['nb'])
25 | self.netG = self.netG.to(self.device)
26 | self.netG = DataParallel(self.netG)
27 | #descriminator
28 | self.netD = model.Discriminator_VGG_128(in_nc=self.configD['in_nc'], nf=self.configD['nf'])
29 | self.netD = self.netD.to(self.device)
30 | self.netD = DataParallel(self.netD)
31 |
32 | self.netG.train()
33 | self.netD.train()
34 | #print(self.configT['pixel_weight'])
35 | # G pixel loss
36 | if self.configT['pixel_weight'] > 0.0:
37 | l_pix_type = self.configT['pixel_criterion']
38 | if l_pix_type == 'l1':
39 | self.cri_pix = nn.L1Loss().to(self.device)
40 | elif l_pix_type == 'l2':
41 | self.cri_pix = nn.MSELoss().to(self.device)
42 | else:
43 | raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
44 | self.l_pix_w = self.configT['pixel_weight']
45 | else:
46 | self.cri_pix = None
47 |
48 | # G feature loss
49 | #print(self.configT['feature_weight']+1)
50 | if self.configT['feature_weight'] > 0:
51 | l_fea_type = self.configT['feature_criterion']
52 | if l_fea_type == 'l1':
53 | self.cri_fea = nn.L1Loss().to(self.device)
54 | elif l_fea_type == 'l2':
55 | self.cri_fea = nn.MSELoss().to(self.device)
56 | else:
57 | raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
58 | self.l_fea_w = self.configT['feature_weight']
59 | else:
60 | self.cri_fea = None
61 | if self.cri_fea: # load VGG perceptual loss
62 | self.netF = model.VGGFeatureExtractor(feature_layer=34,
63 | use_input_norm=True, device=self.device)
64 | self.netF = self.netF.to(self.device)
65 | self.netF = DataParallel(self.netF)
66 | self.netF.eval()
67 |
68 | # GD gan loss
69 | self.cri_gan = GANLoss(self.configT['gan_type'], 1.0, 0.0).to(self.device)
70 | self.l_gan_w = self.configT['gan_weight']
71 | # D_update_ratio and D_init_iters
72 | self.D_update_ratio = self.configT['D_update_ratio'] if self.configT['D_update_ratio'] else 1
73 | self.D_init_iters = self.configT['D_init_iters'] if self.configT['D_init_iters'] else 0
74 |
75 |
76 | # optimizers
77 | # G
78 | wd_G = self.configO['weight_decay_G'] if self.configO['weight_decay_G'] else 0
79 | optim_params = []
80 | for k, v in self.netG.named_parameters(): # can optimize for a part of the model
81 | if v.requires_grad:
82 | optim_params.append(v)
83 |
84 | self.optimizer_G = torch.optim.Adam(optim_params, lr=self.configO['lr_G'],
85 | weight_decay=wd_G,
86 | betas=(self.configO['beta1_G'], self.configO['beta2_G']))
87 | self.optimizers.append(self.optimizer_G)
88 | # D
89 | wd_D = self.configO['weight_decay_D'] if self.configO['weight_decay_D'] else 0
90 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=self.configO['lr_D'],
91 | weight_decay=wd_D,
92 | betas=(self.configO['beta1_D'], self.configO['beta2_D']))
93 | self.optimizers.append(self.optimizer_D)
94 |
95 | # schedulers
96 | if self.configS['type'] == 'MultiStepLR':
97 | for optimizer in self.optimizers:
98 | self.schedulers.append(
99 | lr_scheduler.MultiStepLR_Restart(optimizer, self.configS['args']['lr_steps'],
100 | restarts=self.configS['args']['restarts'],
101 | weights=self.configS['args']['restart_weights'],
102 | gamma=self.configS['args']['lr_gamma'],
103 | clear_state=False))
104 | elif self.configS['type'] == 'CosineAnnealingLR_Restart':
105 | for optimizer in self.optimizers:
106 | self.schedulers.append(
107 | lr_scheduler.CosineAnnealingLR_Restart(
108 | optimizer, self.configS['args']['T_period'], eta_min=self.configS['args']['eta_min'],
109 | restarts=self.configS['args']['restarts'], weights=self.configS['args']['restart_weights']))
110 | else:
111 | raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
112 | print(self.configS['args']['restarts'])
113 | self.log_dict = OrderedDict()
114 |
115 | self.print_network() # print network
116 | self.load() # load G and D if needed
117 | '''
118 | The main repo did not use collate_fn and image read has different flags
119 | and also used np.ascontiguousarray()
120 | Might change my code if problem happens
121 | '''
122 | def feed_data(self, data):
123 | self.var_L = data['image_lq'].to(self.device)
124 | self.var_H = data['image'].to(self.device)
125 | input_ref = data['ref'] if 'ref' in data else data['image']
126 | self.var_ref = input_ref.to(self.device)
127 |
128 | def optimize_parameters(self, step):
129 | #Generator
130 | for p in self.netD.parameters():
131 | p.requires_grad = False
132 | self.optimizer_G.zero_grad()
133 | self.fake_H = self.netG(self.var_L)
134 |
135 | l_g_total = 0
136 | if step % self.D_update_ratio == 0 and step > self.D_init_iters:
137 | if self.cri_pix: #pixel loss
138 | l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
139 | l_g_total += l_g_pix
140 | if self.cri_fea: # feature loss
141 | real_fea = self.netF(self.var_H).detach() #don't want to backpropagate this, need proper explanation
142 | fake_fea = self.netF(self.fake_H) #In netF normalize=False, check it
143 | l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
144 | l_g_total += l_g_fea
145 |
146 | pred_g_fake = self.netD(self.fake_H)
147 | if self.configT['gan_type'] == 'gan':
148 | l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
149 | elif self.configT['gan_type'] == 'ragan':
150 | pred_d_real = self.netD(self.var_ref).detach()
151 | l_g_gan = self.l_gan_w * (
152 | self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
153 | self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
154 | l_g_total += l_g_gan
155 |
156 | l_g_total.backward()
157 | self.optimizer_G.step()
158 |
159 | #descriminator
160 | for p in self.netD.parameters():
161 | p.requires_grad = True
162 |
163 | self.optimizer_D.zero_grad()
164 | l_d_total = 0
165 | pred_d_real = self.netD(self.var_ref)
166 | pred_d_fake = self.netD(self.fake_H.detach()) #to avoid BP to Generator
167 | if self.configT['gan_type'] == 'gan':
168 | l_d_real = self.cri_gan(pred_d_real, True)
169 | l_d_fake = self.cri_gan(pred_d_fake, False)
170 | l_d_total = l_d_real + l_d_fake
171 | elif self.configT['gan_type'] == 'ragan':
172 | l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
173 | l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
174 | l_d_total = (l_d_real + l_d_fake) / 2
175 |
176 | l_d_total.backward()
177 | self.optimizer_D.step()
178 |
179 | # set log
180 | if step % self.D_update_ratio == 0 and step > self.D_init_iters:
181 | if self.cri_pix:
182 | self.log_dict['l_g_pix'] = l_g_pix.item()
183 | if self.cri_fea:
184 | self.log_dict['l_g_fea'] = l_g_fea.item()
185 | self.log_dict['l_g_gan'] = l_g_gan.item()
186 |
187 | self.log_dict['l_d_real'] = l_d_real.item()
188 | self.log_dict['l_d_fake'] = l_d_fake.item()
189 | self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
190 | self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())
191 |
192 | def test(self):
193 | self.netG.eval()
194 | with torch.no_grad():
195 | self.fake_H = self.netG(self.var_L)
196 | self.netG.train()
197 |
198 | def get_current_log(self):
199 | return self.log_dict
200 |
201 | def get_current_visuals(self, need_GT=True):
202 | out_dict = OrderedDict()
203 | out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
204 | out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
205 | if need_GT:
206 | out_dict['GT'] = self.var_H.detach()[0].float().cpu()
207 | return out_dict
208 |
209 | def print_network(self):
210 | # Generator
211 | s, n = self.get_network_description(self.netG)
212 | if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
213 | net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
214 | self.netG.module.__class__.__name__)
215 | else:
216 | net_struc_str = '{}'.format(self.netG.__class__.__name__)
217 |
218 | logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
219 | logger.info(s)
220 |
221 | # Discriminator
222 | s, n = self.get_network_description(self.netD)
223 | if isinstance(self.netD, nn.DataParallel) or isinstance(self.netD,
224 | DistributedDataParallel):
225 | net_struc_str = '{} - {}'.format(self.netD.__class__.__name__,
226 | self.netD.module.__class__.__name__)
227 | else:
228 | net_struc_str = '{}'.format(self.netD.__class__.__name__)
229 |
230 | logger.info('Network D structure: {}, with parameters: {:,d}'.format(
231 | net_struc_str, n))
232 | logger.info(s)
233 |
234 | if self.cri_fea: # F, Perceptual Network
235 | s, n = self.get_network_description(self.netF)
236 | if isinstance(self.netF, nn.DataParallel) or isinstance(
237 | self.netF, DistributedDataParallel):
238 | net_struc_str = '{} - {}'.format(self.netF.__class__.__name__,
239 | self.netF.module.__class__.__name__)
240 | else:
241 | net_struc_str = '{}'.format(self.netF.__class__.__name__)
242 |
243 | logger.info('Network F structure: {}, with parameters: {:,d}'.format(
244 | net_struc_str, n))
245 | logger.info(s)
246 |
247 | def load(self):
248 | load_path_G = self.config['path']['pretrain_model_G']
249 | if load_path_G:
250 | logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
251 | self.load_network(load_path_G, self.netG, self.config['path']['strict_load'])
252 | load_path_D = self.config['path']['pretrain_model_D']
253 | if load_path_D:
254 | logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
255 | self.load_network(load_path_D, self.netD, self.config['path']['strict_load'])
256 |
257 | def save(self, iter_step):
258 | self.save_network(self.netG, 'G', iter_step)
259 | self.save_network(self.netD, 'D', iter_step)
260 |
--------------------------------------------------------------------------------
/detection/coco_eval.py:
--------------------------------------------------------------------------------
1 | import json
2 | import tempfile
3 |
4 | import numpy as np
5 | import copy
6 | import time
7 | import torch
8 | import torch._six
9 |
10 | from pycocotools.cocoeval import COCOeval
11 | from pycocotools.coco import COCO
12 | import pycocotools.mask as mask_util
13 |
14 | from collections import defaultdict
15 |
16 | from .utils import all_gather
17 |
18 |
19 | class CocoEvaluator(object):
20 | def __init__(self, coco_gt, iou_types):
21 | assert isinstance(iou_types, (list, tuple))
22 | coco_gt = copy.deepcopy(coco_gt)
23 | self.coco_gt = coco_gt
24 |
25 | self.iou_types = iou_types
26 | self.coco_eval = {}
27 | for iou_type in iou_types:
28 | self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type)
29 |
30 | self.img_ids = []
31 | self.eval_imgs = {k: [] for k in iou_types}
32 |
33 | def update(self, predictions):
34 | img_ids = list(np.unique(list(predictions.keys())))
35 | self.img_ids.extend(img_ids)
36 |
37 | for iou_type in self.iou_types:
38 | results = self.prepare(predictions, iou_type)
39 | coco_dt = loadRes(self.coco_gt, results) if results else COCO()
40 | coco_eval = self.coco_eval[iou_type]
41 |
42 | coco_eval.cocoDt = coco_dt
43 | coco_eval.params.imgIds = list(img_ids)
44 | img_ids, eval_imgs = evaluate(coco_eval)
45 |
46 | self.eval_imgs[iou_type].append(eval_imgs)
47 |
48 | def synchronize_between_processes(self):
49 | for iou_type in self.iou_types:
50 | self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
51 | create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type])
52 |
53 | def accumulate(self):
54 | for coco_eval in self.coco_eval.values():
55 | coco_eval.accumulate()
56 |
57 | def summarize(self):
58 | for iou_type, coco_eval in self.coco_eval.items():
59 | print("IoU metric: {}".format(iou_type))
60 | coco_eval.summarize()
61 |
62 | def prepare(self, predictions, iou_type):
63 | if iou_type == "bbox":
64 | return self.prepare_for_coco_detection(predictions)
65 | elif iou_type == "segm":
66 | return self.prepare_for_coco_segmentation(predictions)
67 | elif iou_type == "keypoints":
68 | return self.prepare_for_coco_keypoint(predictions)
69 | else:
70 | raise ValueError("Unknown iou type {}".format(iou_type))
71 |
72 | def prepare_for_coco_detection(self, predictions):
73 | coco_results = []
74 | for original_id, prediction in predictions.items():
75 | if len(prediction) == 0:
76 | continue
77 |
78 | boxes = prediction["boxes"]
79 | boxes = convert_to_xywh(boxes).tolist()
80 | scores = prediction["scores"].tolist()
81 | labels = prediction["labels"].tolist()
82 |
83 | coco_results.extend(
84 | [
85 | {
86 | "image_id": original_id,
87 | "category_id": labels[k],
88 | "bbox": box,
89 | "score": scores[k],
90 | }
91 | for k, box in enumerate(boxes)
92 | ]
93 | )
94 | return coco_results
95 |
96 | def prepare_for_coco_segmentation(self, predictions):
97 | coco_results = []
98 | for original_id, prediction in predictions.items():
99 | if len(prediction) == 0:
100 | continue
101 |
102 | scores = prediction["scores"]
103 | labels = prediction["labels"]
104 | masks = prediction["masks"]
105 |
106 | masks = masks > 0.5
107 |
108 | scores = prediction["scores"].tolist()
109 | labels = prediction["labels"].tolist()
110 |
111 | rles = [
112 | mask_util.encode(np.array(mask[0, :, :, np.newaxis], order="F"))[0]
113 | for mask in masks
114 | ]
115 | for rle in rles:
116 | rle["counts"] = rle["counts"].decode("utf-8")
117 |
118 | coco_results.extend(
119 | [
120 | {
121 | "image_id": original_id,
122 | "category_id": labels[k],
123 | "segmentation": rle,
124 | "score": scores[k],
125 | }
126 | for k, rle in enumerate(rles)
127 | ]
128 | )
129 | return coco_results
130 |
131 | def prepare_for_coco_keypoint(self, predictions):
132 | coco_results = []
133 | for original_id, prediction in predictions.items():
134 | if len(prediction) == 0:
135 | continue
136 |
137 | boxes = prediction["boxes"]
138 | boxes = convert_to_xywh(boxes).tolist()
139 | scores = prediction["scores"].tolist()
140 | labels = prediction["labels"].tolist()
141 | keypoints = prediction["keypoints"]
142 | keypoints = keypoints.flatten(start_dim=1).tolist()
143 |
144 | coco_results.extend(
145 | [
146 | {
147 | "image_id": original_id,
148 | "category_id": labels[k],
149 | 'keypoints': keypoint,
150 | "score": scores[k],
151 | }
152 | for k, keypoint in enumerate(keypoints)
153 | ]
154 | )
155 | return coco_results
156 |
157 |
158 | def convert_to_xywh(boxes):
159 | xmin, ymin, xmax, ymax = boxes.unbind(1)
160 | return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)
161 |
162 |
163 | def merge(img_ids, eval_imgs):
164 | all_img_ids = all_gather(img_ids)
165 | all_eval_imgs = all_gather(eval_imgs)
166 |
167 | merged_img_ids = []
168 | for p in all_img_ids:
169 | merged_img_ids.extend(p)
170 |
171 | merged_eval_imgs = []
172 | for p in all_eval_imgs:
173 | merged_eval_imgs.append(p)
174 |
175 | merged_img_ids = np.array(merged_img_ids)
176 | merged_eval_imgs = np.concatenate(merged_eval_imgs, 2)
177 |
178 | # keep only unique (and in sorted order) images
179 | merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
180 | merged_eval_imgs = merged_eval_imgs[..., idx]
181 |
182 | return merged_img_ids, merged_eval_imgs
183 |
184 |
185 | def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
186 | img_ids, eval_imgs = merge(img_ids, eval_imgs)
187 | img_ids = list(img_ids)
188 | eval_imgs = list(eval_imgs.flatten())
189 |
190 | coco_eval.evalImgs = eval_imgs
191 | coco_eval.params.imgIds = img_ids
192 | coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
193 |
194 |
195 | #################################################################
196 | # From pycocotools, just removed the prints and fixed
197 | # a Python3 bug about unicode not defined
198 | #################################################################
199 |
200 | # Ideally, pycocotools wouldn't have hard-coded prints
201 | # so that we could avoid copy-pasting those two functions
202 |
203 | def createIndex(self):
204 | # create index
205 | # print('creating index...')
206 | anns, cats, imgs = {}, {}, {}
207 | imgToAnns, catToImgs = defaultdict(list), defaultdict(list)
208 | if 'annotations' in self.dataset:
209 | for ann in self.dataset['annotations']:
210 | imgToAnns[ann['image_id']].append(ann)
211 | anns[ann['id']] = ann
212 |
213 | if 'images' in self.dataset:
214 | for img in self.dataset['images']:
215 | imgs[img['id']] = img
216 |
217 | if 'categories' in self.dataset:
218 | for cat in self.dataset['categories']:
219 | cats[cat['id']] = cat
220 |
221 | if 'annotations' in self.dataset and 'categories' in self.dataset:
222 | for ann in self.dataset['annotations']:
223 | catToImgs[ann['category_id']].append(ann['image_id'])
224 |
225 | # print('index created!')
226 |
227 | # create class members
228 | self.anns = anns
229 | self.imgToAnns = imgToAnns
230 | self.catToImgs = catToImgs
231 | self.imgs = imgs
232 | self.cats = cats
233 |
234 |
235 | maskUtils = mask_util
236 |
237 |
238 | def loadRes(self, resFile):
239 | """
240 | Load result file and return a result api object.
241 | :param resFile (str) : file name of result file
242 | :return: res (obj) : result api object
243 | """
244 | res = COCO()
245 | res.dataset['images'] = [img for img in self.dataset['images']]
246 |
247 | # print('Loading and preparing results...')
248 | # tic = time.time()
249 | if isinstance(resFile, torch._six.string_classes):
250 | anns = json.load(open(resFile))
251 | elif type(resFile) == np.ndarray:
252 | anns = self.loadNumpyAnnotations(resFile)
253 | else:
254 | anns = resFile
255 | assert type(anns) == list, 'results in not an array of objects'
256 | annsImgIds = [ann['image_id'] for ann in anns]
257 | assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), \
258 | 'Results do not correspond to current coco set'
259 | if 'caption' in anns[0]:
260 | imgIds = set([img['id'] for img in res.dataset['images']]) & set([ann['image_id'] for ann in anns])
261 | res.dataset['images'] = [img for img in res.dataset['images'] if img['id'] in imgIds]
262 | for id, ann in enumerate(anns):
263 | ann['id'] = id + 1
264 | elif 'bbox' in anns[0] and not anns[0]['bbox'] == []:
265 | res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
266 | for id, ann in enumerate(anns):
267 | bb = ann['bbox']
268 | x1, x2, y1, y2 = [bb[0], bb[0] + bb[2], bb[1], bb[1] + bb[3]]
269 | if 'segmentation' not in ann:
270 | ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]]
271 | ann['area'] = bb[2] * bb[3]
272 | ann['id'] = id + 1
273 | ann['iscrowd'] = 0
274 | elif 'segmentation' in anns[0]:
275 | res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
276 | for id, ann in enumerate(anns):
277 | # now only support compressed RLE format as segmentation results
278 | ann['area'] = maskUtils.area(ann['segmentation'])
279 | if 'bbox' not in ann:
280 | ann['bbox'] = maskUtils.toBbox(ann['segmentation'])
281 | ann['id'] = id + 1
282 | ann['iscrowd'] = 0
283 | elif 'keypoints' in anns[0]:
284 | res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
285 | for id, ann in enumerate(anns):
286 | s = ann['keypoints']
287 | x = s[0::3]
288 | y = s[1::3]
289 | x0, x1, y0, y1 = np.min(x), np.max(x), np.min(y), np.max(y)
290 | ann['area'] = (x1 - x0) * (y1 - y0)
291 | ann['id'] = id + 1
292 | ann['bbox'] = [x0, y0, x1 - x0, y1 - y0]
293 | # print('DONE (t={:0.2f}s)'.format(time.time()- tic))
294 |
295 | res.dataset['annotations'] = anns
296 | createIndex(res)
297 | return res
298 |
299 |
300 | def evaluate(self):
301 | '''
302 | Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
303 | :return: None
304 | '''
305 | # tic = time.time()
306 | # print('Running per image evaluation...')
307 | p = self.params
308 | # add backward compatibility if useSegm is specified in params
309 | if p.useSegm is not None:
310 | p.iouType = 'segm' if p.useSegm == 1 else 'bbox'
311 | print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType))
312 | # print('Evaluate annotation type *{}*'.format(p.iouType))
313 | p.imgIds = list(np.unique(p.imgIds))
314 | if p.useCats:
315 | p.catIds = list(np.unique(p.catIds))
316 | p.maxDets = sorted(p.maxDets)
317 | self.params = p
318 |
319 | self._prepare()
320 | # loop through images, area range, max detection number
321 | catIds = p.catIds if p.useCats else [-1]
322 |
323 | if p.iouType == 'segm' or p.iouType == 'bbox':
324 | computeIoU = self.computeIoU
325 | elif p.iouType == 'keypoints':
326 | computeIoU = self.computeOks
327 | self.ious = {
328 | (imgId, catId): computeIoU(imgId, catId)
329 | for imgId in p.imgIds
330 | for catId in catIds}
331 |
332 | evaluateImg = self.evaluateImg
333 | maxDet = p.maxDets[-1]
334 | evalImgs = [
335 | evaluateImg(imgId, catId, areaRng, maxDet)
336 | for catId in catIds
337 | for areaRng in p.areaRng
338 | for imgId in p.imgIds
339 | ]
340 | # this is NOT in the pycocotools code, but could be done outside
341 | evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds))
342 | self._paramsEval = copy.deepcopy(self.params)
343 | # toc = time.time()
344 | # print('DONE (t={:0.2f}s).'.format(toc-tic))
345 | return p.imgIds, evalImgs
346 |
347 | #################################################################
348 | # end of straight copy from pycocotools, just removing the prints
349 | #################################################################
350 |
--------------------------------------------------------------------------------
/trainer/FRCNN_trainer.py:
--------------------------------------------------------------------------------
1 | '''
2 | quick and dirty test, need to change later
3 | '''
4 | import torch
5 | import torch.nn as nn
6 | import numpy as np
7 | import torchvision
8 | import os
9 | import cv2
10 | from collections import OrderedDict
11 | from torch.nn.parallel import DataParallel, DistributedDataParallel
12 | from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
13 | from torch.utils.data import Dataset, DataLoader
14 | from torchvision import utils
15 | from detection.engine import train_one_epoch, evaluate_base
16 | from detection.utils import collate_fn
17 | from scripts_for_datasets import COWCFRCNNDataset
18 | from detection.transforms import ToTensor, RandomHorizontalFlip, Compose
19 | from matplotlib import pyplot as plt
20 | from utils import tensor2img
21 |
22 | class COWCFRCNNTrainer:
23 | """
24 | Trainer class
25 | """
26 | def __init__(self, config):
27 | self.config = config
28 |
29 | n_gpu = torch.cuda.device_count()
30 | self.device = torch.device('cuda:0' if n_gpu > 0 else 'cpu')
31 |
32 | def get_transform(self, train):
33 | transforms = []
34 | # converts the image, a PIL image, into a PyTorch Tensor
35 | transforms.append(ToTensor())
36 | if train:
37 | # during training, randomly flip the training images
38 | # and ground-truth for data augmentation
39 | transforms.append(RandomHorizontalFlip(0.5))
40 | return Compose(transforms)
41 |
42 | def data_loaders(self):
43 | # use our dataset and defined transformations
44 | dataset = COWCFRCNNDataset(root=self.config['path']['data_dir_LR_train'],
45 | image_height=64, image_width=64, transforms=self.get_transform(train=True))
46 | dataset_test = COWCFRCNNDataset(root=self.config['path']['data_dir_Valid'],
47 | image_height=64, image_width=64, transforms=self.get_transform(train=False))
48 | dataset_test_SR = COWCFRCNNDataset(root=self.config['path']['data_dir_SR'],
49 | transforms=self.get_transform(train=False))
50 | dataset_test_SR_combined = COWCFRCNNDataset(root=self.config['path']['data_dir_SR_combined'],
51 | transforms=self.get_transform(train=False))
52 | dataset_test_E_SR_1 = COWCFRCNNDataset(root=self.config['path']['data_dir_E_SR_1'],
53 | transforms=self.get_transform(train=False))
54 | dataset_test_E_SR_2 = COWCFRCNNDataset(root=self.config['path']['data_dir_E_SR_2'],
55 | transforms=self.get_transform(train=False))
56 | dataset_test_E_SR_3 = COWCFRCNNDataset(root=self.config['path']['data_dir_E_SR_3'],
57 | transforms=self.get_transform(train=False))
58 | dataset_test_F_SR = COWCFRCNNDataset(root=self.config['path']['data_dir_F_SR'],
59 | transforms=self.get_transform(train=False))
60 | dataset_test_Bic = COWCFRCNNDataset(root=self.config['path']['data_dir_Bic'],
61 | transforms=self.get_transform(train=False))
62 |
63 | # define training and validation data loaders
64 | data_loader = torch.utils.data.DataLoader(
65 | dataset, batch_size=2, shuffle=True, num_workers=4,
66 | collate_fn=collate_fn)
67 |
68 | data_loader_test = torch.utils.data.DataLoader(
69 | dataset_test, batch_size=1, shuffle=False, num_workers=4,
70 | collate_fn=collate_fn)
71 |
72 | data_loader_test_SR = torch.utils.data.DataLoader(
73 | dataset_test_SR, batch_size=1, shuffle=False, num_workers=4,
74 | collate_fn=collate_fn)
75 |
76 | data_loader_test_SR_combined = torch.utils.data.DataLoader(
77 | dataset_test_SR_combined, batch_size=1, shuffle=False, num_workers=4,
78 | collate_fn=collate_fn)
79 |
80 | data_loader_test_E_SR_1 = torch.utils.data.DataLoader(
81 | dataset_test_E_SR_1, batch_size=1, shuffle=False, num_workers=4,
82 | collate_fn=collate_fn)
83 |
84 | data_loader_test_E_SR_2 = torch.utils.data.DataLoader(
85 | dataset_test_E_SR_2, batch_size=1, shuffle=False, num_workers=4,
86 | collate_fn=collate_fn)
87 |
88 | data_loader_test_E_SR_3 = torch.utils.data.DataLoader(
89 | dataset_test_E_SR_3, batch_size=1, shuffle=False, num_workers=4,
90 | collate_fn=collate_fn)
91 |
92 | data_loader_test_F_SR = torch.utils.data.DataLoader(
93 | dataset_test_F_SR, batch_size=1, shuffle=False, num_workers=4,
94 | collate_fn=collate_fn)
95 |
96 | data_loader_test_Bic = torch.utils.data.DataLoader(
97 | dataset_test_Bic, batch_size=1, shuffle=False, num_workers=4,
98 | collate_fn=collate_fn)
99 |
100 | return data_loader, data_loader_test, data_loader_test_SR, data_loader_test_SR_combined, \
101 | data_loader_test_E_SR_1, data_loader_test_E_SR_2, data_loader_test_E_SR_3, \
102 | data_loader_test_F_SR, data_loader_test_Bic
103 |
104 | def save_model(self, network, network_label, iter_label):
105 | save_filename = '{}_{}.pth'.format(iter_label, network_label)
106 | save_path = os.path.join(self.config['path']['FRCNN_model'], save_filename)
107 |
108 | state_dict = network.state_dict()
109 | for key, param in state_dict.items():
110 | state_dict[key] = param.cpu()
111 | torch.save(state_dict, save_path)
112 |
113 | def load_model(self, load_path, network, strict=True):
114 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
115 | network = network.module
116 | load_net = torch.load(load_path)
117 | load_net_clean = OrderedDict() # remove unnecessary 'module.'
118 | for k, v in load_net.items():
119 | if k.startswith('module.'):
120 | load_net_clean[k[7:]] = v
121 | else:
122 | load_net_clean[k] = v
123 | network.load_state_dict(load_net_clean, strict=strict)
124 | print("model_loaded")
125 |
126 | '''
127 | Draw boxes on the test images
128 | '''
129 | def draw_detection_boxes(self, new_class_conf_box, file_path):
130 | source_image_path = os.path.join(self.config['path']['data_dir_Bic_valid'], os.path.splitext(os.path.basename(file_path))[0]+'.jpg')
131 | dest_image_path = os.path.splitext(file_path)[0]+'.jpg'
132 | #print(dest_image_path)
133 | image = cv2.imread(source_image_path,1)
134 | #print(new_class_conf_box)
135 | #print(len(new_class_conf_box))
136 | for i in range(len(new_class_conf_box)):
137 | clas,con,x1,y1,x2,y2 = new_class_conf_box[i]
138 | cv2.rectangle(image, (x1, y1), (x2, y2), (0,0,255), 4)
139 | font = cv2.FONT_HERSHEY_SIMPLEX
140 | cv2.putText(image, 'Car: '+ str((int(con*100))) + '%', ((x1)+5, (y1)+8), font, 0.2,(0,255,0),1,cv2.LINE_AA)
141 |
142 | cv2.imwrite(dest_image_path, image)
143 |
144 | '''
145 | for generating test boxes
146 | '''
147 | def get_prediction(self, model, images, annotation_path, threshold=0.5):
148 | new_class_conf_box = list()
149 | image = list(img.to(self.device) for img in images)
150 | outputs = model(image)
151 | file_path = os.path.join(self.config['path']['Test_Result_LR_LR_COWC'], os.path.basename(annotation_path))
152 | #print(file_path)
153 | pred_class = [i for i in list(outputs[0]['labels'].detach().cpu().numpy())] # Get the Prediction Score
154 | text_boxes = [ [i[0], i[1], i[2], i[3] ] for i in list(outputs[0]['boxes'].detach().cpu().numpy())] # Bounding boxes
155 | pred_score = list(outputs[0]['scores'].detach().cpu().numpy())
156 | #print(pred_score)
157 | for i in range(len(text_boxes)):
158 | if pred_score[i]<0.8:
159 | continue
160 | new_class_conf_box.append([pred_class[i], pred_score[i], int(text_boxes[i][0]*4), int(text_boxes[i][1]*4), int(text_boxes[i][2]*4), int(text_boxes[i][3]*4)])
161 | self.draw_detection_boxes(new_class_conf_box, file_path)
162 | new_class_conf_box1 = np.matrix(new_class_conf_box)
163 | #print(new_class_conf_box)
164 | if(len(new_class_conf_box))>0:
165 | np.savetxt(file_path, new_class_conf_box1, fmt="%i %1.3f %i %i %i %i")
166 |
167 | #get test results
168 | def test(self):
169 |
170 | # load a model pre-trained pre-trained on COCO
171 | model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
172 |
173 | # replace the classifier with a new one, that has
174 | # num_classes which is user-defined
175 | num_classes = 2 # 1 class (car) + background
176 | # get number of input features for the classifier
177 | in_features = model.roi_heads.box_predictor.cls_score.in_features
178 | # replace the pre-trained head with a new one
179 | model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
180 |
181 | model.to(self.device)
182 |
183 | self.load_model(self.config['path']['pretrain_model_FRCNN_LR_LR'], model)
184 |
185 | _, data_loader_test, data_loader_test_SR, data_loader_test_SR_combined, \
186 | data_loader_test_E_SR_1, data_loader_test_E_SR_2, data_loader_test_E_SR_3, \
187 | data_loader_test_F_SR, data_loader_test_Bic = self.data_loaders()
188 |
189 | print("test lenghts of the data loaders.............")
190 | print(len(data_loader_test))
191 | model.eval()
192 | i = 0
193 | print("Detection started........")
194 | for image, targets, annotation_path in data_loader_test:
195 | annotation_path = ''.join(annotation_path)
196 | self.get_prediction(model, image, annotation_path)
197 | #evaluate_base(model, data_loader_test_Bic, device=self.device)
198 | i=i+1
199 | print(i)
200 | print("successfully generated the results!")
201 | '''
202 | print(len(data_loader_test_SR))
203 | print(len(data_loader_test_SR_combined))
204 | print(len(data_loader_test_E_SR_1))
205 | print(len(data_loader_test_E_SR_2))
206 | print(len(data_loader_test_E_SR_3))
207 | print(len(data_loader_test_F_SR))
208 | print(len(data_loader_test_Bic))
209 | print("test HR images..............................")
210 | evaluate_base(model, data_loader_test, device=self.device)
211 | print("test SR images..............................")
212 | evaluate_base(model, data_loader_test_SR, device=self.device)
213 | print("test SR combined images..............................")
214 | evaluate_base(model, data_loader_test_SR_combined, device=self.device)
215 | print("test Enhanced SR 1 images.....................")
216 | evaluate_base(model, data_loader_test_E_SR_1, device=self.device)
217 | print("test Enhanced SR 2 images.....................")
218 | evaluate_base(model, data_loader_test_E_SR_2, device=self.device)
219 | print("test Enhanced SR 3 images.....................")
220 | evaluate_base(model, data_loader_test_E_SR_3, device=self.device)
221 | print("test Final SR images.........................")
222 | evaluate_base(model, data_loader_test_F_SR, device=self.device)
223 | print("test Bicubic images..........................")
224 | evaluate_base(model, data_loader_test_Bic, device=self.device)
225 | '''
226 | def train(self):
227 | # load a model pre-trained pre-trained on COCO
228 | model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
229 |
230 | # replace the classifier with a new one, that has
231 | # num_classes which is user-defined
232 | num_classes = 2 # 1 class (car) + background
233 | # get number of input features for the classifier
234 | in_features = model.roi_heads.box_predictor.cls_score.in_features
235 | # replace the pre-trained head with a new one
236 | model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
237 |
238 | model.to(self.device)
239 | #self.load_model(self.config['path']['pretrain_model_FRCNN_LR_LR'], model)
240 |
241 | # construct an optimizer
242 | params = [p for p in model.parameters() if p.requires_grad]
243 | optimizer = torch.optim.SGD(params, lr=0.005,
244 | momentum=0.9, weight_decay=0.0005)
245 |
246 | # and a learning rate scheduler which decreases the learning rate by
247 | # 10x every 3 epochs
248 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
249 | step_size=3,
250 | gamma=0.1)
251 |
252 | data_loader, data_loader_test, _, _, _, _, _, _, _ = self.data_loaders()
253 | # let's train it for 10 epochs
254 | num_epochs = 1000
255 |
256 | for epoch in range(num_epochs):
257 | # train for one epoch, printing every 10 iterations
258 | train_one_epoch(model, optimizer, data_loader, self.device, epoch, print_freq=10)
259 | # update the learning rate
260 | lr_scheduler.step()
261 | # evaluate on the test dataset
262 | evaluate_base(model, data_loader_test, device=self.device)
263 | if epoch % 1 == 0:
264 | self.save_model(model, 'FRCNN_LR_LR', epoch)
265 |
--------------------------------------------------------------------------------
/model/ESRGAN_EESN_Model.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from collections import OrderedDict
3 | import torch
4 | import torch.nn as nn
5 | import model.model as model
6 | import model.lr_scheduler as lr_scheduler
7 | from model.loss import GANLoss, CharbonnierLoss
8 | from .gan_base_model import BaseModel
9 | from torch.nn.parallel import DataParallel, DistributedDataParallel
10 |
11 | logger = logging.getLogger('base')
12 | # Taken from ESRGAN BASICSR repository and modified
13 | class ESRGAN_EESN_Model(BaseModel):
14 | def __init__(self, config, device):
15 | super(ESRGAN_EESN_Model, self).__init__(config, device)
16 | self.configG = config['network_G']
17 | self.configD = config['network_D']
18 | self.configT = config['train']
19 | self.configO = config['optimizer']['args']
20 | self.configS = config['lr_scheduler']
21 | self.device = device
22 | #Generator
23 | self.netG = model.ESRGAN_EESN(in_nc=self.configG['in_nc'], out_nc=self.configG['out_nc'],
24 | nf=self.configG['nf'], nb=self.configG['nb'])
25 | self.netG = self.netG.to(self.device)
26 | self.netG = DataParallel(self.netG, device_ids=[1,0])
27 |
28 | #descriminator
29 | self.netD = model.Discriminator_VGG_128(in_nc=self.configD['in_nc'], nf=self.configD['nf'])
30 | self.netD = self.netD.to(self.device)
31 | self.netD = DataParallel(self.netD, device_ids=[1,0])
32 |
33 | self.netG.train()
34 | self.netD.train()
35 | #print(self.configT['pixel_weight'])
36 | # G CharbonnierLoss for final output SR and GT HR
37 | self.cri_charbonnier = CharbonnierLoss().to(device)
38 | # G pixel loss
39 | if self.configT['pixel_weight'] > 0.0:
40 | l_pix_type = self.configT['pixel_criterion']
41 | if l_pix_type == 'l1':
42 | self.cri_pix = nn.L1Loss().to(self.device)
43 | elif l_pix_type == 'l2':
44 | self.cri_pix = nn.MSELoss().to(self.device)
45 | else:
46 | raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
47 | self.l_pix_w = self.configT['pixel_weight']
48 | else:
49 | self.cri_pix = None
50 |
51 | # G feature loss
52 | #print(self.configT['feature_weight']+1)
53 | if self.configT['feature_weight'] > 0:
54 | l_fea_type = self.configT['feature_criterion']
55 | if l_fea_type == 'l1':
56 | self.cri_fea = nn.L1Loss().to(self.device)
57 | elif l_fea_type == 'l2':
58 | self.cri_fea = nn.MSELoss().to(self.device)
59 | else:
60 | raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
61 | self.l_fea_w = self.configT['feature_weight']
62 | else:
63 | self.cri_fea = None
64 | if self.cri_fea: # load VGG perceptual loss
65 | self.netF = model.VGGFeatureExtractor(feature_layer=34,
66 | use_input_norm=True, device=self.device)
67 | self.netF = self.netF.to(self.device)
68 | self.netF = DataParallel(self.netF, device_ids=[1,0])
69 | self.netF.eval()
70 |
71 | # GD gan loss
72 | self.cri_gan = GANLoss(self.configT['gan_type'], 1.0, 0.0).to(self.device)
73 | self.l_gan_w = self.configT['gan_weight']
74 | # D_update_ratio and D_init_iters
75 | self.D_update_ratio = self.configT['D_update_ratio'] if self.configT['D_update_ratio'] else 1
76 | self.D_init_iters = self.configT['D_init_iters'] if self.configT['D_init_iters'] else 0
77 |
78 |
79 | # optimizers
80 | # G
81 | wd_G = self.configO['weight_decay_G'] if self.configO['weight_decay_G'] else 0
82 | optim_params = []
83 | for k, v in self.netG.named_parameters(): # can optimize for a part of the model
84 | if v.requires_grad:
85 | optim_params.append(v)
86 |
87 | self.optimizer_G = torch.optim.Adam(optim_params, lr=self.configO['lr_G'],
88 | weight_decay=wd_G,
89 | betas=(self.configO['beta1_G'], self.configO['beta2_G']))
90 | self.optimizers.append(self.optimizer_G)
91 |
92 | # D
93 | wd_D = self.configO['weight_decay_D'] if self.configO['weight_decay_D'] else 0
94 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=self.configO['lr_D'],
95 | weight_decay=wd_D,
96 | betas=(self.configO['beta1_D'], self.configO['beta2_D']))
97 | self.optimizers.append(self.optimizer_D)
98 |
99 | # schedulers
100 | if self.configS['type'] == 'MultiStepLR':
101 | for optimizer in self.optimizers:
102 | self.schedulers.append(
103 | lr_scheduler.MultiStepLR_Restart(optimizer, self.configS['args']['lr_steps'],
104 | restarts=self.configS['args']['restarts'],
105 | weights=self.configS['args']['restart_weights'],
106 | gamma=self.configS['args']['lr_gamma'],
107 | clear_state=False))
108 | elif self.configS['type'] == 'CosineAnnealingLR_Restart':
109 | for optimizer in self.optimizers:
110 | self.schedulers.append(
111 | lr_scheduler.CosineAnnealingLR_Restart(
112 | optimizer, self.configS['args']['T_period'], eta_min=self.configS['args']['eta_min'],
113 | restarts=self.configS['args']['restarts'], weights=self.configS['args']['restart_weights']))
114 | else:
115 | raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
116 | print(self.configS['args']['restarts'])
117 | self.log_dict = OrderedDict()
118 |
119 | self.print_network() # print network
120 | self.load() # load G and D if needed
121 | '''
122 | The main repo did not use collate_fn and image read has different flags
123 | and also used np.ascontiguousarray()
124 | Might change my code if problem happens
125 | '''
126 | def feed_data(self, data):
127 | self.var_L = data['image_lq'].to(self.device)
128 | self.var_H = data['image'].to(self.device)
129 | input_ref = data['ref'] if 'ref' in data else data['image']
130 | self.var_ref = input_ref.to(self.device)
131 |
132 | def optimize_parameters(self, step):
133 | #Generator
134 | for p in self.netD.parameters():
135 | p.requires_grad = False
136 | self.optimizer_G.zero_grad()
137 | self.fake_H, self.final_SR, _, _ = self.netG(self.var_L)
138 |
139 | l_g_total = 0
140 | if step % self.D_update_ratio == 0 and step > self.D_init_iters:
141 | if self.cri_pix: #pixel loss
142 | l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
143 | l_g_total += l_g_pix
144 | if self.cri_fea: # feature loss
145 | real_fea = self.netF(self.var_H).detach() #don't want to backpropagate this, need proper explanation
146 | fake_fea = self.netF(self.fake_H) #In netF normalize=False, check it
147 | l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
148 | l_g_total += l_g_fea
149 |
150 | pred_g_fake = self.netD(self.fake_H)
151 | if self.configT['gan_type'] == 'gan':
152 | l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
153 | elif self.configT['gan_type'] == 'ragan':
154 | pred_d_real = self.netD(self.var_ref).detach()
155 | l_g_gan = self.l_gan_w * (
156 | self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
157 | self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
158 | l_g_total += l_g_gan
159 | #EESN calculate loss
160 | if self.cri_charbonnier: # charbonnier pixel loss HR and SR
161 | l_e_charbonnier = 5 * self.cri_charbonnier(self.final_SR, self.var_H) #change the weight to empirically
162 | l_g_total += l_e_charbonnier
163 |
164 | l_g_total.backward()
165 | self.optimizer_G.step()
166 |
167 | #descriminator
168 | for p in self.netD.parameters():
169 | p.requires_grad = True
170 |
171 | self.optimizer_D.zero_grad()
172 | l_d_total = 0
173 | pred_d_real = self.netD(self.var_ref)
174 | pred_d_fake = self.netD(self.fake_H.detach()) #to avoid BP to Generator
175 | if self.configT['gan_type'] == 'gan':
176 | l_d_real = self.cri_gan(pred_d_real, True)
177 | l_d_fake = self.cri_gan(pred_d_fake, False)
178 | l_d_total = l_d_real + l_d_fake
179 | elif self.configT['gan_type'] == 'ragan':
180 | l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
181 | l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
182 | l_d_total = (l_d_real + l_d_fake) / 2 # thinking of adding final sr d loss
183 |
184 | l_d_total.backward()
185 | self.optimizer_D.step()
186 |
187 | # set log
188 | if step % self.D_update_ratio == 0 and step > self.D_init_iters:
189 | if self.cri_pix:
190 | self.log_dict['l_g_pix'] = l_g_pix.item()
191 | if self.cri_fea:
192 | self.log_dict['l_g_fea'] = l_g_fea.item()
193 | self.log_dict['l_g_gan'] = l_g_gan.item()
194 | self.log_dict['l_e_charbonnier'] = l_e_charbonnier.item()
195 |
196 | self.log_dict['l_d_real'] = l_d_real.item()
197 | self.log_dict['l_d_fake'] = l_d_fake.item()
198 | self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
199 | self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())
200 |
201 | def test(self):
202 | self.netG.eval()
203 | with torch.no_grad():
204 | self.fake_H, self.final_SR, self.x_learned_lap_fake, self.x_lap = self.netG(self.var_L)
205 | _, _, _, self.x_lap_HR = self.netG(self.var_H)
206 | self.netG.train()
207 |
208 | def get_current_log(self):
209 | return self.log_dict
210 |
211 | def get_current_visuals(self, need_GT=True):
212 | out_dict = OrderedDict()
213 | out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
214 | #out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
215 | out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
216 | out_dict['lap_learned'] = self.x_learned_lap_fake.detach()[0].float().cpu()
217 | out_dict['lap'] = self.x_lap.detach()[0].float().cpu()
218 | out_dict['lap_HR'] = self.x_lap_HR.detach()[0].float().cpu()
219 | out_dict['final_SR'] = self.final_SR.detach()[0].float().cpu()
220 | if need_GT:
221 | out_dict['GT'] = self.var_H.detach()[0].float().cpu()
222 | return out_dict
223 |
224 | def print_network(self):
225 | # Generator
226 | s, n = self.get_network_description(self.netG)
227 | if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
228 | net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
229 | self.netG.module.__class__.__name__)
230 | else:
231 | net_struc_str = '{}'.format(self.netG.__class__.__name__)
232 |
233 | logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
234 | logger.info(s)
235 |
236 | # Discriminator
237 | s, n = self.get_network_description(self.netD)
238 | if isinstance(self.netD, nn.DataParallel) or isinstance(self.netD,
239 | DistributedDataParallel):
240 | net_struc_str = '{} - {}'.format(self.netD.__class__.__name__,
241 | self.netD.module.__class__.__name__)
242 | else:
243 | net_struc_str = '{}'.format(self.netD.__class__.__name__)
244 |
245 | logger.info('Network D structure: {}, with parameters: {:,d}'.format(
246 | net_struc_str, n))
247 | logger.info(s)
248 |
249 | if self.cri_fea: # F, Perceptual Network
250 | s, n = self.get_network_description(self.netF)
251 | if isinstance(self.netF, nn.DataParallel) or isinstance(
252 | self.netF, DistributedDataParallel):
253 | net_struc_str = '{} - {}'.format(self.netF.__class__.__name__,
254 | self.netF.module.__class__.__name__)
255 | else:
256 | net_struc_str = '{}'.format(self.netF.__class__.__name__)
257 |
258 | logger.info('Network F structure: {}, with parameters: {:,d}'.format(
259 | net_struc_str, n))
260 | logger.info(s)
261 |
262 | def load(self):
263 | load_path_G = self.config['path']['pretrain_model_G']
264 | if load_path_G:
265 | logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
266 | self.load_network(load_path_G, self.netG, self.config['path']['strict_load'])
267 | load_path_D = self.config['path']['pretrain_model_D']
268 | if load_path_D:
269 | logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
270 | self.load_network(load_path_D, self.netD, self.config['path']['strict_load'])
271 |
272 | def save(self, iter_step):
273 | self.save_network(self.netG, 'G', iter_step)
274 | self.save_network(self.netD, 'D', iter_step)
275 | #self.save_network(self.netG.module.netE, 'E', iter_step)
276 |
--------------------------------------------------------------------------------